diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml new file mode 100644 index 000000000..ddc09c73a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -0,0 +1,48 @@ +name: Report a bug +description: Report flagtree failing to compile a kernel, or giving incorrect results +labels: ["bug"] + +body: +- type: markdown + attributes: + value: | + #### Disclaimer + The core flagtree team is small and has very limited capacity. We may not have time to look into your report. + For the best results, please: + - Avoid submitting duplicates. Search first to see if it's been reported previously. + - Check if the issue persists with a build from the latest source. + - Provide all relevant information in the initial report, to prevent unnecessary back and forth discussion. + - If you can, try to diagnose and/or fix the issue yourself. We welcome high quality contributions. +- type: textarea + attributes: + label: Describe the bug + description: | + Please provide a clear and concise description of what the bug is. + + If relevant, add a [minimal complete example](https://stackoverflow.com/help/minimal-reproducible-example) that reproduces the bug. It is very important for the snippet to be as simple as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did, so include both the kernel and launching code as well as any relevant imports. + + If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. + + Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + placeholder: | + A clear and concise description of what the bug is. + + ```python + # Sample code to reproduce the problem + ``` + + ``` + The error message you got, with the full traceback. + ``` + validations: + required: true +- type: textarea + attributes: + label: Environment details + description: | + Please include any relevant context about how you're running the reproducer e.g. which version of triton, and what GPU you are using. + placeholder: | + Triton: ... + GPU: ... + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..3ba13e0ce --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: false diff --git a/.github/ISSUE_TEMPLATE/performance.yml b/.github/ISSUE_TEMPLATE/performance.yml new file mode 100644 index 000000000..dc6256971 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/performance.yml @@ -0,0 +1,44 @@ +name: Report a performance issue +description: Report cases where triton is generating sub-optimal (but functionally correct) PTX/LLVM IR +labels: ["performance"] + +body: +- type: markdown + attributes: + value: | + #### Disclaimer + The core flagtree team is small and has very limited capacity. We may not have time to look into your report. + For the best results, please: + - Avoid submitting duplicates. Search first to see if it's been reported previously. + - Check if the issue persists with a build from the latest source. + - Provide all relevant information in the initial report, to prevent unnecessary back and forth discussion. + - If you can, try to diagnose and/or fix the issue yourself. We welcome high quality contributions. +- type: textarea + attributes: + label: Describe the issue + description: | + Please provide a clear and concise description of the issue. + + Include a [minimal complete example](https://stackoverflow.com/help/minimal-reproducible-example) that reproduces the issue. It is very important for the snippet to be as simple as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did. + + A reproducer could be a python program that runs a triton kernel and prints out the relevant suboptimal IR, or an IR file with an accompanying triton-opt command. + + If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. + placeholder: | + A clear and concise description of the issue. + + ```python + # Sample code to reproduce the problem + ``` + validations: + required: true +- type: textarea + attributes: + label: Environment details + description: | + Please include any relevant context about how you're running the reproducer e.g. which version of triton, and what GPU you are using. + placeholder: | + Triton: ... + GPU: ... + validations: + required: true diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..396245e1a --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,16 @@ + diff --git a/.github/workflows/ascend-build-and-test.yml b/.github/workflows/ascend-build-and-test.yml new file mode 100644 index 000000000..2e504494b --- /dev/null +++ b/.github/workflows/ascend-build-and-test.yml @@ -0,0 +1,32 @@ +name: Ascend-Build-And-Test + +on: + push: + branches: [ "triton_v3.2.x" ] + pull_request: + branches: [ "triton_v3.2.x" ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + ascend-build-and-test: + runs-on: ascend + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: FlagTree Build on Ascend + shell: bash + run: | + export FLAGTREE_BACKEND=ascend + source ~/env.sh + cd python + MAX_JOBS=32 python3.9 -m pip install . --no-build-isolation + + - name: FlagTree Test on Ascend + shell: bash + run: | + source /usr/local/Ascend/ascend-toolkit/set_env.sh + python3.9 third_party/ascend/python/tutorials/01-vector-add.py diff --git a/.github/workflows/cambricon-build-and-test.yml b/.github/workflows/cambricon-build-and-test.yml new file mode 100644 index 000000000..7536a3959 --- /dev/null +++ b/.github/workflows/cambricon-build-and-test.yml @@ -0,0 +1,26 @@ +name: Cambricon-Build-And-Test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + cambricon-build-and-test: + runs-on: cambricon + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: FlagTree Build on Cambricon + shell: bash + run: | + export FLAGTREE_BACKEND=cambricon + source ~/env.sh + cd python + MAX_JOBS=8 pip3 install . --no-build-isolation diff --git a/.github/workflows/code-format-check.yml b/.github/workflows/code-format-check.yml new file mode 100644 index 000000000..51ced28c7 --- /dev/null +++ b/.github/workflows/code-format-check.yml @@ -0,0 +1,21 @@ +name: Code-Format-Check + +on: + push: + branches: [ "main", "triton_v3.2.x" ] + pull_request: + branches: [ "main", "triton_v3.2.x" ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml deleted file mode 100644 index fa4b18645..000000000 --- a/.github/workflows/documentation.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: Documentation -on: - workflow_dispatch: - schedule: - - cron: "0 0 * * *" - -permissions: read-all - -jobs: - Build-Documentation: - runs-on: [self-hosted, A100] - timeout-minutes: 30 - - steps: - - name: Checkout branch - uses: actions/checkout@v4 - with: - token: ${{ secrets.CI_PAT }} - fetch-depth: 0 - - - name: Clear docs - run: | - rm -rf /tmp/triton-docs - continue-on-error: true - - - name: Install dependent packages - run: | - sudo pip3 install tabulate cmake sphinx matplotlib myst_parser sphinx-rtd-theme pandas pytest sphinx-gallery sphinx-multiversion - - #- name: Fetch dependent branches - # run: | - # git fetch origin main:main - - - name: Build docs - run: | - cd docs - export PATH=$(python3 -c "import cmake; print(cmake.CMAKE_BIN_DIR)"):$PATH - sudo python3 -m sphinx . _build/html/main - - - name: Update docs - run: | - sudo mkdir /tmp/triton-docs/ - sudo mv docs/_build/html/* /tmp/triton-docs/ - sudo git checkout gh-pages - sudo cp -r CNAME /tmp/triton-docs/ - sudo cp -r index.html /tmp/triton-docs/ - sudo cp -r .nojekyll /tmp/triton-docs/ - sudo rm -rf * - sudo cp -r /tmp/triton-docs/* . - sudo git add . - sudo git config --global user.email "N/A" - sudo git config --global user.name "gh-actions-bot" - sudo git commit -am "[GH-PAGES] Updated website" - - - name: Publish docs - run: | - sudo git push origin gh-pages diff --git a/.github/workflows/iluvatar-build-and-test.yml b/.github/workflows/iluvatar-build-and-test.yml new file mode 100644 index 000000000..f54cb575b --- /dev/null +++ b/.github/workflows/iluvatar-build-and-test.yml @@ -0,0 +1,59 @@ +name: Iluvatar-Build-And-Test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + iluvatar-build-and-test: + runs-on: iluvatar + steps: + - name: Checkout code (attempt 1) + id: checkout1 + uses: actions/checkout@v4 + continue-on-error: true + + - name: Sleep before checkout2 + if: steps.checkout1.outcome == 'failure' + run: | + echo "First checkout attempt failed. Sleeping for 120 seconds before retry..." + sleep 120 + + - name: Checkout code (attempt 2) + id: checkout2 + if: steps.checkout1.outcome == 'failure' + uses: actions/checkout@v4 + continue-on-error: true + + - name: Sleep before final checkout + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + run: | + echo "Second checkout attempt failed. Sleeping for 180 seconds before final retry..." + sleep 180 + + - name: Checkout code (final attempt) + if: steps.checkout1.outcome == 'failure' && steps.checkout2.outcome == 'failure' + uses: actions/checkout@v4 + + - name: Verify checkout success + if: success() + run: echo "Checkout completed successfully" + + - name: FlagTree Build on Iluvatar + shell: bash + run: | + export FLAGTREE_BACKEND=iluvatar + source ~/env.sh + cd python + MAX_JOBS=20 pip3 install . --no-build-isolation + + - name: FlagTree Test on Iluvatar + shell: bash + run: | + CUDA_VISIBLE_DEVICES=15 pytest -s third_party/iluvatar/python/test/unit diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml deleted file mode 100644 index 9761b7dac..000000000 --- a/.github/workflows/integration-tests.yml +++ /dev/null @@ -1,403 +0,0 @@ -# AUTOGENERATED by pre-commit, modify the .in file instead. - -# integration-tests.yml.in is used to generate integration-tests.yml by -# expanding yaml anchors, because github actions don't support them -# (https://github.com/actions/runner/issues/1182). pre-commit will do this for -# you automatically. - - -name: Integration Tests -on: - workflow_dispatch: - pull_request: - # You can name your branch dev-foo to get CI runs. - branches: [main, 'dev-**'] - merge_group: - branches: [main, 'dev-**'] - types: [checks_requested] - push: - branches: [main] -concurrency: - group: ${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} -permissions: read-all -env: - TRITON_BUILD_WITH_CLANG_LLD: "TRUE" - TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" - TRITON_DISABLE_LINE_INFO: 1 -jobs: - Runner-Preparation: - runs-on: ubuntu-latest - timeout-minutes: 30 - outputs: - matrix-CUDA: ${{ steps.set-matrix.outputs.matrix-CUDA }} - matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} - steps: - - name: Decide pre-submit integration test enablement - # Always enable integration tests for pre-submit pull requests. - if: github.event_name == 'pull_request' - run: | - echo "enable_integration=true" >> $GITHUB_ENV - - name: Checkout post-submit commits - if: github.event_name == 'push' - uses: actions/checkout@v4 - with: - # Only fetch two commits to check the latest changed files. - fetch-depth: 2 - - name: Detect if build deps (e.g. LLVM hash) changed - id: detect-change - if: github.event_name == 'push' - uses: tj-actions/changed-files@v44 - with: - files: | - cmake/*.txt - - name: Detect if enough time has passed since last post-submit run - id: detect-time - if: github.event_name == 'push' - run: | - GITHUB_TOKEN=${{ secrets.GITHUB_TOKEN }} - REPO_NAME="${{ github.repository }}" - # ID of integration-tests workflow - WORKFLOW_ID="11678186" - - # Fetch the last run time of this workflow - LAST_RUN=$(curl -s \ - -H "Authorization: token $GITHUB_TOKEN" \ - -H "Accept: application/vnd.github.v3+json" \ - "https://api.github.com/repos/$REPO_NAME/actions/workflows/$WORKFLOW_ID/runs?branch=main&status=success&per_page=1" \ - | jq -r '.workflow_runs[0].updated_at') - - # Convert to timestamp - LAST_RUN_TS=$(date -d "$LAST_RUN" +%s) - NOW_TS=$(date +%s) - DIFF=$(( (NOW_TS - LAST_RUN_TS) / 3600 )) # Difference in hours - - echo "Last run was $DIFF hours ago." - - if [ "$DIFF" -ge 4 ]; then - echo "Will run CI; last build was long enough ago." - echo "n_hours_since_last_run=true" >> $GITHUB_ENV - else - echo "Will not run CI; last build was too recent." - echo "n_hours_since_last_run=false" >> $GITHUB_ENV - fi - # We want to run integration tests on the main branch (i.e. post-submit) - # occasionally, because pre-submit CI caches will only read from caches - # generated from the main branch (or the PR's branch), and we want these - # caches to be recent. - # - # But we also don't want to run the tests on *every* commit, because this - # would compete for resources with pre-commit CI (and the whole point of - # caching is to speed up CI). - # - # As a compromise, run every N hours, or if a build dependency changes - # (e.g. we update the LLVM hash). - - name: Decide whether to run integration tests post-submit - if: | - github.event_name == 'push' && - (steps.detect-change.outputs.any_changed == 'true' || - env.n_hours_since_last_run == 'true') - run: | - echo "enable_integration=true" >> $GITHUB_ENV - - name: Prepare runner matrix - id: set-matrix - if: env.enable_integration == 'true' - run: | - if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then - echo '::set-output name=matrix-CUDA::[["self-hosted", "A100"], ["self-hosted", "H100"]]' - echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]' - else - echo '::set-output name=matrix-CUDA::["ubuntu-latest"]' - echo '::set-output name=matrix-HIP::["ubuntu-latest"]' - fi - pre-commit: - name: pre-commit (code formatting) - needs: Runner-Preparation - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: '3.12' - cache: 'pip' - - name: Compute hash of pre-commit config - id: cache-key - run: | - echo "pre_commit_hash=$(sha256sum .pre-commit-config.yaml)" >> $GITHUB_OUTPUT - shell: bash - - name: Cache pre-commit's cache dir - uses: actions/cache@v4 - with: - # Note that we cannot use environment variables here given there is - # no shell to interpret them in the paths. - path: | - ~/.cache/pre-commit - key: ${{ runner.os }}-${{ steps.cache-key.outputs.pre_commit_hash }} - - name: Check pre-commit - run: | - python3 -m pip install --upgrade pre-commit - # TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed - python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true - # If first run of yapf worked and made changes reset the tree to the original state - git reset --hard - python3 -m pre_commit run --all-files --verbose - - name: Print diff of changes if pre-commit failed - if: failure() - run: | - git diff - Integration-Tests: - needs: Runner-Preparation - if: needs.Runner-Preparation.outputs.matrix-CUDA != '' - runs-on: ${{ matrix.runner }} - timeout-minutes: 30 - strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-CUDA)}} - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - submodules: "true" - - name: Compute cache keys - id: cache-key - run: | - echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT - echo "pybind11=$(cat cmake/pybind11-version.txt)" >> $GITHUB_OUTPUT - echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT - echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT - shell: bash - - name: Cache build dependencies - uses: actions/cache@v4 - with: - # Note that we cannot use environment variables here given there is - # no shell to interpret them in the paths. - path: | - ~/.triton/llvm - ~/.triton/nvidia - ~/.triton/pybind11 - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }} - - # Cache ~/.triton/cache because the vast majority of unit test time is - # spent compiling. Triton won't (well, should not) use these cached files - # if something internal to Triton changes, because Triton's internal - # source code is part of the cache key. - # - # Similarly, cache ~/.cache/ccache to speed up compilation. - # - # On branch `main` we always start from an empty cache, i.e. we skip the - # "restore" step. This is to prevent the caches from accumulating stale - # files over time. - name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' - uses: actions/cache/restore@v4 - with: - path: | - ~/.triton/cache - ~/.cache/ccache - # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- - # We expect this cache key never to hit and for us to fall back - # unconditionally to the restore-key, so it doesn't actually matter - # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - - name: Update PATH - run: | - echo "$HOME/.local/bin" >> $GITHUB_PATH - - name: Install pip dependencies - run: | - python3 -m pip install --upgrade pip - python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit - - name: Install Triton - env: - TRITON_BUILD_WITH_CCACHE: "true" - CUDA_HOME: "/usr/local/cuda" - run: | - echo "PATH is '$PATH'" - cd python - python3 -m pip install '.[tests]' - - name: Run lit tests - run: | - cd python - LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test" - if [ ! -d "${LIT_TEST_DIR}" ]; then - echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 - fi - lit -v "${LIT_TEST_DIR}" - - name: Run python tests on CUDA - run: | - cd python/test/unit - python3 -m pytest -n 8 --ignore=hopper/test_flashattention.py --ignore=runtime --ignore=language/test_line_info.py --ignore=language/test_subprocess.py - python3 -m pytest -n 8 language/test_subprocess.py - # Run runtime tests serially to avoid race condition with cache handling - python3 -m pytest runtime/ - # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 - TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py - # Run hopper/test_flashattention.py separately to avoid out of gpu memory - python3 -m pytest -vs hopper/test_flashattention.py - - name: Run interpreter tests - if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'H100'}} - env: - TRITON_INTERPRET: "1" - run: | - cd python/test/unit - python3 -m pytest -n 16 -m interpreter language/test_core.py language/test_standard.py \ - language/test_random.py language/test_block_pointer.py language/test_subprocess.py \ - operators/test_flash_attention.py::test_op \ - ../../tutorials/06-fused-attention.py::test_op --device cpu - - name: Run C++ unittests - run: | - cd python - cd "build/$(ls build | grep -i cmake)" - ctest -j32 - - name: Run Proton tests - env: - LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" - run: | - cd third_party/proton - python3 -m pytest test - - # If we're on branch `main`, save the ccache Triton compilation artifacts - # to the cache so they can be used by other (non-main) CI runs. - # - # (It wouldn't be a problem to save the cache on every run, because github - # evicts cache entries LRU, but maybe this saves a bit of time in CI.) - name: Save ccache and Triton compilation artifacts to cache - if: github.ref == 'refs/heads/main' - uses: actions/cache/save@v4 - with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache - Integration-Tests-AMD: - needs: Runner-Preparation - if: needs.Runner-Preparation.outputs.matrix-HIP != '' - runs-on: ${{ matrix.runner }} - timeout-minutes: 30 - strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}} - container: - image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2 - options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - submodules: 'true' - - name: Compute cache keys - id: cache-key - run: | - echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT - echo "pybind11=$(cat cmake/pybind11-version.txt)" >> $GITHUB_OUTPUT - echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT - echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT - shell: bash - - name: Cache build dependencies - uses: actions/cache@v4 - with: - # Note that we cannot use environment variables here given there is - # no shell to interpret them in the paths. - path: | - ~/.triton/llvm - ~/.triton/nvidia - ~/.triton/pybind11 - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }} - - # Cache ~/.triton/cache because the vast majority of unit test time is - # spent compiling. Triton won't (well, should not) use these cached files - # if something internal to Triton changes, because Triton's internal - # source code is part of the cache key. - # - # Similarly, cache ~/.cache/ccache to speed up compilation. - # - # On branch `main` we always start from an empty cache, i.e. we skip the - # "restore" step. This is to prevent the caches from accumulating stale - # files over time. - name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' - uses: actions/cache/restore@v4 - with: - path: | - ~/.triton/cache - ~/.cache/ccache - # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- - # We expect this cache key never to hit and for us to fall back - # unconditionally to the restore-key, so it doesn't actually matter - # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - - name: Update PATH - run: | - echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH - - name: Install pip dependencies - run: | - python3 -m pip install --upgrade pip - python3 -m pip install lit - - name: Install Triton - run: | - echo "PATH is '$PATH'" - pip uninstall -y triton - cd python - pip install -v -e '.[tests]' - - name: Run lit tests - run: | - cd python - LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test" - if [ ! -d "${LIT_TEST_DIR}" ]; then - echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 - fi - lit -v "${LIT_TEST_DIR}" - - name: Run python tests on HIP - run: | - pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py - cd python/test/unit - pytest --capture=tee-sys -rfs -n 32 language operators \ - hopper/test_mixed_io.py \ - hopper/test_gemm.py \ - hopper/test_tma_store_gemm.py \ - hopper/test_persistent_warp_specialized_fused-attention.py \ - --ignore=language/test_line_info.py - # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 - TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -n 8 language/test_line_info.py - - # Run runtime tests serially to avoid race condition with cache handling - python3 -m pytest runtime - - name: Run C++ unittests - run: | - cd python - cd "build/$(ls build | grep -i cmake)" - ctest -j32 - - # If we're on branch `main`, save the ccache Triton compilation artifacts - # to the cache so they can be used by other (non-main) CI runs. - # - # (It wouldn't be a problem to save the cache on every run, because github - # evicts cache entries LRU, but maybe this saves a bit of time in CI.) - name: Save ccache and Triton compilation artifacts to cache - if: github.ref == 'refs/heads/main' - uses: actions/cache/save@v4 - with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in deleted file mode 100644 index 7b748b41e..000000000 --- a/.github/workflows/integration-tests.yml.in +++ /dev/null @@ -1,386 +0,0 @@ -# integration-tests.yml.in is used to generate integration-tests.yml by -# expanding yaml anchors, because github actions don't support them -# (https://github.com/actions/runner/issues/1182). pre-commit will do this for -# you automatically. - - -name: Integration Tests - -on: - workflow_dispatch: - pull_request: - # You can name your branch dev-foo to get CI runs. - branches: [main, 'dev-**'] - merge_group: - branches: [main, 'dev-**'] - types: [checks_requested] - push: - branches: [main] - -concurrency: - group: ${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -permissions: read-all - -env: - TRITON_BUILD_WITH_CLANG_LLD: "TRUE" - TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" - TRITON_DISABLE_LINE_INFO: 1 - -jobs: - Runner-Preparation: - runs-on: ubuntu-latest - timeout-minutes: 30 - outputs: - matrix-CUDA: ${{ steps.set-matrix.outputs.matrix-CUDA }} - matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }} - steps: - - name: Decide pre-submit integration test enablement - # Always enable integration tests for pre-submit pull requests. - if: github.event_name == 'pull_request' - run: | - echo "enable_integration=true" >> $GITHUB_ENV - - - name: Checkout post-submit commits - if: github.event_name == 'push' - uses: actions/checkout@v4 - with: - # Only fetch two commits to check the latest changed files. - fetch-depth: 2 - - - name: Detect if build deps (e.g. LLVM hash) changed - id: detect-change - if: github.event_name == 'push' - uses: tj-actions/changed-files@v44 - with: - files: | - cmake/*.txt - - - name: Detect if enough time has passed since last post-submit run - id: detect-time - if: github.event_name == 'push' - run: | - GITHUB_TOKEN=${{ secrets.GITHUB_TOKEN }} - REPO_NAME="${{ github.repository }}" - # ID of integration-tests workflow - WORKFLOW_ID="11678186" - - # Fetch the last run time of this workflow - LAST_RUN=$(curl -s \ - -H "Authorization: token $GITHUB_TOKEN" \ - -H "Accept: application/vnd.github.v3+json" \ - "https://api.github.com/repos/$REPO_NAME/actions/workflows/$WORKFLOW_ID/runs?branch=main&status=success&per_page=1" \ - | jq -r '.workflow_runs[0].updated_at') - - # Convert to timestamp - LAST_RUN_TS=$(date -d "$LAST_RUN" +%s) - NOW_TS=$(date +%s) - DIFF=$(( (NOW_TS - LAST_RUN_TS) / 3600 )) # Difference in hours - - echo "Last run was $DIFF hours ago." - - if [ "$DIFF" -ge 4 ]; then - echo "Will run CI; last build was long enough ago." - echo "n_hours_since_last_run=true" >> $GITHUB_ENV - else - echo "Will not run CI; last build was too recent." - echo "n_hours_since_last_run=false" >> $GITHUB_ENV - fi - - # We want to run integration tests on the main branch (i.e. post-submit) - # occasionally, because pre-submit CI caches will only read from caches - # generated from the main branch (or the PR's branch), and we want these - # caches to be recent. - # - # But we also don't want to run the tests on *every* commit, because this - # would compete for resources with pre-commit CI (and the whole point of - # caching is to speed up CI). - # - # As a compromise, run every N hours, or if a build dependency changes - # (e.g. we update the LLVM hash). - - name: Decide whether to run integration tests post-submit - if: | - github.event_name == 'push' && - (steps.detect-change.outputs.any_changed == 'true' || - env.n_hours_since_last_run == 'true') - run: | - echo "enable_integration=true" >> $GITHUB_ENV - - - name: Prepare runner matrix - id: set-matrix - if: env.enable_integration == 'true' - run: | - if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then - echo '::set-output name=matrix-CUDA::[["self-hosted", "A100"], ["self-hosted", "H100"]]' - echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]' - else - echo '::set-output name=matrix-CUDA::["ubuntu-latest"]' - echo '::set-output name=matrix-HIP::["ubuntu-latest"]' - fi - - pre-commit: - name: pre-commit (code formatting) - needs: Runner-Preparation - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 - with: - python-version: '3.12' - cache: 'pip' - - - name: Compute hash of pre-commit config - id: cache-key - run: | - echo "pre_commit_hash=$(sha256sum .pre-commit-config.yaml)" >> $GITHUB_OUTPUT - shell: bash - - - name: Cache pre-commit's cache dir - uses: actions/cache@v4 - with: - # Note that we cannot use environment variables here given there is - # no shell to interpret them in the paths. - path: | - ~/.cache/pre-commit - key: ${{ runner.os }}-${{ steps.cache-key.outputs.pre_commit_hash }} - - - name: Check pre-commit - run: | - python3 -m pip install --upgrade pre-commit - # TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed - python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true - # If first run of yapf worked and made changes reset the tree to the original state - git reset --hard - python3 -m pre_commit run --all-files --verbose - - - name: Print diff of changes if pre-commit failed - if: failure() - run: | - git diff - - Integration-Tests: - needs: Runner-Preparation - if: needs.Runner-Preparation.outputs.matrix-CUDA != '' - - runs-on: ${{ matrix.runner }} - timeout-minutes: 30 - - strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-CUDA)}} - - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - submodules: "true" - - - &compute-cache-keys-step - name: Compute cache keys - id: cache-key - run: | - echo "llvm=$(cat cmake/llvm-hash.txt | cut -c 1-8)" >> $GITHUB_OUTPUT - echo "pybind11=$(cat cmake/pybind11-version.txt)" >> $GITHUB_OUTPUT - echo "nvidia=$(cat cmake/nvidia-toolchain-version.txt)" >> $GITHUB_OUTPUT - echo "datetime=$(date -u -Iseconds)" >> $GITHUB_OUTPUT - shell: bash - - - &cache-build-dependencies-step - name: Cache build dependencies - uses: actions/cache@v4 - with: - # Note that we cannot use environment variables here given there is - # no shell to interpret them in the paths. - path: | - ~/.triton/llvm - ~/.triton/nvidia - ~/.triton/pybind11 - key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-pybind11-${{ steps.cache-key.outputs.pybind11 }} - - # Cache ~/.triton/cache because the vast majority of unit test time is - # spent compiling. Triton won't (well, should not) use these cached files - # if something internal to Triton changes, because Triton's internal - # source code is part of the cache key. - # - # Similarly, cache ~/.cache/ccache to speed up compilation. - # - # On branch `main` we always start from an empty cache, i.e. we skip the - # "restore" step. This is to prevent the caches from accumulating stale - # files over time. - - &restore-build-artifacts-step - name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' - uses: actions/cache/restore@v4 - with: - path: | - ~/.triton/cache - ~/.cache/ccache - # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- - # We expect this cache key never to hit and for us to fall back - # unconditionally to the restore-key, so it doesn't actually matter - # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - - &inspect-cache-directory-step - name: Inspect cache directory - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - - - name: Update PATH - run: | - echo "$HOME/.local/bin" >> $GITHUB_PATH - - - name: Install pip dependencies - run: | - python3 -m pip install --upgrade pip - python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit - - - name: Install Triton - env: - TRITON_BUILD_WITH_CCACHE: "true" - CUDA_HOME: "/usr/local/cuda" - run: | - echo "PATH is '$PATH'" - cd python - python3 -m pip install '.[tests]' - - - &run-lit-tests-step - name: Run lit tests - run: | - cd python - LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test" - if [ ! -d "${LIT_TEST_DIR}" ]; then - echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 - fi - lit -v "${LIT_TEST_DIR}" - - - name: Run python tests on CUDA - run: | - cd python/test/unit - python3 -m pytest -n 8 --ignore=hopper/test_flashattention.py --ignore=runtime --ignore=language/test_line_info.py --ignore=language/test_subprocess.py - python3 -m pytest -n 8 language/test_subprocess.py - # Run runtime tests serially to avoid race condition with cache handling - python3 -m pytest runtime/ - # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 - TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py - # Run hopper/test_flashattention.py separately to avoid out of gpu memory - python3 -m pytest -vs hopper/test_flashattention.py - - - name: Run interpreter tests - if: ${{matrix.runner[0] == 'self-hosted' && matrix.runner[1] == 'H100'}} - env: - TRITON_INTERPRET: "1" - run: | - cd python/test/unit - python3 -m pytest -n 16 -m interpreter language/test_core.py language/test_standard.py \ - language/test_random.py language/test_block_pointer.py language/test_subprocess.py \ - operators/test_flash_attention.py::test_op \ - ../../tutorials/06-fused-attention.py::test_op --device cpu - - - &run-cpp-unittests-step - name: Run C++ unittests - run: | - cd python - cd "build/$(ls build | grep -i cmake)" - ctest -j32 - - - name: Run Proton tests - env: - LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" - run: | - cd third_party/proton - python3 -m pytest test - - # If we're on branch `main`, save the ccache Triton compilation artifacts - # to the cache so they can be used by other (non-main) CI runs. - # - # (It wouldn't be a problem to save the cache on every run, because github - # evicts cache entries LRU, but maybe this saves a bit of time in CI.) - - &save-build-artifacts-step - name: Save ccache and Triton compilation artifacts to cache - if: github.ref == 'refs/heads/main' - uses: actions/cache/save@v4 - with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - - &inspect-cache-directories-step - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache - - Integration-Tests-AMD: - needs: Runner-Preparation - if: needs.Runner-Preparation.outputs.matrix-HIP != '' - - runs-on: ${{ matrix.runner }} - timeout-minutes: 30 - - strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}} - - container: - image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2 - options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root - - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - submodules: 'true' - - - *compute-cache-keys-step - - *cache-build-dependencies-step - - *restore-build-artifacts-step - - *inspect-cache-directory-step - - - name: Update PATH - run: | - echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH - - - name: Install pip dependencies - run: | - python3 -m pip install --upgrade pip - python3 -m pip install lit - - - name: Install Triton - run: | - echo "PATH is '$PATH'" - pip uninstall -y triton - cd python - pip install -v -e '.[tests]' - - - *run-lit-tests-step - - - name: Run python tests on HIP - run: | - pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py - cd python/test/unit - pytest --capture=tee-sys -rfs -n 32 language operators \ - hopper/test_mixed_io.py \ - hopper/test_gemm.py \ - hopper/test_tma_store_gemm.py \ - hopper/test_persistent_warp_specialized_fused-attention.py \ - --ignore=language/test_line_info.py - # Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 - TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -n 8 language/test_line_info.py - - # Run runtime tests serially to avoid race condition with cache handling - python3 -m pytest runtime - - - *run-cpp-unittests-step - - *save-build-artifacts-step - - *inspect-cache-directories-step diff --git a/.github/workflows/kunlun-build-and-test.yml b/.github/workflows/kunlun-build-and-test.yml new file mode 100644 index 000000000..5c5b5887b --- /dev/null +++ b/.github/workflows/kunlun-build-and-test.yml @@ -0,0 +1,32 @@ + +name: Kunlun-Build-And-Test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + kunlun-build-and-test: + runs-on: kunlun + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: FlagTree Build on Kunlun + shell: bash + run: | + source ~/env.sh + export FLAGTREE_BACKEND=xpu + cd python + pip3 install . --no-build-isolation + + - name: FlagTree Test on Kunlun + shell: bash + run: | + pytest -s third_party/xpu/python/test/unit diff --git a/.github/workflows/llvm-build.yml b/.github/workflows/llvm-build.yml deleted file mode 100644 index a7f8f9783..000000000 --- a/.github/workflows/llvm-build.yml +++ /dev/null @@ -1,313 +0,0 @@ -name: LLVM Build - -on: - push: - branches: - - llvm-head - paths: - - cmake/llvm-hash.txt - workflow_dispatch: - -env: - SCCACHE_DIR: ${{ github.workspace }}/sccache - -permissions: - contents: read - id-token: write - -jobs: - - build: - name: Build on ${{ matrix.config.runner }} - runs-on: ${{ matrix.config.runs_on }} - timeout-minutes: 240 # 4 hours - - strategy: - fail-fast: true - matrix: - config: - - {runner: 'Ubuntu 20.04', runs_on: 'ubuntu-20.04', target-os: 'ubuntu', arch: 'x64'} - - {runner: 'Ubuntu 20.04 ARM64', runs_on: 'ubuntu-20.04', target-os: 'ubuntu', arch: 'arm64'} - - {runner: 'CentOS 7', runs_on: ['self-hosted', 'CPU'], target-os: 'centos', arch: 'x64'} - - {runner: 'AlmaLinux 8', runs_on: ['self-hosted', 'CPU'], target-os: 'almalinux', arch: 'x64'} - - {runner: 'MacOS X64', runs_on: 'macos-12', target-os: 'macos', arch: 'x64'} - - {runner: 'MacOS ARM64', runs_on: 'macos-12', target-os: 'macos', arch: 'arm64'} - # TODO(#2805): add back once the workflow works and runs in comparable time to the other ones - # - {runner: 'Windows Latest', runs_on: 'windows-latest', target-os: 'windows', arch: 'x64'} - - steps: - - - name: Checkout Repo - uses: actions/checkout@v4 - with: - path: llvm-build - - - name: Fetch LLVM Commit Hash - shell: bash - run: | - LLVM_COMMIT_HASH="$(cat llvm-build/cmake/llvm-hash.txt)" - echo "Found LLVM commit hash: ${LLVM_COMMIT_HASH}" - echo "llvm_commit_hash=${LLVM_COMMIT_HASH}" >> ${GITHUB_ENV} - - SHORT_LLVM_COMMIT_HASH="${LLVM_COMMIT_HASH:0:8}" - echo "Short LLVM commit hash: ${SHORT_LLVM_COMMIT_HASH}" - echo "short_llvm_commit_hash=${SHORT_LLVM_COMMIT_HASH}" >> ${GITHUB_ENV} - - INSTALL_DIR="llvm-${SHORT_LLVM_COMMIT_HASH}-${{ matrix.config.target-os }}-${{ matrix.config.arch }}" - echo "LLVM installation directory name: ${INSTALL_DIR}" - echo "llvm_install_dir=${INSTALL_DIR}" >> ${GITHUB_ENV} - - - name: Checkout LLVM - uses: actions/checkout@v4 - with: - repository: llvm/llvm-project - path: llvm-project - ref: ${{ env.llvm_commit_hash }} - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: 3.11 - - - name: Set up MSVC - if: matrix.config.arch == 'x64' && (matrix.config.target-os == 'windows') - uses: ilammy/msvc-dev-cmd@v1.13.0 - with: - arch: amd64 - - - name: Install Prerequisites - shell: bash - run: | - python3 -m pip install cmake ninja sccache - mkdir -p ${{ env.SCCACHE_DIR }} - rm -rf ${{ env.SCCACHE_DIR }}/* - - - name: Enable Cache - uses: actions/cache@v4 - with: - path: ${{ env.SCCACHE_DIR }} - key: ${{ matrix.config.target-os }}-${{ matrix.config.arch }}-${{ env.short_llvm_commit_hash }} - restore-keys: ${{ matrix.config.target-os }}-${{ matrix.config.arch }}- - - - name: Configure, Build, Test, and Install LLVM (Ubuntu and macOS x64) - if: matrix.config.arch == 'x64' && (matrix.config.target-os == 'ubuntu' || matrix.config.target-os == 'macos') - run: > - python3 -m pip install -r llvm-project/mlir/python/requirements.txt - - cmake -GNinja -Bllvm-project/build - -DCMAKE_BUILD_TYPE=Release - -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ - -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache - -DCMAKE_INSTALL_PREFIX="${{ env.llvm_install_dir }}" - -DCMAKE_LINKER=lld - -DLLVM_BUILD_UTILS=ON - -DLLVM_BUILD_TOOLS=ON - -DLLVM_ENABLE_ASSERTIONS=ON - -DMLIR_ENABLE_BINDINGS_PYTHON=ON - -DLLVM_ENABLE_PROJECTS=mlir - -DLLVM_INSTALL_UTILS=ON - -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" - -DLLVM_ENABLE_TERMINFO=OFF - llvm-project/llvm - - ninja -C llvm-project/build check-mlir install - - tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" - - - name: Configure, Build, Test, and Install LLVM (Windows) - if: matrix.config.arch == 'x64' && (matrix.config.target-os == 'windows') - run: > - python3 -m pip install -r llvm-project/mlir/python/requirements.txt - - cmake -GNinja -Bllvm-project/build - -DCMAKE_BUILD_TYPE=Release - -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=cl - -DCMAKE_INSTALL_PREFIX="${{ env.llvm_install_dir }}" - -DLLVM_BUILD_UTILS=ON - -DLLVM_BUILD_TOOLS=ON - -DLLVM_ENABLE_ASSERTIONS=ON - -DMLIR_ENABLE_BINDINGS_PYTHON=ON - -DLLVM_ENABLE_PROJECTS="clang;mlir" - -DLLVM_INSTALL_UTILS=ON - -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" - -DLLVM_ENABLE_TERMINFO=OFF - llvm-project/llvm - - ninja -C llvm-project/build check-mlir install - - tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" - - - - name: Configure, Build, and Install LLVM (ubuntu arm64) - if: matrix.config.arch == 'arm64' && matrix.config.target-os == 'ubuntu' - run: | - python3 -m pip install -r llvm-project/mlir/python/requirements.txt - mkdir arm-sysroot - mkdir -p llvm-project/host-tools - cd llvm-project/host-tools - cmake -GNinja ../llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_PROJECTS="mlir;llvm;clang" - ninja mlir-tblgen - ninja llvm-tblgen - ninja clang-tblgen - cd ../.. - mv ./llvm-project/host-tools/bin ./host-tools - HOST_TOOLS="$(pwd)/host-tools" - rm -rf llvm-project/host-tools - sudo apt-get update - sudo apt-get install gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf qemu-user-static gcc-aarch64-linux-gnu g++-aarch64-linux-gnu - cp -r /usr/aarch64-linux-gnu/lib ./arm-sysroot - cp -r /usr/aarch64-linux-gnu/include ./arm-sysroot - LINKER=$(pwd)/arm-sysroot/lib/ld-linux-aarch64.so.1 - wget http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_13.2.0-7_amd64.deb - dpkg-deb -x gcc-aarch64-linux-gnu_13.2.0-7_amd64.deb ./arm-sysroot - export LD_LIBRARY_PATH=$(pwd)/arm-sysroot/lib:$LD_LIBRARY_PATH - sudo ln -s $LINKER /lib/ld-linux-aarch64.so.1 - SYSROOT="$(pwd)/arm-sysroot" - echo $SYSROOT - echo $LINKER - cmake -GNinja -Bllvm-project/build \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_PROJECTS="mlir;llvm" \ - -DLLVM_BUILD_UTILS=ON \ - -DLLVM_TABLEGEN=$HOST_TOOLS/llvm-tblgen \ - -DMLIR_TABLEGEN=$HOST_TOOLS/mlir-tblgen \ - -DCLANG_TABLEGEN=$HOST_TOOLS/clang-tblgen \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DCMAKE_LINKER=$LINKER \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DLLVM_ENABLE_ZSTD=OFF \ - -DLLVM_INSTALL_UTILS=ON \ - -DCMAKE_INSTALL_PREFIX="${{ env.llvm_install_dir }}" \ - -DLLVM_TARGETS_TO_BUILD="AArch64;NVPTX;AMDGPU" \ - -DCMAKE_CROSSCOMPILING=True \ - -DLLVM_TARGET_ARCH=AArch64 \ - -DLLVM_DEFAULT_TARGET_TRIPLE=aarch64-linux-gnu \ - -DLLVM_USE_HOST_TOOLS=OFF \ - -DCMAKE_C_COMPILER="/usr/bin/aarch64-linux-gnu-gcc" \ - -DCMAKE_CXX_COMPILER="/usr/bin/aarch64-linux-gnu-g++" \ - -DCMAKE_ASM_COMPILER="/usr/bin/aarch64-linux-gnu-as" \ - -DCMAKE_AR="/usr/bin/aarch64-linux-gnu-ar" \ - -DCMAKE_NM="/usr/bin/aarch64-linux-gnu-nm" \ - -DCMAKE_OBJCOPY="/usr/bin/aarch64-linux-gnu-objcopy" \ - -DCMAKE_OBJDUMP="/usr/bin/aarch64-linux-gnu-objdump" \ - -DCMAKE_RANLIB="/usr/bin/aarch64-linux-gnu-ranlib" \ - -DCMAKE_STRIP="/usr/bin/aarch64-linux-gnu-strip" \ - -DCMAKE_SYSROOT=$SYSROOT \ - -DLLVM_ENABLE_TERMINFO=OFF \ - llvm-project/llvm - ninja -C llvm-project/build install - CURR_PWD="$(pwd)" - cd "${{ env.llvm_install_dir }}/python_packages/mlir_core/mlir/_mlir_libs/" - for file in *x86_64*; do - mv "$file" "${file/x86_64/aarch64}" - done - cd $CURR_PWD - tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" - - - name: Configure, Build, and Install LLVM (macOS arm64) - if: matrix.config.arch == 'arm64' && matrix.config.target-os == 'macos' - run: > - python3 -m pip install -r llvm-project/mlir/python/requirements.txt - - cmake -GNinja -Bllvm-project/build - -DCMAKE_BUILD_TYPE=Release - -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ - -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache - -DCMAKE_INSTALL_PREFIX="${{ env.llvm_install_dir }}" - -DCMAKE_LINKER=lld - -DCMAKE_OSX_ARCHITECTURES=arm64 - -DLLVM_BUILD_UTILS=ON - -DLLVM_BUILD_TOOLS=ON - -DLLVM_ENABLE_ASSERTIONS=ON - -DMLIR_ENABLE_BINDINGS_PYTHON=ON - -DLLVM_ENABLE_PROJECTS=mlir - -DLLVM_ENABLE_ZSTD=OFF - -DLLVM_INSTALL_UTILS=ON - -DLLVM_TARGETS_TO_BUILD="AArch64;NVPTX;AMDGPU" - -DLLVM_USE_HOST_TOOLS=ON - -DLLVM_ENABLE_TERMINFO=OFF - llvm-project/llvm - - ninja -C llvm-project/build install - - tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" - - - - name: Configure, Build, Test, and Install LLVM (CentOS) - if: matrix.config.target-os == 'centos' - run: | - # if this step crashes, it can leave behind a stale docker container - docker container prune -f - docker rmi -f $(docker images -q) - - docker build --tag llvm-build --build-arg llvm_dir=llvm-project \ - -f llvm-build/.github/workflows/llvm-build/centos.Dockerfile . - - # Create temporary container to copy cache and installed artifacts. - CONTAINER_ID=$(docker create llvm-build) - docker cp "${CONTAINER_ID}:/install" "${{ env.llvm_install_dir }}" - tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" - - # We remove the existing directory, otherwise docker will - # create a subdirectory inside the existing directory. - rm -rf "${{ env.SCCACHE_DIR }}" - docker cp "${CONTAINER_ID}:/sccache" "${{ env.SCCACHE_DIR }}" - sudo chown -R "$(id -u -n):$(id -g -n)" "${{ env.SCCACHE_DIR }}" - - docker rm "${CONTAINER_ID}" - - - name: Configure, Build, Test, and Install LLVM (AlmaLinux) - if: matrix.config.target-os == 'almalinux' - run: | - # if this step crashes, it can leave behind a stale docker container - docker container prune -f - docker rmi -f $(docker images -q) - - docker build --tag llvm-build --build-arg llvm_dir=llvm-project \ - -f llvm-build/.github/workflows/llvm-build/almalinux.Dockerfile . - - # Create temporary container to copy cache and installed artifacts. - CONTAINER_ID=$(docker create llvm-build) - docker cp "${CONTAINER_ID}:/install" "${{ env.llvm_install_dir }}" - tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}" - - # We remove the existing directory, otherwise docker will - # create a subdirectory inside the existing directory. - rm -rf "${{ env.SCCACHE_DIR }}" - docker cp "${CONTAINER_ID}:/sccache" "${{ env.SCCACHE_DIR }}" - sudo chown -R "$(id -u -n):$(id -g -n)" "${{ env.SCCACHE_DIR }}" - - docker rm "${CONTAINER_ID}" - - - name: Upload Build Artifacts - uses: actions/upload-artifact@v4 - with: - name: llvm-${{ matrix.config.target-os }}-${{ matrix.config.arch }} - path: | - ${{ github.workspace }}/llvm-*-${{ matrix.config.target-os }}-${{ matrix.config.arch }}.tar.gz - - - name: Azure Login - if: ${{ (github.repository == 'triton-lang/triton') }} - uses: azure/login@v2 - with: - client-id: ${{ secrets.AZURE_CLIENT_ID }} - tenant-id: ${{ secrets.AZURE_TENANT_ID }} - subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }} - - - name: Upload LLVM Artifacts to Azure - if: ${{ (github.repository == 'triton-lang/triton') }} - run: | - az storage blob upload --account-name tritonlang --auth-mode login --container-name llvm-builds --file "${{ env.llvm_install_dir }}.tar.gz" --name "${{ env.llvm_install_dir }}.tar.gz" --overwrite - - URL=$(az storage blob url --account-name tritonlang --auth-mode login --container-name llvm-builds --name "${{ env.llvm_install_dir }}.tar.gz") - echo "Blob URL: ${URL}" - - - name: Azure Logout - if: ${{ (github.repository == 'triton-lang/triton') }} - run: | - az logout - az cache purge - az account clear - - - name: Dump Sccache Statistics - run: sccache --show-stats diff --git a/.github/workflows/llvm-build/almalinux.Dockerfile b/.github/workflows/llvm-build/almalinux.Dockerfile deleted file mode 100644 index adf8b5cc6..000000000 --- a/.github/workflows/llvm-build/almalinux.Dockerfile +++ /dev/null @@ -1,39 +0,0 @@ -FROM almalinux:8 -ARG llvm_dir=llvm-project -# Add the cache artifacts and the LLVM source tree to the container -ADD sccache /sccache -ADD "${llvm_dir}" /source/llvm-project -ENV SCCACHE_DIR="/sccache" -ENV SCCACHE_CACHE_SIZE="2G" - -RUN dnf install --assumeyes llvm-toolset -RUN dnf install --assumeyes python38-pip python38-devel git - -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --upgrade cmake ninja sccache lit - -# Install MLIR's Python Dependencies -RUN python3 -m pip install -r /source/llvm-project/mlir/python/requirements.txt - -# Configure, Build, Test, and Install LLVM -RUN cmake -GNinja -Bbuild \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_ASM_COMPILER=clang \ - -DCMAKE_C_COMPILER_LAUNCHER=sccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=sccache \ - -DCMAKE_CXX_FLAGS="-Wno-everything" \ - -DCMAKE_LINKER=lld \ - -DCMAKE_INSTALL_PREFIX="/install" \ - -DLLVM_BUILD_UTILS=ON \ - -DLLVM_BUILD_TOOLS=ON \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DLLVM_ENABLE_PROJECTS=mlir \ - -DLLVM_ENABLE_TERMINFO=OFF \ - -DLLVM_INSTALL_UTILS=ON \ - -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" \ - /source/llvm-project/llvm - -RUN ninja -C build install diff --git a/.github/workflows/llvm-build/centos.Dockerfile b/.github/workflows/llvm-build/centos.Dockerfile deleted file mode 100644 index cd7e852a8..000000000 --- a/.github/workflows/llvm-build/centos.Dockerfile +++ /dev/null @@ -1,43 +0,0 @@ -FROM centos:7 -ARG llvm_dir=llvm-project -# Add the cache artifacts and the LLVM source tree to the container -ADD sccache /sccache -ADD "${llvm_dir}" /source/llvm-project -ENV SCCACHE_DIR="/sccache" -ENV SCCACHE_CACHE_SIZE="2G" - -RUN echo -e "[llvmtoolset-build]\nname=LLVM Toolset 13.0 - Build\nbaseurl=https://buildlogs.centos.org/c7-llvm-toolset-13.0.x86_64/\ngpgcheck=0\nenabled=1" > /etc/yum.repos.d/llvmtoolset-build.repo -# Install build dependencies -RUN yum install --assumeyes centos-release-scl -RUN yum install --assumeyes --nogpgcheck llvm-toolset-13.0 -RUN yum install --assumeyes rh-python38-python-devel rh-python38-python-pip -SHELL [ "/usr/bin/scl", "enable", "llvm-toolset-13.0", "rh-python38" ] - -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --upgrade cmake ninja sccache - -# Install MLIR's Python Dependencies -RUN python3 -m pip install -r /source/llvm-project/mlir/python/requirements.txt - -# Configure, Build, Test, and Install LLVM -RUN cmake -GNinja -Bbuild \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_ASM_COMPILER=clang \ - -DCMAKE_C_COMPILER_LAUNCHER=sccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=sccache \ - -DCMAKE_CXX_FLAGS="-Wno-everything" \ - -DCMAKE_LINKER=lld \ - -DCMAKE_INSTALL_PREFIX="/install" \ - -DLLVM_BUILD_UTILS=ON \ - -DLLVM_BUILD_TOOLS=ON \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DLLVM_ENABLE_PROJECTS=mlir \ - -DLLVM_ENABLE_TERMINFO=OFF \ - -DLLVM_INSTALL_UTILS=ON \ - -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" \ - /source/llvm-project/llvm - -RUN ninja -C build install diff --git a/.github/workflows/metax-build-and-test.yml b/.github/workflows/metax-build-and-test.yml new file mode 100644 index 000000000..c760d19b4 --- /dev/null +++ b/.github/workflows/metax-build-and-test.yml @@ -0,0 +1,28 @@ +name: Metax-Build-And-Test + +on: + workflow_call: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + metax-build-and-test: + runs-on: metax + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: FlagTree Build on Metax + shell: bash + run: | + source ~/env.sh + export FLAGTREE_BACKEND=metax + cd python + MAX_JOBS=20 pip3 install . --no-build-isolation + + - name: FlagTree Test on Metax + shell: bash + run: | + pytest -s python/test/unit diff --git a/.github/workflows/mthreads-build-and-test.yml b/.github/workflows/mthreads-build-and-test.yml new file mode 100644 index 000000000..b3474802e --- /dev/null +++ b/.github/workflows/mthreads-build-and-test.yml @@ -0,0 +1,28 @@ +name: Mthreads-Build-And-Test + +on: + workflow_call: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + mthreads-build-and-test: + runs-on: mthreads + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: FlagTree Build on Mthreads + shell: bash + run: | + source ~/env.sh + export FLAGTREE_BACKEND=mthreads + cd python + MAX_JOBS=20 pip3 install . --no-build-isolation + + - name: FlagTree Test on Mthreads + shell: bash + run: | + pytest -s python/test/unit diff --git a/.github/workflows/nv-build-and-test.yml b/.github/workflows/nv-build-and-test.yml new file mode 100644 index 000000000..f0e14af99 --- /dev/null +++ b/.github/workflows/nv-build-and-test.yml @@ -0,0 +1,60 @@ +name: NV-Build-And-Test + +on: + schedule: + - cron: '0 21 * * *' + push: + branches: [ "main", "triton_v3.2.x", "triton_v3.3.x" ] + pull_request: + branches: [ "main", "triton_v3.2.x", "triton_v3.3.x" ] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + nv-build-and-test: + runs-on: nv-jiuding + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Detect Target Branch + shell: bash + run: | + if [ "${{ github.event_name }}" = "pull_request" ]; then + TARGET_BRANCH="${{ github.base_ref }}" + else + TARGET_BRANCH="${{ github.ref_name }}" + fi + echo "TARGET_BRANCH=$TARGET_BRANCH" >> $GITHUB_ENV + echo "TARGET_BRANCH=$TARGET_BRANCH" + + - name: FlagTree Build (Main branch) + if: ${{ env.TARGET_BRANCH == 'main' }} + shell: bash + run: | + source ~/env.sh + cd python + MAX_JOBS=32 pip3.11 install . --no-build-isolation + + - name: FlagTree Build (triton_v3.2.x branch) + if: ${{ env.TARGET_BRANCH == 'triton_v3.2.x' }} + shell: bash + run: | + source ~/env-3.2.sh + cd python + MAX_JOBS=32 pip3.11 install . --no-build-isolation + + - name: FlagTree Build (triton_v3.3.x branch) + if: ${{ env.TARGET_BRANCH == 'triton_v3.3.x' }} + shell: bash + run: | + source ~/env-3.3.sh + cd python + MAX_JOBS=32 pip3.11 install . --no-build-isolation + + - name: FlagTree Test + shell: bash + run: | + pytest -s python/test/unit diff --git a/.github/workflows/test-backends.yml b/.github/workflows/test-backends.yml deleted file mode 100644 index 5b0d3f0b8..000000000 --- a/.github/workflows/test-backends.yml +++ /dev/null @@ -1,84 +0,0 @@ -name: Backend Tests - -on: - workflow_dispatch: - push: - branches: [main] - -permissions: read-all - -jobs: - Runner-Preparation: - runs-on: ubuntu-latest - outputs: - matrix-optional: ${{ steps.set-matrix.outputs.matrix-optional }} - steps: - - name: Prepare runner matrix - id: set-matrix - run: | - if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then - echo '::set-output name=matrix-optional::[["self-hosted", "gfx90a"], ["self-hosted", "arc770"]]' - else - echo '::set-output name=matrix-optional::["ubuntu-latest"]' - fi - - Integration-Tests-Intel: - needs: Runner-Preparation - timeout-minutes: 30 - if: false && ((github.event_name == 'workflow_dispatch') || (github.event_name == 'push' && github.ref == 'refs/heads/main')) - - runs-on: ${{ matrix.runner }} - - strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-optional)}} - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Set XPU ENV - if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'arc770')}} - run: | - echo "BACKEND=XPU" >> "${GITHUB_ENV}" - - - name: Clear cache - run: | - rm -rf ~/.triton - - - name: Update PATH - run: | - echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" - - - name: Check pre-commit arc770 - if: ${{ matrix.runner != 'macos-10.15' && (matrix.runner[1] == 'arc770') }} - run: | - source ${HOME}/triton_vars.sh - source ${HOME}/miniconda3/bin/activate - conda activate triton-xpu-ci - python3 -m pip install --upgrade pre-commit - python3 -m pre_commit run --all-files - - - name: Install Triton on XPU - if: ${{ env.BACKEND == 'XPU'}} - run: | - source ${HOME}/triton_vars.sh - source ${HOME}/miniconda3/bin/activate - conda activate triton-xpu-ci - git submodule update --init --recursive - cd python - python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 - export TRITON_CODEGEN_INTEL_XPU_BACKEND=1 - python3 -m pip uninstall -y triton - python3 setup.py build - python3 -m pip install --no-build-isolation -vvv '.[tests]' - - - name: Run python tests on XPU - if: ${{ env.BACKEND == 'XPU'}} - run: | - source ${HOME}/triton_vars.sh - source ${HOME}/miniconda3/bin/activate - conda activate triton-xpu-ci - cd python/test/backend/third_party_backends - python3 -m pytest --capture=tee-sys -rfs --verbose --backend xpu diff --git a/.github/workflows/torch-inductor-tests.yml b/.github/workflows/torch-inductor-tests.yml deleted file mode 100644 index 3d8f98095..000000000 --- a/.github/workflows/torch-inductor-tests.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Torchinductor - -on: - workflow_run: - workflows: ["Wheels"] - types: [completed] - workflow_dispatch: - -permissions: read-all - -jobs: - Runner-Preparation: - runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.set-matrix.outputs.matrix }} - steps: - - name: Prepare runner matrix - id: set-matrix - run: | - echo '::set-output name=matrix::[["self-hosted", "A100"]]' - - Torch-Inductor-Tests: - needs: Runner-Preparation - timeout-minutes: 240 # 4 hours - runs-on: ${{ matrix.runner }} - strategy: - matrix: - runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix)}} - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Packages - run: | - ./.github/workflows/torch-inductor/scripts/install_torchinductor.sh torchbench - - name: Environment - run: | - source /tmp/torchinductor_venv/bin/activate - ./.github/workflows/torch-inductor/scripts/install_triton.sh - - name: Performance - run: | - ./.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh torchbench - # Runs too long time - #- name: Accuracy - # run: | - # ./.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh torchbench diff --git a/.github/workflows/torch-inductor/scripts/check_acc.py b/.github/workflows/torch-inductor/scripts/check_acc.py deleted file mode 100644 index c89976aca..000000000 --- a/.github/workflows/torch-inductor/scripts/check_acc.py +++ /dev/null @@ -1,11 +0,0 @@ -import csv -import sys - -file_path = sys.argv[1] -with open(file_path) as f: - reader = csv.reader(f) - for i, row in enumerate(reader): - if i == 0: - continue - if row[3] != "pass": - print(f"{row[1]} failed on device {row[0]} with batch size {row[2]}") diff --git a/.github/workflows/torch-inductor/scripts/check_perf.py b/.github/workflows/torch-inductor/scripts/check_perf.py deleted file mode 100644 index 212eadad5..000000000 --- a/.github/workflows/torch-inductor/scripts/check_perf.py +++ /dev/null @@ -1,70 +0,0 @@ -import argparse -import csv -from collections import namedtuple - -# Create a named tuple for the output of the benchmark -BenchmarkOutput = namedtuple('BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency']) - - -def parse_output(file_path: str) -> dict: - entries = {} - with open(file_path) as f: - reader = csv.reader(f) - for i, row in enumerate(reader): - if i == 0 or len(row) < 5: - continue - dev = row[0] - name = row[1] - batch_size = row[2] - speedup = float(row[3]) - latency = float(row[4]) - entries[name] = BenchmarkOutput(dev, name, batch_size, speedup, latency) - return entries - - -def compare(baseline: dict, new: dict, threshold: float, geomean_threshold: float) -> bool: - baseline_geomean = 1.0 - new_geomean = 1.0 - for key in new: - if key not in baseline: - print(f"New benchmark {key} not found in baseline") - baseline_latency = baseline[key].latency - new_latency = new[key].latency - if baseline_latency == 0: - print(f"Baseline latency for {key} is 0") - continue - elif new_latency == 0: - print(f"New latency for {key} is 0") - continue - - if new_latency < baseline_latency * (1 - threshold): - print(f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}") - elif new_latency > baseline_latency * (1 + threshold): - print(f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}") - else: - print(f"New benchmark {key} is within threshold: {new_latency} vs {baseline_latency}") - baseline_geomean *= baseline[key].speedup - new_geomean *= new[key].speedup - - baseline_geomean = baseline_geomean**(1 / len(baseline)) - new_geomean = new_geomean**(1 / len(new)) - print(f"Baseline geomean: {baseline_geomean}") - print(f"New geomean: {new_geomean}") - assert new_geomean >= baseline_geomean * (1 - geomean_threshold), \ - f"New geomean is slower than baseline: {new_geomean} vs {baseline_geomean}" - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--baseline', required=True) - parser.add_argument('--new', required=True) - parser.add_argument('--threshold', type=float, default=0.1) - parser.add_argument('--geomean-threshold', type=float, default=0.02) - args = parser.parse_args() - baseline = parse_output(args.baseline) - new = parse_output(args.new) - compare(baseline, new, args.threshold, args.geomean_threshold) - - -if __name__ == "__main__": - main() diff --git a/.github/workflows/torch-inductor/scripts/common.sh b/.github/workflows/torch-inductor/scripts/common.sh deleted file mode 100755 index 7e212a06a..000000000 --- a/.github/workflows/torch-inductor/scripts/common.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -TEST_REPORTS_DIR=/tmp/torchinductor_reports -PYTORCH_DIR=/tmp/pytorch -MODELS=(timm_models huggingface torchbench) - -echo "$TEST_REPORTS_DIR" -echo "$PYTORCH_DIR" -echo "${MODELS[@]}" diff --git a/.github/workflows/torch-inductor/scripts/install_torchinductor.sh b/.github/workflows/torch-inductor/scripts/install_torchinductor.sh deleted file mode 100755 index 18bea1f17..000000000 --- a/.github/workflows/torch-inductor/scripts/install_torchinductor.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" -MODEL_SPEC=$1 - -# torchinductor venv -whoami - -sudo apt-get update && sudo apt-get install -y python3-venv libgl1 - -# clean up old venv -rm -rf /tmp/torchinductor_venv -python3 -m venv /tmp/torchinductor_venv -# shellcheck source=/dev/null -source /tmp/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source ./.github/workflows/torch-inductor/scripts/common.sh - -pip3 install --upgrade pip wheel setuptools - -# Install torchtext stable first. Bundling it in the same install as torch -# nightly forces torch stable release to be installed instead. -# From https://github.com/pytorch/text?tab=readme-ov-file#torchtext, -# "WARNING: TorchText development is stopped and the 0.18 release (April 2024) -# will be the last stable release of the library." -pip3 install --force-reinstall torchtext - -# pytorch nightly -pip3 install --force-reinstall --pre torch torchvision torchaudio torchrec --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -# pytorch source to get torchbench for dynamo -cd /tmp || exit -# cleanup old pytorch -rm -rf pytorch -git clone --recursive https://github.com/pytorch/pytorch -cd pytorch || exit -# if you are updating an existing checkout -git submodule sync -git submodule update --init --recursive -cd .. - -# required packages -# https://github.com/pytorch/benchmark/blob/main/docker/gcp-a100-runner-dind.dockerfile#L17 -sudo apt-get install --yes libpango-1.0-0 libpangoft2-1.0-0 -pip3 install expecttest psutil lightning-utilities pyre_extensions - -# torchbench -if [ "$MODEL_SPEC" == "torchbench" ] || [ "$MODEL_SPEC" != "all" ]; then - # clean up old torchbench - rm -rf benchmark - pip3 install pyyaml - git clone https://github.com/pytorch/benchmark.git - cd benchmark || exit - python3 install.py - cd .. -fi - -# timm -if [ "$MODEL_SPEC" == "timm_models" ] || [ "$MODEL_SPEC" != "all" ]; then - # clean up old timm - rm -rf pytorch-image-models - git clone https://github.com/huggingface/pytorch-image-models.git - cd pytorch-image-models || exit - pip3 install -e . - cd .. -fi - -# clean up cache -rm -rf /tmp/torchinductor_"$(whoami)"/ -rm -rf ~/.triton/cache -rm -rf "$TEST_REPORTS_DIR" - -# go back to where we started -cd "$ROOT" || exit diff --git a/.github/workflows/torch-inductor/scripts/install_triton.sh b/.github/workflows/torch-inductor/scripts/install_triton.sh deleted file mode 100755 index 43367a02f..000000000 --- a/.github/workflows/torch-inductor/scripts/install_triton.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" - -# shellcheck source=/dev/null -source /tmp/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source ./.github/workflows/torch-inductor/scripts/common.sh - -# Triton build-time dependencies -pip3 install --upgrade cmake ninja lit - -# build our own triton and preserve the wheel build for later re-use in this test run. -cd python || exit -pip3 uninstall pytorch-triton -y -rm -rf build dist -python3 setup.py bdist_wheel -pip3 install dist/triton*.whl - -# clean up cache -rm -rf ~/.triton/cache - -# go back to where we started -cd "$ROOT" || exit diff --git a/.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh b/.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh deleted file mode 100755 index aefd798f3..000000000 --- a/.github/workflows/torch-inductor/scripts/run_torchinductor_acc.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" -INDUCTOR="$ROOT"/.github/workflows/torch-inductor -MODEL_SPEC=$1 - -# shellcheck source=/dev/null -source /tmp/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source "$INDUCTOR"/scripts/common.sh - -# Dependency of 'torch/fx/experimental/validator.py'. -pip3 install --upgrade z3-solver - -# Install our own triton. -pip3 uninstall pytorch-triton -y -cd $ROOT/python || exit -if [ -d "./dist" ]; then - pip3 install dist/triton*.whl -else - rm -rf build - pip3 install -e . -fi - -cd "$PYTORCH_DIR" || exit -TEST_REPORTS_DIR=$TEST_REPORTS_DIR/acc -mkdir -p "$TEST_REPORTS_DIR" - -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Running accuracy test for $model" - python3 benchmarks/dynamo/"$model".py --ci --accuracy --timing --explain --inductor --inference --device cuda \ - --output "$TEST_REPORTS_DIR"/inference_"$model".csv - python3 benchmarks/dynamo/"$model".py --ci --accuracy --timing --explain --inductor --training --amp --device cuda \ - --output "$TEST_REPORTS_DIR"/training_"$model".csv - python3 benchmarks/dynamo/"$model".py --ci --accuracy --timing --explain --inductor --training --dynamic-shapes --device cuda \ - --output "$TEST_REPORTS_DIR"/dynamic_shapes_"$model".csv -done - -cd "$ROOT" || exit -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Checking accuracy test for $model" - python3 "$INDUCTOR"/scripts/check_acc.py "$TEST_REPORTS_DIR"/inference_"$model".csv - python3 "$INDUCTOR"/scripts/check_acc.py "$TEST_REPORTS_DIR"/training_"$model".csv - python3 "$INDUCTOR"/scripts/check_acc.py "$TEST_REPORTS_DIR"/dynamic_shapes_"$model".csv -done - -# go back to where we started -cd "$ROOT" || exit diff --git a/.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh b/.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh deleted file mode 100755 index 35853d97c..000000000 --- a/.github/workflows/torch-inductor/scripts/run_torchinductor_perf.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash - -# remember where we started -ROOT="$(pwd)" -INDUCTOR="$ROOT"/.github/workflows/torch-inductor -MODEL_SPEC=$1 - -# shellcheck source=/dev/null -source /tmp/torchinductor_venv/bin/activate -# shellcheck source=/dev/null -source "$INDUCTOR"/scripts/common.sh - -# lock GPU clocks to 1350 MHz -sudo nvidia-smi -i 0 -pm 1 -sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350 - -cd "$PYTORCH_DIR" || exit -TRITON_TEST_REPORTS_DIR=$TEST_REPORTS_DIR/perf -BASE_TEST_REPORTS_DIR=$TEST_REPORTS_DIR/acc -mkdir -p "$TRITON_TEST_REPORTS_DIR" -mkdir -p "$BASE_TEST_REPORTS_DIR" - -# Dependency of 'pytorch/benchmarks/dynamo/common.py'. -pip3 install pandas scipy - -echo "Running with Triton Nightly" -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Running performance test for $model" - python3 benchmarks/dynamo/"$model".py --ci --float32 --training --inductor --performance --device cuda \ - --output "$TRITON_TEST_REPORTS_DIR"/"$model".csv -done - -# install pytorch-triton -pip3 uninstall triton -y -pip3 install --pre pytorch-triton --extra-index-url https://download.pytorch.org/whl/nightly/cu121 - -echo "Running with pytorch-triton" -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Running performance test for $model" - python3 benchmarks/dynamo/"$model".py --ci --float32 --training --inductor --performance --device cuda \ - --output "$BASE_TEST_REPORTS_DIR"/"$model".csv -done - -# uninstall pytorch-triton -pip3 uninstall pytorch-triton -y - -cd "$ROOT" || exit -for model in "${MODELS[@]}"; do - if [ "$model" != "$MODEL_SPEC" ] && [ "$MODEL_SPEC" != "all" ]; then - continue - fi - echo "Checking performance test for $model" - python3 "$INDUCTOR"/scripts/check_perf.py --new "$TRITON_TEST_REPORTS_DIR"/"$model".csv --baseline "$BASE_TEST_REPORTS_DIR"/"$model".csv - EXIT_STATUS=$? - if [ "$EXIT_STATUS" -ne 0 ]; then - echo "Performance test for $model failed" - exit "$EXIT_STATUS" - fi -done - -# unlock GPU clocks -sudo nvidia-smi -i 0 -rgc - -# go back to where we started -cd "$ROOT" || exit diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml deleted file mode 100644 index b1b418779..000000000 --- a/.github/workflows/wheels.yml +++ /dev/null @@ -1,72 +0,0 @@ -name: Wheels -on: - workflow_dispatch: - schedule: - - cron: "20 2 * * *" - -jobs: - - Build-Wheels: - timeout-minutes: 60 - - runs-on: [self-hosted, CPU] - permissions: - id-token: write - contents: read - - steps: - - - name: Prune stale docker containers - run: | - # If cibuildwheel crashes (or, say, is OOM-killed), it leaves behind a - # docker container. Eventually these consume all the disk space on - # this machine. - docker container prune -f - - - name: Checkout - uses: actions/checkout@v3 - - # The LATEST_DATE here should be kept in sync with the one in Patch setup.py - - id: check-version - name: Check latest version - run: | - export PACKAGE_DATE=$(python3 -m pip install --user --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ --dry-run triton-nightly== |& grep -oP '(?<=, )[0-9\.]+dev[0-9]+(?=\))' | grep -oP '(?<=dev)[0-9]+') - export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d%H%M%S' --format="%cd") - if cmp -s <(echo $PACKAGE_DATE) <(echo $LATEST_DATE); then - echo "new_commit=false" >> "$GITHUB_OUTPUT" - else - echo "new_commit=true" >> "$GITHUB_OUTPUT" - fi - - - name: Patch setup.py - if: ${{ steps.check-version.outputs.new_commit == 'true' }} - run: | - echo "" >> python/setup.cfg - echo "[build_ext]" >> python/setup.cfg - echo "base-dir=/project" >> python/setup.cfg - - - name: Build wheels - if: ${{ steps.check-version.outputs.new_commit == 'true' }} - run: | - export LATEST_DATE=$(TZ=UTC0 git show --quiet --date='format-local:%Y%m%d%H%M%S' --format="%cd") - # Pass MAX_JOBS=4 because, at time of writing, the VM "only" has 32GB - # of RAM and OOMs while building if we give it the default number of - # workers (2 * NUM_CPUs). - # - # Sadly, I couldn't make TRITON_BUILD_WITH_CLANG_LLD=1 work. The - # manylinux image has a relatively recent gcc (v10, released 2020), - # but its clang is ancient, v3.4, released in 2014 (!). I tried - # installing the prebuilt clang 10 binary distributed by LLVM, and I - # quickly ran into Linux DLL hell. I give up, for now. Perhaps - # manylinux_x_y will save us; I didn't try. - export CIBW_ENVIRONMENT="MAX_JOBS=4 TRITON_WHEEL_NAME=triton" - export CIBW_MANYLINUX_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest" - #export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest" - export CIBW_BEFORE_BUILD="pip install cmake;" - export CIBW_SKIP="cp{35,36,37}-*" - export CIBW_BUILD="cp3*-manylinux_x86_64" - python3 -m cibuildwheel python --output-dir wheelhouse - - - name: Upload wheels to PyPI - run: | - python3 -m twine upload wheelhouse/* -u __token__ -p ${{ secrets.PYPY_API_TOKEN }} diff --git a/.gitignore b/.gitignore index cb802b88c..61ca20637 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ build-*/ python/build/ python/dist/ python/triton*.egg-info/ +python/_deps/ python/triton/_C/*.pyd python/triton/_C/*.so @@ -17,9 +18,19 @@ python/triton/backends/ !python/triton/backends/compiler.py !python/triton/backends/driver.py +# flagtree backend +third_party/cambricon/ +third_party/iluvatar/iluvatarTritonPlugin.so +third_party/triton_shared/ +third_party/xpu/backend/xpu3 +third_party/ascend + # Proton python/triton/profiler +# Instrumentation +python/triton/instrumentation + # Python caches __pycache__/ *.py[cod] @@ -45,6 +56,8 @@ ptxas # Third-party include third_party/nvidia/backend/include +third_party/nvidia/backend/lib/cupti + # Docs docs/_build/ @@ -64,3 +77,6 @@ docs/sg_execution_times.rst # Vim *.swp + +# macOS +.DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4c65f8a1..f2aab636b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: - id: ruff files: '^python/.*' args: ["--fix", "--line-length", "120"] - stages: [commit, push, manual] + stages: [pre-commit, pre-push, manual] exclude: | (?x)( ^python/triton/runtime/.*| @@ -35,14 +35,14 @@ repos: hooks: - id: yapf args: ["-p", "-i"] - stages: [commit, push, manual] + stages: [pre-commit, pre-push, manual] exclude: "python/test/unit/language/test_line_info.py" - repo: https://github.com/pre-commit/mirrors-clang-format rev: v16.0.6 hooks: - id: clang-format - stages: [commit, push, manual] + stages: [pre-commit, pre-push, manual] # Expand YAML anchors in files used by github workflows, because github can't # do this itself. This lets us use anchors, which avoids code duplication. diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f53a2602..8509a2c62 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,9 +12,45 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_INCLUDE_CURRENT_DIR ON) +# FLAGTREE Options +set(FLAGTREE_BACKEND "$ENV{FLAGTREE_BACKEND}") +if(NOT FLAGTREE_BACKEND) + add_definitions(-D__NVIDIA__) + add_definitions(-D__AMD__) +elseif(FLAGTREE_BACKEND STREQUAL "iluvatar") + add_definitions(-D__ILUVATAR__) + remove_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) + add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +elseif(FLAGTREE_BACKEND STREQUAL "mthreads") + set(ENV{PATH} "$ENV{LLVM_SYSPATH}/bin:$ENV{PATH}") + set(CMAKE_C_COMPILER clang) + set(CMAKE_CXX_COMPILER clang++) + set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND}) +elseif(FLAGTREE_BACKEND STREQUAL "ascend") + set(CMAKE_C_COMPILER clang) + set(CMAKE_CXX_COMPILER clang++) +endif() +set(FLAGTREE_PLUGIN "$ENV{FLAGTREE_PLUGIN}") +if(FLAGTREE_PLUGIN) + add_definitions(-D__FLAGTREE_PLUGIN__) +endif() + project(triton) include(CTest) +if (FLAGTREE_BACKEND STREQUAL "ascend") + set(TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}") + set(PATCHED_TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/ascend/triton_patch") + set(PATCHED_TRITON_LIBRARIES + "TritonIR" + ) + set(PATCHED_TRITON_DEPENDS + "TritonTableGen" + ) + include_directories(${PATCHED_TRITON_ROOT_DIR}/include) + include_directories(${PROJECT_BINARY_DIR}/third_party/ascend/triton_patch/include) # Tablegen'd files +endif() + if(NOT WIN32) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") endif() @@ -24,8 +60,13 @@ endif() # Options option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) -option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON) -option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) +if(FLAGTREE_BACKEND) + option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" OFF) + option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" OFF) +else() + option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON) + option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) +endif() set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") # Ensure Python3 vars are set correctly @@ -34,6 +75,8 @@ set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") # Customized release build type with assertions: TritonRelBuildWithAsserts set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") +set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1") +set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1") # Default build type if(NOT CMAKE_BUILD_TYPE) @@ -45,9 +88,18 @@ if(NOT WIN32) find_library(TERMINFO_LIBRARY tinfo) endif() +if(TRITON_BUILD_UT) + include(AddTritonUnitTest) +endif() + # Compiler flags -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") +set(BACKEND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/include) +if(FLAGTREE_BACKEND AND EXISTS "${BACKEND_INCLUDE_DIR}") + include_directories(${BACKEND_INCLUDE_DIR}) +else() + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +endif() +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") # ######### @@ -76,18 +128,54 @@ function(add_triton_object name) INTERFACE $ ) - - # add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) - if(ARG_DEPENDS) - add_dependencies(${name} ${ARG_DEPENDS}) - endif() - if(ARG_LINK_LIBS) - target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) + if (FLAGTREE_BACKEND STREQUAL "ascend") + set(patched_depends "") + foreach(dep ${ARG_DEPENDS}) + list(FIND PATCHED_TRITON_DEPENDS "${dep}" index) + if(index GREATER_EQUAL 0) + list(APPEND patched_depends "Patched_${dep}") + message(STATUS "Replace ${dep} by Patched_${dep} as a dependent of ${name}") + else() + list(APPEND patched_depends ${dep}) + endif() + endforeach() + if(patched_depends) + add_dependencies(${name} ${patched_depends}) + endif() + + set(patched_link_libs "") + foreach(lib ${ARG_LINK_LIBS}) + list(FIND PATCHED_TRITON_LIBRARIES "${lib}" index) + if(index GREATER_EQUAL 0) + list(APPEND patched_link_libs "Patched_${lib}") + message(STATUS "Replace ${lib} by Patched_${lib} to be linked by ${name}") + else() + list(APPEND patched_link_libs ${lib}) + endif() + endforeach() + if(patched_link_libs) + target_link_libraries(${name} PUBLIC ${patched_link_libs}) + endif() + else() + #add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) + if(ARG_DEPENDS) + add_dependencies(${name} ${ARG_DEPENDS}) + endif() + if(ARG_LINK_LIBS) + target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) + endif() endif() endfunction(add_triton_object) set_property(GLOBAL PROPERTY TRITON_LIBS "") function(add_triton_library name) + if (FLAGTREE_BACKEND STREQUAL "ascend") + list(FIND PATCHED_TRITON_LIBRARIES "${name}" index) + if(index GREATER_EQUAL 0) + message(STATUS "Adding Patched_${name} as a lib, instead of ${name}") + return() + endif() + endif() set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name}) add_triton_object(${name} ${ARGN}) llvm_update_compile_flags(${name}) @@ -95,25 +183,51 @@ endfunction() set_property(GLOBAL PROPERTY TRITON_PLUGINS "") function(add_triton_plugin name) - set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${name}) - add_triton_object(${name} ${ARGN}) + cmake_parse_arguments(ARG "" "SHARED_LIB" "LINK_LIBS" ${ARGN}) + if(ARG_SHARED_LIB) + set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${ARG_SHARED_LIB}) + else() + set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${name}) + add_triton_object(${name} ${ARGN}) + endif() endfunction() # Disable warnings that show up in external code (gtest;pybind11) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +if(FLAGTREE_BACKEND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +endif() include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) -include_directories(${PROJECT_SOURCE_DIR}/include) -include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files +if(FLAGTREE_BACKEND AND EXISTS ${BACKEND_INCLUDE_DIR}) + include_directories(${PROJECT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/include) + include_directories(${PROJECT_BINARY_DIR}/third_party/${FLAGTREE_BACKEND}/include) # Tablegen'd files +else() + include_directories(${PROJECT_SOURCE_DIR}/include) + include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files +endif() include_directories(${PROJECT_SOURCE_DIR}/third_party) include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files # link_directories(${LLVM_LIBRARY_DIR}) -add_subdirectory(include) -add_subdirectory(lib) +if (FLAGTREE_BACKEND STREQUAL "cambricon" OR FLAGTREE_BACKEND STREQUAL "ascend") + include_directories(${PROJECT_SOURCE_DIR}/include) + include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files + add_subdirectory(include) + add_subdirectory(lib) +elseif(NOT FLAGTREE_BACKEND) + add_subdirectory(include) + add_subdirectory(lib) +endif() + +if (FLAGTREE_BACKEND STREQUAL "ascend") + add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/include) + add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/lib) +endif() # find_package(PythonLibs REQUIRED) set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") @@ -132,7 +246,14 @@ endif() # ------ if(TRITON_BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") - set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) + set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/python/src) + + set(PATCHED_PYTHON_SRC_PATH ${PATCHED_TRITON_ROOT_DIR}/python/src) + include_directories(${PYTHON_SRC_PATH}) + + if(NOT (FLAGTREE_BACKEND AND EXISTS "${PYTHON_SRC_PATH}")) + set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) + endif() include_directories(${PYTHON_SRC_PATH}) if(PYTHON_INCLUDE_DIRS) @@ -204,6 +325,7 @@ if(TRITON_BUILD_PYTHON_MODULE) MLIRSCFToControlFlow MLIRIndexToLLVM MLIRGPUToROCDLTransforms + MLIRUBToLLVM # LLVM LLVMPasses @@ -213,8 +335,64 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMAMDGPUAsmParser ) + if(FLAGTREE_BACKEND STREQUAL "iluvatar") + set(TRITON_LIBRARIES + ${triton_libs} + ${triton_plugins} + + # mlir + MLIRNVVMDialect + MLIRNVVMToLLVMIRTranslation + MLIRGPUToNVVMTransforms + MLIRGPUToGPURuntimeTransforms + MLIRGPUTransforms + MLIRIR + MLIRControlFlowToLLVM + MLIRBytecodeWriter + MLIRPass + MLIRTransforms + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + MLIRMathToLLVM + MLIRGPUDialect + MLIRSCFToControlFlow + MLIRIndexToLLVM + + # LLVM + LLVMPasses + LLVMIluvatarCodeGen + LLVMIluvatarAsmParser + ) + elseif(FLAGTREE_BACKEND STREQUAL "xpu") + set(TRITON_LIBRARIES + ${triton_libs} + ${triton_plugins} + + # mlir + MLIRIR + MLIRControlFlowToLLVM + MLIRBytecodeWriter + MLIRPass + MLIRTransforms + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + MLIRMathToLLVM + MLIRGPUDialect + MLIRSCFToControlFlow + MLIRIndexToLLVM + + # LLVM + LLVMPasses + LLVMXPUCodeGen + LLVMXPUAsmParser + ) + endif() + if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" OR # Linux arm64 - CMAKE_SYSTEM_PROCESSOR MATCHES "arm64") # macOS arm64 + CMAKE_SYSTEM_PROCESSOR MATCHES "arm64" OR # macOS arm64 + CMAKE_OSX_ARCHITECTURES MATCHES "arm64") # also macOS arm64 list(APPEND TRITON_LIBRARIES LLVMAArch64CodeGen LLVMAArch64AsmParser @@ -224,8 +402,13 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMX86CodeGen LLVMX86AsmParser ) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le") + list(APPEND TRITON_LIBRARIES + LLVMPowerPCAsmParser + LLVMPowerPCCodeGen + ) else() - message(FATAL_ERROR "LLVM codegen/ASM parser libs: This HW architecture is not configured in cmake lib dependencies.") + message(FATAL_ERROR "LLVM codegen/ASM parser libs: This HW architecture (${CMAKE_SYSTEM_PROCESSOR}) is not configured in cmake lib dependencies.") endif() # Define triton library @@ -239,11 +422,20 @@ if(TRITON_BUILD_PYTHON_MODULE) set(TRITON_BACKENDS_TUPLE "(${TRITON_BACKENDS_TUPLE})") add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE}) - add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc - ${PYTHON_SRC_PATH}/ir.cc - ${PYTHON_SRC_PATH}/passes.cc - ${PYTHON_SRC_PATH}/interpreter.cc - ${PYTHON_SRC_PATH}/llvm.cc) + if(FLAGTREE_BACKEND STREQUAL "cambricon") + add_library(triton SHARED) + else() + if(FLAGTREE_BACKEND STREQUAL "ascend") + set(PYTHON_IR_SRC_PATH ${PATCHED_PYTHON_SRC_PATH}) + else() + set(PYTHON_IR_SRC_PATH ${PYTHON_SRC_PATH}) + endif() + add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc + ${PYTHON_IR_SRC_PATH}/ir.cc + ${PYTHON_SRC_PATH}/passes.cc + ${PYTHON_SRC_PATH}/interpreter.cc + ${PYTHON_SRC_PATH}/llvm.cc) + endif() # Link triton with its dependencies target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) @@ -264,7 +456,7 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32) # Check if the platform is MacOS if(APPLE) - set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto") + set(PYTHON_LDFLAGS "-undefined dynamic_lookup") endif() target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS}) @@ -277,8 +469,11 @@ if(NOT TRITON_BUILD_PYTHON_MODULE) endif() add_subdirectory(third_party/f2reduce) -add_subdirectory(bin) -add_subdirectory(test) + +if(NOT FLAGTREE_BACKEND) + add_subdirectory(bin) + add_subdirectory(test) +endif() if(TRITON_BUILD_UT) add_subdirectory(unittest) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fa7d1b057..d23eee318 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,68 +1,44 @@ -# Governance Structure +[中文版](./CONTRIBUTING_cn.md) -Triton adopts the following hierarchical technical governance structure: -* A community of **contributors** who file issues and submit pull requests -* A group of **module maintainers** who own parts of Triton and drive their development -* A body of **core maintainers** who own Triton overall and drive its development -* A **lead core maintainer** who is the catch-all decision maker when consensus cannot be reached by core maintainers +# FlagTree Contributor Guide -All contributions are expected to follow Triton’s design principles, as enforced by module and core maintainers. While high-quality pull requests are appreciated and encouraged, all maintainers reserve the right to prioritize their own work over code reviews at-will, hence contributors should not expect their work to be reviewed promptly. +Thank you for your interest in FlagTree! We use GitHub to host code, manage issues, and handle pull requests. Before contributing, please read the following guidelines. -Contributors can maximize the chances of their work being accepted by maintainers by meeting a high quality bar before sending a PR to maintainers. We encourage maintainers who contribute to Triton on behalf of a company to get reviews from senior developers within their company before sending to maintainers. -Module maintainers -We aim to make the Triton codebase as modular as possible, such that different components (e.g., subdirectories) can be improved in parallel under the supervision of different module maintainers. +## Bug Reports -What constitutes (or not) a module is up to the core maintainers. Core maintainers also reserve the right to decide whether the development of a module should happen – or keep happening – in-tree or not. +Please use GitHub Issues to report bugs. When reporting a bug, include: +- A concise summary +- Steps to reproduce +- Specific and accurate descriptions +- Example code if possible (this is particularly helpful) -**List of in-tree modules (as of 05/12/2024, alphabetical order):** -* AMD backend (Lei Zhang) -* Interpreter (Keren Zhou) -* Profiler (Keren Zhou) +## Code Contributions -Note: Parts of Triton that are not listed above (e.g., Nvidia backend) are assumed to be owned by core maintainers. +When submitting a pull request, contributors should describe the changes made and the rationale behind them. If possible, provide corresponding tests. Pull requests require approval from __ONE__ team member before merging and must pass all continuous integration checks. -Note: Some important parts of the Triton eco-system (e.g., Intel XPU backend) may be maintained out-of-tree and advertised in our repository. The governance rules described in this document do not carry over to these modules. +### Code Formatting -__List of out-of-tree modules (as of 05/12/2024, alphabetical order):__ -* CPU backend (Bert Maher, Ilya Enkovich) -* Intel backend (Ettore Tiotto, Whitney Tsang) +We use pre-commit for code formatting checks: +```shell +python3 -m pip install pre-commit +cd ${YOUR_CODE_DIR}/flagtree +pre-commit install +pre-commit +``` -## Core maintainers -The core maintainers drive the development of Triton at large and set the roadmap for the project. As such, they have the following responsibilities: -* Proposing, implementing and reviewing profound changes to user-facing APIs, IR specifications and/or pass infrastructures -* Enforcing code quality standards and adherence to core design principles -* Drawing module boundaries and resolving disputes between module maintainers +### Unit Tests +After installation, you can run unit tests in the backend directory: +```shell +cd third_party/backendxxx/python/test/unit +python3 -m pytest -s +``` -The core maintainers as a group have the power to veto any decision made at a Module maintainer level. +### Backend Integration -The core maintainers should publicly articulate their decision-making, and share the reasoning behind their decisions, vetoes, and dispute resolution. +Please contact the core development team for backend integration matters. -__List of core maintainers (as of 05/12/2024, alphabetical order):__ -* Justin Lebar -* Keren Zhou -* Pawel Szczerbuk -* Phil Tillet -* Thomas Raoux -* Zahi Moudallal +## License -## Lead core maintainer -When core maintainers cannot come to a consensus, a publicly declared lead maintainer is expected to settle the debate and make executive decisions. - -The Lead Core Maintainer should publicly articulate their decision-making, and give a clear reasoning for their decisions. - -The Lead Core Maintainer is also responsible for confirming or removing core maintainers. - -**Lead maintainer (as of 05/12/2024)** -* Phil Tillet - -# Decision Making - -## Uncontroversial Changes - -We are committed to accepting functional bug fixes that meet our quality standards – and include minimized unit tests to avoid future regressions. Performance improvements generally fall under the same category, with the caveat that they may be rejected if the trade-off between usefulness and complexity is deemed unfavorable by core maintainers (e.g., complex swizzling logic to improve the performance of non-tensor-cores matrix multiplications). Design changes that neither fix known functional nor performance issues are automatically considered controversial. - -## Controversial Changes - -More controversial design changes (e.g., changes in our IRs/APIs/Passes) are evaluated on a case-by-case basis under the subjective judgment of core maintainers. While it is possible for contributors to propose and land deep design changes upstream (see https://github.com/triton-lang/triton/pull/1305), the community should expect such occurrences to be relatively rare. +FlagTree is licensed under the [MIT license](/LICENSE). diff --git a/CONTRIBUTING_cn.md b/CONTRIBUTING_cn.md new file mode 100644 index 000000000..ab3a5e5e9 --- /dev/null +++ b/CONTRIBUTING_cn.md @@ -0,0 +1,44 @@ +[English](./CONTRIBUTING.md) + +# FlagTree 贡献者指南 + +感谢您对 FlagTree 的兴趣!我们使用 GitHub 来托管代码、管理问题和处理拉取请求。在贡献之前,请阅读以下指南。 + +## 错误报告 + +请使用 GitHub 的 Issues 来报告错误。在报告错误时,请提供: +- 简单摘要 +- 复现步骤 +- 确保描述具体且准确 +- 如果可以提供一些示例代码将会很有帮助 + +## 代码贡献 + +在提交拉取请求时,贡献者应描述所做的更改以及更改的原因。如果可以设计测试用例,请提供相应测试。拉取请求在合并前需要 __一位__ 成员的批准,而且需要通过代码的持续集成检查。 + +### 代码格式检查 + +代码格式检查使用 pre-commit。 + +```shell +python3 -m pip install pre-commit +cd ${YOUR_CODE_DIR}/flagtree +pre-commit install +pre-commit +``` + +### 单元测试 + +安装完成后可以在后端目录下运行单元测试: +```shell +cd third_party/backendxxx/python/test/unit +python3 -m pytest -s +``` + +### 后端接入 + +请联系核心开发团队。 + +## 证书 + +FlagTree 使用 [MIT license](/LICENSE)。 diff --git a/LICENSE b/LICENSE index 1d0238e86..08627406b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,7 @@ /* * Copyright 2018-2020 Philippe Tillet * Copyright 2020-2022 OpenAI +* Copyright 2025- FlagTree Project Management Committee * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files diff --git a/README.md b/README.md index 175563849..dec1e0902 100644 --- a/README.md +++ b/README.md @@ -1,224 +1,107 @@ -
- Triton logo -
+[中文版](./README_cn.md) -We're hiring! If you are interested in working on Triton at OpenAI, we have roles open for [Compiler Engineers](https://openai.com/careers/software-engineer-triton-compiler) and [Kernel Engineers](https://openai.com/careers/kernel-engineer). +## FlagTree -| **`Documentation`** | **`Nightly Wheels`** | -|-------------------- | -------------------- | -| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) | +Flagtree is a multi-backend Triton compiler project dedicated to developing a diverse ecosystem of AI chip compilers and related tooling platforms, thereby fostering and strengthening the upstream and downstream Triton ecosystem. Currently in its initial phase, the project aims to maintain compatibility with existing adaptation solutions while unifying the codebase to rapidly implement single-version multi-backend support. - -# Triton - -This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs. - -The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing this work if you use Triton! - -The [official documentation](https://triton-lang.org) contains installation instructions and tutorials. See also these third-party [Triton puzzles](https://github.com/srush/Triton-Puzzles), which can all be run using the Triton interpreter -- no GPU required. - -# Quick Installation - -You can install the latest stable release of Triton from pip: - -```bash -pip install triton +## Install from source +Installation dependencies (ensure you use the correct python3.x version): +```shell +apt install zlib1g zlib1g-dev libxml2 libxml2-dev # ubuntu +cd python; python3 -m pip install -r requirements.txt ``` -Binary wheels are available for CPython 3.8-3.12 and PyPy 3.8-3.9. -And the latest nightly release: - -```bash -pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly +Compile and install. Currently supported backends (backendxxx) include iluvatar, xpu, mthreads, and cambricon (limited support): +```shell +cd python +export FLAGTREE_BACKEND=backendxxx +python3 -m pip install . --no-build-isolation -v ``` -# Install from source +## Tips for building -``` -git clone https://github.com/triton-lang/triton.git; -cd triton; +Automatic dependency library downloads may be limited by network conditions. You can manually download to the cache directory ~/.flagtree (modifiable via the FLAGTREE_CACHE_DIR environment variable). No need to manually set LLVM environment variables such as LLVM_BUILD_DIR. +Complete build commands for each backend: -pip install ninja cmake wheel; # build-time dependencies -pip install -e python +[iluvatar](/third_party/iluvatar/) +```shell +# Recommended: Use Ubuntu 20.04 +mkdir -p ~/.flagtree/iluvatar; cd ~/.flagtree/iluvatar +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatar-llvm18-x86_64.tar.gz +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatarTritonPlugin-cpython3.10-glibc2.30-glibcxx3.4.28-cxxabi1.3.12-ubuntu-x86_64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=iluvatar +python3 -m pip install . --no-build-isolation -v ``` - -Or with a virtualenv: - +[xpu (klx)](/third_party/xpu/) +```shell +# Recommended: Use the Docker image (22GB) https://su.bcebos.com/klx-sdk-release-public/xpytorch/docker/ubuntu2004_v030/ubuntu_2004_x86_64_v30.tar +# Contact kunlunxin-support@baidu.com for support +mkdir -p ~/.flagtree/xpu; cd ~/.flagtree/xpu +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/XTDK-llvm19-ubuntu2004_x86_64.tar.gz +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/xre-Linux-x86_64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=xpu +python3 -m pip install . --no-build-isolation -v ``` -git clone https://github.com/triton-lang/triton.git; -cd triton; - -python -m venv .venv --prompt triton; -source .venv/bin/activate; - -pip install ninja cmake wheel; # build-time dependencies -pip install -e python +[mthreads](https://github.com/FlagTree/flagtree/tree/main/third_party/mthreads/) +```shell +# Recommended: Use the Dockerfile flagtree/dockerfiles/Dockerfile-ubuntu22.04-python3.10-mthreads +mkdir -p ~/.flagtree/mthreads; cd ~/.flagtree/mthreads +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=mthreads +python3 -m pip install . --no-build-isolation -v ``` - -# Building with a custom LLVM - -Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build -downloads a prebuilt LLVM, but you can also build LLVM from source and use that. - -LLVM does not have a stable API, so the Triton build will not work at an -arbitrary LLVM version. - -1. Find the version of LLVM that Triton builds against. Check -`cmake/llvm-hash.txt` to see the current version. For example, if it says: - 49af6502c6dcb4a7f7520178bd14df396f78240c - - This means that the version of Triton you have builds against - [LLVM](https://github.com/llvm/llvm-project) 49af6502. - -2. `git checkout` LLVM at this revision. Optionally, make additional - modifications to LLVM. - -3. [Build LLVM](https://llvm.org/docs/CMake.html). For example, you might run - - $ cd $HOME/llvm-project # your clone of LLVM. - $ mkdir build - $ cd build - $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" - $ ninja - -4. Grab a snack, this will take a while. - -5. Build Triton as above, but set the following environment variables. - - # Modify as appropriate to point to your LLVM build. - $ export LLVM_BUILD_DIR=$HOME/llvm-project/build - - $ cd - $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ - LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ - LLVM_SYSPATH=$LLVM_BUILD_DIR \ - pip install -e python - -# Tips for building - -- Set `TRITON_BUILD_WITH_CLANG_LLD=true` as an environment variable to use clang - and lld. lld in particular results in faster builds. - -- Set `TRITON_BUILD_WITH_CCACHE=true` to build with ccache. - -- Pass `--no-build-isolation` to `pip install` to make nop builds faster. - Without this, every invocation of `pip install` uses a different symlink to - cmake, and this forces ninja to rebuild most of the `.a` files. - -- vscode intellisense has some difficulty figuring out how to build Triton's C++ - (probably because, in our build, users don't invoke cmake directly, but - instead use setup.py). Teach vscode how to compile Triton as follows. - - - Do a local build. - - Get the full path to the `compile_commands.json` file produced by the build: - `find python/build -name 'compile_commands.json | xargs readlink -f'` - - In vscode, install the - [C/C++ - extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.cpptools), - then open the command palette (`Shift + Command + P` on Mac, or `Shift + - Ctrl + P` on Windows/Linux) and open `C/C++: Edit Configurations (UI)`. - - Open "Advanced Settings" and paste the full path to - `compile_commands.json` into the "Compile Commands" textbox. - -# Running tests - -There currently isn't a turnkey way to run all the Triton tests, but you can -follow the following recipe. - +[ascend](https://github.com/FlagTree/flagtree/tree/triton_v3.2.x/third_party/ascend/) ```shell -# One-time setup. Note we have to reinstall local Triton because torch -# overwrites it with the public version. -$ pip install scipy numpy torch pytest lit pandas matplotlib && pip install -e python - -# Run Python tests using your local GPU. -$ python3 -m pytest python/test/unit - -# Move to builddir. Fill in <...> with the full path, e.g. -# `cmake.linux-x86_64-cpython-3.11`. -$ cd python/build/cmake<...> - -# Run C++ unit tests. -$ ninja test - -# Run lit tests. -$ lit test +# Recommended: Use the Dockerfile flagtree/dockerfiles/Dockerfile-ubuntu20.04-python3.9-ascend +# After registering an account at https://www.hiascend.com/developer/download/community/result?module=cann, +# download the cann-toolkit and cann-kernels for the corresponding platform. +# Here we use the A3 processor with AArch64 architecture as an example to demonstrate how to install. +chmod +x Ascend-cann-toolkit_8.2.RC1.alpha002_linux-aarch64.run +./Ascend-cann-toolkit_8.2.RC1.alpha002_linux-aarch64.run --install +chmod +x Atlas-A3-cann-kernels_8.1.RC1_linux-aarch64.run +./Atlas-A3-cann-kernels_8.1.RC1_linux-aarch64.run --install +# build +mkdir -p ~/.flagtree/ascend; cd ~/.flagtree/ascend +wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +git checkout -b triton_v3.2.x origin/triton_v3.2.x +export FLAGTREE_BACKEND=ascend +python3 -m pip install . --no-build-isolation -v ``` -You may find it helpful to make a symlink to the builddir and tell your local -git to ignore it. - -``` -$ ln -s python/build/cmake<...> build -$ echo build >> .git/info/exclude +To build with default backends (nvidia, amd, triton_shared): +```shell +# manually download LLVM +cd ${YOUR_LLVM_DOWNLOAD_DIR} +wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-10dc3a8e-ubuntu-x64.tar.gz +tar -zxvf llvm-10dc3a8e-ubuntu-x64.tar.gz +# build +cd ${YOUR_CODE_DIR}/flagtree/python +export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-10dc3a8e-ubuntu-x64 +export LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include +export LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib +export LLVM_SYSPATH=$LLVM_BUILD_DIR +unset FLAGTREE_BACKEND +python3 -m pip install . --no-build-isolation -v +# If you need to build other backends afterward, you should clear LLVM-related environment variables +unset LLVM_BUILD_DIR LLVM_INCLUDE_DIRS LLVM_LIBRARY_DIR LLVM_SYSPATH ``` -Then you can e.g. rebuild and run lit with the following command. +## Running tests +After installation, you can run tests in the backend directory: +```shell +cd third_party/backendxxx/python/test +python3 -m pytest -s ``` -$ ninja -C build && ( cd build ; lit test ) -``` - -# Tips for hacking - -For detailed instructions on how to debug Triton's frontend, please refer to this [tutorial](https://triton-lang.org/main/programming-guide/chapter-3/debugging.html). The following includes additional tips for hacking on Triton's backend. - -**Helpful environment variables** - -- `MLIR_ENABLE_DUMP=1` dumps the IR before every MLIR pass Triton runs. -- `LLVM_IR_ENABLE_DUMP=1` dumps the IR before every pass run over the LLVM IR. -- `TRITON_INTERPRET=1` uses the Triton interpreter instead of running on the - GPU. You can insert Python breakpoints in your kernel code! -- `TRITON_ENABLE_LLVM_DEBUG=1` passes `-debug` to LLVM, printing a lot of - debugging information to stdout. If this is too noisy, run with just - `TRITON_LLVM_DEBUG_ONLY` instead to limit the output. - - An alternative way to reduce output noisiness is running with - `LLVM_IR_ENABLE_DUMP=1`, extract the IR before the LLVM pass of interest, and - then run LLVM's `opt` standalone, perhaps passing `-debug-only=foo` on the - command line. -- `TRITON_LLVM_DEBUG_ONLY=` is the equivalent of LLVM's - `-debug-only` command-line option. This limits the LLVM debug output to - specific pass or component names (which are specified using `#define - DEBUG_TYPE` throughout LLVM and Triton) in order to allow the debug output to - be less noisy. `TRITON_LLVM_DEBUG_ONLY` allows for one or more comma - separated values to be specified (eg - `TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions` or - `TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc"`). -- `USE_TTGIR_LOC=1` reparses the ttgir such that the location information will - be the line number of the ttgir instead of line number of the python file. - This can provide a direct mapping from ttgir to llir/ptx. When used with - performance tools, it can provide a breakdown on ttgir instructions. -- `TRITON_PRINT_AUTOTUNING=1` prints out the best autotuning config and total time - spent for each kernel after autotuning is complete. -- `DISABLE_LLVM_OPT` will disable llvm optimizations for make_llir and make_ptx - if its value is true when parsing as Bool. Otherwise, it will be parsed as a list - of flags to disable llvm optimizations. One usage case is - `DISABLE_LLVM_OPT="disable-lsr"` - Loop strength reduction is known to cause up to 10% performance changes for - certain kernels with register pressure. -- `TRITON_ALWAYS_COMPILE=1` forces to compile kernels regardless of cache hit. -- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass. -- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass. - -# Changelog - -Version 2.0 is out! New features include: -- Many, many bug fixes -- Performance improvements -- Backend rewritten to use MLIR -- Support for kernels that contain back-to-back matmuls (e.g., flash attention) - -# Contributing - -Community contributions are more than welcome, whether it be to fix bugs or to add new features at [github](https://github.com/triton-lang/triton/). For more detailed instructions, please visit our [contributor's guide](CONTRIBUTING.md). +## Contributing -# Compatibility +Contributions to FlagTree development are welcome. Please refer to [CONTRIBUTING.md](/CONTRIBUTING_cn.md) for details. -Supported Platforms: - * Linux +## License -Supported Hardware: - * NVIDIA GPUs (Compute Capability 7.0+) - * AMD GPUs (ROCm 5.2+) - * Under development: CPUs +FlagTree is licensed under the [MIT license](/LICENSE). diff --git a/README_cn.md b/README_cn.md new file mode 100644 index 000000000..86f0e568f --- /dev/null +++ b/README_cn.md @@ -0,0 +1,105 @@ +[English](./README.md) + +## FlagTree + +FlagTree 是多后端的 Triton 编译器项目。FlagTree 致力于打造多元 AI 芯片编译器及相关工具平台,发展和壮大 Triton 上下游生态。项目当前处于初期,目标是兼容现有适配方案,统一代码仓库,快速实现单版本多后端支持。 + +## 从源代码安装 +安装依赖(注意使用正确的 python3.x 执行): +```shell +apt install zlib1g zlib1g-dev libxml2 libxml2-dev # ubuntu +cd python; python3 -m pip install -r requirements.txt +``` + +编译安装,目前支持的后端 backendxxx 包括 iluvatar、xpu、mthreads、cambricon(有限支持): +```shell +cd python +export FLAGTREE_BACKEND=backendxxx +python3 -m pip install . --no-build-isolation -v +``` + +## 构建技巧 + +自动下载依赖库的速度可能受限于网络环境,编译前可自行下载至缓存目录 ~/.flagtree(可通过环境变量 FLAGTREE_CACHE_DIR 修改),无需自行设置 LLVM_BUILD_DIR 等环境变量。 +各后端完整编译命令如下: + +[iluvatar](/third_party/iluvatar/) +```shell +# 推荐使用镜像 Ubuntu 20.04 +mkdir -p ~/.flagtree/iluvatar; cd ~/.flagtree/iluvatar +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatar-llvm18-x86_64.tar.gz +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatarTritonPlugin-cpython3.10-glibc2.30-glibcxx3.4.28-cxxabi1.3.12-ubuntu-x86_64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=iluvatar +python3 -m pip install . --no-build-isolation -v +``` +[xpu (klx)](/third_party/xpu/) +```shell +# 推荐使用镜像(22GB)https://su.bcebos.com/klx-sdk-release-public/xpytorch/docker/ubuntu2004_v030/ubuntu_2004_x86_64_v30.tar +# 联系 kunlunxin-support@baidu.com 可获取进一步支持 +mkdir -p ~/.flagtree/xpu; cd ~/.flagtree/xpu +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/XTDK-llvm19-ubuntu2004_x86_64.tar.gz +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/xre-Linux-x86_64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=xpu +python3 -m pip install . --no-build-isolation -v +``` +[mthreads](https://github.com/FlagTree/flagtree/tree/main/third_party/mthreads/) +```shell +# 推荐使用镜像 flagtree/dockerfiles/Dockerfile-ubuntu22.04-python3.10-mthreads +mkdir -p ~/.flagtree/mthreads; cd ~/.flagtree/mthreads +wget https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=mthreads +python3 -m pip install . --no-build-isolation -v +``` +[ascend](https://github.com/FlagTree/flagtree/tree/triton_v3.2.x/third_party/ascend/) +```shell +# 推荐使用镜像 flagtree/dockerfiles/Dockerfile-ubuntu20.04-python3.9-ascend +# 在 https://www.hiascend.com/developer/download/community/result?module=cann +# 注册账号后下载对应平台的 cann-toolkit、cann-kernels,这里以 AArch64 架构的 A3 处理器为例展示如何安装 +chmod +x Ascend-cann-toolkit_8.2.RC1.alpha002_linux-aarch64.run +./Ascend-cann-toolkit_8.2.RC1.alpha002_linux-aarch64.run --install +chmod +x Atlas-A3-cann-kernels_8.1.RC1_linux-aarch64.run +./Atlas-A3-cann-kernels_8.1.RC1_linux-aarch64.run --install +# 编译安装 +mkdir -p ~/.flagtree/ascend; cd ~/.flagtree/ascend +wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz +cd ${YOUR_CODE_DIR}/flagtree/python +export FLAGTREE_BACKEND=ascend +python3 -m pip install . --no-build-isolation -v +``` + +使用默认的编译命令,可以编译安装 nvidia、amd、triton_shared 后端: +```shell +# 自行下载 llvm +cd ${YOUR_LLVM_DOWNLOAD_DIR} +wget https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-10dc3a8e-ubuntu-x64.tar.gz +tar -zxvf llvm-10dc3a8e-ubuntu-x64.tar.gz +# 编译安装 +cd ${YOUR_CODE_DIR}/flagtree/python +export LLVM_BUILD_DIR=${YOUR_LLVM_DOWNLOAD_DIR}/llvm-10dc3a8e-ubuntu-x64 +export LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include +export LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib +export LLVM_SYSPATH=$LLVM_BUILD_DIR +unset FLAGTREE_BACKEND +python3 -m pip install . --no-build-isolation -v +# 如果接下来需要编译安装其他后端,应清空 LLVM 相关环境变量 +unset LLVM_BUILD_DIR LLVM_INCLUDE_DIRS LLVM_LIBRARY_DIR LLVM_SYSPATH +``` + +## 运行测试 + +安装完成后可以在后端目录下运行测试: +```shell +cd third_party/backendxxx/python/test +python3 -m pytest -s +``` + +## 关于贡献 + +欢迎参与 FlagTree 的开发并贡献代码,详情请参考[CONTRIBUTING.md](/CONTRIBUTING_cn.md)。 + +## 许可证 + +FlagTree 使用 [MIT license](/LICENSE)。 diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 1c8f45448..fa84e9fd6 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -50,7 +50,6 @@ target_link_libraries(triton-reduce PRIVATE mlir_check_all_link_libraries(triton-reduce) add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED) -mlir_check_all_link_libraries(triton-lsp) llvm_update_compile_flags(triton-lsp) target_link_libraries(triton-lsp PRIVATE @@ -90,3 +89,14 @@ target_link_libraries(triton-llvm-opt PRIVATE LLVMCodeGen ) export_executable_symbols_for_plugins(triton-llvm-opt) + + +add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED) +target_link_libraries(triton-tensor-layout PRIVATE + TritonGPUIR + TritonNvidiaGPUIR + ${triton_libs} + ${conversion_libs} + ${dialect_libs} + TritonTestAnalysis + ) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 1bd1db949..7b0f0051a 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,20 +1,35 @@ #pragma once + +#ifdef __AMD__ +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "amd/include/TritonAMDGPUTransforms/Passes.h" +#endif +#ifdef __NVIDIA__ #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#endif #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#ifdef __NVIDIA__ #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#endif // Below headers will allow registration to ROCm passes +#ifdef __AMD__ #include "TritonAMDGPUToLLVM/Passes.h" #include "TritonAMDGPUTransforms/Passes.h" #include "TritonAMDGPUTransforms/TritonGPUConversion.h" +#endif #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#ifdef __NVIDIA__ #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#endif +#ifdef __NVIDIA__ #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" +#endif #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" #include "triton/Target/LLVMIR/Passes.h" @@ -23,6 +38,10 @@ #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/InitAllPasses.h" +#include +#include +#include + namespace mlir { namespace test { void registerTestAliasPass(); @@ -36,35 +55,51 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerAllPasses(); mlir::registerTritonPasses(); mlir::triton::gpu::registerTritonGPUPasses(); +#ifdef __NVIDIA__ mlir::registerTritonNvidiaGPUPasses(); mlir::test::registerTestAliasPass(); mlir::test::registerTestAlignmentPass(); mlir::test::registerTestAllocationPass(); mlir::test::registerTestMembarPass(); +#endif mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerAllocateSharedMemoryPass(); +#ifdef __NVIDIA__ mlir::triton::registerConvertTritonGPUToLLVMPass(); mlir::triton::registerConvertNVGPUToLLVMPass(); mlir::triton::registerDecomposeUnsupportedNVIDIAConversions(); +#endif mlir::registerLLVMDIScope(); +#ifdef __AMD__ // TritonAMDGPUToLLVM passes mlir::triton::registerConvertTritonAMDGPUToLLVM(); mlir::triton::registerConvertBuiltinFuncToLLVM(); mlir::triton::registerDecomposeUnsupportedAMDConversions(); + mlir::triton::registerOptimizeAMDLDSUsage(); // TritonAMDGPUTransforms passes mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); mlir::registerTritonAMDGPUReorderInstructions(); - mlir::registerTritonAMDGPUStreamPipeline(); + mlir::registerTritonAMDGPUStreamPipelineV2(); + mlir::registerTritonAMDGPUCanonicalizePointers(); + mlir::registerTritonAMDGPUConvertToBufferOps(); +#endif // TODO: register Triton & TritonGPU passes - registry.insert(); + registry.insert< + mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect, +#ifdef __NVIDIA__ + mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, +#endif + mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect, + mlir::arith::ArithDialect, mlir::scf::SCFDialect, mlir::gpu::GPUDialect, +#ifdef __NVIDIA__ + mlir::triton::nvgpu::NVGPUDialect, +#endif +#ifdef __AMD__ + mlir::triton::amdgpu::TritonAMDGPUDialect, mlir::ROCDL::ROCDLDialect, +#endif + mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect>(); } diff --git a/bin/triton-lsp.cpp b/bin/triton-lsp.cpp index b185b0374..f95036dc6 100644 --- a/bin/triton-lsp.cpp +++ b/bin/triton-lsp.cpp @@ -6,6 +6,5 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; registerTritonDialects(registry); - mlir::MLIRContext context(registry); return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); } diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp new file mode 100644 index 000000000..4087ac135 --- /dev/null +++ b/bin/triton-tensor-layout.cpp @@ -0,0 +1,231 @@ +#include "RegisterTritonDialects.h" + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/AsmParser/AsmParserState.h" +#include "mlir/IR/MLIRContext.h" + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace mlir; + +// A CLI tool to print the layout of a tensor. +// +// clang-format off +// Example usage: +// +// triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt +// +// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view +// +// An input file usually looks like: +// ''' +// #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +// #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +// ''' +// clang-format on + +//===--------------------------------------------------------------------===// +// CLI options +//===--------------------------------------------------------------------===// + +cl::OptionCategory PrinterCategory("Available Print Options", + "Options for the tensor layout printing."); + +static cl::opt InputFile( + "i", cl::desc("File that contains the tensor data layout attributes"), + cl::init(""), cl::value_desc("filename"), cl::cat(PrinterCategory)); + +static cl::opt + OutputFile("o", cl::desc("Output file to write the layout into"), + cl::init(""), cl::value_desc("filename"), + cl::cat(PrinterCategory)); + +static cl::opt + DataLayoutStr("l", cl::desc("Tensor data layout attribute in string"), + cl::value_desc("layout-string"), cl::init(""), + cl::cat(PrinterCategory)); + +static cl::list + AliasName("alias-names", + cl::desc("A list of alias names (separated by comma) of the " + "layout attributes in the input file"), + cl::value_desc("name1,name2,name3,..."), cl::CommaSeparated, + cl::ZeroOrMore, cl::cat(PrinterCategory)); + +static cl::opt UseHWPointOfView( + "use-hw-view", + llvm::cl::desc( + "Print the layout in hardware point of view. This means the output is " + "from the warp's perspective. Otherwise, the output is from the " + "tensor's perspective (e.g., each element maps to xxx thread)."), + cl::init(false), cl::cat(PrinterCategory)); + +static cl::opt TensorStr( + "t", cl::desc("Tensor shape and element type (e.g., tensor<2x2xf32>)"), + cl::init(""), cl::value_desc("tensor-type"), cl::cat(PrinterCategory)); + +//===--------------------------------------------------------------------===// +// Helper functions +//===--------------------------------------------------------------------===// + +LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { + StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace(); + + // Dispatch to the corresponding dialect helper function to print the layout. + if (dialectName == "triton_gpu") { + os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + return success(); + } + + llvm::errs() << "Unsupported tensor layout attribute: " + << tensorType.getEncoding() << "\n"; + return failure(); +} + +LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, + ArrayRef names, + TensorType tensorTy, raw_string_ostream &ss) { + if (filename.empty()) + return success(); + + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return failure(); + } + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + ParserConfig config(context); + auto asmState = AsmParserState(); + + Block parsedIR; + if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) { + llvm::errs() << "Fail to parse the input file: " << filename << "\n"; + return failure(); + } + + auto printLambda = [&](StringRef name, mlir::Attribute attr) { + ss << "Print layout attribute: #" << name << " = " << attr << "\n"; + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), attr); + + return layoutPrint(rankedTensorTy, ss); + }; + + if (names.empty()) + // If no alias name is given, we print all layout attributes in the file. + for (const auto &def : asmState.getAttributeAliasDefs()) { + if (failed(printLambda(def.name, def.value))) + return failure(); + } + else { + // Print the layout attributes with the given alias names. + for (const auto &alias : names) { + auto def = asmState.getAttributeAliasDef(alias); + if (!def) { + llvm::errs() << "Can't find the layout attribute: " << alias << "\n"; + return failure(); + } + + if (failed(printLambda(alias, def->value))) + return failure(); + + ss << "\n"; + } + } + + return success(); +} + +LogicalResult printLayoutFromString(MLIRContext *context, + StringRef layoutAttrStr, + TensorType tensorTy, + raw_string_ostream &ss) { + if (layoutAttrStr.empty()) + return success(); + + mlir::Attribute layout = parseAttribute(layoutAttrStr, context); + if (!layout) { + llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n"; + return failure(); + } + + auto rankedTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), layout); + + ss << "Print layout attribute: " << layout << "\n"; + + return layoutPrint(rankedTensorTy, ss); +} + +//===--------------------------------------------------------------------===// +// Main entry point +//===--------------------------------------------------------------------===// + +int main(int argc, char **argv) { + cl::HideUnrelatedOptions(PrinterCategory); + cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n"); + + DialectRegistry registry; + registerTritonDialects(registry); + + MLIRContext ctx(registry); + ctx.loadAllAvailableDialects(); + + if (TensorStr.empty()) { + llvm::errs() << "Must specify the tensor type argument\n"; + return 1; + } + + mlir::Type parsedTy = parseType(TensorStr, &ctx); + if (!parsedTy) { + llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr + << "\n"; + return 1; + } + + TensorType tensorType = dyn_cast(parsedTy); + if (!tensorType) { + llvm::errs() << "Invalid tensor type argument: " << TensorStr << "\n"; + return 1; + } + + std::string storage; + raw_string_ostream ss(storage); + + if (failed(printLayoutFromFile(&ctx, InputFile, AliasName, tensorType, ss))) + return 1; + + if (failed(printLayoutFromString(&ctx, DataLayoutStr, tensorType, ss))) + return 1; + + if (OutputFile.empty()) { + llvm::outs() << ss.str(); + } else { + std::error_code ec; + llvm::raw_fd_ostream outFs(OutputFile, ec, llvm::sys::fs::OF_Text); + if (ec) { + llvm::errs() << "Error: " << ec.message() << " : unable to open " + << OutputFile << " for output\n"; + return 1; + } + outFs << ss.str(); + outFs.close(); + } + + return 0; +} diff --git a/cmake/AddTritonUnitTest.cmake b/cmake/AddTritonUnitTest.cmake new file mode 100644 index 000000000..24fb20a72 --- /dev/null +++ b/cmake/AddTritonUnitTest.cmake @@ -0,0 +1,39 @@ +include(${PROJECT_SOURCE_DIR}/unittest/googletest.cmake) + +include(GoogleTest) +enable_testing() + +function(add_triton_ut) + set(options) + set(oneValueArgs NAME) + set(multiValueArgs SRCS LIBS DEFS) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + + add_test(NAME ${__NAME} + COMMAND ${__NAME}) + add_executable( + ${__NAME} + ${__SRCS}) + target_link_libraries( + ${__NAME} + PRIVATE + GTest::gtest_main + ${triton_libs} + ${dialect_libs} + ${conversion_libs} + gmock + ${__LIBS}) + + target_compile_options(${__NAME} PRIVATE -fno-rtti) + + target_compile_definitions(${__NAME} PRIVATE ${__DEFS}) + + # Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac + # laptop. I think the issue may be that the very first time you run a program + # it's a bit slow. + gtest_discover_tests(${__NAME} PROPERTIES TEST_DISCOVERY_TIMEOUT 60) +endfunction() diff --git a/cmake/json-version.txt b/cmake/json-version.txt new file mode 100644 index 000000000..c294f65bf --- /dev/null +++ b/cmake/json-version.txt @@ -0,0 +1 @@ +v3.11.3 diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index f2895130b..50d024794 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -10dc3a8e916d73291269e5e2b82dd22681489aa1 +86b69c31642e98f8357df62c09d118ad1da4e16a diff --git a/cmake/nvidia-toolchain-version.json b/cmake/nvidia-toolchain-version.json new file mode 100644 index 000000000..baefe3fc8 --- /dev/null +++ b/cmake/nvidia-toolchain-version.json @@ -0,0 +1,8 @@ +{ + "ptxas": "12.4.99", + "cuobjdump": "12.4.99", + "nvdisasm": "12.4.99", + "cudacrt": "12.4.99", + "cudart": "12.4.99", + "cupti": "12.4.99" +} diff --git a/cmake/nvidia-toolchain-version.txt b/cmake/nvidia-toolchain-version.txt deleted file mode 100644 index 0b172fed4..000000000 --- a/cmake/nvidia-toolchain-version.txt +++ /dev/null @@ -1 +0,0 @@ -12.4.99 diff --git a/cmake/pybind11-version.txt b/cmake/pybind11-version.txt deleted file mode 100644 index 6ceb272ee..000000000 --- a/cmake/pybind11-version.txt +++ /dev/null @@ -1 +0,0 @@ -2.11.1 diff --git a/dockerfiles/Dockerfile-ubuntu20.04-python3.9-ascend b/dockerfiles/Dockerfile-ubuntu20.04-python3.9-ascend new file mode 100644 index 000000000..0fde75ffc --- /dev/null +++ b/dockerfiles/Dockerfile-ubuntu20.04-python3.9-ascend @@ -0,0 +1,31 @@ +FROM swr.cn-south-1.myhuaweicloud.com/ascendhub/ascend-pytorch:24.0.0-A1-2.1.0-ubuntu20.04 + +RUN apt-get update && \ + apt-get install zip unzip git vim zstd libzstd-dev && \ + apt-get install zlib1g zlib1g-dev libxml2 libxml2-dev && \ + apt-get install clang lld + +RUN pip3 install -U pip && \ + pip3 install numpy && \ + pip3 install decorator && \ + pip3 install sympy==1.4 && \ + pip3 install cffi==1.12.3 && \ + pip3 install pyyaml && \ + pip3 install pathlib2 && \ + pip3 install protobuf attrs attr && \ + pip3 install scipy && \ + pip3 install requests psutil absl-py && \ + pip3 install ninja cmake wheel pybind11 && \ + pip3 install setuptools==75.1.0 && \ + pip3 install attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 && \ + pip3 install pytest==8.3.2 pytest-xdist==3.6.1 pyyaml torch==2.3.1 torchvision==0.18.1 torch-npu==2.3.1.post2 && \ + pip3 install scikit-build==0.18.1 scikit_build_core==0.11.1 && \ + pip3 install pre-commit torch_npu==2.6.0rc1 && \ + rm -rf /root/.cache/pip + +ENV LD_LIBRARY_PATH=/usr/lib/aarch64-linux-gnu/hdf5/serial:$LD_LIBRARY_PATH + +RUN if [ ! -d "/lib64" ]; \ + then \ + mkdir /lib64 && ln -sf /lib/ld-linux-aarch64.so.1 /lib64/ld-linux-aarch64.so.1; \ + fi diff --git a/dockerfiles/Dockerfile-ubuntu22.04-python3.10-mthreads b/dockerfiles/Dockerfile-ubuntu22.04-python3.10-mthreads new file mode 100644 index 000000000..374d5240d --- /dev/null +++ b/dockerfiles/Dockerfile-ubuntu22.04-python3.10-mthreads @@ -0,0 +1,66 @@ +FROM ubuntu:22.04 + +ARG PYTHON_VER=3.10 +ARG PYTHON_NAME=py310 + +ARG WORKSPACE=/root/wksp +ARG CONDA_INSTALL_DIR=/root/miniconda3 +ARG MUSA_ROOT_PATH=/usr/local/musa + +ENV DEBIAN_FRONTEND=noninteractive +ENV TZ=Etc/UTC + +ENV MTHREADS_VISIBLE_DEVICES all +ENV MTHREADS_DRIVER_CAPABILITIES compute,utility +RUN groupadd -o -g 29 audio; exit 0 && \ + groupadd -o -g 109 render; exit 0 + +ENV MUSA_TOOLKIT_PATH=/usr/local/musa +ENV PATH=${MUSA_TOOLKIT_PATH}/bin:$PATH +ENV LD_LIBRARY_PATH=${MUSA_TOOLKIT_PATH}/lib:$LD_LIBRARY_PATH + +RUN sed -i 's/archive.ubuntu.com/mirrors.ustc.edu.cn/g' /etc/apt/sources.list && \ + apt clean && apt update -y --fix-missing && apt install -y --no-install-recommends \ + build-essential gnupg gnupg2 vim curl wget git g++ clang libclang-dev libelf-dev gcc-multilib llvm \ + ssh sudo clang-format clang-tidy libglib2.0-dev libtinfo-dev patch ccache \ + ripgrep libgtest-dev intel-mkl libnuma-dev ca-certificates openssl && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR ${WORKSPACE} + +RUN wget -q --no-check-certificate \ + https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \ + -O ${WORKSPACE}/miniconda.sh && \ + /bin/bash ${WORKSPACE}/miniconda.sh -b -p ${CONDA_INSTALL_DIR} && \ + rm ${WORKSPACE}/miniconda.sh && \ + ${CONDA_INSTALL_DIR}/bin/conda clean -a && \ + ln -s ${CONDA_INSTALL_DIR}/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ + echo ". ${CONDA_INSTALL_DIR}/etc/profile.d/conda.sh" >> ~/.bashrc && \ + echo "conda activate ${PYTHON_NAME}" >> ~/.bashrc +ENV PATH=${CONDA_INSTALL_DIR}/bin:$PATH + +RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple/ + +RUN rm /bin/sh && ln -s /bin/bash /bin/sh && \ + conda create -n ${PYTHON_NAME} python=${PYTHON_VER} -y && \ + conda info|egrep "conda version|active environment" + +RUN source activate ${PYTHON_NAME} && \ + pip install --no-cache-dir pyyaml setuptools typing_extensions pyahocorasick \ + future six black pre-commit pytest minio cmake cffi ninja pillow transformers openpyxl packaging \ + scipy==1.13.1 numpy==1.23.1 mkl==2023.0.0 mkl-include==2023.0.0 mkl-devel==2023.0.0 + +RUN ln -s /usr/lib/x86_64-linux-gnu/libmkl_intel_lp64.so /usr/lib/x86_64-linux-gnu/libmkl_intel_lp64.so.1 && \ + ln -s /usr/lib/x86_64-linux-gnu/libmkl_gnu_thread.so /usr/lib/x86_64-linux-gnu/libmkl_gnu_thread.so.1 && \ + ln -s /usr/lib/x86_64-linux-gnu/libmkl_core.so /usr/lib/x86_64-linux-gnu/libmkl_core.so.1 && \ + cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6.0.30 ${CONDA_INSTALL_DIR}/envs/${PYTHON_NAME}/lib/ && \ + ln -s -f ${CONDA_INSTALL_DIR}/envs/py310/lib/libstdc++.so.6.0.30 ${CONDA_INSTALL_DIR}/envs/${PYTHON_NAME}/lib/libstdc++.so && \ + ln -s -f ${CONDA_INSTALL_DIR}/envs/py310/lib/libstdc++.so.6.0.30 ${CONDA_INSTALL_DIR}/envs/${PYTHON_NAME}/lib/libstdc++.so.6 + +RUN rm -rf ${WORKSPACE}/* + +WORKDIR /root + +# COPY m3d-musa-toolkit-installer.sh ./ + +RUN source activate ${PYTHON_NAME} diff --git a/docs/conf.py b/docs/conf.py index 484240ab3..ffaab561a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -93,12 +93,15 @@ def setup_generated_mlir_docs(): def setup(app): """Customize function args retrieving to get args under decorator.""" - import os + import subprocess import sphinx app.connect("autodoc-process-signature", process_sig) - os.system("pip install -e ../python") + max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count())) + print(f"Installing Triton Python package using {max_jobs} threads") + subprocess.run("pip install -e ../python", shell=True, env=os.environ.copy()) + setup_generated_mlir_docs() def forward_jit_fn(func): @@ -142,7 +145,7 @@ def documenter(app, obj, parent): autosummary_generate = True # versioning config -smv_tag_whitelist = r'^(v3.0.0)$' +smv_tag_whitelist = r'^(v3.2.0)$' smv_branch_whitelist = r'^main$' smv_remote_whitelist = None smv_released_pattern = r'^tags/.*$' @@ -156,9 +159,6 @@ def documenter(app, obj, parent): 'examples_dirs': '../python/tutorials/', 'gallery_dirs': 'getting-started/tutorials', 'filename_pattern': '', - # TODO: Re-enable the grouped-gemm tutorial. It currently hits this - # assertion: - # https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp#L127 'ignore_pattern': r'(__init__\.py|11.*.py)', 'within_subsection_order': FileNameSortKey, 'reference_url': { diff --git a/docs/index.rst b/docs/index.rst index b249ecec7..e9cf1e79f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,6 +25,7 @@ Python API - :doc:`triton ` - :doc:`triton.language ` - :doc:`triton.testing ` +- :doc:`Triton semantics ` .. toctree:: @@ -35,6 +36,7 @@ Python API python-api/triton python-api/triton.language python-api/triton.testing + python-api/triton-semantics Triton MLIR Dialects and Ops diff --git a/docs/meetups/08-06-2024/notes.md b/docs/meetups/08-06-2024/notes.md new file mode 100644 index 000000000..48762c62a --- /dev/null +++ b/docs/meetups/08-06-2024/notes.md @@ -0,0 +1,13 @@ +#### Agenda: +1. Triton-CPU Update +2. Intel GPU backend update + +##### Items: +Meeting notes: +1. Triton-CPU Update: Intel and Meta jointly presented the work on Triton-CPU, highlighting good progress on coverage and performance improvements. They also covered some of the optimizations they leveraged to get performance comparable to torch-native and torch-inductor. More details are in their slides. +2. Intel GPU Backend: Intel GPU backend shows good performance close to expert-tuned kernels and the use of block pointers for performance gains. There were questions around the future of block pointers and their importance for performance gains. With block-pointer deprecation there is a need for a more generic interface to support various backends including Intel GPU. +3. The 2024 Triton conference is on September 17th 2024 in Fremont California! Please register [here](README.md). +##### Minutes: +Recording link [here](https://youtu.be/dfL3L4_3ujg) + +Presentations repo [here](https://drive.google.com/drive/folders/1fQ3zVrM7DT8W8FGJWKx1wNr2X53tYbeT?usp=sharing) diff --git a/docs/meetups/dev_conference_2024.md b/docs/meetups/dev_conference_2024.md new file mode 100644 index 000000000..6816b4c59 --- /dev/null +++ b/docs/meetups/dev_conference_2024.md @@ -0,0 +1,3 @@ +The conference slides are available [here](https://drive.google.com/drive/folders/1osK9hwcX_lC1EjdZGB-v4w5oKx23UnU2?usp=drive_link) + +The conference videos are available [here](https://www.youtube.com/playlist?list=PLc_vA1r0qoiTjlrINKUuFrI8Ptoopm8Vz). diff --git a/docs/programming-guide/chapter-3/debugging.rst b/docs/programming-guide/chapter-3/debugging.rst index 31e92d282..c470363c6 100644 --- a/docs/programming-guide/chapter-3/debugging.rst +++ b/docs/programming-guide/chapter-3/debugging.rst @@ -70,8 +70,6 @@ The interpreter has several known limitations: ptr = tl.load(ptr) x = tl.load(ptr) -- Unlike the compilation mode, a scalar in interpreter mode is treated as a simple float or integer but not as a 0-d tensor. This means it lacks tensor attributes such as :code:`x.dtype`. A workaround is to explicitly convert the scalar to a tensor using :code:`tl.to_tensor(x)`, where :code:`x` is the scalar. - ---------------------------- Using Third-party Tools ---------------------------- diff --git a/docs/python-api/triton-semantics.rst b/docs/python-api/triton-semantics.rst new file mode 100644 index 000000000..e35a355d3 --- /dev/null +++ b/docs/python-api/triton-semantics.rst @@ -0,0 +1,47 @@ +Triton Semantics +================ + +Triton mostly follows the semantics of NumPy with minor exceptions. In this document, we go over some of the array computing features supported in Triton, and we cover the exceptions where Triton's semantics deviate from that NumPy. + +Type Promotion +-------------- + +**Type Promotion** occurs when tensors of different data types are used in an operation. For binary operations associated to `dunder methods `_ and the ternary function ``tl.where`` on its last two arguments, Triton automatically converts the input tensors to a common data type following a hierarchy of kinds (sets of dtypes): ``{bool} < {integral dypes} < {floating point dtypes}``. + +The algorithm is as follows: + +1. **Kind** If one tensor is of a dtype of a higher kind, the other tensor is promoted to this dtype: ``(int32, bfloat16) -> bfloat16`` + +2. **Width** If both tensors are of dtypes of the same kind, and one of them is of a higher width, the other one is promoted to this dtype: ``(float32, float16) -> float32`` + +3. **Supremum** If both tensors are of the same width and signedness but different dtypes, they are both promoted to the next larger dtype. ``(float16, bfloat16) -> float32`` + + 3.1 If both tensors are of different ``fp8`` dtypes, they are both cast to ``float16``. + +4. **Prefer unsigned** Otherwise (same width, different signedness), they are promoted to the unsigned dtype: ``(int32, uint32) -> uint32`` + +The rules are a bit different when they involve a scalar. By scalar here we mean a numeric literal, a variable marked with `tl.constexpr` or a combination of these. These are represented by NumPy scalars and have types ``bool``, ``int`` and ``float``. + +When an operation involves a tensor and a scalar: + +1. If the scalar is of a kind lower or equal to the tensor, it will not participate in the promotion: ``(uint8, int) -> uint8`` + +2. If the scalar is of a higher kind, we choose the lowest dtype in which it fits among ``int32`` < ``uint32`` < ``int64`` < ``uint64`` for ints and ``float32`` < ``float64`` for floats. Then, both the tensor and the scalar are promoted to this dtype: ``(int16, 4.0) -> float32`` + + +Broadcasting +------------ + +**Broadcasting** allows operations on tensors of different shapes by automatically expanding their shapes to a compatible size without copying the data. This follows the following rules: + +1. If one of the tensor shapes is shorter, pad it on the left with ones until both tensors have the same number of dimensions: ``((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))`` + +2. Two dimensions are compatible if they are equal, or if one of them is 1. A dimension of 1 will be expanded to match the dimension of the other tensor. ``((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))`` + + +Differences with NumPy +---------------------- + +**C rounding in integer division** Operators in Triton follow C semantics rather than Python semantics for efficiency. As such, ``int // int`` implements `rounding towards zero as in C `_ for integers of mixed signs, rather than rounding towards minus infinity as in Python. For the same reason, the modulus operator ``int % int`` (which is defined as ``a % b = a - b * (a // b)``) also follows C semantics rather than Python semantics. + +Perhaps confusingly, integer division and modulus follow Python semantics for computations where all the inputs are scalars. diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 20b2ce35a..ecd0fb3b9 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -28,6 +28,7 @@ Creation Ops full zeros zeros_like + cast Shape Manipulation Ops diff --git a/docs/python-api/triton.testing.rst b/docs/python-api/triton.testing.rst index 824e10c6f..c89b0ba42 100644 --- a/docs/python-api/triton.testing.rst +++ b/docs/python-api/triton.testing.rst @@ -11,3 +11,4 @@ triton.testing do_bench do_bench_cudagraph perf_report + assert_close diff --git a/include/triton/Analysis/Alias.h b/include/triton/Analysis/Alias.h index a06df5ae2..199238bea 100644 --- a/include/triton/Analysis/Alias.h +++ b/include/triton/Analysis/Alias.h @@ -79,13 +79,13 @@ class SharedMemoryAliasAnalysis ModRefResult getModRef(Operation *op, Value location); void setToEntryState(dataflow::Lattice *lattice) override { - propagateIfChanged( - lattice, lattice->join( - AliasInfo::getPessimisticValueState(lattice->getPoint()))); + propagateIfChanged(lattice, + lattice->join(AliasInfo::getPessimisticValueState( + lattice->getAnchor()))); } /// Computes if the alloc set of the results are changed. - void + LogicalResult visitOperation(Operation *op, ArrayRef *> operands, ArrayRef *> results) override; diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index a9e02b420..044be950f 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -18,10 +18,42 @@ namespace mlir { namespace triton { class AllocationAnalysis; -SmallVector -getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, - unsigned &outVec); -SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op); +// To convert a tensor from one layout to another, we need to allocate a +// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may +// require multiple iterations, with each iteration involving multiple +// vectorized loads/stores. The scratch buffer has a shape (`repShape`) that +// represents the maximum size accessed in each dimension during each iteration. +// It is padded (`paddedRepShape`) to avoid bank conflicts and is accessed in a +// specific `order`. +struct ScratchConfig { + SmallVector repShape; + SmallVector paddedRepShape; + SmallVector order; + unsigned inVec; + unsigned outVec; + + ScratchConfig(SmallVector repShape, + SmallVector paddedRepShape, unsigned inVec = 1, + unsigned outVec = 1) + : repShape(repShape), paddedRepShape(paddedRepShape), inVec(inVec), + outVec(outVec) {} + + void print(llvm::raw_ostream &os) const { + os << "repShape: ["; + llvm::interleaveComma(repShape, os); + os << "]"; + os << ", paddedRepShape: ["; + llvm::interleaveComma(paddedRepShape, os); + os << "]"; + os << ", order: ["; + llvm::interleaveComma(order, os); + os << "]"; + os << ", inVec: " << inVec << ", outVec: " << outVec << "\n"; + } +}; + +ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, + RankedTensorType dstTy); } // namespace triton @@ -135,6 +167,9 @@ class Allocation { /// Returns the size of total shared memory allocated size_t getSharedMemorySize() const { return sharedMemorySize; } + /// Returns mapping from operation to list of live LDS buffers + std::map> getLiveBuffers(); + private: /// A class that represents a shared memory buffer struct BufferT { @@ -151,15 +186,17 @@ class Allocation { size_t size; size_t alignment; size_t offset; + SetVector regionIds; + int sharingGroup; // -1 means not shared bool operator==(const BufferT &other) const { return id == other.id; } bool operator<(const BufferT &other) const { return id < other.id; } BufferT() : BufferT(BufferKind::Explicit, 0) {} BufferT(BufferKind kind, size_t size, size_t alignment = 4, - size_t offset = 0) + size_t offset = 0, int sharingGroup = -1) : kind(kind), id(nextId++), size(size), alignment(alignment), - offset(offset) {} + offset(offset), sharingGroup(sharingGroup) {} size_t setOffsetAligned(size_t newOffset) { return offset = llvm::alignTo(newOffset, alignment); diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index 22a7ed554..aad4503b4 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -180,8 +180,8 @@ class ModuleAxisInfoAnalysis : public CallGraph { for (auto funcOp : llvm::reverse(sortedFuncs)) { initialize(funcOp); funcOp.walk([&](CallOpInterface callOp) { - auto callee = - dyn_cast(callOp.resolveCallable(&symbolTable)); + auto callee = dyn_cast( + callOp.resolveCallableInTable(&symbolTable)); update(callOp, callee); }); } diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index 43bd5d15b..038b0e167 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -10,31 +10,38 @@ namespace mlir { class OpBuilder; +/// Callback to allow backend to provide more information on whether a barrier +/// is needed between two operations. Even though two operations access the same +/// shared memory thay may not require a barrier in between them. +using MembarFilterFn = std::function; + struct BlockInfo { - using BufferIdSetT = Allocation::BufferIdSetT; - using IntervalSetT = std::set>; + using IntervalMapT = std::map, std::set>; - IntervalSetT syncReadIntervals; - IntervalSetT syncWriteIntervals; + IntervalMapT syncReadIntervals; + IntervalMapT syncWriteIntervals; BlockInfo() = default; /// Unions two BlockInfo objects. BlockInfo &join(const BlockInfo &other) { - syncReadIntervals.insert(other.syncReadIntervals.begin(), - other.syncReadIntervals.end()); - syncWriteIntervals.insert(other.syncWriteIntervals.begin(), - other.syncWriteIntervals.end()); + for (auto &interval : other.syncReadIntervals) + syncReadIntervals[interval.first].insert(interval.second.begin(), + interval.second.end()); + for (auto &interval : other.syncWriteIntervals) + syncWriteIntervals[interval.first].insert(interval.second.begin(), + interval.second.end()); return *this; } /// Returns true if intervals in two BlockInfo objects are intersected. - bool isIntersected(const BlockInfo &other) const { - return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) || + bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const { + return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals, + filter) || /*WAR*/ - isIntersected(syncReadIntervals, other.syncWriteIntervals) || + isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) || /*WAW*/ - isIntersected(syncWriteIntervals, other.syncWriteIntervals); + isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter); } /// Clears the intervals because a barrier is inserted. @@ -52,12 +59,17 @@ struct BlockInfo { bool operator!=(const BlockInfo &other) const { return !(*this == other); } private: - bool isIntersected(const IntervalSetT &lhsIntervalSet, - const IntervalSetT &rhsIntervalSet) const { + bool isIntersected(const IntervalMapT &lhsIntervalSet, + const IntervalMapT &rhsIntervalSet, + MembarFilterFn filter) const { for (auto &lhs : lhsIntervalSet) for (auto &rhs : rhsIntervalSet) - if (lhs.intersects(rhs)) - return true; + if (lhs.first.intersects(rhs.first)) + for (auto lhsOp : lhs.second) + for (auto rhsOp : rhs.second) + if (!filter || !filter(lhsOp, rhsOp)) + return true; + return false; } }; @@ -82,7 +94,8 @@ class MembarAnalysis { /// it is considered as the problem of the operation itself but not the membar /// analysis. MembarAnalysis() = default; - explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {} + explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter) + : allocation(allocation), filter(filter) {} /// Runs the membar analysis to the given operation, inserts a barrier if /// necessary. @@ -117,6 +130,7 @@ class MembarAnalysis { private: Allocation *allocation = nullptr; + MembarFilterFn filter = nullptr; }; /// Postorder traversal on the callgraph to insert membar instructions @@ -126,9 +140,10 @@ class MembarAnalysis { /// before and after function calls, but might be a bit conservative. class ModuleMembarAnalysis : public CallGraph { public: - ModuleMembarAnalysis(ModuleAllocation *moduleAllocation) + ModuleMembarAnalysis(ModuleAllocation *moduleAllocation, + MembarFilterFn filter = nullptr) : CallGraph(moduleAllocation->getModuleOp()), - moduleAllocation(moduleAllocation) {} + moduleAllocation(moduleAllocation), filter(filter) {} void run() { walk( @@ -139,7 +154,7 @@ class ModuleMembarAnalysis : public CallGraph { auto *allocation = moduleAllocation->getFuncData(funcOp); auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo()); if (inserted) { - MembarAnalysis analysis(allocation); + MembarAnalysis analysis(allocation, filter); analysis.run(funcMap); } }); @@ -147,6 +162,7 @@ class ModuleMembarAnalysis : public CallGraph { private: ModuleAllocation *moduleAllocation; + MembarFilterFn filter; }; } // namespace mlir diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 7b215f267..cb3e3d292 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -6,6 +6,7 @@ #include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" namespace mlir { @@ -62,7 +63,8 @@ class ReduceOpHelper { unsigned getThreadsReductionAxis(); - SmallVector getScratchConfig(); + // The shape of the shared memory space needed for the reduction. + SmallVector getScratchRepShape(); SmallVector getOrderWithAxisAtBeginning(); @@ -176,7 +178,9 @@ class ScanLoweringHelper { SmallVector, SmallVector>> getReshapeDecomposition(ArrayRef srcShape, ArrayRef dstShape); -bool maybeSharedAllocationOp(Operation *op); +// Returns the number of elements in the scratch space needed. +// If shape is empty, it means no shared memory is needed. +unsigned getNumScratchElements(ArrayRef shape); bool supportMFMA(triton::DotOp op); @@ -186,13 +190,33 @@ bool supportMMA(triton::DotOp op, int version); bool supportMMA(Value value, int version); -bool isSingleValue(Value value); +// Conversion from `srcTy` to `dstTy` involving the minimum amount of data +// transfer provided that both types can be converted to LL (if it can't it'll +// return nullopt). The output will be such that layout.getInDimNames() == +// layout.getOutDimNames() and the conversion will not include kBlock (resp. +// kWarp or kLane) if it can be avoided +std::optional +minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy); -bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); +// Conversion from `srcTy` to `dstTy` only involves reordering of registers. +// There is no need for data exchange across threads, warps, or blocks. +bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy); -bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); +// Conversion from `srcTy` to `dstTy` involves data exchange across threads +// within a warp. No data exchange across warps or blocks is needed. +bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy); + +// Conversion from `srcTy` to `dstTy` involves data exchange across threads, +// warps, and possibly blocks. +bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); + +bool atomicNeedsSharedMemory(Value result); -bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy); +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT); + +bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); + +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, @@ -303,7 +327,7 @@ template class CallGraph { moduleOp.walk([&](Operation *op) { auto caller = op->getParentOfType(); if (auto callOp = dyn_cast(op)) { - auto *callee = callOp.resolveCallable(&symbolTable); + auto *callee = callOp.resolveCallableInTable(&symbolTable); auto funcOp = dyn_cast_or_null(callee); if (funcOp) { graph[caller].emplace_back( diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h index fadba413f..a00f9f844 100644 --- a/include/triton/Conversion/MLIRTypes.h +++ b/include/triton/Conversion/MLIRTypes.h @@ -33,6 +33,12 @@ inline bool isFloat(Type type) { type.isFloat8E5M2FNUZ(); } +inline bool isFloat8(Type type) { + return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || + type.isFloat8E5M2FNUZ(); +} + inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } } // namespace type diff --git a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h index 5203ffff9..22c8f9c8a 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -88,10 +88,11 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { // encoding not available return resultVals; Attribute baseEncoding = encoding; - if (isa(baseEncoding)) - // TODO: this logic seems incorrect for mfma layout. Skip for now. - // We saw mismatches for some flash-attention tests on AMD backend. - // Note that this logic works for sliced layout whose parent is + if (isa(baseEncoding) || + isa(baseEncoding)) + // TODO: this logic seems incorrect for mfma and wmma layout. Skip for + // now. We saw mismatches for some flash-attention and dot tests on AMD + // backend. Note that this logic works for sliced layout whose parent is // mfma layout. Therefore, this is not combined with the following check. return resultVals; while (auto sliced = dyn_cast(baseEncoding)) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/include/triton/Conversion/TritonGPUToLLVM/Passes.td index 700dcd6b4..04ced1767 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Passes.td +++ b/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -5,6 +5,13 @@ include "mlir/Pass/PassBase.td" def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> { let summary = "Add metadata for shared memory allocation"; + let description = [{ + This pass uses the `ModuleAllocation` analysis to: + - Annotate modules with an attribute with the amount of shared/local + memory used. + - Annotate operations with an offset into the total shared/local memory. + }]; + let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()"; } diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index d1494fd7e..4ea6bd150 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -80,8 +80,13 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit); + void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit); void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, @@ -91,6 +96,7 @@ void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + const TargetInfoBase &targetInfo, PatternBenefit benefit); void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, @@ -98,6 +104,10 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateRegReallocOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + } // namespace triton } // namespace mlir diff --git a/include/triton/Conversion/TritonGPUToLLVM/Patterns.h b/include/triton/Conversion/TritonGPUToLLVM/Patterns.h index 934501ad3..ac13ecc28 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Patterns.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Patterns.h @@ -20,8 +20,7 @@ void decomposeSplatOpToSharedLayoutConversion(ModuleOp module); /// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the /// given |module| op, but bypass the decomposition if |shortcutFn| returns /// true. -using ShortcutFn = std::function; -template +using ShortcutFn = std::function; void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, ShortcutFn shortcutFn); diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index d03f6b862..68f430d05 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -10,55 +10,79 @@ class TargetInfoBase { virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0; - virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc, - Type type, Value cmp) const = 0; - - virtual void storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) const = 0; - virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *converter, Value ptr, - Type elemTy, Value pred) const = 0; - - virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const = 0; - virtual Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const = 0; - virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const = 0; - virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, Value i) const = 0; - - virtual Value programId(ConversionPatternRewriter &rewriter, Location loc, + virtual Value ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const = 0; + + // Store/load a value from shared memory, either in the same CTA or, if + // `ctaId` is non-nullopt, in another CTA in the same group. + // + // A target that does not support cross-CTA transfers will assert if ctaId is + // non-nullopt. + // + // Assumes the address is aligned to the width of `val`. + virtual void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const = 0; + virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const = 0; + + void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) const { + storeDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, val, pred); + } + Value loadShared(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred) const { + return loadDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, elemTy, + pred); + } + + virtual void storeMatrixShared(RewriterBase &rewriter, Location loc, + Value ptr, Value val) const = 0; + + virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const = 0; + + virtual Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, int axis) const = 0; - virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, + virtual bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const = 0; - - virtual bool processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth = 0) const = 0; + unsigned numLaneToReduce, + unsigned interleave) const = 0; virtual std::string getMulhiFuncName(Type resultElementTy) const = 0; // Emits LLVM code with |rewriter| to print a message following the given // format from the device. |formatStrStart| is the pointer to the start of // the format string global variable; |args| are the arguments to fill // placeholders in the format string. - virtual void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + virtual void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const = 0; + + // Emits LLVM code with |rewriter| to print a message, particularly useful for + // backend debug. |msg| is the message to print, |args| are the arguments to + // fill placeholders in the |msg|. + // NOTE: This function is used for backend debug. DO NOT DELETE. + // Example use: targetInfo.printf(rewriter,"index: %d, value: %f", {index, + // value}); + virtual void printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const = 0; + // Emits LLVM code with |rewriter| to perform assertion failure with the given // |message| from the given |func| in |file|. - virtual void assertFail(ConversionPatternRewriter &rewriter, Location loc, + virtual void assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const = 0; - // Whether to enable linear layout. This is a per-backend temporary escape - // hatch to disable linear layout while figuring out issues. Eventually we - // want to enable linear layout everywhere and delete this control. - virtual bool enableLinearLayout() const { return true; } + virtual int getSharedAddressSpace() const = 0; + + virtual bool supportVectorizedAtomics() const = 0; virtual ~TargetInfoBase() {} }; diff --git a/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h b/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h index ab9d0ebf8..5ae547c39 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h @@ -4,6 +4,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Dialect/TritonGPU/IR/Types.h" using namespace mlir; @@ -14,12 +15,14 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { using TypeConverter::convertType; TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis = nullptr); Type getElementTypeForStruct(TensorOrMemDesc type); Type convertTritonPointerType(triton::PointerType type); - Type convertTritonTensorType(RankedTensorType type); - Type convertMemDescType(MemDescType type); + Type convertTritonTensorType(RankedTensorType type, + const TargetInfoBase &targetInfo); + Type convertMemDescType(MemDescType type, const TargetInfoBase &targetInfo); Type convertAsyncToken(triton::gpu::AsyncTokenType type); }; diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 382f60254..111e4f4b8 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -6,12 +6,14 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" @@ -101,7 +103,7 @@ using namespace mlir::triton; #define barrier() rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) #define null(...) rewriter.create(loc, __VA_ARGS__) -#define call(...) rewriter.create(loc, __VA_ARGS__) +#define call(...) LLVM::createLLVMCallOp(rewriter, loc, __VA_ARGS__) // Types #define int_ty(width) rewriter.getIntegerType(width) @@ -124,13 +126,18 @@ using namespace mlir::triton; #define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) // Constants +#define int_val(bitwidth, val) \ + LLVM::createLLVMIntegerConstant(rewriter, loc, bitwidth, val) +#define i1_val(val) LLVM::createConstantI1(loc, rewriter, val) +#define true_val() i1_val(true) +#define false_val() i1_val(false) #define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__) #define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) #define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) +#define i8_val(val) int_val(8, val) +#define i16_val(val) int_val(16, val) #define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) #define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__) -#define int_val(width, val) \ - LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) #define tid_val() getThreadId(rewriter, loc) // Attributes @@ -141,6 +148,23 @@ using namespace mlir::triton; namespace mlir { namespace triton { +static inline void insertBarrier(OpBuilder &builder, Operation *op) { + auto barrierOp = builder.create(op->getLoc()); + auto asyncTaskIds = getAsyncTaskIds(op); + assert(asyncTaskIds.size() <= 1); + if (asyncTaskIds.size() == 1) { + int asyncTaskId = asyncTaskIds[0]; + int barId = asyncTaskId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int numThreads = numWarps * warpSize; + barrierOp->setAttr("bar_id", builder.getI64IntegerAttr(barId)); + barrierOp->setAttr("num_threads", builder.getI64IntegerAttr(numThreads)); + } +} + // Delinearize supposing order is [0, 1, .. , n] template llvm::SmallVector getMultiDimIndexImpl(T linearIndex, @@ -202,9 +226,9 @@ T getLinearIndex(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape, namespace gpu { Type getFunctionType(Type resultType, ValueRange operands); -LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, - Operation *op, StringRef funcName, - Type funcType, StringRef libname = "", +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, + StringRef libname = "", StringRef libpath = ""); } // namespace gpu @@ -213,31 +237,27 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, namespace LLVM { using namespace mlir::triton; +Value createConstantI1(Location loc, OpBuilder &rewriter, bool v); Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v); - -/// Create a 64-bit integer constant. Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v); - -/// Create a 16-bit float constant. Value createConstantF16(Location loc, OpBuilder &rewriter, float v); - -/// Create a 32-bit float constant. Value createConstantF32(Location loc, OpBuilder &rewriter, float v); - -/// Create a 64-bit float constant. Value createConstantF64(Location loc, OpBuilder &rewriter, double v); - -/// Create NaN constant of specified type. Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type); - -/// Create an index type constant. Value createIndexConstant(OpBuilder &builder, Location loc, const TypeConverter *converter, int64_t value); - -/// Create an integer constant of \param width bits. Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, int64_t value); +LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc, + LLVMFuncOp funcOp, ValueRange args); +LLVM::CallIntrinsicOp +createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic, + TypeRange types, ValueRange args); + +// Is v an integer or floating-point scalar constant equal to 0? +bool isConstantZero(Value v); + /// Helper function to get strides from a given shape and its order SmallVector getStridesFromShapeAndOrder(ArrayRef shape, ArrayRef order, @@ -305,7 +325,7 @@ struct SharedMemoryObject { } Value getBaseBeforeSlice(int order, Location loc, - ConversionPatternRewriter &rewriter) const { + RewriterBase &rewriter) const { Value cSwizzleOffset = getCSwizzleOffset(order); Value offset = sub(i32_val(0), cSwizzleOffset); Type type = base.getType(); @@ -313,9 +333,10 @@ struct SharedMemoryObject { } }; -SharedMemoryObject -getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, - ConversionPatternRewriter &rewriter); +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter); // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. @@ -329,15 +350,14 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, SmallVector delinearize(RewriterBase &rewriter, Location loc, Value linear, ArrayRef shape); -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape, - ArrayRef order); +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape); +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape); -Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, - StringRef key, StringRef content); +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content); // Given an elemId which represents the index of an element from the list of // elements that are in the thread's registers (i.e. total of @@ -346,7 +366,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, // when converting distributed to distributed layout. Also, a replica is the // smallest CTA tile that is common between input and output layouts. SmallVector getMultiDimOffset(Attribute layout, Location loc, - ConversionPatternRewriter &rewriter, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type, ArrayRef multiDimCTAInRepId, @@ -355,33 +375,30 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, // Given a multiDimOffset, this function wraps around each dimension to be // within shape. SmallVector getWrappedMultiDimOffset( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDimOffset, ArrayRef shape, - SmallVector shapePerCTATile, SmallVector shapePerCTA); + RewriterBase &rewriter, Location loc, ArrayRef multiDimOffset, + ArrayRef shape, SmallVector shapePerCTATile, + SmallVector shapePerCTA); inline bool isKernel(FunctionOpInterface funcOp) { return funcOp.getVisibility() == SymbolTable::Visibility::Public; } -inline Value getStackPointer(PatternRewriter &rewriter, +inline Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) { + if (!isKernel(funcOp)) { + return funcOp.getArgument(funcOp.getNumArguments() - 1); + } + auto mod = funcOp->getParentOfType(); - LLVM::GlobalOp globalBase = nullptr; - mod.walk([&](LLVM::GlobalOp op) { - if (op.getSymName() == "global_smem") - globalBase = op; - }); + auto globalBase = dyn_cast(mod.lookupSymbol("global_smem")); assert(globalBase); - if (isKernel(funcOp)) - return rewriter.create(funcOp.getLoc(), globalBase); - else - return funcOp.getArgument(funcOp.getNumArguments() - 1); + return rewriter.create(funcOp.getLoc(), globalBase); } -inline Value getSharedMemoryBase(Location loc, - ConversionPatternRewriter &rewriter, - Operation *op) { - auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); +inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Operation *op) { + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), + target.getSharedAddressSpace()); FunctionOpInterface func = op->template getParentOfType(); assert(op->hasAttr("allocation.offset")); @@ -396,19 +413,12 @@ inline Value getSharedMemoryBase(Location loc, /* ------------------------------------ */ // Returns CTA level thread idx -inline Value getThreadIdInCTA(RewriterBase &rewriter, Location loc) { +inline Value getThreadId(RewriterBase &rewriter, Location loc) { Value tid = rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); return rewriter.create(loc, i32_ty, tid); } -// Returns CTA level thread idx. -inline Value getThreadId(RewriterBase &rewriter, Location loc) { - Value tid = getThreadIdInCTA(rewriter, loc); - auto mod = rewriter.getBlock()->getParent()->getParentOfType(); - return tid; -} - // ----------------------------------------------------------------------- // Shared memory utilities // ----------------------------------------------------------------------- @@ -917,10 +927,12 @@ inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout, auto rank = shapePerCta.size(); assert(rank == 2 || rank == 3); SmallVector elemOffset(rank, 0); + auto elemStride = wmmaLayout.getVersion() == 1 ? 2 : 1; if (rank == 3) elemOffset[0] = ctaBatchOffset; for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { - elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; + elemOffset[rank - 2] = + ctaOffsetX * shapePerCta[rank - 2] + elemStride * elem; elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; offsets.push_back(elemOffset); } @@ -937,7 +949,7 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, SmallVector warpsPerCTA; for (unsigned i = 0; i < rank; ++i) warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); - auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerInstr(); Value threadId = getThreadId(rewriter, loc); Value warpSize = i32_val(triton::gpu::getWarpSize(wmmaLayout)); @@ -966,8 +978,17 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, SmallVector multiDimBase(rank); - multiDimBase[rank - 2] = - add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + auto ver = wmmaLayout.getVersion(); + if (ver == 1) { + multiDimBase[rank - 2] = + add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + } else { + assert(ver == 2); + multiDimBase[rank - 2] = + add(mul(udiv(threadIdPerWarp, i32_val(mnkDim[2])), + i32_val(wmmaLayout.getSizePerThread()[rank - 2])), + offWarp0); + } multiDimBase[rank - 1] = add(laneId, offWarp1); // TODO: It is assumed when rank = 3, warpsPerCTA is set to @@ -991,7 +1012,7 @@ emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout, assert(rank == 2 || rank == 3); SmallVector numWarpsPerDim(rank, 1); - auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerInstr(); SmallVector shapePerWarp(rank, 1); shapePerWarp[rank - 2] = mnkDim[0]; shapePerWarp[rank - 1] = mnkDim[1]; @@ -1017,8 +1038,8 @@ emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout, return offsets; } -inline SmallVector> -emitOffsetForLayout(Attribute layout, RankedTensorType type); +SmallVector> emitOffsetForLayout(Attribute layout, + RankedTensorType type); inline SmallVector> emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout, @@ -1170,78 +1191,36 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, return idx; } -inline SmallVector> -emitOffsetForLayout(Attribute layout, RankedTensorType type) { - if (auto blockedLayout = dyn_cast(layout)) - return emitOffsetForBlockedLayout(blockedLayout, type); - if (auto mmaLayout = dyn_cast(layout)) { - if (mmaLayout.isVolta()) - return emitOffsetForMmaLayoutV1(mmaLayout, type); - if (mmaLayout.isAmpere()) - return emitOffsetForMmaLayoutV2(mmaLayout, type); - if (mmaLayout.isHopper()) - return emitOffsetForMmaLayoutV3(mmaLayout, type); - } - if (auto mfmaLayout = mlir::dyn_cast(layout)) { - return emitOffsetForMfmaLayout(mfmaLayout, type); - } - if (auto wmmaLayout = mlir::dyn_cast(layout)) { - return emitOffsetForWmmaLayout(wmmaLayout, type); - } - if (auto sliceLayout = mlir::dyn_cast(layout)) - return emitOffsetForSliceLayout(sliceLayout, type); - llvm_unreachable("unsupported emitOffsetForLayout"); -} - -// Eventually this will become the only emitIndices function. -std::optional>> -emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, Attribute layout, - RankedTensorType type, bool withCTAOffset); - // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. -inline SmallVector> +SmallVector> emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - Attribute layout, RankedTensorType type, bool withCTAOffset, - bool allowLL = true) { - // Eventually the LinearLayout path will be the only one. For now we allow - // both paths so we can test that they produce the same results. - if (allowLL && target.enableLinearLayout()) { - std::optional>> llOffsets = - emitIndicesUsingLinearLayouts(loc, rewriter, target, layout, type, - withCTAOffset); - if (llOffsets.has_value()) - return *llOffsets; - } + Attribute layout, RankedTensorType type, bool withCTAOffset); - // step 1, delinearize threadId to get the base index - auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, target, layout, - type, withCTAOffset); - // step 2, get offset of each element - auto offset = emitOffsetForLayout(layout, type); - // step 3, add offset to base, and reorder the sequence - // of indices to guarantee that elems in the same - // sizePerThread are adjacent in order - auto shape = type.getShape(); - unsigned rank = shape.size(); - unsigned elemsPerThread = offset.size(); - SmallVector> multiDimIdx(elemsPerThread, - SmallVector(rank)); - for (unsigned n = 0; n < elemsPerThread; ++n) - for (unsigned k = 0; k < rank; ++k) - multiDimIdx[n][k] = add(multiDimBase[k], i32_val(offset[n][k])); - - return multiDimIdx; -} +// Emits IR to load data from shared memory into registers, or to store data +// from registers into shared memory. +// +// You supply perVectorCallback, which is called once per group of register +// elements to transfer. You can use this callback to emit IR to load or store +// data from or to shared memory. +// +// elemLlvmTy should be dstTy's element type converted to an LLVM-dialect type. +// +// If maxVecElems is provided, we won't vectorize more than this many elements. +// +// Returns true on success. +[[nodiscard]] bool emitTransferBetweenRegistersAndShared( + RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, + std::optional maxVecElems, Value shmemBase, + ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, + std::function perVectorCallback); -/* ---------------- */ -/* ---------------- */ inline DenseMap getSwizzledSharedPtrs( Location loc, const TargetInfoBase &target, unsigned inVec, RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout, Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter, - SmallVectorImpl &offsetVals, SmallVectorImpl &srcStrides) { + ArrayRef offsetVals, ArrayRef srcStrides) { // This utility computes the pointers for accessing the provided swizzled // shared memory layout `resSharedLayout`. More specifically, it computes, // for all indices (row, col) of `srcEncoding` such that idx % inVec = 0, @@ -1268,7 +1247,7 @@ inline DenseMap getSwizzledSharedPtrs( // then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y // This means that we can use some immediate offsets for shared memory // operations. - auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); + auto dstPtrTy = smemObj.base.getType(); auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); Value dstPtrBase = gep(dstPtrTy, resElemTy, smemObj.base, dstOffset); @@ -1289,7 +1268,7 @@ inline DenseMap getSwizzledSharedPtrs( // Tensor indices held by the current thread, as LLVM values auto srcIndices = emitIndices(loc, rewriter, target, srcEncoding, srcTy, /*withCTAOffset=*/false); - // Swizzling with leading offsets (e.g. Hopper GMMA) + // Swizzling with leading offsets (e.g. Hopper WGMMA) unsigned swizzlingByteWidth = 0; if (resSharedLayout.getHasLeadingOffset()) { if (perPhase == 4 && maxPhase == 2) @@ -1347,9 +1326,8 @@ inline DenseMap getSwizzledSharedPtrs( idxCol = urem(idxCol, numElemsPerSwizzlingRowVal); strideRow = numElemsPerSwizzlingRowVal; } - if (auto add = dyn_cast_or_null(idxCol.getDefiningOp())) { - if (auto _cst = dyn_cast_or_null( - add.getRhs().getDefiningOp())) { + if (auto add = idxCol.getDefiningOp()) { + if (auto _cst = add.getRhs().getDefiningOp()) { unsigned cst = cast(_cst.getValue()).getValue().getSExtValue(); unsigned key = cst % (outVec * maxPhase); @@ -1358,9 +1336,8 @@ inline DenseMap getSwizzledSharedPtrs( immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase); } } - if (auto add = dyn_cast_or_null(idxRow.getDefiningOp())) { - if (auto _cst = dyn_cast_or_null( - add.getRhs().getDefiningOp())) { + if (auto add = idxRow.getDefiningOp()) { + if (auto _cst = add.getRhs().getDefiningOp()) { unsigned cst = mlir::cast(_cst.getValue()).getValue().getSExtValue(); unsigned key = cst % (perPhase * maxPhase); @@ -1381,6 +1358,7 @@ inline DenseMap getSwizzledSharedPtrs( colOffOrdered = udiv(colOffOrdered, i32_val(minVec)); colOffOrdered = mul(colOffOrdered, i32_val(minVec)); Value colOff = add(colOffSwizzled, colOffOrdered); + // compute non-immediate offset if (outOrder.size() == 3) offset = add(offset, mul(idx[outOrder[2]], srcStrides[outOrder[2]])); @@ -1400,121 +1378,21 @@ inline DenseMap getSwizzledSharedPtrs( return ret; } -inline SmallVector loadSharedToDistributed( - Value dst, Value src, SharedMemoryObject smemObj, Type elemTy, Location loc, - ConversionPatternRewriter &rewriter, const TargetInfoBase &target) { - auto dstTy = cast(dst.getType()); - auto dstShape = dstTy.getShape(); - assert(dstShape.size() <= 2 && "Unexpected rank of loadSharedToDistributed"); - auto srcTy = cast(src.getType()); - auto dstDistributedLayout = dstTy.getEncoding(); - if (auto mmaLayout = dyn_cast(dstDistributedLayout)) { - assert((!mmaLayout.isVolta()) && - "ConvertLayout Shared->MMAv1 is not supported yet"); - } - auto srcSharedLayout = - cast(srcTy.getEncoding()); - auto srcElemTy = srcTy.getElementType(); - auto dstElemTy = dstTy.getElementType(); - LDBG("loadSharedToDistributed elemTy " << elemTy << " srcElemTy " << srcElemTy - << " dstElemTy " << dstElemTy); - auto inOrd = triton::gpu::getOrder(srcSharedLayout); - auto outOrd = triton::gpu::getOrder(dstDistributedLayout); - unsigned outVec = inOrd == outOrd - ? triton::gpu::getUniqueContigPerThread( - dstDistributedLayout, dstShape)[outOrd[0]] - : 1; - - // If the shmem layout is not swizzled, we can trivially vectorize loads - // across the whole width of the most-minor dimension of the shape, because - // Triton requires all the dims are powers of 2. - unsigned inVec = srcSharedLayout.getMaxPhase() == 1 - ? srcTy.getShape()[inOrd[0]] - : srcSharedLayout.getVec(); - unsigned minVec = std::min(outVec, inVec); - unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy); - SmallVector offsetVals = {smemObj.strides.size(), i32_val(0)}; - - DenseMap sharedPtrs = - getSwizzledSharedPtrs(loc, target, outVec, dstTy, srcSharedLayout, elemTy, - smemObj, rewriter, offsetVals, smemObj.strides); - assert(outElems % minVec == 0 && "Unexpected number of elements"); - unsigned numVecs = outElems / minVec; - auto wordTy = vec_ty(elemTy, minVec); - SmallVector outVals(outElems); - for (unsigned i = 0; i < numVecs; ++i) { - Value smemAddr = sharedPtrs[i * minVec]; - smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); - auto valVec = load(wordTy, smemAddr); - valVec.setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8); - for (unsigned v = 0; v < minVec; ++v) { - Value currVal = extract_element(elemTy, valVec, i32_val(v)); - outVals[i * minVec + v] = currVal; - } - } - return outVals; -} - -inline void storeDistributedToShared(Value src, ArrayRef inVals, - ArrayRef dstStrides, Value dst, - Value smemBase, Type elemTy, Location loc, - ConversionPatternRewriter &rewriter, - const TargetInfoBase &target) { - auto srcTy = cast(src.getType()); - auto srcShape = srcTy.getShape(); - auto rank = srcShape.size(); - assert(rank <= 3 && "Unexpected rank of storeDistributedToShared"); - auto dstTy = cast(dst.getType()); - auto srcDistributedLayout = srcTy.getEncoding(); - if (auto mmaLayout = dyn_cast(srcDistributedLayout)) { - assert((!mmaLayout.isVolta()) && - "ConvertLayout MMAv1->Shared is not supported yet"); - } - auto dstSharedLayout = - cast(dstTy.getEncoding()); - auto dstElemTy = dstTy.getElementType(); - auto inOrd = triton::gpu::getOrder(srcDistributedLayout); - auto outOrd = dstSharedLayout.getOrder(); - unsigned inVec = inOrd == outOrd - ? triton::gpu::getUniqueContigPerThread( - srcDistributedLayout, srcShape)[inOrd[0]] - : 1; - // If the shmem layout is not swizzled, we can trivially vectorize stores - // across the whole width of the most-minor dimension of the shape, because - // Triton requires all the dims are powers of 2. - unsigned outVec = dstSharedLayout.getMaxPhase() == 1 - ? dstTy.getShape()[inOrd[0]] - : dstSharedLayout.getVec(); - unsigned minVec = std::min(outVec, inVec); - unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); - auto wordTy = vec_ty(elemTy, minVec); - Value word; - - SmallVector srcStrides(dstStrides); - SmallVector offsetVals(rank, i32_val(0)); - SharedMemoryObject smemObj(smemBase, elemTy, srcStrides, offsetVals); - - DenseMap sharedPtrs = - getSwizzledSharedPtrs(loc, target, inVec, srcTy, dstSharedLayout, elemTy, - smemObj, rewriter, offsetVals, srcStrides); - LDBG("storeDistributedToShared: numElems = " << numElems << " minVec = " - << minVec << " " << wordTy); - for (unsigned i = 0; i < numElems; ++i) { - if (i % minVec == 0) - word = undef(wordTy); - word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec)); - if (i % minVec == minVec - 1) { - Value smemAddr = sharedPtrs[i / minVec * minVec]; - smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); - store(word, smemAddr) - .setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8); - } - } -} - -inline Value -getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, - ConversionPatternRewriter &rewriter) { +SmallVector loadSharedToDistributed(RankedTensorType dstTy, + MemDescType srcTy, Type elemLlvmTy, + SharedMemoryObject smemObj, + Location loc, RewriterBase &rewriter, + const TargetInfoBase &target); + +void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, + Type elemLlvmTy, ArrayRef srcVals, + Value smemBase, ArrayRef dstStrides, + Location loc, RewriterBase &rewriter, + const TargetInfoBase &target); + +inline Value getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + RewriterBase &rewriter) { auto elems = smemObj.getElems(); auto types = smemObj.getTypes(); auto structTy = @@ -1528,9 +1406,8 @@ getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, return llvmStruct; } -inline SmallVector -unpackLLElements(Location loc, Value llvmStruct, - ConversionPatternRewriter &rewriter) { +inline SmallVector unpackLLElements(Location loc, Value llvmStruct, + RewriterBase &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); if (llvmStruct.getType().isIntOrIndexOrFloat() || isa(llvmStruct.getType()) || @@ -1548,8 +1425,8 @@ unpackLLElements(Location loc, Value llvmStruct, inline Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter, - ValueRange resultVals, - ConversionPatternRewriter &rewriter, Type type) { + ValueRange resultVals, RewriterBase &rewriter, + Type type) { auto structType = dyn_cast(typeConverter->convertType(type)); if (!structType) { @@ -1573,7 +1450,7 @@ inline Value packLLElements(Location loc, if (v.value().getType() != elementTypes[v.index()]) { LDBG("type " << type << " structType " << structType); LDBG("value " << v.value()); - emitError(loc) << "invalid element type in packLLEElements. Expected " + emitError(loc) << "invalid element type in packLLElements. Expected " << elementTypes[v.index()] << " but got " << v.value().getType(); } @@ -1582,6 +1459,33 @@ inline Value packLLElements(Location loc, return llvmStruct; } +inline SmallVector unpackLLVector(Location loc, Value llvmVec, + RewriterBase &rewriter) { + assert(bool(llvmVec) && "cannot unpack null value"); + if (llvmVec.getType().isIntOrIndexOrFloat() || + isa(llvmVec.getType()) || + isa(llvmVec.getType())) + return {llvmVec}; + + SmallVector results; + for (int i = 0; i < cast(llvmVec.getType()).getNumElements(); + i++) { + results.push_back(extract_element(llvmVec, i32_val(i))); + } + return results; +} + +inline Value packLLVector(Location loc, ValueRange vals, + RewriterBase &rewriter) { + assert(vals.size() > 0); + auto vecType = vec_ty(vals[0].getType(), vals.size()); + Value vec = undef(vecType); + for (int i = 0; i < vals.size(); i++) { + vec = insert_element(vec, vals[i], i32_val(i)); + } + return vec; +} + inline bool isLayoutMmaV1(Attribute layout) { bool isMmaV1 = false; if (auto mmaLayout = dyn_cast(layout)) { diff --git a/include/triton/Conversion/TritonToTritonGPU/Passes.td b/include/triton/Conversion/TritonToTritonGPU/Passes.td index 84150fe67..f20c36040 100644 --- a/include/triton/Conversion/TritonToTritonGPU/Passes.td +++ b/include/triton/Conversion/TritonToTritonGPU/Passes.td @@ -6,7 +6,13 @@ include "mlir/Pass/PassBase.td" def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> { let summary = "Convert Triton to TritonGPU"; let description = [{ - + This pass converts the Triton Dialect into the TritonGPU Dialect. + This is a partial conversion that also affects other dialects + (namely `Arith`, `Math`, `SCF` and `CF`). + For these dialects, and many Triton dialect operations the conversions + mainly consists of enhancing the tensor type and the `tt.ptr>` + type with an appropriate layout encoding (these encodings generally + include information on `numWarps`, `threadsPerWarp` and `numCTAs`). }]; let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()"; diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h index d3da1394e..78917fdfd 100644 --- a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -3,6 +3,7 @@ #include #include +#include namespace mlir { diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index f34a0fd59..7f0e5109e 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -4,8 +4,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LogicalResult.h" - -#include +#include "triton/Dialect/Triton/IR/Types.h" namespace mlir { namespace OpTrait { @@ -58,6 +57,53 @@ class VerifyTensorLayoutsTrait } }; +// Verify if the op is a dot-like operation. +// A dot-like operation should have three operands. +// The first two operands should share a common dimension, and the result +// should have the dimensions of the two operands that are not shared. +// A dot-like operation can be either 2d or 3d. +// In the 3d case, the first dimension of operands is the batch dimension. +template +class DotLike : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + if (op->getNumOperands() < 3) + return op->emitOpError("expected at least 3 operands"); + auto aTy = cast(op->getOperand(0).getType()); + auto bTy = cast(op->getOperand(1).getType()); + auto cTy = cast(op->getOperand(2).getType()); + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + auto cShape = cTy.getShape(); + // Check if all 3d or all 2d + if (aShape.size() != 2 && aShape.size() != 3) + return op->emitOpError("expected operands to be 2d or 3d"); + if (aShape.size() != bShape.size() || aShape.size() != cShape.size()) + return op->emitOpError("expected all operands to have the same rank"); + // Check if the first two operands share a common dimension + // TODO: enable back with an interface to support scaled dot. + // if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2]) + // return op->emitOpError("expected the last dimension of the first + // operand " + // "to be equal to the second-to-last dimension of + // " "the second operand"); + // Check the batch dimension + if (aShape.size() == 3 && + (aShape[0] != cShape[0] || bShape[0] != cShape[0])) + return op->emitOpError("expected the first dimension of the first " + "operand to be equal to the first dimension of " + "the result"); + // Check the output shape + if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] || + cShape[cShape.size() - 1] != bShape[aShape.size() - 1]) + return op->emitOpError( + "expected the output shape to be the concatenation of the last " + "dimension of the first operand and the last dimension of the " + "second "); + return success(); + } +}; + template class SameOperandsAndResultEncoding : public TraitBase { diff --git a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td index adfeaff6f..f3159338b 100644 --- a/include/triton/Dialect/Triton/IR/TritonAttrDefs.td +++ b/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -13,6 +13,7 @@ def TT_CacheModifierAttr : I32EnumAttr< I32EnumAttrCase<"WB", 4, "wb">, I32EnumAttrCase<"CS", 5, "cs">, I32EnumAttrCase<"WT", 6, "wt">, + I32EnumAttrCase<"CV", 7, "cv">, ]> { let cppNamespace = "::mlir::triton"; } @@ -118,4 +119,18 @@ def TT_InputPrecisionAttr : I32EnumAttr< let cppNamespace = "::mlir::triton"; } +// Type for F8F6F4 kind of floats. +def TT_F8F6F4TypeAttr : I32EnumAttr< + "F8F6F4Type", "", + [ + I32EnumAttrCase<"E4M3", 0, "e4m3">, + I32EnumAttrCase<"E5M2", 1, "e5m2">, + I32EnumAttrCase<"E2M3", 2, "e2m3">, + I32EnumAttrCase<"E3M2", 3, "e3m2">, + I32EnumAttrCase<"E2M1", 4, "e2m1"> + + ]>{ + let cppNamespace = "::mlir::triton"; +} + #endif diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td index c917538c7..a91b7951a 100644 --- a/include/triton/Dialect/Triton/IR/TritonDialect.td +++ b/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -28,7 +28,8 @@ def Triton_Dialect : Dialect { "arith::ArithDialect", "math::MathDialect", "scf::SCFDialect", - "cf::ControlFlowDialect" + "cf::ControlFlowDialect", + "ub::UBDialect" ]; let extraClassDeclaration = [{ diff --git a/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/include/triton/Dialect/Triton/IR/TritonInterfaces.td index cfc7d0032..f51cca0bc 100644 --- a/include/triton/Dialect/Triton/IR/TritonInterfaces.td +++ b/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -5,6 +5,7 @@ include "mlir/IR/OpBase.td" def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">; +def DotLike : NativeOpTrait<"DotLike">; def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">; def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">; def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">; diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index a8ab6caa2..283dd9165 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -221,6 +221,8 @@ def TT_AdvanceOp : TT_Op<"advance", let results = (outs TT_TensorPtr:$result); let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let hasFolder = 1; } // @@ -458,17 +460,12 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure, If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. The compiler is still free to change it for better performance. }]; - let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr:$efficient_layout); + let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout); let results = (outs TT_Tensor:$result); - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)"; let hasCanonicalizeMethod = 1; let hasFolder = 1; let hasVerifier = 1; - let builders = [ - OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder), - [{ - build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr()); - }]>]; } def TT_BroadcastOp : TT_Op<"broadcast", [Pure, @@ -490,6 +487,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [Pure, let hasCanonicalizeMethod = 1; let hasFolder = 1; + let hasVerifier = 1; } // cat is not `pure` because it may reorder elements @@ -599,6 +597,12 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { let assemblyFormat = "$axis attr-dict `:` type($result)"; + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; + let extraClassDeclaration = [{ int32_t getAxisAsInt() { return static_cast(getAxis()); @@ -612,6 +616,11 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { let results = (outs I32:$result); let assemblyFormat = "$axis attr-dict `:` type($result)"; + let builders = [ + OpBuilder<(ins "int":$axis), [{ + build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis))); + }]> + ]; let extraClassDeclaration = [{ int32_t getAxisAsInt() { @@ -625,6 +634,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { // def TT_DotOp : TT_Op<"dot", [Pure, DeclareOpInterfaceMethods, + DotLike, TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> { let summary = "dot"; @@ -640,8 +650,8 @@ def TT_DotOp : TT_Op<"dot", [Pure, let arguments = ( ins - TT_TensorOrMemDesc:$a, - TT_TensorOrMemDesc:$b, + TT_FpIntTensor:$a, + TT_FpIntTensor:$b, TT_FpIntTensor:$c, DefaultValuedAttr:$inputPrecision, DefaultValuedAttr:$maxNumImpreciseAcc @@ -658,11 +668,49 @@ def TT_DotOp : TT_Op<"dot", [Pure, let hasVerifier = 1; } + +// +// DotScaled Op +// +def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, + DotLike, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot_scaled"; + + let description = [{ + $d = matrix_multiply(scale($lhs, $lhs_scale), scale($rhs, $rhs_scale)) + $c. + Where scale(x, s) is a function that applies the scale per block following microscaling spec. + }]; + + let arguments = ( + ins + // inputs are integer types as they are packed types and we currently + // don't have a representation for those. + TT_IntTensor:$lhs, + TT_IntTensor:$rhs, + TT_FloatTensor:$c, + TT_IntTensor:$lhs_scale, + Optional:$rhs_scale, + TT_F8F6F4TypeAttr:$lhs_type, + TT_F8F6F4TypeAttr:$rhs_type + ); + + let results = (outs TT_FloatTensor:$d); + + // Not sure why I need to fully specify the optional group, but otherwise it complains when loading the mlir file + let assemblyFormat = [{ + $lhs `,` $lhs_scale `,` $rhs (`,`) : (`,` $rhs_scale^ `,`)? $c `lhs` `=` $lhs_type `rhs` `=` $rhs_type attr-dict + `:` type($lhs) `,` type($lhs_scale) `*` type($rhs) (`,` type($rhs_scale)^)? `->` type($d) + }]; +} + // // Reduce Op // def TT_ReduceOp: TT_Op<"reduce", [Pure, + SameOperandsShape, SameOperandsEncoding, SingleBlock, DeclareOpInterfaceMethods]> { @@ -728,7 +776,8 @@ def TT_ScanReturnOp: TT_Op<"scan.return", def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, SameOperandsAndResultEncoding, SameVariadicOperandSize, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ConditionallySpeculatable]> { let description = [{ call an external function $symbol implemented in $libpath/$libname with $args @@ -740,6 +789,12 @@ def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, let results = (outs TT_Type:$result); let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; + + let extraClassDeclaration = [{ + // Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + } // @@ -816,8 +871,14 @@ def TT_HistogramOp : TT_Op<"histogram", [Pure]> { // // Print Op // -def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>, - Arguments<(ins StrAttr:$prefix, BoolAttr:$hex, Variadic>:$args)> { +def TT_PrintOp : TT_Op<"print", [SameVariadicOperandSize, MemoryEffects<[MemWrite]>]> { + let arguments = ( + ins + StrAttr:$prefix, + BoolAttr:$hex, + Variadic>:$args, + DenseI32ArrayAttr:$isSigned + ); let summary = "Device-side print, as in CUDA for debugging"; let description = [{ `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. @@ -834,11 +895,11 @@ def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>, def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { let summary = "Device-side assert, as in CUDA for correctness checking"; let description = [{ - `tt.assert` takes a condition tensor, a message string, a file string, a function string, and a line number. + `tt.assert` takes a condition tensor and a message string. If the condition is false, the message is printed, and the program is aborted. }]; - let arguments = (ins TT_Tensor:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line); - let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)"; + let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message); + let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; } // @@ -849,7 +910,7 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", SameVariadicOperandSize, TypesMatchWith<"infer pointer type from the result type", "result", "base", - "getPointerType(getElementTypeOfTensorPointerType($_self))">]> { + "getPointerType(getElementTypeOfTensorPointerType($_self), getAddressSpace($_self))">]> { let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified"; let description = [{ @@ -1123,7 +1184,7 @@ def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ } def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ - MemoryEffects<[MemWrite]>]> { + MemoryEffects<[MemRead, MemWrite]>]> { let summary = "store value based on descriptor"; let description = [{ This operation will be lowered to Nvidia TMA store operation on targets supporting it. @@ -1146,4 +1207,54 @@ def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ }]; } +def TT_ExperimentalTensormapCreateOp: TT_Op< + "experimental_tensormap_create", + [ + MemoryEffects<[MemRead, MemWrite]>, + AttrSizedOperandSegments, + ] +> { + let summary = "Create a new TMA descriptor on device"; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_PtrType:$global_address, + Variadic:$box_dim, + Variadic:$global_dim, + Variadic:$global_stride, + Variadic:$element_stride, + ConfinedAttr]>:$elem_type, + ConfinedAttr]>:$interleave_layout, + ConfinedAttr]>:$swizzle_mode, + ConfinedAttr]>:$fill_mode + ); + let extraClassDeclaration = [{ + int32_t getRank() { + return getBoxDim().size(); + } + }]; + let assemblyFormat = [{ + $desc_ptr `,` $global_address `,` + `[` $box_dim `]` `,` + `[` $global_dim `]` `,` + `[` $global_stride `]` `,` + `[` $element_stride `]` + attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op< + "experimental_tensormap_fenceproxy_acquire", + [MemoryEffects<[MemWrite]>] +> { + let summary = "Acquire fence on a tensormap object"; + let arguments = (ins TT_PtrType:$desc_ptr); + let assemblyFormat = [{ + $desc_ptr attr-dict `:` qualified(type($desc_ptr)) + }]; +} + + #endif // Triton_OPS diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index fd5af9cc8..4c709cd44 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -15,7 +15,7 @@ class TritonTypeDef traits = []> } // Floating-point Type -def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; def TT_FloatTensor : RankedTensorOf<[TT_Float]>; def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; @@ -25,7 +25,8 @@ def TT_BoolTensor : RankedTensorOf<[I1]>; def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; // Integer Type -def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def I4 : I<4>; +def TT_Int : AnyTypeOf<[I1, I4, I8, I16, I32, I64], "integer">; def TT_IntTensor : RankedTensorOf<[TT_Int]>; def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>; @@ -106,12 +107,13 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> ArrayRefParameter<"int64_t">:$shape, "Type":$elementType, "Attribute":$encoding, + "Attribute":$memorySpace, "bool":$mutable_memory ); let extraClassDeclaration = [{ MemDescType cloneWith(std::optional> shape, Type elementType) const { - return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding()); + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory()); } bool hasRank() const { return true; } @@ -120,17 +122,19 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$shape, "Type":$elementType, - "Attribute":$encoding + "Attribute":$encoding, + "Attribute":$memorySpace ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, /*mutableMemory=*/false); + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false); }]>, TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$shape, "Type":$elementType, "Attribute":$encoding, + "Attribute":$memorySpace, "bool":$mutableMemory ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, mutableMemory); + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory); }]> ]; let hasCustomAssemblyFormat = 1; diff --git a/include/triton/Dialect/Triton/IR/Types.h b/include/triton/Dialect/Triton/IR/Types.h index bf1967f1b..74fa4ba96 100644 --- a/include/triton/Dialect/Triton/IR/Types.h +++ b/include/triton/Dialect/Triton/IR/Types.h @@ -22,7 +22,9 @@ unsigned getPointeeBitWidth(Type type); Type getPointeeType(Type type); -Type getPointerType(Type type); +Type getPointerType(Type type, int addressSpace = 1); + +int getAddressSpace(Type type); Type getElementTypeOfTensorPointerType(Type type); @@ -32,6 +34,8 @@ Type getI32SameShape(Type type); Type getPointerTypeSameShape(Type type); +Type getPointerTypeToElement(Type type); + } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/Triton/IR/Utility.h b/include/triton/Dialect/Triton/IR/Utility.h index 0ef597147..1ff63697e 100644 --- a/include/triton/Dialect/Triton/IR/Utility.h +++ b/include/triton/Dialect/Triton/IR/Utility.h @@ -31,7 +31,11 @@ template Int ceil(Int m, Int n) { return (m + n - 1) / n; } /// Get the highest power of 2 divisor of an integer. template T highestPowOf2Divisor(T n) { - if (n == 0) { + // When n is 0 or min, return the highest power of 2. The min case is handled + // separately to avoid underflow when T is a signed integer. Technically + // in that case the correct divisor is -n, but this value is outside the + // range of possible values, so we take the next best alternative. + if (n == 0 || n == std::numeric_limits::min()) { return (static_cast(1) << (sizeof(T) * 8 - 2)); } return (n & (~(n - 1))); diff --git a/include/triton/Dialect/Triton/Transforms/Passes.h b/include/triton/Dialect/Triton/Transforms/Passes.h index fde54fe17..29e88fb6d 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.h +++ b/include/triton/Dialect/Triton/Transforms/Passes.h @@ -10,6 +10,7 @@ std::unique_ptr createCombineOpsPass(); std::unique_ptr createReorderBroadcastPass(); std::unique_ptr createRewriteTensorPointerPass(); +std::unique_ptr createLoopUnrollPass(); } // namespace triton diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td index 4ebff63fa..0433204b5 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.td +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -6,12 +6,17 @@ include "mlir/Pass/PassBase.td" def TritonCombineOps : Pass { let summary = "combine ops"; let description = [{ - dot(a, b, 0) + c => dot(a, b, c) + This pass aims to optimize the five following patterns: + - `dot(a, b, 0) + c => dot(a, b, c)` - addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1)) + - `addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))` - select(cond, load(ptrs, broadcast(cond), ???), other) => - load(ptrs, broadcast(cond), other) + - `select(cond, load(ptrs, broadcast(cond), ???), other) => + load(ptrs, broadcast(cond), other)` + + - `broadcast(constant) => reshaped_constant` + - `torch.sum(x[:,:,None].expand(-1,-1,n) * y[None,:,:].expand(m,-1,-1),1) + => dot(x,y,splat(0))` }]; let constructor = "mlir::triton::createCombineOpsPass()"; @@ -22,7 +27,11 @@ def TritonCombineOps : Pass def TritonReorderBroadcast : Pass { let summary = "Moves broadcast and splat after elementwise operations"; let description = [{ - elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + The purpose of this pass is to transform: + - `elementwise(broadcast(a)) => broadcast(elementwise(a))` + - `elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...))` + In the event of a match, the broadcast (or splat) operation is delayed + and performed after the ElementWise operation. }]; let constructor = "mlir::triton::createReorderBroadcastPass()"; let dependentDialects = ["mlir::triton::TritonDialect"]; @@ -41,4 +50,14 @@ def TritonRewriteTensorPointer : Pass { + let summary = "Loop unroller"; + let description = [{ + The pass unrolls a scf loop with tt.loop_unroll_factor attribute. The attribute specialises how many iterations + the loop should be unrolled. + }]; + let constructor = "mlir::triton::createLoopUnrollPass()"; + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + #endif diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 5ae7848a0..74ea99b58 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout, SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); +// Returns the dimensions of the tensor from minor (fast-varying) to +// major (slow-varying). For blocked, mma, and dotOperand layouts, +// though the elements are in registers, the order refers to memory +// layout of the original tensor in global memory. +// For shared Layout, the order refers to which dimension of the original tensor +// is contiguous in shared memory. +SmallVector getOrder(Attribute layout); + +// Returns the dimensions along which warpId's are distributed. +// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4] +// tells there are 2 warps along dim0 and 4 warps along dim1. +// warpOrder tells the specific order when distributing warp IDs. +// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows +// [warp0 warp2 warp4 warp6] +// [warp1 warp3 warp5 warp7] +// Note that in most cases, getWarpOrder and getOrder return the same results. +// But this is not guaranteed. SmallVector getWarpOrder(Attribute layout); -SmallVector getOrder(Attribute layout); +// Returns the dimensions along which threadId's are distributed. +// Similar to warpOrder, threadOrder is necessary to tell the specific thread +// distribution in the warp. +// Note that, in most cases, getThreadOrder and getOrder return the same +// results. But this is not guaranteed. One exception is mfma.transposed layout, +// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1]. +SmallVector getThreadOrder(Attribute layout); CTALayoutAttr getCTALayout(Attribute layout); @@ -107,8 +130,6 @@ unsigned getNumWarpsPerCTA(Attribute layout); unsigned getNumCTAs(Attribute layout); -bool isaDistributedLayout(Attribute layout); - bool isExpensiveCat(CatOp cat, Attribute targetEncoding); // Return true if a view between the two types cannot be implemented as a no-op. @@ -120,6 +141,17 @@ triton::gpu::BlockedEncodingAttr getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, int numWarps, int threadsPerWarp, int numCTAs); +// Dump information about which threads/registers contain each of the tensor +// elements. +void dumpLayout(RankedTensorType tensorType); + +// Dump the layout from HW point of view and prints what tensor element is held +// by each thread and register. +void dumpHWLayout(RankedTensorType tensorType); + +// Return a string representation of the layout of the tensor. +std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView); + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index d4f274742..5140a03e7 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -1,6 +1,9 @@ // Conversions from TritonGPU layouts (e.g. BlockedEncodingAttr) to // LinearLayout. +#ifndef TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H +#define TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H + #include #include "triton/Tools/LinearLayout.h" @@ -16,7 +19,9 @@ namespace mlir::triton::gpu { // // - An n-dimensional SharedEncodingAttr has the following input dimensions. // -// "offset": the n'th element in the allocation, within a particular block +// "offset": the n'th element in the allocation, within a particular thread +// block (i.e. within a CTA). The offset is measured in elements, not +// bytes. // "block": blocks in a cluster // // All layouts have the following output dimensions. @@ -28,10 +33,219 @@ namespace mlir::triton::gpu { // You can flatten the input or output dimensions into a single dimension using // LinearLayout::flattenIns/Outs(). // +// elemBitWidth is the bit width of one element in the layout. This is required +// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e. +// shared layouts with hasLeadingOffset == true) but is otherwise unused. +// // Returns std::nullopt if the given layout can't be converted to an LL. // TODO(jlebar): Remove the std::optional once all layouts are supported. // -std::optional toLinearLayout(ArrayRef shape, - Attribute layout); +std::optional +toLinearLayout(ArrayRef shape, Attribute layout, + std::optional elemBitWidth = std::nullopt); + +// Given a linear layout where the input dimensions contain a "block" dimension, +// this method sets the "block" dimension to 0 and removes the corresponding +// output dimensions. +// +// Note that this behavior differs from calling +// `LinearLayout::sublayout(inDimNames, outDimNames)` when "block" is not in +// `inDimNames`. The latter does not modify the output sizes. +LinearLayout getLayoutWithinBlock(const LinearLayout &layout); + +// In this function, we construct a linear layout representing the +// -> mapping +// for entire `src` and `dst` tensors. We determine the shape of the +// intermediate shared memory buffer needed for a register-to-register +// conversion using the maximum size accessed in each dimension from `src`'s +// layout and `dst`'s layout. See the getRepShapeForCvt function in +// Allocation.cpp for details. Note that the buffer might be smaller than the +// tensor being converted, so we need multiple "iterations" to move a subregion +// of the `src` tensor to the corresponding subregion of the `dst` tensor. The +// pesudo code of layout conversion is as follows: +// +// for iter in 0..numIterations: +// sync threads +// for vecIdx in [0..numRegisters/storeVec]: +// registers <- get registers used in iter +// offsets <- get offsets using the intermediate linear layout +// store registers[vecIdx * storeVec, (vecIdx + 1) * storeVec)] to shared +// memory +// sync threads +// for vecIdx in [0..numRegisters/loadVec]: +// registers <- get registers used in iter +// offsets <- get offsets using the intermediate linear layout +// load registers[vecIdx * loadVec, (vecIdx + 1) * loadVec)] from shared +// memory +LinearLayout chooseShemLayoutForRegToRegConversion( + MLIRContext *ctx, ArrayRef tensorShape, + ArrayRef repShape, ArrayRef order); +// This function constructs a linear layout that maps +// to . +// The primary goal is to efficiently store 2D tiles of a tensor into shared +// memory using the `stmatrix` instruction, with each thread responsible for +// storing `N` elements. If `stmatrix` cannot be used for the given tensor +// encoding, this function returns `std::nullopt`. +// +// Unlike standard vectorized stores, such as `st.shared.v4 [%offset], +// %vec_reg`, where `%vec_reg` contains four consecutive data elements, the +// `stmatrix` instruction allows `N` registers to point to non-contiguous +// locations within a tensor tile. +// +// For instance, the `stmatrix [%offset], %mat_reg` instruction on NVIDIA GPUs +// enables `%mat_reg` to store `N` elements that do not need to be consecutive. +// However, it is crucial that the address (`%offset`) of each row in a tensor +// tile should be aligned to `N` * `elemBitWidth`. The `%offset` of each thread +// is calculated based on the provided tensor encoding. +// +// Currently, we support only the NVIDIA MMAv3 encoding and the `stmatrix.x4` +// instruction. Each `stmatrix.x4` instruction stores eight 16-bit elements per +// thread, resulting in a total of 8 * 32 = 256 elements per warp, or 16 * 16 +// elements per warp when distributed across four 8x8 tiles. Each thread's +// `%offset` points to an address aligned with 8 * 16 bits, denoting a row in +// the 8x8 tile. The values in `%mat_reg` are non-consecutive elements, +// composed of 4 pairs of consecutive elements. These matrix addresses are +// distributed as follows: +// +// col[0-7] col[8-15] +// row[0-7] lane[0-7] lane[16-23] +// row[8-15] lane[8-15] lane[24-31] +// +// The matrix elements of thread 0 are distributed in the following pattern: +// +// col0 col8 +// row0 reg[0-1] reg[4-5] +// row8 reg[2-3] reg[6-7] +// +// When `swizzleByteSize` is non-zero, the layout is constructed +// differently due to leading dimension offset and swizzling. +// There are two key concepts to understand: +// +// 1. Chunks: The leading dimension (i.e., the column dimension) is divided +// into chunks, where each chunk's size is determined by `swizzleByteSize`. +// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its +// rows to optimize memory access. +// +// - Concept 1: Chunks +// +// In the swizzled layout, the leading dimension is strided by +// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk +// spans a certain number of columns. +// +// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16 +// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16 +// elements * 2 bytes per element = 32 bytes per row). +// +// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be +// calculated as: +// +// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes / +// 32 bytes = 4 tiles +// +// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns +// (since each tile is 16 columns): +// +// col0-15 col16-31 col32-47 col48-63 +// row0-15 tile0 tile1 tile2 tile3 +// +// For a tensor of size 128x128 elements (#rows x #columns), and each element +// being 16 bits, the tensor can be divided into multiple chunks both +// horizontally and vertically. Chunks are stored in memory in a "column-major" +// order based on chunks, meaning chunk1's address follows chunk0's. +// +// Assuming we have 8 warps, and we assign each warp to process a chunk of 16 +// rows (rows per tile) and 128 columns (the width of two chunks). This results +// in each warp handling one horizontal slice of the tensor. +// +// The overall layout can be visualized as: +// +// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->| +// columns 0-63 columns 64-127 +// warp0 | rows 0-15 chunk0 chunk8 +// warp1 | rows 16-31 chunk1 chunk9 +// warp2 | rows 32-47 chunk2 chunk10 +// warp3 | rows 48-63 chunk3 chunk11 +// warp4 | rows 64-79 chunk4 chunk12 +// warp5 | rows 80-95 chunk5 chunk13 +// warp6 | rows 96-111 chunk6 chunk14 +// warp7 | rows 112-127 chunk7 chunk15 +// +// - Concept 2: Swizzling within tiles +// +// Within each 16x16 tile, rows are swizzled to optimize memory access patterns. +// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the +// level of each 16x16 tile rather than the entire tensor. +// +// Key parameters for swizzling: +// +// - `perPhase`: The number of rows over which to apply a XOR operation at +// each phase. +// - `maxPhase`: The total number of phases. +// - `vectorWidth`: The number of elements per vector, which is 8 in this case +// because `stmatrix` stores 8 contiguous elements per thread. +// +// The offset of each element within a tile is calculated using the formula: +// +// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) % +// maxPhase)) * elementSize +// +// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit +// elements). +// +// For example, consider the element at index `(row=1, col=0)` in chunk0: +// +// Without swizzling: +// +// offset = row * swizzleByteSize + col * elementSize +// = 1 * 128 bytes + 0 * 2 bytes +// = 128 bytes +// +// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`): +// +// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) % +// maxPhase)) * elementSize +// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes +// = 128 bytes + (8 * (1 % 8)) * 2 bytes +// = 128 bytes + 8 * 2 bytes +// = 128 bytes + 16 bytes +// = 144 bytes +// +// This swizzling ensures that elements are stored in a way that optimizes for +// memory bandwidth and reduces bank conflicts. +// +// - Verification through Linear Layout +// +// We can verify the offsets with the following outputs of the corresponding +// linear layout, where each element is 16 bits (2 bytes): +// +// - register=1 -> offset=1 +// register=2 -> offset=2 +// register=4 -> offset=4 +// register=8 -> offset=16 +// register=16 -> offset=32 +// register=32 -> offset=8192 +// - lane=1 -> offset=72 +// lane=2 -> offset=144 +// lane=4 -> offset=288 +// lane=8 -> offset=512 +// lane=16 -> offset=8 +// - warp=1 -> offset=1024 +// warp=2 -> offset=2048 +// warp=4 -> offset=4096 +// +// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in +// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result +// matches our earlier calculation. +// +// TODO(Keren): We should replace tensorTy with a LinearLayout and the element +// bit width of the tensor in the future to support more flexible tensor +// encodings +std::optional +chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, int swizzleByteSize); } // namespace mlir::triton::gpu + +#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index ae23f9d13..c8512fce5 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -525,6 +525,10 @@ For the Threads Per Warp and Values Per Thread level, the linear id distribution InterfaceMethod<"Gets the number of contiguous elements per thread.", "SmallVector", "getContigPerThread">, + InterfaceMethod<"Convert to LinearLayout.", + "std::optional", + "toLinearLayout", + (ins "ArrayRef":$shape)> ]; } @@ -576,6 +580,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, SmallVector getSizePerThread() const; SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; + + std::optional toLinearLayout(ArrayRef shape) const; }]; } @@ -704,7 +710,7 @@ for // starting from the contiguous dimension for (unsigned d = 0; d < rank - 1; ++d) { unsigned i = order[d]; - unsigned threadsPerCTA = std::clamp(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]); + unsigned threadsPerCTA = std::clamp(remainingThreads, 1, std::max(1, shapePerCTA[i] / sizePerThread[i])); threadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); warpsPerCTA[i] = std::clamp(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps); remainingWarps /= warpsPerCTA[i]; @@ -737,7 +743,7 @@ for // starting from the most strided dimension for (int d = rank - 1; d >= 0; --d) { unsigned i = order[d]; - CTAsPerCGA[i] = std::clamp(remainingCTAs, 1, shape[i] / sizePerThread[i]); + CTAsPerCGA[i] = std::clamp(remainingCTAs, 1, std::max(1, shape[i] / sizePerThread[i])); CTASplitNum[i] = CTAsPerCGA[i]; remainingCTAs /= CTAsPerCGA[i]; } @@ -775,22 +781,24 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { InterfaceMethod<"Return shape per CTA.", "SmallVector", - "getShapePerCTATileForDotOperands", + "getShapePerCTATileForOperand", (ins "ArrayRef":$tensorShape, - "unsigned":$opIdx)>, + "int":$kWidth, + "int":$opIdx)>, InterfaceMethod<"Return total element size per thread for dot operands.", "unsigned", - "getTotalElemsPerThreadForOperands", + "getTotalElemsPerThreadForOperand", (ins "ArrayRef":$tensorShape, "Type":$eltTy, - "unsigned":$kWidth, - "unsigned":$opIdx)>, + "int":$kWidth, + "int":$opIdx)>, InterfaceMethod<"Return size per thread for dot operands.", "SmallVector", - "getSizePerThreadForOperands", - (ins "unsigned":$opIdx)>, + "getSizePerThreadForOperand", + (ins "int":$opIdx, + "int":$kWidth)>, ]; } @@ -806,7 +814,7 @@ It is characterized by the following parameters: - 1.0: gfx908, i.e. MI100 - 2.0: gfx90a: i.e. MI200, MI210, MI250 - 3.0: gfx940, gfx941, gfx942: MI300 -- `warpsPerCTA` indicates the wave layout in the workgroup. +- `warpsPerCTA` indicates the warp layout in the block. - `MDim` and `NDim` indicate the dimension of the output of the mfma instruction. - `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel). @@ -815,7 +823,7 @@ Example 1: Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32. The data will be distributed between threads as follows: - wave 0 wave 1 + warp 0 warp 1 -----------------/\-------------- -----------------/\-------------- [ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] [ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] @@ -854,7 +862,7 @@ Example 2: Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and MDim=NDim=16. The data will be distributed between threads as follows: - wave 0 wave 1 + warp 0 warp 1 -----------------/\------------- ------------------/\--------------- [ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] [ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] @@ -879,13 +887,13 @@ The data will be distributed between threads as follows(note that each element i Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4. The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): -M N -> wave 0 wave 2 +M N -> warp 0 warp 2 | --------------------------/\-------------------------- ------------------------------/\------------------------------ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] - wave 1 wave 3 + warp 1 warp 3 --------------------------/\-------------------------- ------------------------------/\------------------------------ [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] @@ -908,11 +916,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, bool supportReduction() const { return true; } - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; - SmallVector getMFMAInstrShapeForOperands(int kWidth, int opIdx) const; - SmallVector getMFMARepForOperands(ArrayRef operandShape, int kWidth, int opIdx) const; + SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; SmallVector getContigPerThread() { auto rank = getWarpsPerCTA().size(); @@ -934,36 +942,77 @@ def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encod let mnemonic = "amd_wmma"; let description = [{ -An important limitation of WMMA for layout is a shape for tiles proccessed -by a single wave. It is [16, 16]. -This encoding assumes specific access to matrix elements by threads. +An encoding for tensors that have been produced by WMMA matrix core instructions, +available on AMD Radeon GPUs of RDNA architectures. +- A `version` parameter specifies instruction version to lower in. The data + distribution within one warp is also depends on it. Following architectures are + supported: + - 1: gfx11 + - 2: gfx12 +- A `warpsPerCTA` parameter characterizes data distribution between warps. + An important limitation of WMMA for layout is a shape for tiles proccessed + by a single warp. It is [16, 16]. + This encoding assumes specific access to matrix elements by threads. Example: -Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3]. - - wave 0 [16, 16] wave 1 [16, 16] wave 2 [16, 16] ------------/\---------- -----------/\---------- -----------/\---------- -[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] -[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] -[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] -[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] -... ... ... -[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] -[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] - - wave 3 [16, 16] wave 4 [16, 16] wave 5 [16, 16] ------------/\---------- -----------/\---------- -----------/\---------- -[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] -[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] -[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] -[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] -... ... ... -[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] -[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +Suppose we have a tensor with shape [32, 64], `warpsPerCTA` set to [2, 2]. +Matrix elements represent which lane owns the element. Currently only wave32 mode +is supported. + +// ----------------------------------- version = 1 ----------------------------------- // + +Row | warp 0 warp 2 + |/-------------------^-------------------\ /-------------------^-------------------\ +0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +1 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +2 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +3 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + | ... ... ... ... +14 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + + | warp 1 warp 3 +16 |/-------------------^-------------------\ /-------------------^-------------------\ +17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +18 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +19 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +20 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + | ... ... ... ... +30 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + +// ----------------------------------- version = 2 ----------------------------------- // + +Row | warp 0 warp 2 + |/--------^---------\ /---------^--------\ +0 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +1 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +.. | ... ... +6 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +7 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +8 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +9 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +.. | ... ... +14 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +15 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] + | + | warp 1 warp 3 + |/--------^---------\ /---------^--------\ +16 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +17 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +.. | ... ... +22 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +23 |[0 1 2 ... 14 15] [0 1 2 ... 14 15] +24 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +25 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +.. | ... ... +30 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] +31 |[16 17 18 ... 30 31] [16 17 18 ... 30 31] }]; let parameters = ( ins + "unsigned": $version, ArrayRefParameter<"unsigned">:$warpsPerCTA__, "CTALayoutAttr":$CTALayout ); @@ -974,17 +1023,21 @@ Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3]. bool supportReduction() const { return true; } - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; - SmallVector getWMMAElemsPerInstrForOperands() const; - SmallVector getWMMARepForOperands(ArrayRef operandShape, - Type elemType, int kWidth, int opIdx) const; - static SmallVector getMNKDimPerWMMAInstr(); + SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getElemsPerInstrForOperands() const; + SmallVector getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, int opIdx) const; + static SmallVector getMNKDimPerInstr(); SmallVector getContigPerThread() { auto rank = getWarpsPerCTA().size(); + assert(rank == 2 || rank == 3); SmallVector contigPerThread(rank, 1); + if (getVersion() == 2) { + contigPerThread[rank - 2] = 8; + } return contigPerThread; }; }]; @@ -1171,8 +1224,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: SmallVector getMMAv1Rep(int opIdx) const; SmallVector getMMAv1ShapePerWarp(int opIdx) const; int getMMAv1Vec(int opIdx) const; - SmallVector getMMAv2Rep(ArrayRef shape, - int bitwidth, int opIdx) const; + SmallVector getMMAv2RepForOperand(ArrayRef shape, + int bitwidth, int kWidth, int opIdx) const; bool supportReduction() const { if (isAmpere() || isHopper()) { @@ -1180,9 +1233,9 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: } return false; }; - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getContigPerThread() { assert(isVolta() || isAmpere() || isHopper()); @@ -1293,9 +1346,24 @@ elements along the K dim, or they use all elements of the tensor along the K dim let genVerifyDecl = 1; let extraClassDeclaration = extraDistributedDeclaration # [{ SmallVector getContigPerThread() { - return getSizePerThread(); + auto rank = getWarpsPerCTA().size(); + assert(rank == 2 || rank == 3); + SmallVector contigPerThread(rank, 1); + auto kWidth = getKWidth(); + assert(kWidth != 0 && "Do not support kWidth=0"); + if (getOpIdx() == 0) + contigPerThread[rank - 1] = kWidth; + else + contigPerThread[rank - 2] = kWidth; + return contigPerThread; }; }]; } +def TTG_SharedMemorySpace : AttrDef { + let mnemonic = "shared_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to shared memory. + }]; +} #endif diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h b/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h index 0ee2cfeca..9cf2876d2 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h @@ -1,6 +1,6 @@ #ifndef TRITON_GPU_DIALECT_INTERFACES_H #define TRITON_GPU_DIALECT_INTERFACES_H - +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc" #endif // TRITON_GPU_DIALECT_INTERFACES_H diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 2530009cb..a290cb203 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -84,6 +84,7 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [ ]> { let summary = "copy data from global memory to local memory asynchronously"; + let hasVerifier = 1; let description = [{ This operation copies data from global memory to local memory asynchronously. This is analogue to tt.load except the data are copied to local memory pointed @@ -142,11 +143,33 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods:$src); + let arguments = ( + ins + Optional:$src, + OptionalAttr:$alignment + ); + + let builders = [ + OpBuilder<(ins "Type":$result), + [{ build($_builder, $_state, result, Value(), IntegerAttr()); }]>, + OpBuilder<(ins "Type":$result, "Value":$src), + [{ build($_builder, $_state, result, src, IntegerAttr()); }]>, + OpBuilder<(ins "Type":$result, "Value":$src, "int32_t":$alignment), + [{ build($_builder, $_state, result, src, $_builder.getI32IntegerAttr(alignment)); }]> + ]; + let extraClassDeclaration = [{ + bool isSharedMemoryAlloc() { + return getType().getMemorySpace() && + isa(getType().getMemorySpace()); + } + int32_t getAlignmentOrDefault(); + }]; let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}]; let results = (outs TT_MemDescType:$result); + let hasFolder = 1; + let hasVerifier = 1; } // Deallocate shared memory @@ -163,7 +186,7 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree" is printed as "". let assemblyFormat = [{ $src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst)) }]; } +def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods]> { + let summary = "Convert an mxfp tensor to bf16"; + + let hasVerifier = 1; + + let description = [{ + Compute the bf16 encoded in the given mxfp number as per + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + }]; + let arguments = (ins + TT_Tensor:$src, + TT_Tensor:$scale, + TT_F8F6F4TypeAttr:$fp_type); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result) + }]; +} + #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index fdceb2cfe..be27f141d 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -39,8 +39,19 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { let summary = "prefetch"; let description = [{ + This pass attempts to prefetch from shared memory the operands (A and B) + of a `tt.dot`, when this operation is located in a loop. Decompose `DotOp` instructions in loops into several finer-grained `DotOp` - that may have their operands constructed at the end of the previous iteration + that may have their operands constructed at the end of the previous + iteration. + Transformations are performed in five different places: + 1. The pass emits a prologue to the loop where the data for the first + loop iteration are prefetched. + 2. The loop arguments are extended with the new prefetched values. + 3. The dotOp parameters is updated with the new args. + 4. The prefetch operations for the next iteration are added to the loop. + 5. The yieldOp is updated by adding the prefetched values for the next + iteration. }]; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", @@ -84,7 +95,11 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { let summary = "coalesce"; let description = [{ - TODO + The pass analyses loads/stores with type `tensor>` or + `tt.ptr>` and replaces the layouts of these operations with + coalesced layouts, i.e. cache friendly access patterns. + Layout conversions are inserted before and after the load/store op + to maintain consistency with the rest of the program. }]; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; @@ -95,6 +110,11 @@ def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions let summary = "remove superfluous layout conversions"; let description = [{ + The purpose of this pass is to rewrite the `ConvertLayoutOps` to reduce + the number of operations and to prefer favorable layouts like + `BlockedEncodingAttr` layout for "expensive" loads and stores + (good for coalescing) and `NvidiaMmaEncodingAttr` otherwise + (good for tensor ops). }]; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", @@ -106,7 +126,11 @@ def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality", let summary = "Reduce the cost of synchronization between threads in an SM"; let description = [{ - Today, this optimizes reduction yielded by loop to be thread-local until after the loop completes. + The aim of this pass is to reduce cross-thread communication for reduction + operations, by adjusting the reduction size (or layout) to avoid splitting + the reduction operation between multiple threads. Currently, this pass only + optimizes reduction yielded by loop to be thread-local until + after the loop completes. }]; let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", @@ -145,4 +169,100 @@ def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and "mlir::triton::TritonDialect"]; } +def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init", "mlir::ModuleOp"> { + let summary = "Replace accumulater zero-initialization with the flag indicating first use of the accumulator"; + + let description = "For the dot operations that support accumulator-use flag this pass replaces the zero-initialization " + "of the accumulator with the flag indicating the first use of the accumulator."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUWSTaskPartition : Pass<"tritongpu-warp-spec-task-partition", "mlir::ModuleOp"> { + let summary = "Warp specialization task partition"; + + let description = "This pass computes a warp schedule partition by annoating anchor operations with async task ids"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUTaskIdPropagate : Pass<"triton-gpu-taskid-propagate", "mlir::ModuleOp"> { + let summary = "Propagate async_task_id annotations based on dependencies"; + + let description = [{ + This pass propagates the `async_task_id` annotation to the dependencies + of any op that has it set. This has the functional effect of partitioning + the graph into multiple async tasks, based on the initial annotation. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUWSCodePartition: Pass<"tritongpu-warp-spec-code-partition", "mlir::ModuleOp"> { + let summary = "TritonGPU warp specialization code partition"; + + let description = "This pass generates warp specialized code baed on task id attributes."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numBuffers", "num-buffers", + "int32_t", /*default*/"0", + "number of buffering for producer-consumer">, + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization">, + Option<"regDecProducer", "producer-reg-dec", + "int32_t", /*default*/"40", + "register decrement for producer warp group">, + Option<"regIncConsumer", "consumer-reg-inc", + "int32_t", /*default*/"232", + "register indrement for consumer warp group"> + ]; +} + +def TritonGPUWSDataPartition : Pass<"tritongpu-warp-spec-data-partition", "mlir::ModuleOp"> { + let summary = "Warp specialization data partition"; + + let description = "This pass partitions operations into multiple suboperations which operate on smaller data shapes"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUWSLowering : Pass<"tritongpu-warp-spec-lowering", "mlir::ModuleOp"> { + let summary = "Warp specialization lowering"; + + let description = "This pass lowers warp specializtion related operations."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} #endif diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h b/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h similarity index 100% rename from lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h rename to include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h new file mode 100644 index 000000000..88f062a01 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -0,0 +1,35 @@ +#ifndef TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ +#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include + +namespace mlir { +namespace triton { + +static const char *kNumStagesAttrName = "tt.num_stages"; + +/// Function to mask operations during scheduling. +Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred); + +/// Collect ssa dependencies of `op` in `deps`. if `includeArg` is true, +/// continue looking through loop block arguments. +void addDep(Operation *op, DenseSet &deps, bool includeArg = true, + DenseSet *filter = nullptr); + +/// Add operations from `forOp` into a pipeline schedule with the the given +/// `stage` when filter is true. This will add operation in the original loop +/// order. +void addOps(scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter); + +/// Replace all uses of `oldUse` with `val` and propagate the type if needed. +/// This is useful when we need to change a memory descriptor from immutable to +/// mutable. +void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, + Value val); +} // namespace triton +} // namespace mlir + +#endif // TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h new file mode 100644 index 000000000..1dd1fc686 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -0,0 +1,107 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include + +namespace mlir { +namespace triton { + +/// This fill out the pipelining options including schedule and annotations +/// for wait ops. This also does pre-processing by converting some of the +/// loads into async loads so that the IR is ready to be pipelined. +bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +/// Fills out pipelining options for an outer loop pipelining case. This +/// schedules async copies to overlap with the epilogue of a loop. +bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +/// Pipeline the TMA stores in the loop. +bool pipelineTMAStores(scf::ForOp forOp); + +/// This does post-processing on the pipelined loop to try to pipeline wgmma +/// ops. +// TODO: this should be included as part of the pipeline but currently the wgmma +// wait modeling is problematic. +void asyncLaunchDots(scf::ForOp forOp); + +/// Post process the pipelined loop by updating the wait ops with the right +/// number of groups in flight. +void updateWaits(ModuleOp module); + +class CoarseSchedule { +public: + class ClusterList { + std::list orderClusters; + + public: + using iterator = decltype(orderClusters)::iterator; + ClusterList() = default; + iterator begin() { return orderClusters.begin(); } + iterator end() { return orderClusters.end(); } + size_t size() { return orderClusters.size(); } + iterator newAtBack() { + orderClusters.push_back(orderClusters.size()); + return std::prev(orderClusters.end()); + } + iterator newAtFront() { + orderClusters.push_front(-1); + for (auto &clusterId : orderClusters) { + clusterId++; + } + return orderClusters.begin(); + } + iterator newBefore(iterator cluster) { + auto ret = orderClusters.insert(cluster, *cluster); + for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) { + clusterId++; + } + return ret; + } + }; + + CoarseSchedule(int numStages) : numStages(numStages) {} + int numStages; + ClusterList clusters; + using Cluster = decltype(clusters)::iterator; + + DenseMap> opToStageAndCluster; + + void insert(Operation *op, int stage, Cluster cluster) { + opToStageAndCluster[op] = {stage, cluster}; + } + + bool insertIfAbsent(Operation *op, int stage, Cluster cluster) { + if (opToStageAndCluster.count(op)) + return false; + insert(op, stage, cluster); + return true; + } + + void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, + bool includeArg); + + void erase(Operation *op) { opToStageAndCluster.erase(op); } + + int count(Operation *op) { return opToStageAndCluster.count(op); } + + std::pair operator[](Operation *op) { + return opToStageAndCluster[op]; + } + + SmallVector> + getOpsInOrder(scf::ForOp forOp); + std::vector> + createFinalSchedule(scf::ForOp forOp); + void dump(); +}; + +} // namespace triton +} // namespace mlir +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 114c18142..41094258a 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -21,12 +21,16 @@ class SharedEncodingAttr; } } // namespace triton +// Return a tuple of two or three entries representing the shape of the +// instruction used to perform a matrix multiplication operation. +// Version = 1: +// Version = 2: <1, m, n> +// Version = 3: SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, - TensorOrMemDesc type, - int numWarps); + Type type, int numWarps); -/// Returns true if the Load uses block pointer. +// Return true if the Load uses block pointer. bool isLoadFromTensorPtr(triton::LoadOp op); // Return an array of indices enumerating the elements of 'arr' in descending @@ -129,11 +133,27 @@ scf::ForOp replaceForOpWithNewSignature( scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newIterOperands); +// Replace WhileOp with a new WhileOp with extra operands. The YieldOp is not +// updated and needs to be updated separately for the loop to be correct. +scf::WhileOp replaceWhileOpWithNewSignature( + RewriterBase &rewriter, scf::WhileOp loop, ValueRange newIterOperands, + TypeRange newResultTypes, + SmallVectorImpl> &replacements); +scf::WhileOp replaceWhileOpWithNewSignature(RewriterBase &rewriter, + scf::WhileOp loop, + ValueRange newIterOperands, + TypeRange newResultTypes); + // Replace IfOp with a new IfOp with extra results operands. The YieldOp is not // updated and needs to be updated separately for the bodies to be correct. scf::IfOp replaceIfOpWithNewSignature( RewriterBase &rewriter, scf::IfOp loop, TypeRange newResultTypes, SmallVectorImpl> &replacements); +scf::IfOp replaceIfOpWithNewSignature(RewriterBase &rewriter, scf::IfOp ifOp, + TypeRange newResultTypes); + +// Append the given |newOperands| to the |forOp|'s yield op. +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands); Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping); @@ -172,6 +192,102 @@ bool isPureUnaryInlineAsm(Operation *op); // read the compute capability from the module attributes int getNVIDIAComputeCapability(Operation *module); +// 0 is reserved for default sync. +// TODO: comprehensive mechanism to globally manage namedbarrier. +static int const nameBarrierIdBegin = 1; +static int nameBarrierIdEnd = 16; + +/// Helper functions for async task +typedef int AsyncTaskId; +SmallVector getAsyncTaskIds(Operation *op); +bool hasAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId); +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds); +SmallVector getNestedAsyncTaskIds(Operation *op); +void addAsyncTaskIds(Operation *op, ArrayRef asyncTasks); +void removeAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId); +void removeAsyncTaskIds(Operation *op); + +class OpBuilderWithAsyncTaskIds : public OpBuilder { +public: + OpBuilderWithAsyncTaskIds(MLIRContext *context) : OpBuilder(context) {} + + explicit OpBuilderWithAsyncTaskIds(Operation *op) : OpBuilder(op) { + setAsyncTaskIdsFromOp(op); + } + + void setAsynTaskIdsFromArray(ArrayRef newAsyncTaskIds) { + asyncTaskIds = SmallVector(newAsyncTaskIds.begin(), + newAsyncTaskIds.end()); + } + + void setAsyncTaskIdsFromOp(Operation *op) { + setAsynTaskIdsFromArray(getAsyncTaskIds(op)); + } + + void setAsyncTaskIdsFromValueUsers(Value value) { + SetVector asyncTaskIdSet; + for (Operation *user : value.getUsers()) + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(user)) + asyncTaskIdSet.insert(asyncTaskId); + setAsynTaskIdsFromArray(asyncTaskIdSet.getArrayRef()); + } + + template + OpTy createWithAsyncTaskIds(Args &&...args) { + OpTy op = create(std::forward(args)...); + if (!asyncTaskIds.empty()) + setAsyncTaskIds(op, asyncTaskIds); + return op; + } + +private: + SmallVector asyncTaskIds; +}; + +class PatternRewriterWithAsyncTaskIds { +public: + PatternRewriterWithAsyncTaskIds(PatternRewriter &rewriter, Operation *op) + : rewriter(&rewriter) { + setAsyncTaskIdsFromOp(op); + } + + void setAsynTaskIdsFromArray(ArrayRef newAsyncTaskIds) { + asyncTaskIds = SmallVector(newAsyncTaskIds.begin(), + newAsyncTaskIds.end()); + } + + void setAsyncTaskIdsFromOp(Operation *op) { + setAsynTaskIdsFromArray(getAsyncTaskIds(op)); + } + + void setAsyncTaskIdsFromValueUsers(Value value) { + SetVector asyncTaskIdSet; + for (Operation *user : value.getUsers()) + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(user)) + asyncTaskIdSet.insert(asyncTaskId); + setAsynTaskIdsFromArray(asyncTaskIdSet.getArrayRef()); + } + + template + OpTy create(Location location, Args &&...args) { + OpTy op = rewriter->create(location, std::forward(args)...); + if (!asyncTaskIds.empty()) + setAsyncTaskIds(op, asyncTaskIds); + return op; + } + + template + OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { + auto newOp = + rewriter->replaceOpWithNewOp(op, std::forward(args)...); + return newOp; + } + +private: + PatternRewriter *rewriter; + SmallVector asyncTaskIds; +}; + } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 486bbf553..3ce1d80de 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -43,6 +43,38 @@ class TTNG_Op traits = []> : !listconcat(traits, [VerifyTensorLayoutsTrait])> { } +def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments, + MemoryEffects<[MemWrite]>]> { + let summary = "mbarrier arrive"; + + let description = [{ + This operation defining the arriving action for a mbarrier. + txCount: + An optional attribute that set tx-count. This Op will be lowered into + mbarrier.arrive.expect_tx if the optional attribute exist. + trackAsyncOp: + If true, this op will be lowered into cp.async.mbarrier.arrive.noinc. + pred: + Only perform arrive action when pred is true. + remoteCtaId: + if set, perform an remote arrive action. + + Example: + + triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr + + }]; + + let arguments = (ins TT_MemDescType:$mbarrier, + Optional:$pred, + Optional:$remoteCtaId, + I1Attr: $trackAsyncOp, + DefaultValuedAttr: $txCount + ); + + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { let arguments = (ins BoolAttr:$bCluster); @@ -57,6 +89,31 @@ def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { }]; } +def TTNG_GetCanonicalWarpIdOp : TTNG_Op<"get_canonical_warp_id", [Pure]> { + let description = [{ + Returns the one dimensional warpId when it's used for producing warp uniform values. + }]; + + let results = (outs I32:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> { + let summary = "named barrier arrive"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> { + let summary = "named barrier wait"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> { let arguments = (ins I1Attr:$relaxed); let assemblyFormat = "attr-dict"; @@ -67,13 +124,14 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> { } // -// DotAsync Op +// WarpGroupDot Op // -def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - TypesMatchWith<"result's type matches accumulator's type", - "d", "c", "$_self">]> { - let summary = "dot async"; +def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DotLike, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "warp group dot"; let description = [{ $d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp @@ -82,17 +140,23 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [DeclareOpInterfaceMethods:$useC, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc, + DefaultValuedAttr:$isAsync); let results = (outs TT_FpIntTensor:$d); - let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; + let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)"; + + let extraClassDeclaration = [{ + bool needsPartialAccumulator(); + }]; } -def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods, - AllTypesMatch<["inputs", "outputs"]>]> { - let summary = "dot wait"; +def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods, + AllTypesMatch<["inputs", "outputs"]>]> { + let summary = "warp group dot wait"; let arguments = (ins Variadic:$inputs, I32Attr:$pendings); let results = (outs Variadic:$outputs); let description = [{ @@ -100,7 +164,7 @@ def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods { let assemblyFormat = "attr-dict"; } +def TTNG_GetAsyncTaskIdOp : TTNG_Op<"get_async_task_id", [Pure]> { + let results = (outs I32:$result); + + let builders = [OpBuilder<(ins)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +// +// Token +// + +def TTNG_CreateTokenOp : TTNG_Op<"create_token"> { + let results = (outs TensorOf<[TTNG_TokenType]>:$result); + + let arguments = (ins I32Attr:$num); + + let builders = [OpBuilder<(ins "uint32_t":$num)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1:$phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1: $phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> { + let summary = "register allocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} + +def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> { + let summary = "register deallocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td index 6fe71ade2..c399bc172 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td @@ -28,7 +28,8 @@ def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp"> let summary = "plan CTA"; let description = [{ - Plan CTAs in CGA + This pass computes and applies "optimized" CTA tilings to DotOp, ReduceOp + and StoreLikeOps operations. }]; let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()"; @@ -43,6 +44,8 @@ def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::M let summary = "Insert fences across generic and async proxy"; let description = [{ + This pass is to insert memory fences to ensure that memory operations are + properly ordered across generic and async operations. }]; let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()"; diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index fb2680241..c728cfbb3 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -2,6 +2,7 @@ #define TRITON_TOOLS_LINEARLAYOUT_H #include +#include #include #include #include @@ -22,10 +23,10 @@ namespace mlir::triton { // location" to a "logical tensor index". // // For example, suppose we have a 2D tensor T stored in GPU registers. T's -// layout is the function that, given a "hardware location" tuple of (thread-id, -// warp-id), returns an index (x,y) into T. In other words, if L(t,w) = (x,y) -// is our linear layout func, then a register in thread t in warp w contains the -// value T[x,y]. +// layout (i.e., L) is the function that, given a "hardware location" tuple of +// (thread-id, warp-id), returns an index (x,y) into T. In other words, if +// L(t,w) = (x,y) is our linear layout func, then a register in thread t in warp +// w contains the value T[x,y]. // // The key fact about LLs is, the mapping from (t,w) to (x,y) is not arbitrary. // We only need to specify the value of L(t,w) at certain special points @@ -36,11 +37,11 @@ namespace mlir::triton { // tensor T has shape 4x4. We define the function L by choosing the values of // L(0,1), L(0,2), L(1,0), and L(2,0). Our choices are shown below. // -// t/w 0 1 2 3 -// 0 ? (0,1) (0,2) ? -// L(t,w) = 1 (1,1) ? ? ? -// 2 (2,2) ? ? ? -// 3 ? ? ? ? +// t/w 0 1 2 3 +// 0 ? (0,1) (0,2) ? +// L(t,w) = 1 (1,1) ? ? ? +// 2 (2,2) ? ? ? +// 3 ? ? ? ? // // You only need to specify these four values to define the whole linear layout. // These special values are called the "basis vectors" or "bases" of the layout. @@ -151,15 +152,19 @@ namespace mlir::triton { // ## Dimension order // // An LL's input and output dimensions have an order. This order only affects -// the reshapeIns/Outs operations, where the layout is logically flattened -// according to the dimension order and then chopped up again. +// the reshapeIns/Outs and similar operations, where the layout is logically +// flattened according to the dimension order and then chopped up again. // -// ## Surjectivity +// ## Surjectivity and injectivity // -// We require that all output values are covered by some input value, i.e. the -// function L is surjective. But multiple input values can map to the same -// output value. This represents the idea that the same logical tensor element -// can be stored in multiple places in the hardware. +// Most LLs are surjective, i.e. all output values are covered by some input +// value. But occasionally you might create a non-surjective layout, usually +// via invertAndCompose. We aggressively assert that LLs are surjective unless +// you explicitly create one that's not. +// +// LLs are not, in general, injective. There might exist multiple input values +// that map to the same output value. This represents the idea that the same +// logical tensor elements can be stored in multiple places in the hardware. // // ## Why map hardware loc -> tensor index and not the other way around? // @@ -271,7 +276,7 @@ namespace mlir::triton { // // That's all we need in order to define linear layouts mathematically! // -// # Comaprison to Nvidia CuTe +// # Comparison to Nvidia CuTe // // (Note, I'm not an expert on CuTe; this is my best understanding.) // @@ -285,7 +290,7 @@ namespace mlir::triton { // subsume all of these special cases. The CUTLASS folks say this simplified // CUTLASS, in the same way that we hope LLs will simplify Triton. // -// Like CuTe layouts, LLs are also programmable and composible. But there are +// Like CuTe layouts, LLs are also programmable and composable. But there are // also some differences. // // - Dimensions in LLs are named; CuTe dimensions are numbered. @@ -317,7 +322,8 @@ class LinearLayout { /*size=getInDimSizeLog2(inDim)*/> bases; - llvm::SetVector outDimNames; + llvm::MapVector outDims; + bool surjective; public: using BasesT = decltype(bases); @@ -342,8 +348,40 @@ class LinearLayout { // Creates a LinearLayout from a list of bases. These are interpreted // according to the rules written for the member variable `bases`. + // + // Calculates the out-dim sizes according to the bases. Consider the + // following example. + // + // L(in1=1) = (out1=1, out2=0) + // L(in1=2) = (out1=5, out2=1) + // L(in1=4) = (out1=2, out2=2) + // + // To calculate the out-dim sizes, we first find the largest values for out1 + // and out2, namely 5 and 2, then round these up to the next power of 2, + // namely 8 and 4. These are the out-dim sizes. + // + // Assert-fails if the layout is not surjective given these out-dim sizes. + // That is, every possible out-dim in range [0, size) must be produced by + // xor'ing some combination of bases. explicit LinearLayout(BasesT bases, ArrayRef outDimNames); + // Creates a LinearLayout given a list of bases and the explicit out-dimension + // sizes. Allows the layout to be non-surjective. + // + // To see why we need to explicitly pass out-dim sizes when creating a + // non-surjective layout, consider the following example. + // + // L(in1=1) = 1 + // L(in1=2) = 4 + // + // If we naively infer the out-dim sizes from these bases, we'd infer a size + // of nextPow2(4) = 8. But given that the layout is non-surjective, who is to + // say that the codomain is not (say) [0,32)? We can't tell, thus we need to + // be explicit about the sizes. + explicit LinearLayout(BasesT bases, + ArrayRef> outDims, + bool requireSurjective); + // Construct a LinearLayout from an explicit list of bases. (This constructor // is needed because llvm::MapVector does not have a constructor that accepts // an initializer_list.) @@ -363,9 +401,17 @@ class LinearLayout { // {"in2", {/*L(in2=1)=*/{0,4}, /*L(in2=2)=*/{0,8}, /*L(in2=4)=*/{1,1}}}, // }, // {"out1", "out2"}) + // + // The overload that infers out-dim sizes assert-fails if the layout is not + // surjective. explicit LinearLayout( ArrayRef>>> bases, ArrayRef outDimNames); + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef> outDims, bool requireSurjective); + + bool isSurjective() const { return surjective; } const BasesT &getBases() const { return bases; } @@ -380,27 +426,22 @@ class LinearLayout { int32_t getBasis(StringAttr inDim, int32_t pos, StringAttr outDim) const { return getBasis(inDim, pos)[getOutDimIndex(outDim)]; - ; } // These are in minor-to-major order, although if you don't flatten the dims // (e.g. by reshaping) then the order doesn't really affect anything. auto getInDimNames() const { return llvm::make_first_range(bases); } - ArrayRef getOutDimNames() const { - return outDimNames.getArrayRef(); - } + auto getOutDimNames() const { return llvm::make_first_range(outDims); } // Gets the position that this outDim occupies in getOutDimNames(). Asserts // if the dim is not present. int32_t getOutDimIndex(StringAttr outDim) const; bool hasInDim(StringAttr inDim) const { return bases.contains(inDim); } - bool hasOutDim(StringAttr outDim) const { - return outDimNames.contains(outDim); - } + bool hasOutDim(StringAttr outDim) const { return outDims.contains(outDim); } int32_t getNumInDims() const { return bases.size(); } - int32_t getNumOutDims() const { return outDimNames.size(); } + int32_t getNumOutDims() const { return outDims.size(); } // Asserts if the dimension is not present. int32_t getInDimSizeLog2(StringAttr inDim) const; @@ -408,8 +449,11 @@ class LinearLayout { return 1 << getInDimSizeLog2(inDim); } + int32_t getTotalInDimSizeLog2() const; + int32_t getTotalInDimSize() const { return 1 << getTotalInDimSizeLog2(); } + // getOutDimSize(dim) == s means that there exists an input value that will - // produce each output value in [0,s). + // produce each output value in [0,s) (if the layout is surjective). // // For example, if our bases are // @@ -428,6 +472,30 @@ class LinearLayout { return 1 << getOutDimSizeLog2(outDim); } + int32_t getTotalOutDimSizeLog2() const; + int32_t getTotalOutDimSize() const { return 1 << getTotalOutDimSizeLog2(); } + + // Finds the number of consecutive input elements in the first input dimension + // that map to consecutive output elements in the first output dimension. + // + // Mathematically, finds the maximum value V such that for any a, b, c, and + // for all v in [0,V), + // + // L(a*V + v, b, c, ...) = L(a*V, b, c, ...) + (v, 0, ..., 0) + // + // Note that's +, not ⊕, in the RHS. (Equivalently, we could use binary-or + // instead of +. In other words, we require that L(a*V, b, c, ...) have no + // bits that overlap with v.) + // + // For example, if L maps (register, lane) to (dim1, dim0), then this tells + // you how many consecutive registers map to consecutive elements of dim1. + // + // This only works across the first (i.e. the most-minor) dimension of in/out. + // If you want it to work across more dimensions, flatten the layout. + // + // TODO(jlebar): Replace with divideLeft. + int32_t getNumConsecutiveInOut() const; + // Reorders the in/out dimensions of the layout. This is mostly cosmetic // (affecting e.g. the order of getIn/OutDimNames), but it also affects the // behavior of reshape. @@ -436,6 +504,30 @@ class LinearLayout { [[nodiscard]] LinearLayout transposeOuts(ArrayRef newOutDimOrder) const; + [[nodiscard]] LinearLayout reshapeIns( + ArrayRef> newInDims) + const; + + // Reshapes to a single input dim (named whatever our first in-dim is named). + [[nodiscard]] LinearLayout flattenIns() const { + if (getNumInDims() == 0) { + return reshapeIns({}); + } + return reshapeIns({{*getInDimNames().begin(), getTotalInDimSize()}}); + } + + [[nodiscard]] LinearLayout + reshapeOuts(ArrayRef> + newOutDims) const; + + // Reshapes to a single out dim (named whatever our first out-dim is named). + [[nodiscard]] LinearLayout flattenOuts() const { + if (getNumOutDims() == 0) { + return reshapeOuts({}); + } + return reshapeOuts({{*getOutDimNames().begin(), getTotalOutDimSize()}}); + } + // Creates a new layout which, roughly speaking, is equivalent to one where // every element of the `outer` layout is replaced by a full instance of the // `inner` layout. @@ -474,12 +566,50 @@ class LinearLayout { // // Requires: Any in/out dimensions which are in both outer and inner appear in // the same relative order. + // + // Postcondition: If both inner and outer are surjective, the result is + // surjective. friend LinearLayout operator*(LinearLayout inner, LinearLayout outer); LinearLayout &operator*=(LinearLayout outer) { *this = *this * outer; return *this; } + // Returns true if this layout acts trivially (as the identity) on the given + // dimensions. This means that it's the identity on those dimensions, and it + // does not map other dimensions onto those or these onto other dimensions. + bool isTrivialOver(ArrayRef dimNames) const; + + // For an endomorphism on dimNames (linear map that maps dimNames to dimNames) + // checks whether it is the identity map on these dimensions (i.e + // LinearLayouts::isTrivialOver) and if so, returns the sublayout of the + // remaining dimensions. + // nb. The isTrivialOver condition is more restrictive than the usual + // "leaves the subspace invariant" condition in maths. + // We can always relax it if we know how to take advantage of a conversion + // layout being block-diagonal in the future. + std::optional quotient(ArrayRef dimNames) const; + + // Gets a layout with only these in/out dimensions. + // + // In other words, gets a layout where the in-dims not mentioned in inDimNames + // are set to 0, and the out-dims not mentioned in outDimNames are omitted. + // + // The output-dim sizes are unchanged. The order of the in/out dims in the + // returned layout matches the order of the original layout, not the order of + // the arguments. + LinearLayout sublayout(ArrayRef inDimNames, + ArrayRef outDimNames) const; + + // Is the sublayout restricted to inDimNames + outDimNames all zeros? + bool sublayoutIsZero(ArrayRef inDimNames, + ArrayRef outDimNames) const; + + // Is the sublayout defined from dimNames to dimNames the identity? + // In particular, is the input and output size in these dimensions + // the same, and are the bases the identity? + bool squareSublayoutIsIdentity(ArrayRef dimNames) const; + // Computes and returns L(x, y, z). // // If you want to apply the layout to mlir Values instead of integers, that @@ -494,19 +624,60 @@ class LinearLayout { // - let `outer` be O(x). // - Then compose(outer) returns the layout (O∘L)(x), aka O(L(x)). // - // Requires: The output dimensions of this layout equal the input dimensions - // of outer (order doesn't matter). + // Requires: + // - The output dimensions of this layout equal the input dimensions of + // outer (order doesn't matter). + // - For each output dim d of this layout, this->getOutDimSize(d) <= + // outer.getInDimSize(d). + // + // Postcondition: The result is surjective iff `this` and `outer` are + // surjective and this->getOutDimSize(d) == outer.getInDimSize(d) for each of + // this->getOutDimNames(). + // [[nodiscard]] LinearLayout compose(const LinearLayout &outer) const; - // TODO(jlebar): Not yet implemented. - // [[nodiscard]] LinearLayout reshapeIns( - // std::vector> - // newInDims) const; - - // TODO(jlebar): Not yet implemented. - // [[nodiscard]] LinearLayout reshapeOuts( - // std::vector> - // newOutDims) const; + // Inverts or pseudo-inverts `outer` and composes it with `this`. + // + // Formally, if C = A.invertAndCompose(B), then for all x, C(x) = y implies + // A(x) = B(y), or in other words A(x) = B(C(x)). If B is invertible, then + // C(x) = B^-1(A(x)), which is how this function gets its name. + // + // For example, suppose you have the following two LLs. + // + // - R is an LL representing registers, mapping (lane, warp) to a 2D index. + // - S is an LL representing shared memory, mapping offset to a 2D index. + // + // Suppose you want to store tensor values from registers into shared memory. + // That is, given a (lane, warp), you want to know the corresponding shared + // memory offset to store into. + // + // This is equivalent to converting a (lane, warp) into a 2D index (i.e. + // applying R), then converting a 2D index into a shmem offset (i.e. applying + // the inverse of S). R.invertAndCompose(S) computes this transformation. + // + // Notice the following requirements in order for this to work. + // + // - R and S must have the same output dimension names (different order is + // allowed). + // - S must be surjective, i.e. there must be some offset for each output + // dimension of S. This way when we compose S^-1 with R, every possible + // 2D index that we might get from R has some shmem offset. + // - The codomain of S must be at least as large as the codomain of R. + // Otherwise, R could map some tensor index that is not stored in S. + // + // One requirement we *don't* have is that S is injective; we allow two shmem + // offsets to hold the same 2D index. If S is not injective, there's + // ambiguity in which offset we choose for a given (lane, warp). For now we + // don't place any guarantees on the choices made by this function. + [[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const; + + // For each in-dim, returns a bitmask of the "free variables" in the layout + // function. + // + // These are the bits in the input that can be changed without changing the + // output. If all of the free variables are 0, then the layout is injective + // (i.e. every input bit affects the output). + llvm::MapVector getFreeVariableMasks() const; std::string toString() const; @@ -514,6 +685,22 @@ class LinearLayout { friend bool operator!=(LinearLayout lhs, LinearLayout rhs) { return !(lhs == rhs); } + bool equalIgnoringOutDimSizes(const LinearLayout &other) const; + +private: + // Factory function that gracefully fails rather than asserts if the layout is + // not well-formed. + static std::optional + tryCreate(BasesT bases, ArrayRef> outDims, + bool requireSurjective); + + // Constructor that does not check invariants. Used by tryCreate. + struct NoCheckInvariants {}; + LinearLayout(BasesT bases, ArrayRef> outDims, + NoCheckInvariants); + + [[nodiscard]] std::optional + checkInvariants(bool requireSurjective); }; inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 12584aa8f..e5132b6d3 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -13,26 +13,32 @@ namespace mlir::triton { inline const std::set CACHE_INVALIDATING_ENV_VARS = { // clang-format off "AMDGCN_ENABLE_DUMP", + "AMDGCN_USE_BUFFER_OPS", "DISABLE_FAST_REDUCTION", "DISABLE_LLVM_OPT", "DISABLE_MMA_V3", "DISABLE_PTXAS_OPT", "LLVM_IR_ENABLE_DUMP", "LLVM_ENABLE_TIMING", + "LLVM_PASS_PLUGIN_PATH", "MLIR_ENABLE_DIAGNOSTICS", "MLIR_ENABLE_DUMP", "MLIR_ENABLE_TIMING", + "TRITON_DEFAULT_FP_FUSION", "TRITON_DISABLE_LINE_INFO", "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", "TRITON_ENABLE_LLVM_DEBUG", "TRITON_LLVM_DEBUG_ONLY", - "USE_TTGIR_LOC", + "USE_IR_LOC", "NVPTX_ENABLE_DUMP", // clang-format on }; inline const std::set CACHE_NEUTRAL_ENV_VARS = { + // clang-format off "TRITON_REPRODUCER_PATH", + "TRITON_ENABLE_PYTHON_STACKTRACE" + // clang-format on }; namespace tools { diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index 5b3910013..3840bf419 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -21,13 +21,18 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { return ret; } -void SharedMemoryAliasAnalysis::visitOperation( +LogicalResult SharedMemoryAliasAnalysis::visitOperation( Operation *op, ArrayRef *> operands, ArrayRef *> results) { AliasInfo aliasInfo; bool pessimistic = true; - // These ops may allocate a new shared memory buffer. auto result = op->getResult(0); + // skip ops that return memdesc in a different memory space. + if (auto memdescTy = dyn_cast(result.getType())) { + if (!isa_and_nonnull( + memdescTy.getMemorySpace())) + return success(); + } // Only LocalAllocOp creates a new buffer. if (isa(op)) { @@ -44,11 +49,14 @@ void SharedMemoryAliasAnalysis::visitOperation( } if (pessimistic) { - return setAllToEntryStates(results); + setAllToEntryStates(results); + return success(); } // Join all lattice elements for (auto *result : results) propagateIfChanged(result, result->join(aliasInfo)); + + return success(); } AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index a129cb194..918a15e55 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -10,9 +10,14 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/Alias.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" using ::mlir::triton::gpu::AMDMfmaEncodingAttr; using ::mlir::triton::gpu::BlockedEncodingAttr; @@ -27,6 +32,10 @@ using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; +#define DEBUG_TYPE "allocation-analysis" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir { //===----------------------------------------------------------------------===// @@ -44,7 +53,8 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) { auto dstMmaLayout = mlir::dyn_cast(dstLayout); auto dstDotLayout = mlir::dyn_cast(dstLayout); - assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && + assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() && + !srcMmaLayout.isHopper()) && "mma -> mma layout conversion is only supported on Ampere"); // mma or dot layout does not have an order, so the order depends on the @@ -57,35 +67,21 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) { return {inOrd, outOrd}; } -SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { - auto srcTy = op.getSrc().getType(); - auto dstTy = op.getType(); +static SmallVector getRepShapeForCvt(RankedTensorType srcTy, + RankedTensorType dstTy) { Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); + if (!cvtNeedsSharedMemory(srcTy, dstTy)) { + return {}; + } + if (shouldUseDistSmem(srcLayout, dstLayout)) { // TODO: padding to avoid bank conflicts return convertType(getShapePerCTA(srcTy)); } - if (isMfmaToDotShortcut(srcTy, dstTy)) - return {}; - - // MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem - if (auto srcMmaLayout = mlir::dyn_cast(srcLayout)) { - if (mlir::isa(dstLayout)) { - if (isMmaToDotShortcut(srcTy, dstTy)) { - return {}; - } - } else if (auto dstMmaLayout = - mlir::dyn_cast(dstLayout)) { - if (isMmaToMmaShortcut(srcTy, dstTy)) { - return {}; - } - } - } - - assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()"); + assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()"); auto srcShapePerCTA = getShapePerCTA(srcTy); auto dstShapePerCTA = getShapePerCTA(dstTy); @@ -102,21 +98,39 @@ SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { return repShape; } -SmallVector -getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, - unsigned &outVec) { - auto repShape = getRepShapeForCvtLayout(op); +// Both `atomic_cas` and `atomic_rmw need a single scratch element if returning +// a scalar value because Triton's block-based programming model ensures that +// all threads in each block see the same return value, even those threads that +// do not participate in the atomic operation +static SmallVector getRepShapeForAtomic(Value result) { + SmallVector smemShape; + if (atomicNeedsSharedMemory(result)) { + smemShape.push_back(1); + } + return smemShape; +} + +ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, + RankedTensorType dstTy) { + // Initialize vector sizes and stride + auto repShape = getRepShapeForCvt(srcTy, dstTy); if (repShape.empty()) - return repShape; + return ScratchConfig({}, {}); + ScratchConfig scratchConfig(repShape, repShape); auto rank = repShape.size(); - auto srcTy = op.getSrc().getType(); - auto dstTy = op.getType(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); assert(!isMfmaToDotShortcut(srcTy, dstTy)); - auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); + // FIXME This is NOT entirely correct + // This should be getElemOrder, but we don't have such a method + // TODO Implement getElemOrder and make sure it's consistent with + // getContigPerThread + auto inOrd = gpu::getThreadOrder(srcLayout); + auto outOrd = gpu::getThreadOrder(dstLayout); + scratchConfig.order = outOrd; + unsigned srcContigPerThread = getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; unsigned dstContigPerThread = @@ -124,52 +138,32 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, // TODO: Fix the legacy issue that ourOrd[0] == 0 always means // that we cannot do vectorization. unsigned innerDim = rank - 1; - inVec = outOrd[0] != innerDim ? 1 - : inOrd[0] != innerDim ? 1 - : srcContigPerThread; - outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; + scratchConfig.inVec = outOrd[0] != innerDim ? 1 + : inOrd[0] != innerDim ? 1 + : srcContigPerThread; + scratchConfig.outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; - // For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the - // codegen. if (auto mma = mlir::dyn_cast(srcLayout)) { if (mma.getVersionMajor() == 1) { - inVec = srcContigPerThread; + // For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the + // codegen. + scratchConfig.inVec = srcContigPerThread; } else if (mlir::isa(dstLayout)) { // when storing from mma layout and loading in blocked layout vectorizing // the load back gives better performance even if there is a // transposition. - outVec = dstContigPerThread; + scratchConfig.outVec = dstContigPerThread; } } - if (rank <= 1) - return repShape; - // pad the last dimension - unsigned paddedDim = rank - 1; - if (auto dstBlockedLayout = mlir::dyn_cast(dstLayout)) { - paddedDim = dstBlockedLayout.getOrder()[0]; - } - unsigned pad = std::max(inVec, outVec); - repShape[paddedDim] += pad; - return repShape; -} - -// TODO: extend beyond scalars -SmallVector getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { - SmallVector smemShape; - if (isa(op.getPtr().getType())) { - // do nothing or just assert because shared memory is not used in tensor up - // to now - } else { - // need only bytes for scalar - // always vec = 1 and elemsPerThread = 1 for scalar? - smemShape.push_back(1); - } - return smemShape; -} + // No padding is required if the tensor is 1-D, or if all dimensions except + // the first accessed dimension have a size of 1. + if (rank <= 1 || product(repShape) == repShape[outOrd[0]]) + return scratchConfig; -SmallVector getScratchConfigForAtomicCAS(triton::AtomicCASOp op) { - return SmallVector{1}; + auto paddedSize = std::max(scratchConfig.inVec, scratchConfig.outVec); + scratchConfig.paddedRepShape[outOrd[0]] += paddedSize; + return scratchConfig; } class AllocationAnalysis { @@ -199,18 +193,9 @@ class AllocationAnalysis { /// Initializes explicitly defined shared memory values for a given operation. void getExplicitValueSize(Operation *op) { - // Values returned from scf.yield will not be allocated even though they - // have the shared encoding. - // For example: %a = scf.if -> yield - // %a must be allocated elsewhere by other operations. - // FIXME(Keren): extract and insert are always alias for now - if (!maybeSharedAllocationOp(op)) - return; - - // XXX(Keren): Why this hard-coded alignment? - size_t kAlignment = 8; for (Value result : op->getResults()) { - if (auto alloc = result.getDefiningOp()) { + auto alloc = result.getDefiningOp(); + if (alloc && alloc.isSharedMemoryAlloc()) { // Bytes could be a different value once we support padding or other // allocation policies. auto allocType = alloc.getType(); @@ -218,15 +203,20 @@ class AllocationAnalysis { auto bytes = product(shapePerCTA) * allocType.getElementTypeBitWidth() / 8; - // XXX(Keren): magic numbers 256 and 1024 - // benzh@maybe alignment should be passed in. - // Software swizzling calculates phase based on offset, while hardware - // swizzling do that based on physical address. Thus only by setting the - // alignment to 1024 can ensure the correctness.  - if (bytes > 256) - kAlignment = 1024; - allocation->addBuffer(result, bytes, - kAlignment); + auto alignment = alloc.getAlignmentOrDefault(); + LLVM_DEBUG({ + llvm::dbgs() << "check localAlloc in getExplicitValueSize: "; + alloc.dump(); + }); + int sharingGroup = -1; + if (alloc->hasAttr("allocation.shareGroup")) { + sharingGroup = + mlir::cast(alloc->getAttr("allocation.shareGroup")) + .getInt(); + LDBG("with shareGroup of " << sharingGroup); + } + allocation->addBuffer( + result, bytes, alignment, 0, sharingGroup); } } } @@ -279,27 +269,23 @@ class AllocationAnalysis { // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's // also possible to realize it with other approaches in restricted // conditions, such as warp-shuffle - unsigned inVec = 0; - unsigned outVec = 0; - auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec); - unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, - std::multiplies{}); + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + auto elems = getNumScratchElements(scratchConfig.paddedRepShape); auto bytes = isa(srcTy.getElementType()) ? elems * kPtrBitWidth / 8 : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; maybeAddScratchBuffer(op, bytes, scratchAlignment); - } else if (auto atomicRMWOp = dyn_cast(op)) { + } else if (isa(op)) { auto value = op->getOperand(0); // only scalar requires scratch memory // make it explicit for readability if (dyn_cast(value.getType())) { // nothing to do } else { - auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp); - unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, - std::multiplies{}); + auto smemShape = getRepShapeForAtomic(op->getResult(0)); + auto elems = getNumScratchElements(smemShape); auto elemTy = cast(value.getType()).getPointeeType(); auto bytes = @@ -309,24 +295,6 @@ class AllocationAnalysis { maybeAddScratchBuffer(op, bytes, scratchAlignment); } - } else if (auto atomicCASOp = dyn_cast(op)) { - // only scalar requires scratch memory - // make it explicit for readability - auto value = op->getOperand(0); - if (dyn_cast(value.getType())) { - // nothing to do - } else { - auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp); - unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, - std::multiplies{}); - auto elemTy = - cast(value.getType()).getPointeeType(); - auto bytes = isa(elemTy) - ? elems * kPtrBitWidth / 8 - : elems * elemTy.getIntOrFloatBitWidth() / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } } else if (auto callOp = dyn_cast(op)) { auto callable = callOp.resolveCallable(); auto funcOp = dyn_cast(callable); @@ -334,6 +302,12 @@ class AllocationAnalysis { auto bytes = funcAlloc->getSharedMemorySize(); maybeAddScratchBuffer(op, bytes, scratchAlignment); + } else if (auto createTensormap = + dyn_cast(op)) { + constexpr int32_t kTMASize = 128; + constexpr int32_t kTMAAlign = 128; + maybeAddScratchBuffer(op, kTMASize, + kTMAAlign); } } @@ -357,6 +331,13 @@ class AllocationAnalysis { getExplicitValueSize(op); getScratchValueSize(op); }); + LDBG("getValuesAndSizes --"); + for (auto valueBufferIter : allocation->valueBuffer) { + auto *buffer = valueBufferIter.second; + LLVM_DEBUG(llvm::dbgs() + << "-- buffer " << buffer->id << " " << buffer->size << " " + << buffer->offset << " " << buffer->sharingGroup << "\n"); + } // Get the alias values std::unique_ptr solver = createDataFlowSolver(); SharedMemoryAliasAnalysis *aliasAnalysis = @@ -378,11 +359,12 @@ class AllocationAnalysis { /// Computes the liveness range of the allocated value. /// Each buffer is allocated only once. void resolveExplicitBufferLiveness( - function_ref(Value value)> getLiveness) { + function_ref(Value value, BufferT *buffer)> + getLiveness) { for (auto valueBufferIter : allocation->valueBuffer) { auto value = valueBufferIter.first; auto *buffer = valueBufferIter.second; - bufferRange[buffer] = getLiveness(value); + bufferRange[buffer] = getLiveness(value, buffer); } } @@ -390,11 +372,12 @@ class AllocationAnalysis { /// values because each allocated buffer could be an alias of others, if block /// arguments are involved. void resolveAliasBufferLiveness( - function_ref(Value value)> getLiveness) { + function_ref(Value value, BufferT *buffer)> + getLiveness) { for (auto aliasBufferIter : allocation->aliasBuffer) { auto value = aliasBufferIter.first; auto buffers = aliasBufferIter.second; - auto range = getLiveness(value); + auto range = getLiveness(value, buffers.front()); for (auto *buffer : buffers) { auto minId = range.start(); auto maxId = range.end(); @@ -420,8 +403,21 @@ class AllocationAnalysis { // range. auto *op = opScratchIter.first; auto *buffer = opScratchIter.second; - bufferRange.insert({buffer, Interval(operationId.lookup(op), - operationId.lookup(op) + 1)}); + // Extend live range when asyncTaskId is not empty (i.e when we have + // warp spec). + if (getAsyncTaskIds(op).empty()) { + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } else { + for (auto tId : getAsyncTaskIds(op)) + buffer->regionIds.insert(tId); + // For warp-specialized code, we can assume each region has its own + // copy of a scratch buffer, i.e each region is for a single taskId. + // In that case, we don't need to extend the liveness of scratch + // buffers. + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } } }; processScratchMemory(allocation->opScratch); @@ -452,19 +448,38 @@ class AllocationAnalysis { // Analyze liveness of explicit buffers Liveness liveness(operation); - auto getValueLivenessRange = [&](Value value) { + auto getValueLivenessRange = [&](Value value, BufferT *buffer) { auto liveOperations = liveness.resolveLiveness(value); - auto minId = std::numeric_limits::max(); - auto maxId = std::numeric_limits::min(); + // Update regions for buffer. std::for_each(liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { - if (operationId[liveOp] < minId) { - minId = operationId[liveOp]; - } - if ((operationId[liveOp] + 1) > maxId) { - maxId = operationId[liveOp] + 1; + for (auto rId : getAsyncTaskIds(liveOp)) { + buffer->regionIds.insert(rId); } }); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each( + liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { + if (buffer->regionIds.size() > 1 || buffer->sharingGroup >= 0) { + // For a buffer that is associated with warp + // specialization, due to producer-consumer channel, it + // should have at least two regions, and it will be live + // throughout. For a buffer that is local to a consumer: + // we need to make sure not to overlap with local + // buffers from another consumer. This will be handled + // when building the interference graph. + minId = 0; + maxId = operationId.size(); + return; + } + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); return Interval(minId, maxId); }; @@ -473,16 +488,66 @@ class AllocationAnalysis { resolveScratchBufferLiveness(operationId); } + void dumpBuffers() { + LDBG("Dump bufferRange: id size offset sharingGroup ---------"); + for (auto bufferIter : bufferRange) { + LLVM_DEBUG({ + llvm::dbgs() << "-- " << bufferIter.first->id << " " + << bufferIter.first->size << " " + << bufferIter.first->offset << " " + << bufferIter.first->sharingGroup << " regions ["; + for (auto tId : bufferIter.first->regionIds) { + llvm::dbgs() << tId << " "; + } + llvm::dbgs() << "] interval " << bufferIter.second.start() << " " + << bufferIter.second.end() << "\n"; + }); + } + } + /// Computes the shared memory offsets for all related values. /// Paper: Algorithms for Compile-Time Memory Optimization /// (https://dl.acm.org/doi/pdf/10.5555/314500.315082) void computeOffsets() { SmallVector buffers; + // Handle sharingGroup here. For allocations with the same sharingGroup + // get the union of the live range, and union of the regionIds. Put + // the + // largest buffer in buffers. + DenseMap> toGroup; for (auto bufferIter : bufferRange) { + if (bufferIter.first->sharingGroup >= 0) + toGroup[bufferIter.first->sharingGroup].push_back(bufferIter.first); + } + DenseMap sharingIdToRep; + for (auto &kv : toGroup) { + size_t bigSize = 0; + BufferT *rep = nullptr; + for (auto *buf : kv.second) { + if (buf->size > bigSize) { + rep = buf; + bigSize = buf->size; + } + } + // FIXME: update live range and regionIds. + sharingIdToRep[kv.first] = rep; + } + for (auto bufferIter : bufferRange) { + if (sharingIdToRep.find(bufferIter.first->sharingGroup) != + sharingIdToRep.end()) { + if (bufferIter.first != + sharingIdToRep[bufferIter.first->sharingGroup]) { + LDBG("-- ignore shared buffer " << bufferIter.first->size << " " + << bufferIter.first->offset << " " + << bufferIter.first->sharingGroup); + continue; + } + } buffers.emplace_back(bufferIter.first); } calculateStarts(buffers); + dumpBuffers(); // NOTE: The original paper doesn't consider interference between // the bumped ranges. Buffers that previously do not interfere with @@ -497,6 +562,18 @@ class AllocationAnalysis { allocate(buffers, interference); buildInterferenceGraph(buffers, interference); } while (!interference.empty()); + // Update allocation for sharingGroup. + for (auto &kv : toGroup) { + auto *rep = sharingIdToRep[kv.first]; + for (auto *buf : kv.second) { + if (buf != rep) { + buf->setOffsetAligned(rep->offset); + LDBG("-- set sharing buffer's offset " + << buf->size << " " << buf->offset << " " << buf->sharingGroup); + } + } + } + dumpBuffers(); } /// Computes the initial shared memory offsets. @@ -564,6 +641,21 @@ class AllocationAnalysis { void buildInterferenceGraph(const SmallVector &buffers, GraphT &interference) { // Reset interference graph + auto inDifferentRegion = [&](BufferT *A, BufferT *B) { + auto tA = A->regionIds; + auto tB = B->regionIds; + if (tA.empty() && tB.empty()) + return false; + if (tA.empty() || tB.empty()) + return true; + for (auto t1 : tA) { + for (auto t2 : tB) { + if (t1 != t2) + return true; + } + } + return false; + }; interference.clear(); for (auto x : buffers) { for (auto y : buffers) { @@ -581,6 +673,9 @@ class AllocationAnalysis { xSizeRange.intersects(ySizeRange)) { interference[x].insert(y); } + // if x and y belong to different regions (ignore producer region). + if (inDifferentRegion(x, y) && xSizeRange.intersects(ySizeRange)) + interference[x].insert(y); } } } @@ -642,4 +737,27 @@ void Allocation::run(FuncAllocMapT &funcAllocMap) { triton::AllocationAnalysis(getOperation(), &funcAllocMap, this); } +std::map> +Allocation::getLiveBuffers() { + std::map> liveBuffers; + + Operation *rootOperation = getOperation(); + mlir::Liveness liveness(rootOperation); + auto analyzeOperation = [&](Operation *op) -> void { + auto scratchBuffer = getBufferId(op); + if (scratchBuffer != InvalidBufferId) + liveBuffers[op].push_back(scratchBuffer); + for (auto result : op->getOpResults()) { + auto bufferId = getBufferId(result); + if (bufferId == Allocation::InvalidBufferId) + continue; + auto liveOperations = liveness.resolveLiveness(result); + for (auto depOp : liveOperations) + liveBuffers[depOp].push_back(bufferId); + } + }; + rootOperation->walk(analyzeOperation); + return liveBuffers; +} + } // namespace mlir diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 49d559618..717df8d1b 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -172,8 +172,8 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< void setToEntryState(dataflow::Lattice *lattice) override { propagateIfChanged( - lattice, - lattice->join(AxisInfo::getPessimisticValueState(lattice->getPoint()))); + lattice, lattice->join( + AxisInfo::getPessimisticValueState(lattice->getAnchor()))); } void visitNonControlFlowArguments( @@ -195,9 +195,10 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< dataflow::Lattice>::getLatticeElement; using FuncAxisInfoMapT = DenseMap; - void visitOperation(Operation *op, - ArrayRef *> operands, - ArrayRef *> results) override; + LogicalResult + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; void visitForOpInductionVar(scf::ForOp op, ArrayRef *> argLattices); @@ -277,6 +278,11 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { private: int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { + // Contiguity assumes an increasing sequence. So for SubIOp contiguous + // RHS doesn't produce a contiguous result. + if (isa(op)) + return gcd(lhs.getContiguity(dim), rhs.getConstancy(dim)); + return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)), gcd(lhs.getContiguity(dim), rhs.getConstancy(dim))); } @@ -285,8 +291,7 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { int dim) override { // lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs) // rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs) - // lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) * - // gcd(d_lhs, d_rhs) + // lhs + rhs = k * d_lhs + p * d_rhs = (k * k' + p * p') * gcd(d_lhs, d_rhs) auto rhsDivisibility = rhs.getDivisibility(dim); if constexpr (std::is_same_v) { // %ptr = addptr %lhs, %rhs @@ -883,16 +888,14 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { - auto shift = rhs.getConstantValue().has_value() - ? rhs.getConstantValue().value() - : rhs.getDivisibility(dim); + auto shift = rhs.getConstantValue().value_or(0); auto lhsDivisibility = lhs.getDivisibility(dim); if (lhs.getContiguity(dim) > 1 && shift) { // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n lhsDivisibility = 1; } auto numBits = log2Int(lhsDivisibility); - return multiplyDivisor(lhsDivisibility, 1 << shift); + return multiplyDivisor(lhsDivisibility, 1ll << shift); } int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, @@ -926,9 +929,9 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { - auto shift = rhs.getConstantValue().has_value() - ? rhs.getConstantValue().value() - : rhs.getDivisibility(dim); + if (!rhs.getConstantValue().has_value()) + return 1; + auto shift = rhs.getConstantValue().value(); auto lhsDivisibility = lhs.getDivisibility(dim); if (lhs.getContiguity(dim) > 1 && shift) { // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n @@ -1042,7 +1045,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) visitors.append(); } -void AxisInfoAnalysis::visitOperation( +LogicalResult AxisInfoAnalysis::visitOperation( Operation *op, ArrayRef *> operands, ArrayRef *> results) { // TODO: For sure not the right way to do this @@ -1051,8 +1054,10 @@ void AxisInfoAnalysis::visitOperation( if (op->getValue().getRank() == 0) setToEntryState((dataflow::Lattice *)op); AxisInfo curr = visitors.apply(op, operands); - if (curr.getRank() == 0) - return setAllToEntryStates(results); + if (curr.getRank() == 0) { + setAllToEntryStates(results); + return success(); + } // override with hint auto newContiguity = curr.getContiguity(); auto newDivisibility = curr.getDivisibility(); @@ -1074,12 +1079,14 @@ void AxisInfoAnalysis::visitOperation( // join all lattice elements for (auto *result : results) propagateIfChanged(result, result->join(curr)); + return success(); } void AxisInfoAnalysis::visitForOpInductionVar( scf::ForOp op, ArrayRef *> argLattices) { - auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue(); - auto step = getLatticeElementFor(op, op.getStep())->getValue(); + ProgramPoint *programPoint = getProgramPointAfter(op); + auto lb = getLatticeElementFor(programPoint, op.getLowerBound())->getValue(); + auto step = getLatticeElementFor(programPoint, op.getStep())->getValue(); AxisInfo::DimVectorT knownContiguity(1, 1); AxisInfo::DimVectorT knownDivisibility(1, 1); @@ -1140,6 +1147,14 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, &knownContiguity, &knownDivisibility, &knownConstancy); + else if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. + knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); + knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); + knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); + } } else if (Operation *op = value.getDefiningOp()) { if (isa(op)) { // scf::ForOp, scf::IfOp, scf::WhileOp diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 407a5ae15..f45cd0c2c 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -6,6 +6,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include namespace mlir { @@ -94,7 +95,7 @@ void MembarAnalysis::visitTerminator(Operation *op, void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { OpBuilder::InsertionGuard g(*builder); - auto barrierOp = builder->create(op->getLoc()); + ::insertBarrier(*builder, op); } void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, @@ -117,6 +118,7 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, } BlockInfo curBlockInfo; + auto scratchBufferId = Allocation::InvalidBufferId; if (isa(op)) { // Inter-function dependencies auto callOpInterface = dyn_cast(op); @@ -135,38 +137,44 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, for (auto bufferId : allocation->getBufferIds(value)) { if (bufferId != Allocation::InvalidBufferId) { if (isa(effectInstance.getEffect())) - curBlockInfo.syncWriteIntervals.insert( - allocation->getAllocatedInterval(bufferId)); + curBlockInfo + .syncWriteIntervals[allocation->getAllocatedInterval( + bufferId)] + .insert(op); else if (isa(effectInstance.getEffect())) - curBlockInfo.syncReadIntervals.insert( - allocation->getAllocatedInterval(bufferId)); + curBlockInfo + .syncReadIntervals[allocation->getAllocatedInterval( + bufferId)] + .insert(op); } } } } } - // XXX(Keren): This is a hack as we cannot set side effects for dot ops, but - // on hopper they do have side effects. Need to clean it up - if (auto dotOp = dyn_cast(op)) { - for (auto value : dotOp.getOperands()) { - for (auto bufferId : allocation->getBufferIds(value)) { - if (bufferId != Allocation::InvalidBufferId) - curBlockInfo.syncReadIntervals.insert( - allocation->getAllocatedInterval(bufferId)); - } - } - } - // Scratch buffer is considered as both shared memory write & read - auto bufferId = allocation->getBufferId(op); - if (bufferId != Allocation::InvalidBufferId) { - curBlockInfo.syncWriteIntervals.insert( - allocation->getAllocatedInterval(bufferId)); - curBlockInfo.syncReadIntervals.insert( - allocation->getAllocatedInterval(bufferId)); - } + scratchBufferId = allocation->getBufferId(op); } - if (blockInfo->isIntersected(curBlockInfo)) { + // Scratch buffer operations consist of a series of shared memory operations + // starting from a shared memory write, followed by a series of shared memory + // read/write operations, and ending with a shared memory read, i.e., shared + // memory write -> ... -> shared memory read. + if (scratchBufferId != Allocation::InvalidBufferId) { + if (!curBlockInfo.syncReadIntervals.empty() || + !curBlockInfo.syncWriteIntervals.empty()) { + llvm::report_fatal_error( + "scratch buffer operations should not have any shared memory " + "dependencies"); + } + auto interval = allocation->getAllocatedInterval(scratchBufferId); + curBlockInfo.syncWriteIntervals[interval].insert(op); + if (blockInfo->isIntersected(curBlockInfo, filter)) { + builder->setInsertionPoint(op); + insertBarrier(op, builder); + } + // Ops with a scratch buffer internally syncs read/write on shared memory + blockInfo->sync(); + curBlockInfo.syncReadIntervals[interval].insert(op); + } else if (blockInfo->isIntersected(curBlockInfo, filter)) { builder->setInsertionPoint(op); insertBarrier(op, builder); blockInfo->sync(); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 32cc43c9d..4915d7b1a 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -13,7 +13,9 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" #include "triton/Tools/Sys/GetEnv.hpp" namespace mlir { @@ -34,7 +36,7 @@ SmallVector getParentOrder(Attribute layout) { if (auto sliceEncoding = mlir::dyn_cast(layout)) { return getParentOrder(sliceEncoding.getParent()); } - return getOrder(layout); + return getThreadOrder(layout); } } // namespace @@ -73,7 +75,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { threadOffset = threadsPerWarp[sliceLayout.getDim()]; } else { auto threadsPerWarp = getThreadsPerWarp(srcLayout); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); for (unsigned i = 0; i < order.size(); i++) { if (order[i] == axis) break; @@ -170,7 +172,7 @@ bool ReduceOpHelper::isWarpSynchronous() { return getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == 1; } -SmallVector ReduceOpHelper::getScratchConfig() { +SmallVector ReduceOpHelper::getScratchRepShape() { SmallVector smemShape; // that case doesn't need inter-warp communication if (isWarpSynchronous()) @@ -183,7 +185,7 @@ SmallVector ReduceOpHelper::getScratchConfig() { } unsigned ReduceOpHelper::getScratchSizeInBytes() { - auto smemShape = getScratchConfig(); + auto smemShape = getScratchRepShape(); auto elems = product(smemShape); unsigned bytesPerElem = 0; @@ -399,18 +401,10 @@ unsigned ScanLoweringHelper::getAxisBlockStride() { llvm_unreachable("Axis not found in order"); } -bool maybeSharedAllocationOp(Operation *op) { - // TODO(Keren): This function can be replaced by adding - // MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to - // query the memory effects of the op. - auto *dialect = op->getDialect(); - return dialect && - (dialect->getTypeID() == TypeID::get() || - dialect->getTypeID() == - TypeID::get() || - dialect->getTypeID() == TypeID::get() || - dialect->getTypeID() == TypeID::get() || - dialect->getTypeID() == TypeID::get()); +unsigned getNumScratchElements(ArrayRef shape) { + if (shape.empty()) + return 0; + return product(shape); } static bool supportMFMAGranularity(int m, int n, int k) { @@ -431,6 +425,8 @@ bool supportMFMATypes(Type a, Type b) { if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth()) return false; + auto F8E5M2 = TypeID::get(); + auto F8E4M3FN = TypeID::get(); auto F8E4M3FNUZ = TypeID::get(); auto F8E5M2FNUZ = TypeID::get(); auto F16 = TypeID::get(); @@ -441,6 +437,8 @@ bool supportMFMATypes(Type a, Type b) { {F32, F32}, {F16, F16}, {BF16, BF16}, + {F8E5M2, F8E5M2}, + {F8E4M3FN, F8E4M3FN}, {F8E4M3FNUZ, F8E4M3FNUZ}, {F8E4M3FNUZ, F8E5M2FNUZ}, {F8E5M2FNUZ, F8E4M3FNUZ}, @@ -480,58 +478,6 @@ bool supportMFMA(triton::DotOp op) { return true; } -static bool supportWMMAGranularity(int m, int n, int k) { - return m % 16 == 0 && n % 16 == 0 && k % 16 == 0; -} - -static bool supportWMMATypes(Type a, Type b, Type c, Type d) { - if (a != b || c != d) - return false; - auto aWidth = a.getIntOrFloatBitWidth(); - auto cWidth = c.getIntOrFloatBitWidth(); - if (a.isIntOrIndex()) { - if (!c.isIntOrIndex()) - return false; - bool aValid = aWidth <= 8; - bool cValid = cWidth <= 32; - return aValid && cValid; - } else if (isa(a) && isa(c)) { - if (a.isBF16()) - return c.isBF16() || c.isF32(); - if (a.isF16()) - return c.isF16() || c.isF32(); - return aWidth <= cWidth && aWidth <= 16; - } - return false; -} - -bool supportWMMA(triton::DotOp op) { - auto aTy = cast(op.getA().getType()); - auto bTy = cast(op.getB().getType()); - auto cTy = cast(op.getC().getType()); - auto dTy = cast(op.getResult().getType()); - - auto aElemTy = aTy.getElementType(); - auto bElemTy = bTy.getElementType(); - auto cElemTy = cTy.getElementType(); - auto dElemTy = dTy.getElementType(); - - if (!supportWMMATypes(aElemTy, bElemTy, cElemTy, dElemTy)) - return false; - - auto aShape = aTy.getShape(); - auto bShape = bTy.getShape(); - - auto rank = aShape.size(); - assert(bShape.size() == rank); - assert(aShape[rank - 1] == bShape[rank - 2]); - if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1], - aShape[rank - 1])) - return false; - - return true; -} - bool supportMMA(triton::DotOp op, int version) { // Refer to mma section for the data type supported by Volta and Hopper // Tensor Core in @@ -542,20 +488,28 @@ bool supportMMA(triton::DotOp op, int version) { if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) return false; auto retType = op.getType(); + RankedTensorType typeA = op.getA().getType(); + int k = typeA.getShape().back(); + // If k size is smaller than the native mma size, we cannot use MMA. + if (k < 256 / aElemTy.getIntOrFloatBitWidth()) + return false; auto retShapePerCTA = getShapePerCTA(retType); auto rank = retShapePerCTA.size(); auto mod = op->getParentOfType(); int numWarps = TritonGPUDialect::getNumWarps(mod); + // TODO(Keren): for now, fallback to MMAv2 if handling batch matmul. + if (rank == 3) + return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && retShapePerCTA[rank - 1] % 8 == 0 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { return false; } // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. if (op.getMaxNumImpreciseAcc() < 32 && - (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) && cast(op.getType()).getElementType().isF32()) { return false; } @@ -582,11 +536,78 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } -bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { - auto srcLayout = srcTy.getEncoding(); - auto dstLayout = dstTy.getEncoding(); - auto mfmaLayout = dyn_cast(srcLayout); - auto dotOperandLayout = dyn_cast(dstLayout); +bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + auto blockedLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (blockedLayout == nullptr || dotOperandLayout == nullptr) + return false; + auto parentLayout = + dyn_cast(dotOperandLayout.getParent()); + if (parentLayout == nullptr) + return false; + auto opShape = srcTy.getShape(); + auto rank = opShape.size(); + + int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1; + auto ctaLayout = blockedLayout.getCTALayout(); + + // The following logic checks that a source blocked layout matches a + // destination dot operand layout. This means that given tensor in source + // layout could be converted into destination layout without any data movement + // between registers or threads. + // + // It is considered a match if + // 1) Each thread in source layout holds a whole copy of all elements along + // the K dimension of a tensor + // 2) Distribution of data along all other non-K dimensions(Batch/M/N) + // matches between source and destination parent layouts. + // + // First condition comes from the property of dot operand layout with Blocked + // parent: size per threads along K dimension equals size of the tensor along + // K. Second condition comes from other property: dot operand layout + // inherits non-K dimensions from it's parent layout. + // + // clang-format off + // + // For example, following conversion is a no op: + // tensor<128x32xf16, #blocked<{sizePerThread = [2, 32], threadsPerWarp = [32, 1]}>> + // -> + // tensor<128x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>> + // + // clang-format on + bool ctaLayoutCompatible = + ctaLayout.getCTASplitNum()[kDim] == 1 && + blockedLayout.getCTALayout() == parentLayout.getCTALayout(); + bool threadHoldsWholeKDim = + blockedLayout.getSizePerThread()[kDim] == opShape[kDim]; + bool nonKDimCompatible = + blockedLayout.getOrder() == parentLayout.getOrder() && + blockedLayout.getSizePerThread()[nonKDim] == + parentLayout.getSizePerThread()[nonKDim] && + blockedLayout.getThreadsPerWarp()[nonKDim] == + parentLayout.getThreadsPerWarp()[nonKDim] && + blockedLayout.getWarpsPerCTA()[nonKDim] == + parentLayout.getWarpsPerCTA()[nonKDim]; + bool matrixDimsCompatible = + ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible; + if (rank == 2) + return matrixDimsCompatible; + + // additional check for batch dimension if it is present + assert(rank == 3); + bool bDimCompatible = + blockedLayout.getSizePerThread()[0] == + parentLayout.getSizePerThread()[0] && + blockedLayout.getThreadsPerWarp()[0] == + parentLayout.getThreadsPerWarp()[0] && + blockedLayout.getWarpsPerCTA()[0] == parentLayout.getWarpsPerCTA()[0]; + return matrixDimsCompatible && bDimCompatible; +} + +bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + auto mfmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); if (mfmaLayout == nullptr || dotOperandLayout == nullptr) return false; // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is @@ -600,46 +621,104 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); } -static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) { - auto src = dyn_cast(srcEncoding); - auto dst = dyn_cast(dstEncoding); - if (!src || !dst) - return false; - // when #mma = MmaEncoding - return src && dst && src.getVersionMajor() == 3 && - src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 && - dst.getWarpsPerCTA()[1] == 1; -} - -bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { - return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding()); -} - // For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy) { - auto srcLayout = srcTy.getEncoding(); - auto dstLayout = dstTy.getEncoding(); - auto mmaLayout = cast(srcLayout); - auto dotOperandLayout = cast(dstLayout); + auto mmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (!mmaLayout || !dotOperandLayout) { + return false; + } int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); + auto parentTy = RankedTensorType::get( + srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent()); auto ans = mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && - isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) && + mmaLayout.getWarpsPerCTA()[1] == 1 && + !cvtNeedsSharedMemory(parentTy, srcTy) && (elementTypeSize == 16 || elementTypeSize == 8); return ans; } +// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity +// under kBlock, kWarp or kLane (in that order). The idea here is that if we +// have a transformation that's the identity on kBlock, we don't need to use +// distributed shared memory. If it's also the identity on kWarp, we can +// transfer via warp-shuffles, and if it's the identity on kLane just have to +// reorder the registers +std::optional minimalCvtLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { + MLIRContext *ctx = srcTy.getContext(); + std::optional srcLayout = + toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + std::optional dstLayout = + toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + if (!(srcLayout.has_value() && dstLayout.has_value())) + return std::nullopt; + // comp describes the layout function to create dst from src. + LinearLayout comp = dstLayout->invertAndCompose(*srcLayout); + // We try to quotient by the largest subspace first + auto dims = SmallVector{"block", "warp", "lane", "register"}; + for (auto dim : dims) { + auto quotient = comp.quotient(StringAttr::get(ctx, dim)); + if (!quotient.has_value()) { + break; + } + comp = *quotient; + } + return comp; +} + +bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + if (!layout.has_value()) { + return false; + } + auto kRegister = StringAttr::get(ctx, "register"); + auto outDims = llvm::to_vector(layout->getOutDimNames()); + return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister}); +} + +bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + if (!layout.has_value()) { + return false; + } + auto kRegister = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + return llvm::to_vector(layout->getOutDimNames()) == + llvm::SmallVector{kRegister, kLane}; +} + +bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { + // TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`, + // `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully + // subsumed by the linear-layout checks. + // TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not + // supported yet in Triton's backend. + return !cvtReordersRegisters(srcTy, dstTy) && + !isBlockedToDotShortcut(srcTy, dstTy) && + !isMmaToDotShortcut(srcTy, dstTy) && + !isMfmaToDotShortcut(srcTy, dstTy); +} + +bool atomicNeedsSharedMemory(Value value) { + auto type = value.getType(); + if (isa(type) || value.use_empty()) + return false; + return true; +} + bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) return true; // dot_op = #mma // when #mma = MmaEncoding - auto srcLayout = srcTy.getEncoding(); - auto dstLayout = dstTy.getEncoding(); - auto mmaLayout = mlir::cast(srcLayout); - auto dotOperandLayout = mlir::cast(dstLayout); - return mmaLayout.getVersionMajor() == 2 && + auto mmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 && mmaLayout.getWarpsPerCTA()[1] == 1 && dotOperandLayout.getOpIdx() == 0 && dotOperandLayout.getParent() == mmaLayout && @@ -816,15 +895,16 @@ class ConstantAnalysis : public DataFlowAnalysis { LogicalResult initialize(Operation *top) override { WalkResult result = top->walk([&](Operation *op) { - if (failed(visit(op))) + ProgramPoint programPoint(op); + if (failed(visit(&programPoint))) return WalkResult::interrupt(); return WalkResult::advance(); }); return success(!result.wasInterrupted()); } - LogicalResult visit(ProgramPoint point) override { - Operation *op = point.get(); + LogicalResult visit(ProgramPoint *point) override { + Operation *op = point->getOperation(); Attribute value; if (matchPattern(op, m_Constant(&value))) { auto *constant = getOrCreate>( diff --git a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp index aae9faf0e..8d9ad31c0 100644 --- a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -40,6 +40,8 @@ struct AllocateSharedMemory } if (offset == -1) return; + if (op->hasAttr("allocation.offset")) + return; op->setAttr("allocation.offset", IntegerAttr::get(IntegerType::get(ctx, 32), offset)); }); diff --git a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp index a3f55f1e7..508eb25f9 100644 --- a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -34,19 +34,41 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { return failure(); } } - llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(), - adaptor.getFunc(), adaptor.getLine(), rewriter); + llAssert(op, condition, adaptor.getMessage(), rewriter); + if (isa(op.getCondition().getType())) { + // Add a barrier to avoid a race condition in case an assert is followed + // by an op that may trap if the assert condition is true. Since the + // tensor in those two operations may have different layout we need to + // make sure all the threads are done executing the assert before going to + // the next op. + barrier(); + } rewriter.eraseOp(op); return success(); } // op: the op at which the assert is inserted. Unlike printf, we need to // know about the op to split the block. void llAssert(Operation *op, Value condition, StringRef message, - StringRef file, StringRef func, int line, ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter::InsertionGuard guard(rewriter); + auto ctx = rewriter.getContext(); auto loc = op->getLoc(); + + StringRef file = "unknown"; + StringRef func = "unknown"; + int line = 0; + int col = 0; + + while (auto callLoc = dyn_cast(loc)) + loc = callLoc.getCallee(); + + if (auto fileLineColLoc = dyn_cast(loc)) { + file = fileLineColLoc.getFilename(); + line = fileLineColLoc.getLine(); + col = fileLineColLoc.getColumn(); + } + // #block1 // if (condition) { // #block2 diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index cca2830b0..4d57131d0 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -34,5 +34,4 @@ add_triton_library(TritonGPUToLLVM TritonGPUIR TritonGPUTransforms TritonNvidiaGPUTransforms - NVGPUIR ) diff --git a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp index 9765d7bf0..8d5a63eb1 100644 --- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -51,8 +51,10 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern { // CallOpInterfaceLowering is adapted from // https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 struct CallOpConversion : public ConvertOpToLLVMPattern { - CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} + CallOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} LogicalResult matchAndRewrite(triton::CallOp callOp, @@ -85,8 +87,8 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { promotedOperands.push_back(base); return promotedOperands; } - promotedOperands.push_back( - LLVM::getSharedMemoryBase(callOp->getLoc(), rewriter, callOp)); + promotedOperands.push_back(LLVM::getSharedMemoryBase( + callOp->getLoc(), rewriter, targetInfo, callOp)); return promotedOperands; } @@ -107,6 +109,10 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { auto newCallOp = rewriter.create( callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promotedOperands, callOp->getAttrs()); + newCallOp.getProperties().setOpBundleSizes( + rewriter.getDenseI32ArrayAttr({})); + newCallOp.getProperties().setOperandSegmentSizes( + {static_cast(promotedOperands.size()), 0}); return newCallOp; } @@ -129,13 +135,14 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { } return results; } + const TargetInfoBase &targetInfo; }; } // namespace void mlir::triton::populateControlFlowOpToLLVMPattern( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { + const TargetInfoBase &targetInfo, PatternBenefit benefit) { patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 94894ceb1..29dd696e9 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1,116 +1,36 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Analysis/Utility.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -using mlir::isLayoutMmaV1; -using mlir::LLVM::getMultiDimOffset; +namespace { + +using ::mlir::isLayoutMmaV1; +using ::mlir::LLVM::getMultiDimOffset; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStridesFromShapeAndOrder; -using mlir::LLVM::getWrappedMultiDimOffset; +using ::mlir::LLVM::getWrappedMultiDimOffset; using ::mlir::LLVM::linearize; -using ::mlir::triton::gpu::DotOperandEncodingAttr; -using ::mlir::triton::gpu::getOrder; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; -using ::mlir::triton::gpu::getSizePerThread; -using ::mlir::triton::gpu::getTotalElemsPerThread; -using ::mlir::triton::gpu::isaDistributedLayout; -using ::mlir::triton::gpu::SharedEncodingAttr; - -namespace { - -struct LocalLoadOpConversion - : public ConvertOpToLLVMPattern { -public: - LocalLoadOpConversion(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { - } - LogicalResult - matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MemDescType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - // TODO: do we need to check if src is shared ? - if (isa(srcLayout) && isaDistributedLayout(dstLayout)) { - return lowerSharedToDistributed(op, adaptor, getTypeConverter(), - rewriter); - } - if (isa(dstLayout) && - isa( - cast(dstLayout).getParent())) { - return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter); - } - return failure(); - } +using namespace mlir::triton::gpu; -private: - LogicalResult - lowerSharedToDotOpFMA(triton::gpu::LocalLoadOp op, - triton::gpu::LocalLoadOpAdaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - RankedTensorType dstTy = op.getType(); - Attribute dstLayout = dstTy.getEncoding(); - auto dotLayout = cast(dstLayout); - auto blockedLayout = cast( - cast(dstLayout).getParent()); - auto thread = getThreadId(rewriter, loc); - Value res = SharedToDotOperandFMA::convertLayout( - dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, - thread, loc, getTypeConverter(), rewriter); - rewriter.replaceOp(op, res); - return success(); - } - LogicalResult - lowerSharedToDistributed(triton::gpu::LocalLoadOp op, - triton::gpu::LocalLoadOpAdaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto srcTy = op.getSrc().getType(); - auto dstTy = op.getResult().getType(); - auto dstShape = dstTy.getShape(); - assert(dstShape.size() <= 2 && - "Unexpected rank of ConvertLayout(shared->blocked)"); - auto srcSharedLayout = cast(srcTy.getEncoding()); - auto dstLayout = dstTy.getEncoding(); - auto inOrd = getOrder(srcSharedLayout); - - auto smemObj = getSharedMemoryObjectFromStruct( - loc, adaptor.getSrc(), - typeConverter->convertType(srcTy.getElementType()), rewriter); - auto elemTy = typeConverter->convertType(dstTy.getElementType()); - - auto srcStrides = - getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter); - - SmallVector outVals = - loadSharedToDistributed(op.getResult(), op.getSrc(), smemObj, elemTy, - loc, rewriter, targetInfo); - - Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); - rewriter.replaceOp(op, result); - - return success(); - } - -private: - const TargetInfoBase &targetInfo; -}; +// XXX(Keren): A temporary knob to control the use of legacy MMA conversion +// because LinearLayout seems to have some performance issues. +constexpr bool useLegacyMMAConversion = false; struct ConvertLayoutOpConversion - : public ConvertOpToLLVMPattern { + : public ConvertOpToLLVMPattern { public: ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, @@ -119,23 +39,27 @@ struct ConvertLayoutOpConversion } LogicalResult - matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType srcTy = op.getSrc().getType(); RankedTensorType dstTy = op.getType(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); if (isSupported(srcLayout, dstLayout)) { - return lowerDistributedToDistributed(op, adaptor, rewriter); + return lowerDistributedToDistributed(op, adaptor, rewriter, targetInfo); } return failure(); } private: bool isSupported(Attribute srcLayout, Attribute dstLayout) const { - return isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout) && + return isa( + srcLayout) && + isa( + dstLayout) && !isLayoutMmaV1(srcLayout) && !isLayoutMmaV1(dstLayout); } + // shared memory rd/st for blocked or mma layout with data padding void processReplica(Location loc, ConversionPatternRewriter &rewriter, bool stNotRd, RankedTensorType type, @@ -191,10 +115,9 @@ struct ConvertLayoutOpConversion shapePerCTA); Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, paddedRepShape, outOrd); - auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + auto elemPtrTy = smemBase.getType(); Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); - ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); if (stNotRd) { Value valVec = undef(vecTy); for (unsigned v = 0; v < vec; ++v) { @@ -225,9 +148,9 @@ struct ConvertLayoutOpConversion // blocked/mma -> blocked/mma. // Data padding in shared memory to avoid bank conflict. LogicalResult - lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + lowerDistributedToDistributed(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { auto loc = op.getLoc(); auto typeConverter = getTypeConverter(); RankedTensorType srcTy = op.getSrc().getType(); @@ -245,9 +168,7 @@ struct ConvertLayoutOpConversion } Value smemBase = - LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); - auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); - smemBase = bitcast(smemBase, elemPtrTy); + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); auto shape = dstTy.getShape(); unsigned rank = dstTy.getRank(); SmallVector numReplicates(rank); @@ -272,13 +193,15 @@ struct ConvertLayoutOpConversion inNumCTAs[d] = ceil(shapePerCTA[d], inPerCTA); outNumCTAs[d] = ceil(shapePerCTA[d], outPerCTA); } + // Potentially we need to store for multiple CTAs in this replication auto accumNumReplicates = product(numReplicates); auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - unsigned inVec = 0; - unsigned outVec = 0; - auto origRepShape = getRepShapeForCvtLayout(op); - auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + unsigned inVec = scratchConfig.inVec; + unsigned outVec = scratchConfig.outVec; + const auto &paddedRepShape = scratchConfig.paddedRepShape; + const auto &origRepShape = scratchConfig.repShape; unsigned outElems = getTotalElemsPerThread(dstTy); auto outOrd = getOrder(dstLayout); @@ -288,18 +211,12 @@ struct ConvertLayoutOpConversion auto multiDimRepId = getMultiDimIndex(repId, numReplicates, outOrd); if (repId != 0) { - barrier(); - } - auto successful = targetInfo.processReplicaUsingStMatrix( - rewriter, loc, smemBase, vals, srcTy, - getTypeConverter()->convertType(srcTy.getElementType()), - paddedRepShape, origRepShape, outOrd, accumNumReplicates); - if (!successful) { - processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, - multiDimRepId, inVec, paddedRepShape, origRepShape, - outOrd, vals, smemBase); + insertBarrier(rewriter, op); } - barrier(); + processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, + multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, + vals, smemBase); + insertBarrier(rewriter, op); processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape, origRepShape, outOrd, outVals, smemBase); @@ -314,11 +231,466 @@ struct ConvertLayoutOpConversion private: const TargetInfoBase &targetInfo; }; + +struct ConvertLayoutOpBlockedToDotOpShortcutConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase &targetInfo; + explicit ConvertLayoutOpBlockedToDotOpShortcutConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + + const auto &shape = op.getType().getShape(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + auto dstDotEncoding = dyn_cast(dstTy.getEncoding()); + if (!dstDotEncoding) + return failure(); + if (!isa(srcTy.getEncoding()) || + !isa(dstDotEncoding.getParent())) + return failure(); + if (cvtNeedsSharedMemory(srcTy, dstTy)) + return failure(); + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } +}; + +struct ConvertLayoutOpUsingLinearLayoutsConversion + : public ConvertOpToLLVMPattern { + const TargetInfoBase &targetInfo; + + // Set benefit to 2 so that this pattern applies before other convert-layout + // conversions. TODO(jlebar): Eventually we want this to be the only pattern. + explicit ConvertLayoutOpUsingLinearLayoutsConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit = 2) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + + const auto &shape = op.getType().getShape(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + auto conversion = minimalCvtLayout(srcTy, dstTy); + if (!conversion.has_value()) { + return rewriter.notifyMatchFailure( + op, "NYI. srcTy and/or dstTy don't implement LLs yet"); + } + + assert(to_vector(conversion->getInDimNames()) == + to_vector(conversion->getOutDimNames())); + auto dims = conversion->getInDimNames(); + if (llvm::is_contained(dims, str_attr("block"))) { + // Case 1: Transfer between values in different CTAs. + // This requires moving values through distributed shared memory. + return rewriter.notifyMatchFailure( + op, "NYI: Transfer between different CTAs"); + } else if (llvm::is_contained(dims, str_attr("warp"))) { + // Case 2: Transfer between values in the same CTA, in which case we move + // values through shared memory. + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, str_attr("lane"))) { + // Case 3. Transfer between values in the same warp, in which case we try + // to move values using warp shuffles, though if the pattern is + // complicated enough we may fall back to using shared memory + // TODO(Keren): implement warp shuffle instead of using the general + // approach that uses shared memory + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, str_attr("register"))) { + // Case 4. Transfer between values in the same thread, in which case we + // simply reorder the elements of adaptor.getSrc(). + return transferWithinThread(op, *conversion, adaptor, rewriter); + } else { + // The two layouts are equivalent. We should probably remove these in + // RemoveLayoutConversion. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + } + + LogicalResult + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + StringAttr kRegister = str_attr("register"); + assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); + + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals; + outVals.resize(conversion.getInDimSize(kRegister)); + for (int i = 0; i < conversion.getInDimSize(kRegister); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + outVals[i] = inVals[srcIdx]; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + LogicalResult transferWithinBlock(ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + + // TODO (Keren): Currently, we handle general mma/blocked/slice -> + // mma/blocked/slice conversions. + // The following tasks must be completed before we can remove the layoutIsOK + // check: + // 1. Support for AMD's MFMA and WMMA + std::function layoutIsOK = [&](Attribute layout) { + if (auto nvidiaMma = dyn_cast(layout)) { + if (useLegacyMMAConversion) { + return false; + } + return true; + } + if (auto dotOperand = dyn_cast(layout)) { + if (auto nvidiaMma = + dyn_cast(dotOperand.getParent())) { + if (product(getCTAsPerCGA(nvidiaMma)) > 1) { + return false; + } + if (useLegacyMMAConversion) { + return false; + } + // FIXME [Dot LL] + // Enabling LL path for buggy kWidth path + bool largeKWidth = + dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64; + return largeKWidth && nvidiaMma.isAmpere(); + } + } + if (isa(layout)) { + return true; + } + if (auto slice = dyn_cast(layout)) { + return layoutIsOK(slice.getParent()); + } + return false; + }; + if (!layoutIsOK(srcTy.getEncoding()) || !layoutIsOK(dstTy.getEncoding())) { + return failure(); + } + + assert(cvtNeedsSharedMemory(srcTy, dstTy)); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(!inVals.empty()); + + // We munge the input values by converting i (n<8) elements to i8 and + // pointers to i64. This is necessary because TargetInfo::loadDShared and + // storeDShared can't handle vectors of pointers or sub-byte elements. + auto elemTy = srcTy.getElementType(); + auto isSubByteInt = + elemTy.isInteger() && elemTy.getIntOrFloatBitWidth() < 8; + auto isPtr = isa(elemTy); + auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy); + if (isSubByteInt) + elemTy = IntegerType::get(elemTy.getContext(), 8); + else if (isPtr) + elemTy = IntegerType::get(elemTy.getContext(), 64); + auto llvmElemTy = getTypeConverter()->convertType(elemTy); + + // Munge input values + for (const auto &it : llvm::enumerate(inVals)) { + if (isSubByteInt) { + inVals[it.index()] = zext(llvmElemTy, it.value()); + } else if (isPtr) { + inVals[it.index()] = ptrtoint(llvmElemTy, it.value()); + } + } + + // Pretty sure this is the identity function ATM + // It'd be better to simply call `quotient({kBlock})` and + // remove kBlock from transferWithinBlockImpl + auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout); + auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout); + SmallVector outVals = + transferWithinBlockImpl(inVals, op, srcLayoutWithinBlock, + dstLayoutWithinBlock, adaptor, rewriter); + + // Unmunge output values + for (const auto &it : llvm::enumerate(outVals)) { + if (isSubByteInt) { + outVals[it.index()] = trunc(llvmElemTyOrig, it.value()); + } else if (isPtr) { + outVals[it.index()] = inttoptr(llvmElemTyOrig, it.value()); + } + } + + // FIXME [Dot LL] + // We know it's just for largeKWidth case in Ampere + // In this case, we need to pack the outputs into i32 + if (isa(dstTy.getEncoding())) { + auto concat = [&](Value a, Value b) { + return or_(zext(i32_ty, bitcast(a, i16_ty)), + shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); + }; + + SmallVector outVals32(outVals.size() / 2); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); + } + outVals = outVals32; + } + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + SmallVector + transferWithinBlockImpl(ArrayRef inVals, ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = op.getContext(); + auto loc = op.getLoc(); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + StringAttr kOffset = str_attr("offset"); + StringAttr kIteration = str_attr("iteration"); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(srcLayout.getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + + auto scratchConfig = + getScratchConfigForCvt(op.getSrc().getType(), op.getType()); + auto tensorShapePerCTA = convertType(getShapePerCTA( + op.getSrc().getType().getEncoding(), op.getType().getShape())); + // Input dims: [offset, iteration, block] + // Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape + LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion( + ctx, tensorShapePerCTA, scratchConfig.repShape, scratchConfig.order); + + // Layout for the store from registers to shared memory. + // + // Note: If two threads in the same warp write to the same shmem offset, the + // hardware resolves that without a stall or a bank conflict. Therefore we + // don't need to avoid duplicate writes. + // Input dims: [reg, lane, warp] + // Output dims: [offset, iteration] + std::optional shmemStoreLayout = + chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape, + scratchConfig.paddedRepShape, scratchConfig.order, + /*swizzleByteSize=*/0); + bool isStMatrix = shmemStoreLayout.has_value(); + if (!isStMatrix) { + shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout); + } + assert(shmemStoreLayout.has_value()); + + const int shmemAllocatedNumElems = + getNumScratchElements(scratchConfig.paddedRepShape); + assert(shmemStoreLayout->getOutDimSize(kOffset) <= shmemAllocatedNumElems); + + // Layout for the load from shmem to registers. + LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout); + + // Check that the `register` fully determines the `iteration`. That is, + // each thread does exactly the same reads and writes to shmem on each + // iteration, just with different input/output registers. + assert(shmemStoreLayout->sublayoutIsZero({kLane, kWarp, kBlock}, + {kIteration})); + assert( + shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); + + // iteration -> registers + SmallVector> inRegsForIter = + collectRegsForIter(ctx, *shmemStoreLayout); + SmallVector> outRegsForIter = + collectRegsForIter(ctx, shmemLoadLayout); + + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + auto sharedPtrTy = smemBase.getType(); + Type elemTy = inVals[0].getType(); + auto outSize = shmemLoadLayout.getInDimSize(kRegister); + auto iterations = sharedLayout.getInDimSize(kIteration); + assert(scratchConfig.inVec * iterations <= inVals.size()); + assert(scratchConfig.outVec * iterations <= outSize); + + // Check only one dimension has been padded. + // This means the difference between the padded shape and the original shape + // should only be in one dimension, specifically in + // `scratchConfig.order[0]`. + auto rank = scratchConfig.repShape.size(); + for (auto i = 0; i < rank; i++) { + if (i == scratchConfig.order[0]) { + continue; + } + assert(scratchConfig.repShape[i] == scratchConfig.paddedRepShape[i]); + } + auto paddedStride = scratchConfig.repShape[scratchConfig.order[0]]; + auto paddedSize = + scratchConfig.paddedRepShape[scratchConfig.order[0]] - paddedStride; + + // Linear layout function is split in two parts below: + // + // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0) + // offset = regBase xor regIdx + // + // It is the same hack as what we've done in the emitIndices function to get + // around performance issues on AMD GPUs + auto getVecAddr = [&](LinearLayout &layout, Value ®Base, + int regSlice) -> Value { + auto regIdx = layout + .apply({{kRegister, regSlice}, + {kLane, 0}, + {kWarp, 0}, + {kBlock, 0}})[0] + .second; + Value offset = xor_(regBase, i32_val(regIdx)); + if (paddedSize > 0) { + assert(llvm::isPowerOf2_32(paddedStride)); + assert(llvm::isPowerOf2_32(paddedSize)); + auto rshiftVal = llvm::Log2_32(paddedStride); + auto lshiftVal = llvm::Log2_32(paddedSize); + offset = add(shl(lshr(offset, i32_val(rshiftVal)), i32_val(lshiftVal)), + offset); + } + auto vecAddr = gep(sharedPtrTy, elemTy, smemBase, offset); + vecAddr.setInbounds(true); + return vecAddr; + }; + + auto storeBase = applyLinearLayout(loc, rewriter, *shmemStoreLayout, + {{kRegister, i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})[0] + .second; + auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout, + {{kRegister, i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})[0] + .second; + // register idx -> Value + llvm::MapVector outVals; + for (int i = 0; i < iterations; i++) { + if (i != 0) + insertBarrier(rewriter, op); + + auto &inRegs = inRegsForIter[i]; + auto &outRegs = outRegsForIter[i]; + + // When using `stmatrix`, we can store `inVec` elements even if they are + // not contiguous + auto inVec = isStMatrix ? shmemStoreLayout->getNumConsecutiveInOut() + : scratchConfig.inVec; + for (int j = 0; j < inVals.size() / iterations; j += inVec) { + auto inRegSlice = inRegs[j]; + Value vecAddr = getVecAddr(*shmemStoreLayout, storeBase, inRegSlice); + SmallVector inValsVec; + for (int k = 0; k < inVec; k++) + inValsVec.push_back(inVals[inRegSlice + k]); + Value valsVec = packLLVector(loc, inValsVec, rewriter); + if (isStMatrix) { + targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + } else { + targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec, + /*pred=*/true_val()); + } + } + + insertBarrier(rewriter, op); + + for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) { + auto outRegSlice = outRegs[j]; + auto vecAddr = getVecAddr(shmemLoadLayout, loadBase, outRegSlice); + Value valsVec = + targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt, + vec_ty(elemTy, scratchConfig.outVec), + /*pred=*/true_val()); + for (Value v : unpackLLVector(loc, valsVec, rewriter)) + outVals[outRegSlice++] = v; + } + } + + SmallVector outValsVec; + for (size_t i = 0; i < outVals.size(); i++) + outValsVec.push_back(outVals[i]); + return outValsVec; + } + + // Determine which registers are read/written in which iteration of the shmem + // transfer specified by `layout`. + SmallVector /*registers*/> + collectRegsForIter(MLIRContext *ctx, const LinearLayout &layout) const { + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + StringAttr kIteration = str_attr("iteration"); + + // The choice of iteration should be determined only by the register. That + // is, it should be correct to split the register dimension into iterations. + assert(layout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration})); + + LinearLayout sublayout = layout.sublayout({kRegister}, {kIteration}); + SmallVector> ret(sublayout.getOutDimSize(kIteration)); + for (int reg = 0; reg < sublayout.getInDimSize(kRegister); reg++) { + auto idx = sublayout.apply({{kRegister, reg}}); + ret[idx.begin()->second].push_back(reg); + } + return ret; + } +}; + } // namespace +void mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add( + typeConverter, targetInfo, benefit); +} + void mlir::triton::populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { + // We prefer using the linear layout conversion, so it gets a higher benefit. + // Eventually the LL conversion will subsume all of the others and be the only + // one left. + mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( + typeConverter, targetInfo, patterns, benefit.getBenefit() + 1); + patterns.add( + typeConverter, targetInfo, benefit); patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index b7bd5fbc3..be2e6f584 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -12,7 +12,6 @@ using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; -using ::mlir::triton::gpu::isaDistributedLayout; using ::mlir::triton::gpu::SharedEncodingAttr; SmallVector @@ -132,7 +131,7 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, } auto elemTy = typeConverter->convertType(aTensorTy.getElementType()); - Type ptrTy = ptr_ty(rewriter.getContext(), 3); + Type ptrTy = aSmem.base.getType(); SmallVector aPtrs(aNumPtr); for (int i = 0; i < aNumPtr; ++i) aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); @@ -198,7 +197,7 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, } auto elemTy = typeConverter->convertType(bTensorTy.getElementType()); - Type ptrTy = ptr_ty(rewriter.getContext(), 3); + Type ptrTy = bSmem.base.getType(); SmallVector bPtrs(bNumPtr); for (int i = 0; i < bNumPtr; ++i) bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 690155ee5..1346cc143 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -44,7 +44,6 @@ void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) { }); } -template void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, ShortcutFn shortcutFn) { int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); @@ -55,7 +54,7 @@ void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, OpBuilder builder(cvtOp); auto srcType = cast(cvtOp.getSrc().getType()); auto dstType = cast(cvtOp.getType()); - auto srcMma = dyn_cast(srcType.getEncoding()); + auto srcMma = dyn_cast(srcType.getEncoding()); auto dstDotOp = dyn_cast(dstType.getEncoding()); if (srcMma && dstDotOp && !shortcutFn(srcType, dstType)) { @@ -76,12 +75,6 @@ void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, }); } -template void decomposeTensorCoreToDotLayoutConversion< - triton::gpu::NvidiaMmaEncodingAttr>(ModuleOp, ShortcutFn); -template void - decomposeTensorCoreToDotLayoutConversion( - ModuleOp, ShortcutFn); - void decomposeBlockedToDotLayoutConversion(ModuleOp module) { int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); @@ -90,17 +83,22 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { OpBuilder builder(cvtOp); auto srcType = cast(cvtOp.getSrc().getType()); auto dstType = cast(cvtOp.getType()); + if (!cvtNeedsSharedMemory(srcType, dstType)) + return; auto srcBlocked = dyn_cast(srcType.getEncoding()); auto dstDotOp = dyn_cast(dstType.getEncoding()); if (srcBlocked && dstDotOp) { + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = MemDescType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( module.getContext(), dstDotOp, srcType.getShape(), srcBlocked.getOrder(), srcBlocked.getCTALayout(), - srcType.getElementType())); + srcType.getElementType()), + sharedMemorySpace); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getSrc()); addAttrs(tmp, cvtOp->getAttrs()); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 0287207be..8ee166866 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -41,36 +41,60 @@ SmallVector reorderValues(const SmallVector &values, Type inType, if (inBitWidth == ouBitWidth) return values; if (inBitWidth == 16 && ouBitWidth == 32) { + // Register layout conversion: + // + // [0, 1], [4, 5] ⟶ [0], [1], [4], [5] + // [2, 3], [6, 7] [2], [3], [6], [7] + // + // Original access order: + // + // [0, 1], [2, 3], [4, 5], [6, 7] + // + // Transformed access order: + // + // [0], [2], [1], [3], [4], [6], [5], [7] SmallVector ret; for (unsigned i = 0; i < values.size(); i += 8) { ret.push_back(values[i]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 5]); ret.push_back(values[i + 2]); + ret.push_back(values[i + 1]); ret.push_back(values[i + 3]); + ret.push_back(values[i + 4]); ret.push_back(values[i + 6]); + ret.push_back(values[i + 5]); ret.push_back(values[i + 7]); } return ret; } if (inBitWidth == 8 && ouBitWidth == 16) { + // Register layout conversion: + // + // [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11] + // [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15] + // + // Original access order: + // + // [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15] + // + // Transformed access order: + // + // [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15] SmallVector ret; for (unsigned i = 0; i < values.size(); i += 16) { - ret.push_back(values[i + 0]); + ret.push_back(values[i]); ret.push_back(values[i + 1]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 8]); - ret.push_back(values[i + 9]); - ret.push_back(values[i + 10]); - ret.push_back(values[i + 11]); ret.push_back(values[i + 4]); ret.push_back(values[i + 5]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); ret.push_back(values[i + 6]); ret.push_back(values[i + 7]); + ret.push_back(values[i + 8]); + ret.push_back(values[i + 9]); ret.push_back(values[i + 12]); ret.push_back(values[i + 13]); + ret.push_back(values[i + 10]); + ret.push_back(values[i + 11]); ret.push_back(values[i + 14]); ret.push_back(values[i + 15]); } @@ -299,7 +323,7 @@ struct MulhiUIOpConversion LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(rewriter, op, funcName, funcType); return { - rewriter.create(loc, funcOp, operands[0]).getResult()}; + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; } protected: @@ -327,7 +351,7 @@ struct ExternElementwiseOpConversion LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath()); return { - rewriter.create(loc, funcOp, operands[0]).getResult()}; + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; } }; @@ -610,10 +634,9 @@ struct IndexCastOpLowering if (targetBits == sourceBits) return {operands[0][0]}; if (targetBits < sourceBits) - return {rewriter.replaceOpWithNewOp(op, elemTy, - operands[0][0])}; - return { - rewriter.replaceOpWithNewOp(op, elemTy, operands[0][0])}; + return { + rewriter.create(op.getLoc(), elemTy, operands[0][0])}; + return {rewriter.create(op.getLoc(), elemTy, operands[0][0])}; } }; diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 47f40ebec..8ffa9517e 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -18,8 +18,9 @@ using namespace mlir::triton; /// information. struct FuncOpConversion : public ConvertOpToLLVMPattern { FuncOpConversion(LLVMTypeConverter &converter, int numWarps, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {} + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps), + targetInfo(targetInfo) {} /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument @@ -38,12 +39,14 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { } triton::FuncOp amendFuncOp(triton::FuncOp funcOp, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { // Push back a variable that indicates the current stack pointer of shared // memory to the function arguments. auto loc = funcOp.getLoc(); auto ctx = funcOp->getContext(); - auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), + targetInfo.getSharedAddressSpace()); // 1. Modify the function type to add the new argument. auto funcTy = funcOp.getFunctionType(); auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); @@ -67,20 +70,59 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { return amendedFuncOp; } + // Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM + // attributes. + static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) { + const bool isKernel = LLVM::isKernel(llvmFuncOp); + for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) { + const auto attrs = llvmFuncOp.getArgAttrDict(i); + if (!attrs) { + continue; + } + + for (const auto &attr : attrs) { + if (attr.getName() == "tt.nv_tma_desc") { + const auto i32_type = + mlir::IntegerType::get(llvmFuncOp.getContext(), 32); + assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1)); + assert(isKernel && + "tt.nv_tma_desc is not supported for device functions"); + + // See + // https://github.com/google/jax/blob/main/jaxlib/mosaic/gpu/passes.cc + mlir::BlockArgument arg = llvmFuncOp.getArgument(i); + const auto byteType = + mlir::IntegerType::get(llvmFuncOp.getContext(), 8); + const auto arrayType = mlir::LLVM::LLVMArrayType::get( + llvmFuncOp.getContext(), byteType, 128); + llvmFuncOp.setArgAttr(i, "llvm.byval", + mlir::TypeAttr::get(arrayType)); + llvmFuncOp.setArgAttr(i, "nvvm.grid_constant", + mlir::UnitAttr::get(llvmFuncOp.getContext())); + llvmFuncOp.setArgAttr(i, "llvm.align", + mlir::IntegerAttr::get(i32_type, 64)); + } + } + } + } + LogicalResult matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Prevent LLVM's inliner to inline this function auto amendedFuncOp = funcOp; if (!LLVM::isKernel(funcOp)) - amendedFuncOp = amendFuncOp(funcOp, rewriter); + amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo); - LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( - amendedFuncOp, rewriter, *getTypeConverter()); - if (!newFuncOp) { + FailureOr maybeNewFuncOp = + mlir::convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter, + *getTypeConverter()); + if (failed(maybeNewFuncOp)) { return failure(); } + LLVM::LLVMFuncOp newFuncOp = *maybeNewFuncOp; + auto ctx = funcOp->getContext(); if (LLVM::isKernel(funcOp)) { @@ -97,22 +139,27 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { rewriter.eraseOp(amendedFuncOp); newFuncOp.setLinkage(LLVM::Linkage::Internal); } - // Set an attribute for maxntidx, it could be used in latter LLVM codegen + // Set an attribute for reqntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. - newFuncOp->setAttr("nvvm.maxntid", + newFuncOp->setAttr("nvvm.reqntid", rewriter.getDenseI32ArrayAttr(32 * numWarps)); rewriter.eraseOp(funcOp); + + // Add attributes for by-value TMA descriptor args (nvidia) + handleByvalTmaDescArgs(newFuncOp); + return success(); } private: int numWarps{0}; + const TargetInfoBase &targetInfo; }; } // namespace void mlir::triton::populateFuncOpConversionPattern( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, - PatternBenefit benefit) { - patterns.add(typeConverter, numWarps, benefit); + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, numWarps, targetInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp index acf940b3e..ed4837fc1 100644 --- a/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp @@ -182,7 +182,7 @@ struct HistogramOpConversion // generate the right layout. Currently the warp level histogram generates // data in the default blocked layout. Value baseSharedMemPtr = - LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); auto dstType = op.getType(); Attribute dstEncoding = dstType.getEncoding(); auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding, diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 12ab6684c..1a0c115a9 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -23,8 +23,6 @@ void lowerDistributedToShared(Location loc, Value src, Value dst, const TargetInfoBase &targetInfo) { auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); - auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy); - auto srcLayout = srcTy.getEncoding(); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); assert(srcTy.getShape().size() <= 2 || (srcTy.getShape().size() == 3 && outOrd[2] == 0) && @@ -32,12 +30,10 @@ void lowerDistributedToShared(Location loc, Value src, Value dst, auto elemTy = typeConverter->convertType(srcTy.getElementType()); auto smemBase = smemObj.getBase(); - int32_t elemSize = elemTy.getIntOrFloatBitWidth(); - unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); - storeDistributedToShared(src, inVals, dstStrides, dst, smemBase, elemTy, loc, - rewriter, targetInfo); + storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, + loc, rewriter, targetInfo); } struct LocalAllocOpConversion @@ -51,9 +47,11 @@ struct LocalAllocOpConversion LogicalResult matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (!op.isSharedMemoryAlloc()) + return failure(); Location loc = op->getLoc(); Value smemBase = - LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); auto resultTy = cast(op.getType()); auto typeConverter = getTypeConverter(); auto sharedLayout = @@ -103,6 +101,135 @@ struct LocalDeallocOpConversion } }; +struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { +public: + LocalLoadOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + // FIXME [Dot LL] + // Do for all DotOperandEncodingAttr once we have LLs for all of them + static bool isSupportedDotOpLayout(Attribute layout) { + if (auto dot = dyn_cast(layout)) { + if (auto mma = dyn_cast(dot.getParent())) { + return mma.isAmpere() && dot.getKWidth() == 8; + } + if (isa(dot.getParent())) + return true; + } + return false; + }; + + LogicalResult + matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isa(srcLayout) && + (isa( + dstLayout) || + isSupportedDotOpLayout(dstLayout))) { + return lowerSharedToDistributed(op, adaptor, getTypeConverter(), + rewriter); + } + if (isa(dstLayout) && + isa( + cast(dstLayout).getParent())) { + return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter); + } + return failure(); + } + +private: + LogicalResult + lowerSharedToDotOpFMA(LocalLoadOp op, LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + RankedTensorType dstTy = op.getType(); + Attribute dstLayout = dstTy.getEncoding(); + auto dotLayout = cast(dstLayout); + auto blockedLayout = cast( + cast(dstLayout).getParent()); + auto thread = getThreadId(rewriter, loc); + Value res = SharedToDotOperandFMA::convertLayout( + dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, + thread, loc, getTypeConverter(), rewriter); + rewriter.replaceOp(op, res); + return success(); + } + LogicalResult + lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + auto dstShape = dstTy.getShape(); + auto srcSharedLayout = cast(srcTy.getEncoding()); + auto dstLayout = dstTy.getEncoding(); + assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) && + "Unexpected rank of ConvertLayout(shared->distributed)"); + auto inOrd = getOrder(srcSharedLayout); + + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), + typeConverter->convertType(srcTy.getElementType()), rewriter); + auto elemLlvmTy = typeConverter->convertType(dstTy.getElementType()); + + SmallVector outVals = loadSharedToDistributed( + dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); + + // FIXME [Dot LL] + // Ampere case + // In this case, we need to pack the outputs into i32 + if (auto dotOp = dyn_cast(dstTy.getEncoding())) { + if (auto parent = dyn_cast(dotOp.getParent())) { + if (parent.isAmpere()) { + if (elemLlvmTy.isInteger(8)) { + auto concat = [&](Value a1, Value a2, Value a3, Value a4) { + return or_( + or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))), + or_(shl(zext(i32_ty, a3), i32_val(16)), + shl(zext(i32_ty, a4), i32_val(24)))); + }; + SmallVector outVals32(outVals.size() / 4); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1], + outVals[4 * i + 2], outVals[4 * i + 3]); + } + outVals = outVals32; + } else { + assert(elemLlvmTy.isBF16() && "Unexpected element type"); + auto concat = [&](Value a, Value b) { + return or_(zext(i32_ty, bitcast(a, i16_ty)), + shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); + }; + + SmallVector outVals32(outVals.size() / 2); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); + } + outVals = outVals32; + } + } + } + } + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + struct LocalStoreOpConversion : public ConvertOpToLLVMPattern { public: @@ -141,5 +268,6 @@ void mlir::triton::populateMemoryOpToLLVMPattern( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, targetInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp index 32c7835c2..5cb27bb48 100644 --- a/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp @@ -44,7 +44,10 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { return success(); } + assert(op.getNumOperands() == op.getIsSigned().size()); + for (size_t i = 0; i < op.getNumOperands(); i++) { + bool isSigned = op.getIsSigned()[i] > 0; // Elements of the tensor that are resident in this GPU thread. auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); @@ -76,7 +79,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { if (!elems.empty()) { printTensor(op.getPrefix(), /*operand=*/i, /*numOperands=*/op.getNumOperands(), elems, pid, indices, - dimWidths, op.getHex(), rewriter); + dimWidths, op.getHex(), rewriter, isSigned); } } rewriter.eraseOp(op); @@ -87,7 +90,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { ArrayRef elems, std::array pid, ArrayRef> indices, ArrayRef dimWidths, bool hex, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, bool isSigned) const { assert(!elems.empty()); assert(elems.size() == indices.size()); assert(dimWidths.size() == indices.front().size()); @@ -151,7 +154,8 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { } auto elem = elems[i]; - os << getFormatSubstr(elem, hex); + + os << getFormatSubstr(elem, hex, /*width=*/std::nullopt, isSigned); printfOperands.push_back(elem); // It's the same format string each iteration, but it's a lot easier if we @@ -169,8 +173,10 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { } std::string getFormatSubstr(Value value, bool hex = false, - std::optional width = std::nullopt) const { + std::optional width = std::nullopt, + bool isSigned = false) const { Type type = value.getType(); + // If the `value` is a pointer, just return %p. if (isa(type)) { return "%p"; } @@ -190,23 +196,15 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { std::string prefix = "%"; if (width.has_value()) { prefix += std::to_string(*width); - } else if (hex) { - prefix += "0"; - prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); } if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { return prefix + "f"; - } else if (type.isSignedInteger()) { - if (type.getIntOrFloatBitWidth() == 64) - return prefix + "lli"; - else - return prefix + "i"; - } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + } else if (type.isInteger()) { if (type.getIntOrFloatBitWidth() == 64) - return prefix + "llu"; + return prefix + (isSigned ? "lli" : "llu"); else - return prefix + "u"; + return prefix + (isSigned ? "i" : "u"); } assert(false && "not supported type"); return ""; diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 4d036c21a..966f6d31c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -1,10 +1,7 @@ #include "ReduceScanCommon.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include using namespace mlir; using namespace mlir::triton; @@ -12,6 +9,7 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getThreadOrder; using ::mlir::triton::gpu::getTotalElemsPerThread; namespace { @@ -49,10 +47,10 @@ struct ReduceOpConversion } // Compute a shared memory base per operand. - auto smemShape = helper.getScratchConfig(); + auto smemShape = helper.getScratchRepShape(); SmallVector smemBases = - getSmemBases(op, product(smemShape), rewriter); + getSmemBases(op, product(smemShape), rewriter, targetInfo); storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); @@ -80,36 +78,16 @@ struct ReduceOpConversion private: const TargetInfoBase &targetInfo; - void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, - SmallVector &acc, ValueRange cur, bool isFirst) const { - if (isFirst) { - acc = SmallVector(cur.begin(), cur.end()); - return; - } - - // Create a new copy of the reduce block, and inline it - Block *currentBlock = rewriter.getBlock(); - Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(combineOp, &parent.front()); - auto &newReduce = parent.front(); - auto returnOp = dyn_cast(newReduce.getTerminator()); - - llvm::SmallVector combineArgs(2 * acc.size()); - for (unsigned i = 0; i < acc.size(); ++i) { - combineArgs[i] = acc[i]; - combineArgs[acc.size() + i] = cur[i]; + void accumulate(Location loc, ConversionPatternRewriter &rewriter, + Region &combineOp, SmallVector &acc, ValueRange cur, + Value pred = {}) const { + auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); + if (acc.size() < results.size()) { + acc.resize(results.size()); } - - rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), - combineArgs); - - auto results = returnOp.getResult(); for (unsigned i = 0; i < acc.size(); ++i) { acc[i] = results[i]; } - - // Delete the terminator, which is no longer used - rewriter.eraseOp(returnOp); } SmallVector> @@ -165,7 +143,7 @@ struct ReduceOpConversion SmallVector key = offsets[i]; key[op.getAxis()] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); + accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]); if (isFirst) indices[key] = srcIndices[i]; } @@ -175,17 +153,22 @@ struct ReduceOpConversion // region and the accumulator values as source. void warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce, unsigned interleave) const { - auto success = - targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce); + unsigned numLaneToReduce, unsigned interleave, + Value pred = {}) const { + auto success = targetInfo.warpReduce(rewriter, loc, acc, op, + numLaneToReduce, interleave); if (success) return; + + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { SmallVector shfl(acc.size()); for (unsigned i = 0; i < acc.size(); ++i) { shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); } - accumulate(rewriter, op.getCombineOp(), acc, shfl, false); + accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred); } } @@ -278,11 +261,11 @@ struct ReduceOpConversion Value laneId = urem(threadId, warpSize); auto srcShape = helper.getSrcShape(); unsigned axis = op.getAxis(); - auto smemShape = helper.getScratchConfig(); + auto smemShape = helper.getScratchRepShape(); auto threadsPerWarp = triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); Value laneIdAxis = multiDimLaneId[axis]; @@ -304,8 +287,8 @@ struct ReduceOpConversion linearize(rewriter, loc, writeIdx, smemShape, smemOrder); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); - Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, - smemBases[i], writeOffset); + Value writePtr = + gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); targetInfo.storeShared(rewriter, loc, writePtr, acc[i], laneZero); } } @@ -318,7 +301,7 @@ struct ReduceOpConversion ConversionPatternRewriter &rewriter) const { triton::ReduceOp op = helper.getOperation(); auto srcLayout = helper.getSrcLayout(); - auto smemShape = helper.getScratchConfig(); + auto smemShape = helper.getScratchRepShape(); unsigned elems = product(smemShape); unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); Location loc = op.getLoc(); @@ -339,19 +322,20 @@ struct ReduceOpConversion SmallVector acc(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); - Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, - smemBases[i], readOffset); - acc[i] = targetInfo.loadShared(rewriter, loc, getTypeConverter(), - readPtr, elemTy, threadIsNeeded); + Value readPtr = + gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); + acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, + threadIsNeeded); } - warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */, + threadIsNeeded); // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; SmallVector writePtrs(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); - writePtrs[i] = gep(ptr_ty(rewriter.getContext(), 3), elemTy, - smemBases[i], writeOffset); + writePtrs[i] = + gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset); } Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); @@ -412,8 +396,8 @@ struct ReduceOpConversion } Value readOffset = linearize(rewriter, loc, readIdx, smemShape, smemOrder); - Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, - smemBases[i], readOffset); + Value readPtr = + gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset); resultVals[j] = load(elemTy, readPtr); } diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h index 3130001cc..a35d52776 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -4,15 +4,14 @@ // TODO: refactor so that it doesn't fail if Allocation.h // is included after utility.h (due to conflict in `store` macro // and -#include "triton/Analysis/Allocation.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" // #include "mlir/IR/TypeUtilities.h" -#include "triton/Analysis/AxisInfo.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include +#include #include #define DEBUG_TYPE "ttgpu_to_llvm" @@ -32,6 +31,91 @@ namespace ttng = ::mlir::triton::nvidia_gpu; namespace mlir::triton { class ReduceOp; class ScanOp; + +inline SmallVector +inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock, + Block *insertionBlock, Block::iterator insertionPoint, + ValueRange combineArgs) { + auto returnOp = combineBlock.getTerminator(); + rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint, + combineArgs); + + auto results = SmallVector(returnOp->getOperands()); + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +inline SmallVector applyCombineOp(Location loc, + ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur, Value pred = {}) { + // Allows for passing an unitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + + // Create a new copy of the combine block, and try to speculatively inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + + rewriter.cloneRegionBefore(combineOp, parent, + std::next(currentBlock->getIterator())); + Block &newCombine = *currentBlock->getNextNode(); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + auto isRegionSpeculatable = + std::all_of(newCombine.begin(), newCombine.end(), + [](auto &op) { return isSpeculatable(&op); }); + + if (!pred || isRegionSpeculatable) { + // Fast path, region has no side effects so we can unconditionally execute + return inlineCombineBlock(rewriter, newCombine, currentBlock, + rewriter.getInsertionPoint(), combineArgs); + } + + // Slow case, create an if to only execute region when pred is true + // #currentBlock + // if (pred) { + // #newCombine + // results = combineOp(cur, acc) + // yield results + // } else { + // yield undef + // } + // #thenBlock + Block *thenBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + auto returnOp = newCombine.getTerminator(); + auto results = SmallVector(returnOp->getOperands()); + + rewriter.setInsertionPointToEnd(currentBlock); + SmallVector thenBlockArgs; + thenBlockArgs.reserve(results.size()); + for (auto result : results) { + auto ty = result.getType(); + auto undef = rewriter.create(loc, ty); + thenBlockArgs.push_back(undef); + thenBlock->addArgument(ty, loc); + } + rewriter.create(loc, pred, &newCombine, combineArgs, + thenBlock, thenBlockArgs); + + // Split a block after the call. + rewriter.setInsertionPointToEnd(&newCombine); + rewriter.replaceOpWithNewOp(returnOp, thenBlock, results); + rewriter.setInsertionPointToStart(thenBlock); + return SmallVector(thenBlock->getArguments()); +} + } // namespace mlir::triton template @@ -53,7 +137,8 @@ class ConvertTritonGPUReduceScanToLLVMPattern // Helper to compute the smem bases in both reductions and scans SmallVector getSmemBases(SourceOp op, unsigned elems, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { auto loc = op.getLoc(); // indices will store the index of the op operands in descending order // of their bitwidths @@ -66,12 +151,13 @@ class ConvertTritonGPUReduceScanToLLVMPattern }); // Assign base index to each operand in their order in indices std::map indexToBase; - indexToBase[indices[0]] = - LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto basePtr = + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + indexToBase[indices[0]] = basePtr; for (unsigned i = 1; i < op.getNumOperands(); ++i) { - indexToBase[indices[i]] = gep( - ptr_ty(rewriter.getContext(), 3), getElementType(op, indices[i - 1]), - indexToBase[indices[i - 1]], i32_val(elems)); + indexToBase[indices[i]] = + gep(basePtr.getType(), getElementType(op, indices[i - 1]), + indexToBase[indices[i - 1]], i32_val(elems)); } // smemBases[k] is the base pointer for the k-th operand SmallVector smemBases(op.getNumOperands()); diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 675bf5a34..969b227c8 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -1,5 +1,3 @@ -#include - #include "ReduceScanCommon.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" @@ -16,37 +14,13 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getTotalElemsPerThread; // apply combine region to acc and cur and accumulate it into acc -// TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce. -// Deduplicate -static SmallVector accumulate(ConversionPatternRewriter &rewriter, - Region &combineOp, ValueRange acc, - ValueRange cur) { - // Allows for passing an unitialized acc and use cur as the neutral element - if (acc.size() == 0) { - return cur; - } - assert(cur.size() == acc.size()); - // Create a new copy of the reduce block, and inline it - Block *currentBlock = rewriter.getBlock(); - Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(combineOp, &parent.front()); - auto &newScan = parent.front(); - auto returnOp = dyn_cast(newScan.getTerminator()); - - SmallVector combineArgs(2 * acc.size()); - for (unsigned i = 0; i < acc.size(); ++i) { - combineArgs[i] = acc[i]; - combineArgs[acc.size() + i] = cur[i]; - } - - rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(), - combineArgs); - SmallVector results; - llvm::transform(returnOp.getResult(), std::back_inserter(results), - [&](Value res) { return rewriter.getRemappedValue(res); }); - // Delete the terminator, which is no longer used - rewriter.eraseOp(returnOp); - return results; +static SmallVector accumulate(ScanLoweringHelper &helper, + ConversionPatternRewriter &rewriter, + ValueRange acc, ValueRange cur, + Value pred = {}) { + auto loc = helper.getLoc(); + auto &combineOp = helper.getCombineOp(); + return applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); } // Scan a contiguous elements within a thread and update `srcValues` in place. @@ -66,8 +40,8 @@ scanThreadContiguousElements(SmallVector> &srcValues, unsigned accIndex = (srcIndex % stride) + ((srcIndex / stride) / scanElementsPerThreads) * stride; - accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex], - srcValues[srcIndex]); + accs[accIndex] = + accumulate(helper, rewriter, accs[accIndex], srcValues[srcIndex]); srcValues[srcIndex] = accs[accIndex]; } } @@ -95,14 +69,14 @@ static void warpScan(SmallVector> &srcValues, for (unsigned j = 0; j < acc.size(); ++j) { shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); } + Value mask = icmp_sge(laneIdAxis, i32_val(i)); SmallVector tempAcc = - accumulate(rewriter, helper.getCombineOp(), shfl, acc); - Value mask = icmp_slt(laneIdAxis, i32_val(i)); + accumulate(helper, rewriter, shfl, acc, mask); for (unsigned j = 0; j < acc.size(); ++j) { - acc[j] = select(mask, acc[j], tempAcc[j]); + acc[j] = select(mask, tempAcc[j], acc[j]); } } - srcValues[srcIndex] = acc; + srcValues[srcIndex] = std::move(acc); } } @@ -137,8 +111,8 @@ static void storeWarpAccumulator(SmallVector> &srcValues, Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane))); index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps)); for (unsigned i = 0; i < lastElement.size(); ++i) { - Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), smemTypes[i], - smemBases[i], index); + Value writePtr = + gep(smemBases[i].getType(), smemTypes[i], smemBases[i], index); targetInfo.storeShared(rewriter, loc, writePtr, lastElement[i], mask); } chunkId++; @@ -154,8 +128,8 @@ static void AddPartialReduce(SmallVector> &srcValues, ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, ScanLoweringHelper &helper, - SmallVector smemBases, - SmallVector smemTypes, Value warpId, + ArrayRef smemBases, + ArrayRef smemTypes, Value warpId, Value laneIdAxis, Value parallelLaneId) { Location loc = helper.getLoc(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); @@ -164,9 +138,9 @@ static void AddPartialReduce(SmallVector> &srcValues, unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); - Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); - Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); - Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + Value maskNotFirstWarp = icmp_ne(warpId, i32_val(0)); + Value maskNotFirstLane = icmp_ne(laneIdAxis, i32_val(0)); + Value maskNotFirstThread = or_(maskNotFirstWarp, maskNotFirstLane); struct Accumulator { SmallVector acc; SmallVector maskedAcc; @@ -202,8 +176,7 @@ static void AddPartialReduce(SmallVector> &srcValues, SmallVector partialReduce(helper.getNumOperands()); for (unsigned j = 0; j < helper.getNumOperands(); ++j) { auto elemTy = smemTypes[j]; - Value ptr = - gep(ptr_ty(rewriter.getContext(), 3), elemTy, smemBases[j], index); + Value ptr = gep(smemBases[j].getType(), elemTy, smemBases[j], index); partialReduce[j] = load(elemTy, ptr); } @@ -212,22 +185,24 @@ static void AddPartialReduce(SmallVector> &srcValues, accumulator.maskedAcc = partialReduce; continue; } - accumulator.acc = accumulate(rewriter, helper.getCombineOp(), - accumulator.acc, partialReduce); - Value mask = icmp_slt(warpId, i32_val(i + 1)); + Value mask = icmp_sge(warpId, i32_val(i + 1)); + accumulator.acc = + accumulate(helper, rewriter, accumulator.acc, partialReduce); for (unsigned j = 0; j < helper.getNumOperands(); ++j) { accumulator.maskedAcc[j] = - select(mask, accumulator.maskedAcc[j], accumulator.acc[j]); + select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); } } - auto temp = accumulate(rewriter, helper.getCombineOp(), - accumulator.maskedAcc, srcValues[srcIndex]); + + Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{}; + auto temp = accumulate(helper, rewriter, accumulator.maskedAcc, + srcValues[srcIndex], pred); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. auto val = srcValues[srcIndex]; for (unsigned i = 0; i < helper.getNumOperands(); ++i) { - temp[i] = select(maskFirstWarp, val[i], temp[i]); + temp[i] = select(maskNotFirstWarp, temp[i], val[i]); } } srcValues[srcIndex] = temp; @@ -235,22 +210,21 @@ static void AddPartialReduce(SmallVector> &srcValues, SmallVector lastElement(helper.getNumOperands()); for (unsigned i = 0; i < helper.getNumOperands(); ++i) { auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); - lastElement[i] = select(maskFirstLane, accumulator.maskedAcc[i], elem); + lastElement[i] = select(maskNotFirstLane, elem, accumulator.maskedAcc[i]); } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + pred = axisBlockId == 0 ? maskNotFirstThread : Value{}; auto laneValue = srcValues[srcIndex - i * elementStride]; - laneValue = - accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + laneValue = accumulate(helper, rewriter, lastElement, laneValue, pred); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. for (unsigned j = 0; j < helper.getNumOperands(); ++j) { - laneValue[j] = - select(maskFirstThread, - srcValues[srcIndex - i * elementStride][j], laneValue[j]); + laneValue[j] = select(maskNotFirstThread, laneValue[j], + srcValues[srcIndex - i * elementStride][j]); } } - srcValues[srcIndex - i * elementStride] = laneValue; + srcValues[srcIndex - i * elementStride] = std::move(laneValue); } // For the next chunk start back from the value containing the // accumulated value of all the warps. @@ -300,8 +274,8 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, if (axisBlockId == 0) // First chunk and first block accumulator = srcValues[srcIndex]; else - srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(), - accumulator, srcValues[srcIndex]); + srcValues[srcIndex] = + accumulate(helper, rewriter, accumulator, srcValues[srcIndex]); // Update the rest of the contiguous elements. auto lastElement = srcValues[srcIndex]; if (scanDim > 1) { @@ -319,8 +293,7 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { auto laneValue = srcValues[srcIndex - i * elementStride]; - laneValue = - accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + laneValue = accumulate(helper, rewriter, lastElement, laneValue); if (axisBlockId == 0) { for (unsigned j = 0; j < helper.getNumOperands(); ++j) { // For the first warp and first chunk we don't have anything to @@ -330,7 +303,7 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, srcValues[srcIndex - i * elementStride][j], laneValue[j]); } } - srcValues[srcIndex - i * elementStride] = laneValue; + srcValues[srcIndex - i * elementStride] = std::move(laneValue); } // For the next chunk start back from the value containing the // accumulated value of all the warps. @@ -354,7 +327,7 @@ struct ScanOpConversion LogicalResult matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (succeeded(emitFastScan(op, adaptor, rewriter))) + if (succeeded(emitFastScan(op, adaptor, rewriter, targetInfo))) return success(); return failure(); } @@ -372,7 +345,8 @@ struct ScanOpConversion ScanLoweringHelper &helper, Value laneId, Value warpId) const; LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const; + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const; }; SmallVector @@ -482,7 +456,8 @@ flipSrcValues(Location loc, triton::ScanOp op, // Lowering using warp shuffle operations to do warp level scan. LogicalResult ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { ScanLoweringHelper helper(op); auto loc = helper.getLoc(); if (!helper.isSupported()) @@ -525,7 +500,8 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, // Slow path for the case where there are multiple warps with unique data on // the axis. auto elems = helper.getScratchSizeInElems(); - SmallVector smemBases = getSmemBases(op, elems, rewriter); + SmallVector smemBases = + getSmemBases(op, elems, rewriter, targetInfo); SmallVector smemTypes(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { smemTypes[i] = getElementType(op, i); diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 908aa1e2b..cc6d8875b 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -17,16 +17,16 @@ using ::mlir::triton::gpu::SliceEncodingAttr; TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( MLIRContext *ctx, LowerToLLVMOptions &option, - const DataLayoutAnalysis *analysis) + const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis) : LLVMTypeConverter(ctx, option, analysis) { addConversion([&](triton::PointerType type) -> std::optional { return convertTritonPointerType(type); }); addConversion([&](RankedTensorType type) -> std::optional { - return convertTritonTensorType(type); + return convertTritonTensorType(type, targetInfo); }); addConversion([&](MemDescType type) -> std::optional { - return convertMemDescType(type); + return convertMemDescType(type, targetInfo); }); addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional { return convertAsyncToken(type); @@ -34,16 +34,15 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); + addConversion([&](mlir::Float8E4M3FNType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); addConversion([&](mlir::Float8E5M2Type type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); addConversion([&](mlir::Float8E5M2FNUZType type) -> std::optional { return IntegerType::get(type.getContext(), 8); }); - // Internally store bfloat16 as int16 - addConversion([&](BFloat16Type type) -> std::optional { - return IntegerType::get(type.getContext(), 16); - }); } Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( @@ -89,7 +88,7 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( } Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( - RankedTensorType type) { + RankedTensorType type, const TargetInfoBase &targetInfo) { auto ctx = type.getContext(); Attribute layout = type.getEncoding(); SmallVector shape(type.getShape().begin(), type.getShape().end()); @@ -98,7 +97,8 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( if (auto shared_layout = mlir::dyn_cast(layout)) { SmallVector types; // base ptr - auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); + auto ptrType = + LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace()); types.push_back(ptrType); // shape dims auto rank = type.getRank(); @@ -114,13 +114,15 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( return LLVM::LLVMStructType::getLiteral(ctx, types); } -Type TritonGPUToLLVMTypeConverter::convertMemDescType(MemDescType type) { +Type TritonGPUToLLVMTypeConverter::convertMemDescType( + MemDescType type, const TargetInfoBase &targetInfo) { auto ctx = type.getContext(); Attribute layout = type.getEncoding(); SmallVector shape(type.getShape().begin(), type.getShape().end()); SmallVector types; // base ptr - auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); + auto ptrType = + LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace()); types.push_back(ptrType); // shape dims auto rank = type.getShape().size(); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index a80158a46..e857dd36f 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -1,8 +1,7 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/Attributes.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "llvm/ADT/STLExtras.h" @@ -12,7 +11,7 @@ using CoordTy = SmallVector; using ValueTable = std::map, std::pair>; static SmallVector -getMNCoords(Value thread, Location loc, ConversionPatternRewriter &rewriter, +getMNCoords(Value thread, Location loc, RewriterBase &rewriter, ArrayRef wpt, const NvidiaMmaEncodingAttr &mmaLayout, ArrayRef shape, bool isARow, bool isBRow, bool isAVec4, bool isBVec4) { @@ -111,6 +110,7 @@ getMNCoords(Value thread, Location loc, ConversionPatternRewriter &rewriter, return coords; // {M,N} in row-major } } // namespace SharedToDotOperandMMAv1 + namespace mlir { namespace triton::gpu { @@ -119,9 +119,8 @@ Type getFunctionType(Type resultType, ValueRange operands) { return LLVM::LLVMFunctionType::get(resultType, operandTypes); } -LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, - Operation *op, StringRef funcName, - Type funcType, +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, StringRef libname /*= ""*/, StringRef libpath /*= ""*/) { using LLVM::LLVMFuncOp; @@ -164,7 +163,7 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, // Manually constant-fold the layout where possible. SmallVector> constantIns; for (auto [inDimName, idx] : indices) { - if (auto constant = dyn_cast(idx.getDefiningOp())) { + if (auto constant = idx.getDefiningOp()) { constantIns.push_back( {inDimName, cast(constant.getValue()).getInt()}); } else { @@ -184,7 +183,7 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, } for (auto [inDimName, idx] : indices) { - if (isa(idx.getDefiningOp())) { + if (idx.getDefiningOp()) { continue; } @@ -204,17 +203,15 @@ applyLinearLayout(Location loc, RewriterBase &rewriter, return outIndices; } -std::optional>> -emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, Attribute layout, - RankedTensorType type, bool withCTAOffset) { +SmallVector> +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset) { MLIRContext *ctx = rewriter.getContext(); auto shape = type.getShape(); std::optional ll = triton::gpu::toLinearLayout(shape, layout); - if (!ll.has_value()) { - return std::nullopt; - } + if (!ll.has_value()) + llvm::report_fatal_error("Failed to convert layout to linear layout"); // TODO(jlebar): We could add strong typing if we wanted; for now this is // "stringly typed". @@ -231,12 +228,31 @@ emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); unsigned rank = shape.size(); SmallVector> ret; + // Linear layout function is split in two parts below: + // L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0) + // idxs = idxsBase xor idxsReg + // + // L(0, t, w, b) part is the same for all registers, + // so we hoist it out of the main register loop in the below. + // + // This approach produces code with lower register pressure and + // less computations, compared to fused L(r,t,w,b) method. + auto idxsBase = applyLinearLayout(loc, rewriter, *ll, + {{kRegister, i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}}); for (unsigned reg = 0; reg < ll->getInDimSize(str_attr("register")); reg++) { - auto idxs = applyLinearLayout(loc, rewriter, *ll, - {{kRegister, i32_val(reg)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, blockId}}); + auto idxsReg = + ll->apply({{kRegister, reg}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + SmallVector> idxs; + for (auto [idxBase, idxReg] : llvm::zip(idxsBase, idxsReg)) { + auto dimName = idxBase.first; + assert(dimName == idxReg.first && + "dim names of block+warp+thread and register idx should be equal"); + auto idx = xor_(idxBase.second, i32_val(idxReg.second)); + idxs.emplace_back(dimName, idx); + } assert(idxs.size() == rank); for (unsigned k = 0; k < rank; ++k) { assert(idxs[k].first == str_attr("dim" + std::to_string(k))); @@ -247,11 +263,206 @@ emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, return ret; } +bool emitTransferBetweenRegistersAndShared( + RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, + std::optional maxVecElems, Value shmemBase, + ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, + std::function perVectorCallback) { + MLIRContext *ctx = rewriter.getContext(); + + auto shape = registerTy.getShape(); + int rank = shape.size(); + + StringAttr kBlock = str_attr("block"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + + std::optional regLayout = + triton::gpu::toLinearLayout(shape, registerTy.getEncoding()); + std::optional sharedLayout = triton::gpu::toLinearLayout( + shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth()); + if (!regLayout.has_value() || !sharedLayout.has_value()) { + return false; + } + auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); + + // sharedLayout's in-dims are currently (offset, block). Reshape to + // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional + // shmem strides. (The offsetX's appear in minor-to-major order.) + auto sharedLegacy = + cast(sharedTy.getEncoding()); + SmallVector> multiDimSharedSize; + for (int i = 0; i < rank; i++) { + int dim = sharedOrder[i]; + int64_t size = std::max( + int64_t{1}, + shape[dim] / sharedLegacy.getCTALayout().getCTASplitNum()[dim]); + multiDimSharedSize.push_back( + {str_attr("offset" + std::to_string(dim)), size}); + } + multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)}); + sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize); + + // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1, + // ..., offsetXN, block), where the offsetX's are in minor-to-major order. + LinearLayout regToSharedLayout = regLayout->invertAndCompose(*sharedLayout); + + // TODO(jlebar): We don't currently support loading from shared memory in a + // different CTA. We'd need to emit `mapa.shared::cluster` instructions. + for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock); + inBlock *= 2) { + auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply( + {{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}}))); + // offsetX1, ..., offsetXN must all be 0. + if (!llvm::all_of(ArrayRef(idx).drop_back(1), + [&](auto offset) { return offset == 0; })) { + return false; + } + // Check if there's any cross CTA load. + int32_t outBlock = idx.back(); + if (outBlock != inBlock) { + return false; + } + } + + // Determine how many consecutive registers map to consecutive shmem elements + // in out-dimension offsetN. This is our load instruction's vector width. + // + // It's OK if the vector width we choose here is wider than the hardware + // supports; LLVM will legalize it. + // + // TODO(jlebar): shmemStrides are Values, but most of them are usually integer + // constants. We could add those constant strides to the LL, and then before + // calling getNumConsecutiveInOut(), we could flatten consecutive out-dims + // which have known strides. This would allow us to vectorize across multiple + // shmem out dimensions where possible. + const int vecElems = + std::min(regToSharedLayout.getNumConsecutiveInOut(), + maxVecElems.value_or(std::numeric_limits::max())); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(regToSharedLayout.getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + + int numElems = regToSharedLayout.getInDimSize(kRegister); + auto vecTy = vec_ty(elemLlvmTy, vecElems); + auto ptrTy = shmemBase.getType(); + Value zero = i32_val(0); + SmallVector ret; + for (int i = 0; i < numElems / vecElems; i++) { + // Get the address to load/store. The multi-dim address is (offsetX1, ..., + // offsetXN, block), where the offsets appear in minor-to-major order, and + // we drop_end to drop block, which we know from above will be 0. + auto multiDimShmemOffset = + llvm::to_vector(llvm::drop_end(llvm::make_second_range( + applyLinearLayout(loc, rewriter, regToSharedLayout, + {{kRegister, i32_val(i * vecElems)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, zero}})))); + + // Reorder strides according to `order`. This way they match the + // multi-dimensional offsets in regToSharedLayout. + Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset, + applyPermutation(shmemStrides, sharedOrder)); + auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset); + vecAddr.setInbounds(true); + + perVectorCallback(vecTy, vecAddr); + } + return true; +} + +SmallVector loadSharedToDistributed(RankedTensorType dstTy, + MemDescType srcTy, Type elemLlvmTy, + SharedMemoryObject smemObj, + Location loc, RewriterBase &rewriter, + const TargetInfoBase &target) { + SmallVector ret; + bool success = emitTransferBetweenRegistersAndShared( + dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(), + smemObj.getStrides(), loc, rewriter, target, + [&](VectorType vecTy, Value vecAddr) { + auto vecVal = load(vecTy, vecAddr); + vecVal.setAlignment(vecTy.getNumElements() * + elemLlvmTy.getIntOrFloatBitWidth() / 8); + + for (int v = 0; v < vecTy.getNumElements(); v++) { + ret.push_back(extract_element(elemLlvmTy, vecVal, i32_val(v))); + } + }); + if (!success) + llvm::report_fatal_error("Failed to emit transfer from shared to register"); + + return ret; +} + +void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, + Type elemLlvmTy, ArrayRef srcVals, + Value smemBase, ArrayRef dstStrides, + Location loc, RewriterBase &rewriter, + const TargetInfoBase &target) { + bool success = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, + dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { + ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); + srcVals = srcVals.drop_front(vecTy.getNumElements()); + + Value vec = undef(vecTy); + for (int i = 0; i < vals.size(); i++) { + vec = insert_element(vec, vals[i], i32_val(i)); + } + store(vec, vecAddr) + .setAlignment(vecTy.getNumElements() * + elemLlvmTy.getIntOrFloatBitWidth() / 8); + }); + if (!success) + llvm::report_fatal_error("Failed to emit transfer from register to shared"); +} + +SmallVector> emitOffsetForLayout(Attribute layout, + RankedTensorType type) { + MLIRContext *ctx = layout.getContext(); + auto shape = type.getShape(); + unsigned rank = shape.size(); + + auto ll = triton::gpu::toLinearLayout(shape, layout); + if (!ll.has_value()) + llvm::report_fatal_error("Unsupported layout"); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + SmallVector> offsets; + for (int i = 0; i < ll->getInDimSize(str_attr("register")); i++) { + auto idxs = + ll->apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + assert(idxs.size() == rank); + for (unsigned k = 0; k < rank; ++k) { + assert(idxs[k].first == str_attr("dim" + std::to_string(k))); + } + offsets.push_back( + llvm::to_vector_of(llvm::make_second_range(idxs))); + } + return offsets; +} + namespace LLVM { using namespace mlir::triton; using mlir::triton::gpu::getOrder; using mlir::triton::gpu::getSizePerThread; +Value createConstantI1(Location loc, OpBuilder &rewriter, bool v) { + auto i1ty = rewriter.getIntegerType(1); + return rewriter.create(loc, i1ty, + IntegerAttr::get(i1ty, v)); +} + Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) { auto i32ty = rewriter.getIntegerType(32); return rewriter.create(loc, i32ty, @@ -306,9 +517,40 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, builder.getIntegerAttr(ty, value)); } -SharedMemoryObject -getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, - ConversionPatternRewriter &rewriter) { +LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc, + LLVMFuncOp funcOp, ValueRange args) { + auto op = builder.create(loc, funcOp, args); + op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({})); + op.getProperties().setOperandSegmentSizes({static_cast(args.size()), 0}); + return op; +} + +LLVM::CallIntrinsicOp +createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic, + TypeRange types, ValueRange args) { + auto op = builder.create(loc, types, args); + op.getProperties().setIntrin(builder.getStringAttr(intrinsic)); + op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({})); + op.getProperties().setOperandSegmentSizes({static_cast(args.size()), 0}); + return op; +} + +bool isConstantZero(Value v) { + if (auto constantOp = v.getDefiningOp()) { + if (auto attr = dyn_cast(constantOp.getValue())) { + return attr.getValue().isZero(); + } + if (auto attr = dyn_cast(constantOp.getValue())) { + return attr.getValue().isZero(); + } + } + return false; +} + +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter) { ArrayRef types = cast(llvmStruct.getType()).getBody(); SmallVector elems(types.size()); @@ -390,15 +632,14 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, return multiDim; } -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape, - ArrayRef order) { +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { return linearize(rewriter, loc, applyPermutation(multiDim, order), applyPermutation(shape, order)); } -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape) { +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape) { auto rank = multiDim.size(); Value linear = i32_val(0); if (rank > 0) { @@ -412,8 +653,8 @@ Value linearize(ConversionPatternRewriter &rewriter, Location loc, return linear; } -Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, - StringRef key, StringRef content) { +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); auto ctx = moduleOp.getContext(); unsigned stringNumber = 0; @@ -429,7 +670,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, LLVM::GlobalOp global; { - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); global = rewriter.create( UnknownLoc::get(ctx), globalType, @@ -447,7 +688,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, } SmallVector getMultiDimOffset(Attribute layout, Location loc, - ConversionPatternRewriter &rewriter, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type, ArrayRef multiDimCTAInRepId, @@ -601,9 +842,9 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, } SmallVector getWrappedMultiDimOffset( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDimOffset, ArrayRef shape, - SmallVector shapePerCTATile, SmallVector shapePerCTA) { + RewriterBase &rewriter, Location loc, ArrayRef multiDimOffset, + ArrayRef shape, SmallVector shapePerCTATile, + SmallVector shapePerCTA) { unsigned rank = shape.size(); SmallVector multiDimOffsetWrapped(rank); for (unsigned d = 0; d < rank; ++d) { diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index e0f6e9377..297a94e85 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -83,6 +83,10 @@ struct ArithConstantSplatOpConversion << value.getType() << "\n"; return failure(); } + // Lower FP8 constant to int8 constant since FP8 types are not supported on + // LLVM IR. + if (type::isFloat8(elemType)) + elemType = rewriter.getIntegerType(8); auto constOp = rewriter.create(loc, elemType, val); auto typeConverter = getTypeConverter(); auto llStruct = SplatOpConversion::convertSplatLikeOp( @@ -168,11 +172,21 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern { // verifier): // // - The op has a blocked encoding. - // - The last dimension (the one we're spliting) is also the most minor - // dimension, and has sizePerThread=2. + // - The last dimension (the one we're spliting) has sizePerThread=2, + // threadPerWarp=1 and warpPerBlock=1. // - // With these invariants, split is trivial: Every other value goes into - // return value 0, and every other goes into return value 1. + // With these invariants, split is trivial: We can count how many contiguous + // registers belong to the same chunk then we separate the registers between + // two different chunks. + int numContiguousValues = 1; + auto encoding = cast( + cast(op.getSrc().getType()).getEncoding()); + int splitDim = encoding.getOrder().size() - 1; + for (int i = 0; i < encoding.getOrder().size(); i++) { + if (encoding.getOrder()[i] == splitDim) + break; + numContiguousValues *= encoding.getSizePerThread()[i]; + } Location loc = op->getLoc(); auto typeConverter = getTypeConverter(); SmallVector srcVals = @@ -180,9 +194,11 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern { assert(srcVals.size() % 2 == 0); SmallVector outLhsVals; SmallVector outRhsVals; - for (int i = 0; i < srcVals.size(); i += 2) { - outLhsVals.push_back(srcVals[i]); - outRhsVals.push_back(srcVals[i + 1]); + for (int i = 0; i < srcVals.size(); i += 2 * numContiguousValues) { + for (int j = 0; j < numContiguousValues; j++) { + outLhsVals.push_back(srcVals[i + j]); + outRhsVals.push_back(srcVals[i + numContiguousValues + j]); + } } auto resultTy = cast(op.getResult(0).getType()); Value retLhs = @@ -371,7 +387,7 @@ struct MemDescSubviewOpConversion // Compute the offset based on the original strides of the shared memory // object auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides); - auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + auto elemPtrTy = smemObj.base.getType(); smemObj = SharedMemoryObject(gep(elemPtrTy, llvmElemTy, smemObj.base, offset), llvmElemTy, strides, offsetVals); diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 34fb89954..06e75ee18 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -56,20 +56,19 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // This will create newArg, and map(origArg, newArg) addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, - Location loc) -> std::optional { + Location loc) -> Value { llvm_unreachable("Argument rematerialization should not happen in Triton " "-> TritonGPU conversion"); - return std::nullopt; + return {}; }); // If the origValue still has live user(s), use this to // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { llvm_unreachable("Source rematerialization should not happen in Triton -> " "TritonGPU Conversion"); - return std::nullopt; + return {}; }); // This will be called when (desiredType != newOperandType) @@ -79,7 +78,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, ValueRange inputs, Location loc) { auto cast = builder.create(loc, tensorType, inputs); - return std::optional(cast.getResult()); + return cast.getResult(); }); } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 4aa2712ec..bd17e2d7c 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -551,8 +551,11 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, - GenericOpPattern, TritonFuncOpPattern>(typeConverter, - context); + GenericOpPattern, + GenericOpPattern, + // this assumes the right layout will be set later for dot scaled. + GenericOpPattern, GenericOpPattern, + TritonFuncOpPattern>(typeConverter, context); } // diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index 8f46e8ca8..dc2417712 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -2,6 +2,7 @@ #include "triton/Dialect/Triton/IR/Types.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index ce4f97336..ffea5f3c6 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -15,7 +15,7 @@ namespace triton { void LoadOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getPtr(), + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), triton::GlobalMemory::get()); if (getIsVolatile()) effects.emplace_back(MemoryEffects::Write::get(), @@ -91,8 +91,7 @@ struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { if (!mask) return failure(); - auto constantMask = - llvm::dyn_cast_or_null(mask.getDefiningOp()); + auto constantMask = mask.getDefiningOp(); if (!constantMask) return failure(); @@ -159,8 +158,7 @@ struct CanonicalizeMaskedStorePattern : public OpRewritePattern { if (!mask) return failure(); - auto constantMask = - llvm::dyn_cast_or_null(mask.getDefiningOp()); + auto constantMask = mask.getDefiningOp(); if (!constantMask) return failure(); @@ -224,9 +222,10 @@ LogicalResult TransOp::inferReturnTypes( return failure(); } } - if (isa(argTy)) { - inferredReturnTypes.push_back( - MemDescType::get(retShape, retEltTy, retEncoding)); + if (auto memDescTy = dyn_cast(argTy)) { + inferredReturnTypes.push_back(MemDescType::get( + retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), + memDescTy.getMutableMemory())); } else { inferredReturnTypes.push_back( RankedTensorType::get(retShape, retEltTy, retEncoding)); @@ -270,8 +269,8 @@ DotOp::inferReturnTypes(MLIRContext *context, std::optional location, auto bEnc = cast(operands[1].getType()).getEncoding(); auto retEnc = accTy.getEncoding(); if (aEnc) { - assert(bEnc); - Dialect &dialect = aEnc.getDialect(); + assert(bEnc && retEnc); + Dialect &dialect = retEnc.getDialect(); auto interface = dyn_cast(&dialect); if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) return failure(); @@ -295,7 +294,11 @@ LogicalResult DotOp::verify() { // Verify that the encodings are valid. if (!aEncoding || !bEncoding) return emitError("mismatching encoding between A and B operands"); - Dialect &dialect = aEncoding.getDialect(); + auto accTy = getC().getType(); + auto retEnc = accTy.getEncoding(); + if (!retEnc) + return emitError("miss encoding of C operand"); + Dialect &dialect = retEnc.getDialect(); auto interface = cast(&dialect); return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, bEncoding); @@ -335,8 +338,8 @@ LogicalResult MakeRangeOp::verify() { //-- ReduceOp -- static LogicalResult -inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy, - int axis, SmallVectorImpl &inferredReturnTypes) { +inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis, + SmallVectorImpl &inferredReturnTypes) { auto retShape = argTy.getShape().vec(); retShape.erase(retShape.begin() + axis); if (retShape.empty()) { @@ -543,6 +546,8 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { auto value = adaptor.getSrc(); if (!value) return {}; + if (!isa(value)) + return {}; auto shapedType = cast(getType()); auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); return ret; @@ -673,7 +678,7 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op, } LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { - if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + if (!op.getAllowReorder() || op.getEfficientLayout()) return failure(); return canonicalizeViewOrBroadcast(op, rewriter); } @@ -756,6 +761,26 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { return {}; } +LogicalResult BroadcastOp::verify() { + auto src = getSrc(); + auto srcTensorType = cast(src.getType()); + auto srcShape = srcTensorType.getShape(); + auto result = getResult(); + auto resultTensorType = cast(result.getType()); + auto resultShape = resultTensorType.getShape(); + if (srcShape.size() != resultShape.size()) { + return emitError("rank of source must be same as rank of result"); + } + for (int i = 0; i < srcShape.size(); i++) { + if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) { + return emitError("Different dimensions at index ") + << i << " between source and result. " + << "Broadcast requires the source dimension to be 1."; + } + } + return success(); +} + //-- MakeTensorPtrOp -- void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, Value base, ValueRange shape, ValueRange strides, @@ -775,6 +800,19 @@ void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, builder.getDenseI32ArrayAttr(order)); } +//-- AdvanceOp -- +OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { + // advance(ptr, 0, 0) -> ptr + SmallVector rawOffsets = getOffsets(); + auto offsets = getConstantIntValues(rawOffsets); + if (!offsets.has_value()) + return {}; + for (int64_t offset : offsets.value()) + if (offset != 0) + return {}; + return getPtr(); +} + // The following ops, including `call`, `func`, and `return` are copied and // modified from // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -978,5 +1016,29 @@ void ExternElementwiseOp::getEffects( SideEffects::DefaultResource::get()); } +Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + +// -- ExperimentalTensormapCreateOp -- +LogicalResult ExperimentalTensormapCreateOp::verify() { + auto rank = getBoxDim().size(); + if (getGlobalDim().size() != rank) { + return emitError("Rank mismatch for global dim. Got") + << getGlobalDim().size() << " but expected " << rank; + } + if (getGlobalStride().size() + 1 != rank) { + return emitError("Rank mismatch for global stride. Got") + << getGlobalStride().size() << " but expected " << rank - 1; + } + if (getElementStride().size() != rank) { + return emitError("Rank mismatch for element stride. Got") + << getElementStride().size() << " but expected " << rank; + } + return success(); +} + } // namespace triton } // namespace mlir diff --git a/lib/Dialect/Triton/IR/Types.cpp b/lib/Dialect/Triton/IR/Types.cpp index 0e1df5b74..6e41e70a8 100644 --- a/lib/Dialect/Triton/IR/Types.cpp +++ b/lib/Dialect/Triton/IR/Types.cpp @@ -1,6 +1,7 @@ #include "triton/Dialect/Triton/IR/Types.h" #include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` @@ -70,16 +71,24 @@ Type MemDescType::parse(AsmParser &parser) { return Type(); } bool mutableMemory = false; + Attribute memorySpace; if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseOptionalKeyword(kMutableMemory))) { + if (parser.parseAttribute(memorySpace)) + return Type(); + } else { + mutableMemory = true; + } + } + if (mutableMemory == false && succeeded(parser.parseOptionalComma())) { if (parser.parseOptionalKeyword(kMutableMemory)) return Type(); mutableMemory = true; } if (parser.parseGreater()) return Type(); - return MemDescType::get(parser.getContext(), dimensions, elementType, - encoding, mutableMemory); + encoding, memorySpace, mutableMemory); } void MemDescType::print(AsmPrinter &printer) const { @@ -89,6 +98,8 @@ void MemDescType::print(AsmPrinter &printer) const { printer << getElementType(); if (getEncoding()) printer << ", " << getEncoding(); + if (getMemorySpace()) + printer << ", " << getMemorySpace(); if (getMutableMemory()) printer << ", " << kMutableMemory; printer << ">"; @@ -147,7 +158,22 @@ Type getPointerTypeSameShape(Type type) { } } -Type getPointerType(Type type) { return PointerType::get(type, 1); } +Type getPointerTypeToElement(Type type) { + Type elementType = getElementTypeOrSelf(type); + PointerType ptrType = PointerType::get(elementType, 1); + return ptrType; +} + +// upstream Triton only uses address space 1 for Pointer Type +Type getPointerType(Type type, int addressSpace) { + return PointerType::get(type, addressSpace); +} + +int getAddressSpace(Type type) { + if (auto ptrType = dyn_cast(type)) + return ptrType.getAddressSpace(); + return 1; +} bool isTensorPointerType(Type type) { if (auto ptrType = dyn_cast(type)) diff --git a/lib/Dialect/Triton/Transforms/CMakeLists.txt b/lib/Dialect/Triton/Transforms/CMakeLists.txt index 298398750..cda076d4e 100644 --- a/lib/Dialect/Triton/Transforms/CMakeLists.txt +++ b/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_public_tablegen_target(TritonCombineIncGen) add_triton_library(TritonTransforms Combine.cpp + LoopUnroll.cpp ReorderBroadcast.cpp RewriteTensorPointer.cpp diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index c5d638754..33c4516b4 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -113,8 +113,7 @@ class CombineSelectMaskedLoadPattern : public RewritePattern { Value falseValue = selectOp.getFalseValue(); Value condSelect = selectOp.getCondition(); - auto *loadOpCandidate = trueValue.getDefiningOp(); - auto loadOp = llvm::dyn_cast_or_null(loadOpCandidate); + auto loadOp = trueValue.getDefiningOp(); if (!loadOp) return failure(); @@ -122,8 +121,7 @@ class CombineSelectMaskedLoadPattern : public RewritePattern { if (!mask) return failure(); - auto *splatOpCandidate = mask.getDefiningOp(); - auto splatOp = llvm::dyn_cast_or_null(splatOpCandidate); + auto splatOp = mask.getDefiningOp(); if (!splatOp) return failure(); @@ -175,26 +173,21 @@ class CombineBroadcastMulReducePattern : public RewritePattern { if (!isReduceAdd) return failure(); // operand of reduce has to be mul - auto mulOp = llvm::dyn_cast_or_null( - reduceOp.getOperand(0).getDefiningOp()); + auto mulOp = reduceOp.getOperand(0).getDefiningOp(); if (!mulOp) return failure(); // mul operand has to be broadcast - auto broadcastLhsOp = llvm::dyn_cast_or_null( - mulOp.getOperand(0).getDefiningOp()); + auto broadcastLhsOp = mulOp.getOperand(0).getDefiningOp(); if (!broadcastLhsOp) return failure(); - auto broadcastRhsOp = llvm::dyn_cast_or_null( - mulOp.getOperand(1).getDefiningOp()); + auto broadcastRhsOp = mulOp.getOperand(1).getDefiningOp(); if (!broadcastRhsOp) return failure(); // broadcast operand is expand dims - auto expandLhsOp = llvm::dyn_cast_or_null( - broadcastLhsOp.getSrc().getDefiningOp()); + auto expandLhsOp = broadcastLhsOp.getSrc().getDefiningOp(); if (!expandLhsOp) return failure(); - auto expandRhsOp = llvm::dyn_cast_or_null( - broadcastRhsOp.getSrc().getDefiningOp()); + auto expandRhsOp = broadcastRhsOp.getSrc().getDefiningOp(); if (!expandRhsOp) return failure(); // get not-broadcast dimensions diff --git a/lib/Dialect/Triton/Transforms/LoopUnroll.cpp b/lib/Dialect/Triton/Transforms/LoopUnroll.cpp new file mode 100644 index 000000000..257e734b7 --- /dev/null +++ b/lib/Dialect/Triton/Transforms/LoopUnroll.cpp @@ -0,0 +1,67 @@ +#include + +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "triton-loop-unroll" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton { + +static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor"; + +namespace { + +class LoopUnrollPass : public TritonLoopUnrollBase { + + int getUnrollFactorOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise set the + // factor to 1 to suppress the unrolling. + if (auto factor = forOp->getAttrOfType( + mlir::triton::loopUnrollFactorAttrName)) + return factor.getInt(); + return 1; + } + +public: + LoopUnrollPass() = default; + LoopUnrollPass(const LoopUnrollPass &) {} + void runOnOperation() override { + LDBG("Loop unroll pass"); + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with unroll factor <= 1. + if (getUnrollFactorOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + for (auto loop : loops) { + auto unrollFactor = getUnrollFactorOrDefault(loop); + loop->removeAttr(mlir::triton::loopUnrollFactorAttrName); + LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop); + (void)loopUnrollByFactor(loop, unrollFactor); + } + } +}; + +} // anonymous namespace + +std::unique_ptr createLoopUnrollPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 52f4ba0b3..a7028ef20 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -1,7 +1,9 @@ #include #include +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" @@ -171,10 +173,7 @@ struct RewritedInfo { auto otherTensorType = RankedTensorType::get(tensorShape, elementType); // Set zero padding value - TypedAttr attr = - elementType.isIntOrIndex() - ? cast(builder.getIntegerAttr(elementType, 0)) - : cast(builder.getFloatAttr(elementType, 0)); + TypedAttr attr = builder.getZeroAttr(elementType); // Float NaN padding case if (padding.value() == triton::PaddingOption::PAD_NAN) { @@ -209,18 +208,20 @@ class RewriteTensorPointerPass }); } - static SmallVector - generateNewOperands(const SmallVector &oldOperands, unsigned index, - const SmallVector &newValues) { - assert(index < oldOperands.size()); - SmallVector newOperands; - for (int i = 0; i < index; ++i) - newOperands.push_back(oldOperands[i]); - for (auto value : newValues) - newOperands.push_back(value); - for (auto i = index + 1; i < oldOperands.size(); ++i) - newOperands.push_back(oldOperands[i]); - return newOperands; + static void generateNewOperands(SmallVector &oldOperands, + unsigned index, ArrayRef newValues) { + size_t size = oldOperands.size(); + assert(index < size); + SmallVector operands = oldOperands; + oldOperands.reserve(size - 1 + newValues.size()); + oldOperands.clear(); + if (index != 0) { + oldOperands.append(operands.begin(), operands.begin() + index); + } + oldOperands.append(newValues.begin(), newValues.end()); + if (index != size - 1) { + oldOperands.append(operands.begin() + index + 1, operands.end()); + } } Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, @@ -313,10 +314,14 @@ class RewriteTensorPointerPass loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); op->getResult(0).replaceAllUsesWith(newResult); + if (op->getAttr("async_task_id")) + newResult->setAttr("async_task_id", op->getAttr("async_task_id")); } else if (auto storeOp = dyn_cast(op)) { - builder.create(storeOp.getLoc(), newPtr, - storeOp.getValue(), newMask, - storeOp.getCache(), storeOp.getEvict()); + auto newOp = builder.create( + storeOp.getLoc(), newPtr, storeOp.getValue(), newMask, + storeOp.getCache(), storeOp.getEvict()); + if (op->getAttr("async_task_id")) + newOp->setAttr("async_task_id", op->getAttr("async_task_id")); } // Erase the original operation @@ -341,7 +346,7 @@ class RewriteTensorPointerPass needRewrite = true; auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]); assert(rewritedInfo.count(makeTensorPtrOp.getResult())); - auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + const auto &info = rewritedInfo[makeTensorPtrOp.getResult()]; for (unsigned j = 0; j < info.length(); ++j) { newRetTypes.push_back(builder.getI64Type()); } @@ -358,7 +363,7 @@ class RewriteTensorPointerPass } auto rematerialize = [&](Block *block) { for (Operation &opInIf : block->getOperations()) { - auto newOp = builder.clone(opInIf, mapping); + builder.clone(opInIf, mapping); } }; builder.setInsertionPointToStart(newOp.thenBlock()); @@ -369,9 +374,11 @@ class RewriteTensorPointerPass } // update rewritedInfo + auto opResults = op.getResults(); unsigned oldResIdx = 0, newResIdx = 0; while (oldResIdx < results.size()) { if (!triton::isTensorPointerType(results[oldResIdx].getType())) { + opResults[oldResIdx].replaceAllUsesWith(newOp.getResult(newResIdx)); oldResIdx++; newResIdx++; } else { @@ -403,8 +410,7 @@ class RewriteTensorPointerPass // Expand the tensor pointer into offsets assert(rewritedInfo.count(newIterOperands[i])); auto info = rewritedInfo[newIterOperands[i]]; - newIterOperands = - generateNewOperands(newIterOperands, i, info.getOffsets()); + generateNewOperands(newIterOperands, i, info.getOffsets()); i += info.length() - 1; size += info.length() - 1; } @@ -413,6 +419,7 @@ class RewriteTensorPointerPass auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), newIterOperands); + newForOp->setAttrs(op->getAttrs()); // Create value mapping. Note that for tensor pointers, we use identity // mapping. It may refer to a value in the old loop, but we will rewrite it @@ -439,9 +446,7 @@ class RewriteTensorPointerPass // Clone body builder.setInsertionPointToStart(newForOp.getBody()); for (auto &opInFor : *op.getBody()) { - auto *newOp = builder.clone(opInFor, mapping); - for (unsigned i = 0; i < opInFor.getNumResults(); ++i) - mapping.map(op->getResult(i), newOp->getResult(i)); + builder.clone(opInFor, mapping); } // Replace later usages @@ -476,7 +481,7 @@ class RewriteTensorPointerPass assert(rewritedInfo.count(newOperands[i])); auto info = rewritedInfo[newOperands[i]]; - newOperands = generateNewOperands(newOperands, i, info.getOffsets()); + generateNewOperands(newOperands, i, info.getOffsets()); i += info.length() - 1; size += info.length() - 1; } @@ -492,15 +497,13 @@ class RewriteTensorPointerPass // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers // Rewriting functions return the next operation to visit, if there is no // next one, simply return `nullptr` - std::pair rewrited; if (auto makeTensorPtrOp = dyn_cast(op)) { return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser); } else if (auto advanceOp = dyn_cast(op)) { return rewriteAdvanceOp(builder, advanceOp, eraser); } else if (isa(op) || isa(op)) { return rewriteLoadStoreOp(builder, op, eraser); - } else if (op->getDialect()->getNamespace() == "scf" || - op->getDialect()->getNamespace() == "cf") { + } else if (isa(op->getDialect())) { if (auto ifOp = dyn_cast(op)) { return rewriteIfOp(builder, ifOp, eraser); } @@ -524,18 +527,12 @@ class RewriteTensorPointerPass } void visitOperation(Operation *op, std::stack &eraser) { - for (auto ®ion : op->getRegions()) { - for (auto &block : region) { - // We need an extra copy because erasing operations may break the - // iterator behavior - SmallVector blockCopy; - for (auto &nestedOp : block) - blockCopy.push_back(&nestedOp); - - // Rewrite and recursively visit - for (auto &nestedOp : blockCopy) { - if (auto newOp = rewriteOp(nestedOp, eraser)) + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : llvm::make_early_inc_range(block)) { + if (auto newOp = rewriteOp(&nestedOp, eraser)) { visitOperation(newOp, eraser); + } } } } diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index b5dcdb5ea..98831f0db 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonGPUIR Dialect.cpp LinearLayoutConversions.cpp + Ops.cpp Types.cpp DEPENDS diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 69067b706..2e822e722 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1,5 +1,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" +#include #include #include "mlir/IR/DialectImplementation.h" @@ -7,7 +8,11 @@ #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/TypeSwitch.h" @@ -230,8 +235,36 @@ static SmallVector eraseOrder(ArrayRef order, return resOrder; } +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kMajor) { + // kMajor: if true, the matrix is fastest-running on k, + // otherwise it is on m (resp. n) + // opIdx=0: [batch, m, k] if rank == 3 else [m, k] + // opIdx=1: [batch, k, n] if rank == 3 else [k, n] + // batch (if rank == 3) is always the slowest running dimension + assert(rank == 2 || rank == 3); + assert(opIdx == 0 || opIdx == 1); + SmallVector order(rank); + std::iota(order.rbegin(), order.rend(), 0); + // If opIdx is 1 and kMajor is true, the order is [0, 1] + // (resp. [1, 2, 0] if rank == 3) + // Same if opIdx is 0 and kMajor is false + if (bool(opIdx) == kMajor) { + std::swap(order[0], order[1]); + } + return order; +} + SmallVector getWarpOrder(Attribute layout) { + if (auto dotLayout = dyn_cast(layout)) { + if (isa(dotLayout.getParent())) { + return getWarpOrder(dotLayout.getParent()); + } + } auto order = getOrder(layout); + // FIXME: This mmaLayout if should just return + // getOrderForDotOperand(0, order.size(), kMajor=false) + // as mma has the same order as DotOperand(opIdx=0) if (auto mmaLayout = dyn_cast(layout)) { if (mmaLayout.isHopper()) { // Hopper MMA instructions force a warp order of [0, 1]. See docs: @@ -240,51 +273,51 @@ SmallVector getWarpOrder(Attribute layout) { order.erase(it); order.insert(order.begin(), 0); } + } else if (auto dotOpLayout = dyn_cast(layout)) { + order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(), + /*kMajor*/ false); } return order; } SmallVector getOrder(Attribute layout) { if (auto blockedLayout = dyn_cast(layout)) { - return SmallVector(blockedLayout.getOrder().begin(), - blockedLayout.getOrder().end()); - } else if (auto mmaLayout = dyn_cast(layout)) { + return llvm::to_vector(blockedLayout.getOrder()); + } + if (auto mmaLayout = dyn_cast(layout)) { auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); SmallVector order(rank); - for (auto i = 0; i < rank; ++i) - order[i] = rank - 1 - i; - if (auto mfmaLayout = dyn_cast(layout)) { - if (mfmaLayout.getIsTransposed()) { - std::swap(order[rank - 2], order[rank - 1]); - } - } - return order; - } else if (auto dotLayout = dyn_cast(layout)) { - auto rank = getWarpsPerCTA(dotLayout.getParent()).size(); - SmallVector order(rank); - for (auto i = 0; i < rank; ++i) - order[i] = rank - 1 - i; + std::iota(order.rbegin(), order.rend(), 0); return order; - } else if (auto sliceLayout = dyn_cast(layout)) { + } + if (auto dotLayout = dyn_cast(layout)) { + auto rank = dotLayout.getWarpsPerCTA().size(); + return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); + } + if (auto sliceLayout = dyn_cast(layout)) { SmallVector parentOrder = getOrder(sliceLayout.getParent()); unsigned dim = sliceLayout.getDim(); SmallVector order; for (unsigned d : parentOrder) { - if (d == dim) - continue; - else if (d > dim) - order.push_back(d - 1); - else - order.push_back(d); + if (d != dim) + order.push_back(d > dim ? d - 1 : d); } return order; - } else if (auto sharedLayout = mlir::dyn_cast(layout)) { - return SmallVector(sharedLayout.getOrder().begin(), - sharedLayout.getOrder().end()); - } else { - llvm::report_fatal_error("Unimplemented usage of getOrder"); } + if (auto sharedLayout = mlir::dyn_cast(layout)) { + return llvm::to_vector(sharedLayout.getOrder()); + } + + llvm::report_fatal_error("Unimplemented usage of getOrder"); + return {}; +}; + +SmallVector getThreadOrder(Attribute layout) { + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getThreadOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); return {}; }; @@ -305,8 +338,6 @@ SmallVector getCTAsPerCGA(Attribute layout) { ArrayRef ref; if (auto distributedLayout = mlir::dyn_cast(layout)) return distributedLayout.getCTAsPerCGA(); - else if (mlir::isa(layout)) - return {1, 1}; else if (auto sharedLayout = mlir::dyn_cast(layout)) ref = sharedLayout.getCTALayout().getCTAsPerCGA(); else @@ -319,9 +350,6 @@ SmallVector getCTASplitNum(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { return distributedLayout.getCTASplitNum(); - } else if (mlir::isa(layout)) { - res.resize(2); - res[0] = res[1] = 1; } else if (auto sharedLayout = mlir::dyn_cast(layout)) { res.assign(sharedLayout.getCTALayout().getCTASplitNum().begin(), sharedLayout.getCTALayout().getCTASplitNum().end()); @@ -336,8 +364,6 @@ SmallVector getCTAOrder(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { res = distributedLayout.getCTAOrder(); - } else if (mlir::isa(layout)) { - return {0, 1}; } else if (auto sharedLayout = mlir::dyn_cast(layout)) { res = SmallVector(sharedLayout.getCTALayout().getCTAOrder()); } else { @@ -361,9 +387,9 @@ SmallVector getShapePerCTA(ArrayRef CTASplitNum, SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { if (auto sharedLayout = mlir::dyn_cast(layout)) { // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. - // The first dim of shape is numStages. This is a work around, otherwise too - // many places would have to be modified in pipeline pass. Maybe we need to - // refactor this logic in the future. + // The first dim of shape is numStages. This is a work around, otherwise + // too many places would have to be modified in pipeline pass. Maybe we + // need to refactor this logic in the future. auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); if (shape.size() == CTASplitNum.size() + 1) { auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); @@ -386,7 +412,8 @@ unsigned getNumWarpsPerCTA(Attribute layout) { else if (auto sliceLayout = dyn_cast(layout)) return getNumWarpsPerCTA(sliceLayout.getParent()); else if (auto mmaLayout = dyn_cast(layout)) { - // Use the distributed layout interface to get the number of warps per CTA. + // Use the distributed layout interface to get the number of warps per + // CTA. auto distributedLayout = cast(layout); warpsPerCTA = distributedLayout.getWarpsPerCTA(); } else if (auto mfmaLayout = dyn_cast(layout)) @@ -406,10 +433,6 @@ unsigned getNumCTAs(Attribute layout) { return product(getCTAsPerCGA(layout)); } -bool isaDistributedLayout(Attribute layout) { - return isa(layout); -} - template bool hasEncoding(Value value) { auto type = value.getType(); if (auto tensorType = dyn_cast(type)) { @@ -424,9 +447,9 @@ bool hasDotOperandEncoding(Value value) { } bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { - // If the new elements per thread is less than the old one, we will need to do - // convert encoding that goes through shared memory anyway. So we consider it - // as expensive. + // If the new elements per thread is less than the old one, we will need to + // do convert encoding that goes through shared memory anyway. So we + // consider it as expensive. RankedTensorType tensorTy = cat.getType(); auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); auto shape = tensorTy.getShape(); @@ -451,6 +474,15 @@ LogicalResult CTALayoutAttr::verify( << CTAOrder << "]"; } + if (llvm::any_of(CTAsPerCGA, [](unsigned x) { return x == 0; })) { + return emitError() << "Every element in CTAsPerCGA must be greater than 0."; + } + + if (llvm::any_of(CTASplitNum, [](unsigned x) { return x == 0; })) { + return emitError() + << "Every element in CTASplitNum must be greater than 0."; + } + return success(); } @@ -794,7 +826,7 @@ unsigned AMDMfmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return product(getElemsPerThread(shape, eltTy)); } -// +// Wmma encoding SmallVector AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef shape, @@ -803,7 +835,7 @@ AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef shape, assert((rank == 2 || rank == 3) && "Unexpected rank of wmma layout"); SmallVector elemsPerThread(rank); - auto mnkDim = getMNKDimPerWMMAInstr(); + auto mnkDim = getMNKDimPerInstr(); auto elemsPerThreadPerTile = getSizePerThread(); auto warpsPerCTA = getWarpsPerCTA(); @@ -823,8 +855,6 @@ unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return product(getElemsPerThread(shape, eltTy)); } -// - SmallVector NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { @@ -929,6 +959,27 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, SmallVector DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { + + if (auto parent = mlir::dyn_cast(getParent())) { + auto rank = shape.size(); + assert(rank == 2 || rank == 3); + + auto idx = getOpIdx(); + assert(idx == 0 || idx == 1); + + SmallVector elemsPerThread(rank); + + auto kWidth = getKWidth(); + auto rep = parent.getRepForOperand(shape, kWidth, idx); + + if (rank == 3) + elemsPerThread[0] = rep[0]; + elemsPerThread[rank - 2] = (idx == 0) ? rep[1] : rep[1] * kWidth; + elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2]; + + return elemsPerThread; + } + llvm_unreachable("getElemsPerThread is not supported for dot operand"); return SmallVector(); } @@ -936,8 +987,8 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { if (auto mmaParent = mlir::dyn_cast(getParent())) { - return mmaParent.getTotalElemsPerThreadForOperands(shape, eltTy, - getKWidth(), getOpIdx()); + return mmaParent.getTotalElemsPerThreadForOperand(shape, eltTy, getKWidth(), + getOpIdx()); } if (auto blockedLayout = mlir::dyn_cast(getParent())) { auto shapePerCTA = getShapePerCTA(*this, shape); @@ -983,30 +1034,27 @@ SmallVector DotOperandEncodingAttr::getCTASplitNum() const { return res; } SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { - auto parentLayout = getParent(); - assert(parentLayout && "DotOperandEncodingAttr must have a parent"); - if (auto distributedLayout = - mlir::dyn_cast(parentLayout)) { - return distributedLayout.getWarpsPerCTA(); - } else { - llvm::report_fatal_error( - "DotOperandEncodingAttr non-DistributedEncodingAttr parent not " - "supported yet"); - } + auto distributedLayout = mlir::cast(getParent()); + auto warps = distributedLayout.getWarpsPerCTA(); + auto rank = warps.size(); + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + warps[kDim] = 1; + return warps; } SmallVector DotOperandEncodingAttr::getWarpOrder() const { return ::getWarpOrder(*this); } SmallVector DotOperandEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), + /*kMajor*/ true); } SmallVector DotOperandEncodingAttr::getShapePerCTATile( ArrayRef tensorShape) const { auto parentLayout = getParent(); assert(parentLayout && "DotOperandEncodingAttr must have a parent"); if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getShapePerCTATileForDotOperands(tensorShape, - getOpIdx()); + return parentMmaLayout.getShapePerCTATileForOperand( + tensorShape, getKWidth(), getOpIdx()); } else { llvm::report_fatal_error( "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " @@ -1037,10 +1085,10 @@ LogicalResult DotOperandEncodingAttr::verify( } if (auto parentAttr = mlir::dyn_cast(parent)) { - // TODO: remove this condition if new values are supported - if (kWidth != 16) - return emitError() << "triton_gpu.dot_op kWidth parameter supports " - "only 16 for WMMA parent"; + if (kWidth != 16 && parentAttr.getVersion() == 1 || + kWidth != 8 && parentAttr.getVersion() == 2) + return emitError() << "triton_gpu.dot_op kWidth parameter must be 16 for " + "gfx11 and 8 for gfx12"; return success(); } @@ -1356,12 +1404,17 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseGreater().failed()) return {}; + unsigned version = 0; SmallVector warpsPerCTA; std::optional> CTAsPerCGA; std::optional> CTASplitNum; std::optional> CTAOrder; for (const NamedAttribute &attr : dict) { + if (attr.getName() == "version") { + if (parseUInt(parser, attr, version, "version").failed()) + return {}; + } if (attr.getName() == "warpsPerCTA") { if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) return {}; @@ -1388,13 +1441,14 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { if (!CTALayout.has_value()) return {}; - return parser.getChecked(parser.getContext(), + return parser.getChecked(parser.getContext(), version, warpsPerCTA, *CTALayout); } void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" - << "warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + << "version = " << getVersion() << ", warpsPerCTA = [" + << ArrayRef(getWarpsPerCTA()) << "]"; maybePrintCTALayout(getContext(), printer, getCTALayout(), /*rank=*/getWarpsPerCTA().size()); printer << "}>"; @@ -1534,7 +1588,10 @@ SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { return ::getWarpOrder(*this); } SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto order = ::getOrder(*this); + if (getIsTransposed()) + std::swap(order[0], order[1]); + return order; } SmallVector AMDMfmaEncodingAttr::getThreadsPerWarp() const { unsigned rows, cols; @@ -1582,15 +1639,15 @@ SmallVector AMDMfmaEncodingAttr::getSizePerThread() const { } SmallVector -AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { +AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { unsigned mDim = getMDim(); unsigned nDim = getNDim(); assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); - constexpr int waveSize = 64; // MFMA is used on wave64 architectures only + constexpr int warpSize = 64; // MFMA is always based on the 64-wide warps. int kGroups = -1; if (mDim == nDim) - kGroups = waveSize / mDim; + kGroups = warpSize / mDim; if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) kGroups = 1; int64_t kDim = kWidth * kGroups; @@ -1602,9 +1659,9 @@ AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { } SmallVector -AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, - int kWidth, int opIdx) const { - auto operandTileShape = getMFMAInstrShapeForOperands(kWidth, opIdx); +AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + int kWidth, int opIdx) const { + auto operandTileShape = getInstrShapeForOperand(kWidth, opIdx); auto rank = operandShape.size(); auto warpsPerCTA = getWarpsPerCTA(); int numRepBatch = @@ -1625,27 +1682,31 @@ AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, } } -unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - auto rep = getMFMARepForOperands(shape, kWidth, opIdx); + auto rep = getRepForOperand(shape, kWidth, opIdx); return product(rep) * kWidth; } SmallVector -AMDMfmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { + auto rank = getWarpsPerCTA().size(); + auto sizePerThread = SmallVector(rank, 1); if (opIdx == 0) { - return {4, 1}; + sizePerThread[rank - 2] = 1; + sizePerThread[rank - 1] = kWidth; } else if (opIdx == 1) { - return {1, 4}; + sizePerThread[rank - 2] = kWidth; + sizePerThread[rank - 1] = 1; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - return {}; } + return sizePerThread; } SmallVector -AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, + int kWidth, int opIdx) const { assert(getMDim() == 32 || getMDim() == 16); auto parentShapePerCTATile = getShapePerCTATile(shape); auto rank = parentShapePerCTATile.size(); @@ -1665,13 +1726,17 @@ AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1"); } +//===----------------------------------------------------------------------===// +// Wmma encoding +//===----------------------------------------------------------------------===// + SmallVector AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { auto warpsPerCTA = getWarpsPerCTA(); auto rank = warpsPerCTA.size(); SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); - auto mnkDim = getMNKDimPerWMMAInstr(); + auto mnkDim = getMNKDimPerInstr(); shapePerCTATile[rank - 2] *= mnkDim[0]; shapePerCTATile[rank - 1] *= mnkDim[1]; return shapePerCTATile; @@ -1697,7 +1762,7 @@ SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { SmallVector AMDWmmaEncodingAttr::getThreadsPerWarp() const { auto rank = getWarpsPerCTA().size(); SmallVector threads(rank, 1); - auto mnkInstr = getMNKDimPerWMMAInstr(); + auto mnkInstr = getMNKDimPerInstr(); threads[rank - 2] = mnkInstr[0] / getSizePerThread()[rank - 2]; threads[rank - 1] = mnkInstr[1] / getSizePerThread()[rank - 1]; return threads; @@ -1711,14 +1776,17 @@ SmallVector AMDWmmaEncodingAttr::getSizePerThread() const { return sizePerThread; } SmallVector -AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { auto rank = getWarpsPerCTA().size(); SmallVector sizePerThread(rank, 1); + auto numReplicated = getVersion() == 1 ? 2 : 1; + auto elemsPerInstr = numReplicated * product(getElemsPerInstrForOperands()) / + product(getThreadsPerWarp()); if (opIdx == 0) { sizePerThread[rank - 2] = 1; - sizePerThread[rank - 1] = 16; + sizePerThread[rank - 1] = elemsPerInstr; } else if (opIdx == 1) { - sizePerThread[rank - 2] = 16; + sizePerThread[rank - 2] = elemsPerInstr; sizePerThread[rank - 1] = 1; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); @@ -1727,11 +1795,11 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { } SmallVector -AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, + int kWidth, int opIdx) const { auto parentShapePerCTA = getShapePerCTATile(shape); auto rank = shape.size(); - assert(rank = 2); + assert(rank == 2); if (opIdx == 0) { return {parentShapePerCTA[0], static_cast(shape[1])}; } else if (opIdx == 1) { @@ -1741,22 +1809,21 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, } } -unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx); + auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx); return product(rep) * kWidth; } -SmallVector -AMDWmmaEncodingAttr::getWMMAElemsPerInstrForOperands() const { +SmallVector AMDWmmaEncodingAttr::getElemsPerInstrForOperands() const { return {16, 16}; } SmallVector -AMDWmmaEncodingAttr::getWMMARepForOperands(ArrayRef operandShape, - Type elemType, int kWidth, - int opIdx) const { - auto operandTileShape = getWMMAElemsPerInstrForOperands(); +AMDWmmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, + int opIdx) const { + auto operandTileShape = getElemsPerInstrForOperands(); assert(operandTileShape.size() == 2); auto warpsPerCTA = getWarpsPerCTA(); auto rank = operandShape.size(); @@ -1779,7 +1846,7 @@ AMDWmmaEncodingAttr::getWMMARepForOperands(ArrayRef operandShape, } } -SmallVector AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr() { +SmallVector AMDWmmaEncodingAttr::getMNKDimPerInstr() { // TODO: move magic numbers out of the code return {16, 16, 16}; } @@ -1946,11 +2013,11 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { return 2 * getMMAv1Rep(opIdx)[opIdx]; } -SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, - int bitwidth, - int opIdx) const { +SmallVector NvidiaMmaEncodingAttr::getMMAv2RepForOperand( + ArrayRef shape, int bitwidth, int kWidth, int opIdx) const { auto rank = shape.size(); auto warpsPerCTA = getWarpsPerCTA(); + SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; int numRepBatch = rank == 3 @@ -1971,7 +2038,7 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, warpsPerCTA[rank - 1]))}; } } -unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); int warpsPerCTAM = getWarpsPerCTA()[0]; @@ -1982,7 +2049,8 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( } // A100 if (isAmpere()) { - auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); + auto rep = getMMAv2RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(), + kWidth, opIdx); if (opIdx == 0) return 4 * rep[0] * rep[1] * rep[2]; if (opIdx == 1) @@ -2050,43 +2118,58 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( } llvm_unreachable("unknown mma layout"); } -SmallVector -NvidiaMmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( + ArrayRef shape, int kWidth, int opIdx) const { assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); auto parentShapePerCTATile = getShapePerCTATile(shape); auto rank = parentShapePerCTATile.size(); + // 4 threads * 2 subtiles + unsigned kWidthTile = kWidth * 2 * 4; if (opIdx == 0) { if (rank == 2) - return {parentShapePerCTATile[rank - 2], 16}; + return {parentShapePerCTATile[rank - 2], kWidthTile}; else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 16}; + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], + kWidthTile}; } else if (opIdx == 1) { if (rank == 2) - return {16, parentShapePerCTATile[rank - 1]}; + return {kWidthTile, parentShapePerCTATile[rank - 1]}; else - return {parentShapePerCTATile[0], 16, parentShapePerCTATile[rank - 1]}; + return {parentShapePerCTATile[0], kWidthTile, + parentShapePerCTATile[rank - 1]}; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } } SmallVector -NvidiaMmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); + auto rank = getWarpsPerCTA().size(); + auto sizePerThread = SmallVector(rank, 1); if (opIdx == 0) { - return {2, 4}; + sizePerThread[rank - 2] = 2; + sizePerThread[rank - 1] = 2 * kWidth; } else if (opIdx == 1) { - return {4, 1}; + sizePerThread[rank - 2] = 2 * kWidth; + sizePerThread[rank - 1] = 1; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - return {}; } + return sizePerThread; } //===----------------------------------------------------------------------===// // DotOperand Encoding //===----------------------------------------------------------------------===// SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { + auto parent = getParent(); + if (auto mma = mlir::dyn_cast(parent)) { + auto threadsPerWarp = mma.getThreadsPerWarp(); + auto rank = threadsPerWarp.size(); + if (getOpIdx() == 1) + std::swap(threadsPerWarp[rank - 2], threadsPerWarp[rank - 1]); + return threadsPerWarp; + } llvm::report_fatal_error( "getThreadsPerWarp not implemented for DotOperandEncodingAttr"); } @@ -2094,7 +2177,7 @@ SmallVector DotOperandEncodingAttr::getSizePerThread() const { auto parentLayout = getParent(); assert(parentLayout && "DotOperandEncodingAttr must have a parent"); if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getSizePerThreadForOperands(getOpIdx()); + return parentMmaLayout.getSizePerThreadForOperand(getKWidth(), getOpIdx()); } else { llvm::report_fatal_error( "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " @@ -2630,22 +2713,21 @@ struct TritonGPUInferLayoutInterface loc, "SplitOp requires threadsPerWarp, warpsPerCTA, " "and CTAsPerCGA = 1 for the last dimension of the input"); } - if (enc.getOrder().front() != enc.getOrder().size() - 1) { - return emitOptionalError( - loc, "SplitOp requires the last dimension to be most-minor in order"); - } if (enc.getCTALayout().getCTAsPerCGA().back() != 1) { return emitOptionalError( loc, "SplitOp requires the last dimension to be most-minor in CTAOrder"); } - + SmallVector newOrder(enc.getOrder()); + int splitDim = newOrder.size() - 1; + // Remove splitDim from order. + newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim), + newOrder.end()); dstEnc = BlockedEncodingAttr::get( enc.getContext(), // ArrayRef(enc.getSizePerThread()).drop_back(1), ArrayRef(enc.getThreadsPerWarp()).drop_back(1), - ArrayRef(enc.getWarpsPerCTA()).drop_back(1), - ArrayRef(enc.getOrder()).drop_front(1), + ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder), CTALayoutAttr::get(enc.getContext(), // ArrayRef(enc.getCTAsPerCGA()).drop_back(1), ArrayRef(enc.getCTASplitNum()).drop_back(1), @@ -2671,7 +2753,7 @@ struct CanonicalizeConvertFromReshape return failure(); if (isExpensiveView(convert.getSrc().getType(), op.getType())) return failure(); - if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + if (!op.getAllowReorder() || op.getEfficientLayout()) return failure(); rewriter.replaceOpWithNewOp( @@ -2710,8 +2792,9 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); - rewriter.replaceOpWithNewOp( + auto newAlloc = rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); + newAlloc->setAttrs(op->getAttrs()); return mlir::success(); } }; @@ -2727,8 +2810,31 @@ struct CanonicalizeConvertFromLocalStore auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); - rewriter.replaceOpWithNewOp(op, convert.getSrc(), - op.getDst()); + auto store = rewriter.replaceOpWithNewOp( + op, convert.getSrc(), op.getDst()); + store->setAttrs(op->getAttrs()); + return mlir::success(); + } +}; + +struct CanonicalizeConvertFromSplit + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::SplitOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + auto srcEncoding = convert.getSrc().getType().getEncoding(); + // Multiple source layout can give the same output layout, if the source + // layout of the convert gives the same destination layout we can skip the + // convert. + auto dstEncoding = inferDstEncoding(op, srcEncoding); + if (dstEncoding != op.getOutLHS().getType().getEncoding()) + return failure(); + rewriter.replaceOpWithNewOp(op, convert.getSrc()); return mlir::success(); } }; @@ -2757,8 +2863,9 @@ struct CanonicalizeConvertFromConvert // for hopper MMAv3 if (mlir::isa(dstType.getEncoding()) && mlir::isa(srcType.getEncoding()) && - llvm::any_of(op.getResult().getUsers(), - [](Operation *dot) { return isa(dot); })) { + llvm::any_of(op.getResult().getUsers(), [](Operation *dot) { + return dot->hasTrait(); + })) { return failure(); } @@ -2768,8 +2875,7 @@ struct CanonicalizeConvertFromConvert // cvt(reshape) -> reshape if (auto reshape = dyn_cast(arg)) { - if (!reshape.getAllowReorder() || - reshape.getEfficientLayout().has_value() || + if (!reshape.getAllowReorder() || reshape.getEfficientLayout() || isExpensiveView(reshape.getSrc().getType(), op.getType())) return failure(); @@ -2802,8 +2908,12 @@ struct CanonicalizeConvertFromConvert if (auto sharedLoad = dyn_cast(arg)) { // Shared_load can load to any layout so we can always fold convert into // it. + // We insert at the point of the original op as there could be ops with + // memory side-effects between the LocalLoad op and the ConvertLayout op + rewriter.setInsertionPoint(arg); rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), sharedLoad.getSrc()); + return success(); } @@ -2820,8 +2930,10 @@ struct CanonicalizeConvertFromConvert // cvt(cvt(x, type1), type2) -> cvt(x, type2) if (auto cvt = dyn_cast(arg)) { auto srcType = op.getSrc().getType(); - rewriter.replaceOpWithNewOp( + auto origAttrs = op->getAttrs(); + auto newOp = rewriter.replaceOpWithNewOp( op, op->getResultTypes().front(), cvt.getSrc()); + newOp->setAttrs(origAttrs); return success(); } @@ -2859,6 +2971,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); } // LocalAllocOp @@ -2875,33 +2988,76 @@ void LocalAllocOp::getEffects( effects.emplace_back(MemoryEffects::Allocate::get(), mlir::triton::gpu::SharedMemory::get()); if (getSrc()) - effects.emplace_back(MemoryEffects::Write::get(), getResult(), + effects.emplace_back(MemoryEffects::Write::get(), + getOperation()->getOpResult(0), mlir::triton::gpu::SharedMemory::get()); } +OpFoldResult LocalAllocOp::fold(FoldAdaptor adaptor) { + if (getType().getMutableMemory()) + return {}; + auto src = getSrc(); + if (!src) + return {}; + auto localLoadOp = src.getDefiningOp(); + if (!localLoadOp) + return {}; + auto loadSrc = localLoadOp.getSrc(); + if (loadSrc.getType() != getType()) + return {}; + return loadSrc; +} + +LogicalResult LocalAllocOp::verify() { + if (!getSrc()) { + if (!getType().getMutableMemory()) + return emitError("uninitialized alloc must have a mutable memdesc type"); + return success(); + } + auto srcTy = getSrc().getType(); + auto dstTy = getType(); + + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match desc element type"); + } + return success(); +} + // LocalLoadOp void LocalLoadOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), mlir::triton::gpu::SharedMemory::get()); } // LocalStoreOp +LogicalResult LocalStoreOp::verify() { + if (!getDst().getType().getMutableMemory()) + return emitOpError("Cannot store into immutable memory"); + return success(); +} + void LocalStoreOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Write::get(), getDst(), + effects.emplace_back(MemoryEffects::Write::get(), &getDstMutable(), mlir::triton::gpu::SharedMemory::get()); } // AsyncCopyGlobalToLocalOp +LogicalResult AsyncCopyGlobalToLocalOp::verify() { + if (!getResult().getType().getMutableMemory()) + return emitOpError("Cannot store into immutable memory"); + return success(); +} + void AsyncCopyGlobalToLocalOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), mlir::triton::GlobalMemory::get()); - effects.emplace_back(MemoryEffects::Write::get(), getResult(), + effects.emplace_back(MemoryEffects::Write::get(), &getResultMutable(), mlir::triton::gpu::SharedMemory::get()); } @@ -2950,6 +3106,325 @@ LogicalResult MemDescSubviewOp::verify() { return success(); } +// -- LocalAllocOp -- + +int32_t LocalAllocOp::getAlignmentOrDefault() { + auto align = getAlignment(); + if (align) { + return *align; + } + + auto ty = getType(); + auto shapePerCTA = triton::gpu::getShapePerCTA(ty); + auto bytes = + product(shapePerCTA) * (ty.getElementTypeBitWidth() / 8); + + // XXX(Keren): magic numbers 256 and 1024 + // Software swizzling calculates phase based on offset, while hardware + // swizzling do that based on physical address. Thus only by setting the + // alignment to 1024 can ensure the correctness. + return bytes > 256 ? 1024 : 8; +} + +//===----------------------------------------------------------------------===// +// Layout debug printing +//===----------------------------------------------------------------------===// + +// Return N-D delinearized indices from a linear index. +static SmallVector delinearizeIndex(int64_t idx, + ArrayRef shape) { + SmallVector ret(shape.size()); + for (int i = shape.size() - 1; i >= 0; i--) { + ret[i] = idx % shape[i]; + idx /= shape[i]; + } + return ret; +} + +// Returns how many padding characters are needed for the string representation +// of value to be the same as max. +static int numCharacterPadding(int value, int max) { + return std::to_string(max).size() - std::to_string(value).size(); +} + +// return the string padded to have the same length as max. +static std::string paddedString(int value, int max) { + int nbChar = numCharacterPadding(value, max); + std::string str; + for (int i = 0; i < nbChar; i++) + str += " "; + str += std::to_string(value); + return str; +} + +std::string getSharedLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + if (!layout) + return ""; + + std::optional ll = + triton::gpu::toLinearLayout(tensorType.getShape(), layout); + if (!ll.has_value()) + llvm::report_fatal_error("Failed to convert layout to linear layout"); + + StringAttr kOffset = StringAttr::get(tensorType.getContext(), "offset"); + StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block"); + int64_t tensorSize = product(tensorType.getShape()); + unsigned numBlocks = getNumCTAs(layout); + int32_t blockSize = tensorSize / numBlocks; + + // elementMapping is for the non-hw layout, offsetMapping for hw-layout + std::vector elementMapping(tensorSize); + std::vector offsetMapping; + + // Shared layouts are a mapping of (block, offset) --> (...) + + // We can just use a single int to index into elementMapping because + // the 'swizzle' operation rearranges the indicies---and we want to keep it + // that way + int32_t idx = 0; + // Enumerate all the offsets for each block + for (int32_t block = 0; block < numBlocks; block++) { + for (int32_t offset = 0; offset < blockSize; offset++) { + SmallVector> inputs = { + {kBlock, block}, + {kOffset, offset}, + }; + + SmallVector> outputs = ll->apply(inputs); + + std::string sharedInfo = "("; + std::string &value = elementMapping[idx]; + + if (!value.empty()) + value += "|"; + + value += "("; + // We can build up both strings (for hw/non-hw layouts) concurrently + for (int i = 0; i < outputs.size(); i++) { + // Based on the formatting from LinearLayout::toString, the format for + // the hw layout is slightly different. HW layouts use "," vs ":". + if (i > 0) { + sharedInfo += ","; + value += ":"; + } + auto index = paddedString(outputs[i].second, tensorType.getDimSize(i)); + sharedInfo += index; + value += index; + } + value += ")"; + sharedInfo += ")"; + + offsetMapping.push_back(sharedInfo); + + idx++; + } + } + + std::string layoutStr; + + if (!useHWPointOfView) { + int rank = tensorType.getRank(); + bool newLine = true; + for (int i = 0; i < tensorSize; i++) { + auto indices = delinearizeIndex(i, tensorType.getShape()); + int numOpenBracket = 0; + for (int j = rank - 1; j >= 0; j--) { + if (indices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "["; + numOpenBracket++; + } + if (newLine) { + for (int j = 0; j < rank - numOpenBracket; j++) + layoutStr += " "; + newLine = false; + } + + layoutStr += elementMapping[i]; + auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape()); + for (int j = rank - 1; j >= 0; j--) { + if (nextIndices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "]"; + } + if (nextIndices.back() % tensorType.getShape().back() == 0) { + layoutStr += "\n"; + newLine = true; + } else { + layoutStr += ","; + } + } + } else { + // For the HW view here, print the (block, offset) --> (r,c) mapping + uint32_t idx = 0; + for (int32_t block = 0; block < numBlocks; block++) { + layoutStr += "Block: " + std::to_string(block) + ":\n"; + for (int32_t offset = 0; offset < (tensorSize / numBlocks); offset++) { + layoutStr += "Offset: " + std::to_string(offset) + " -> "; + layoutStr += offsetMapping[idx]; + layoutStr += "\n"; + idx++; + } + } + } + + return layoutStr; +} + +std::string getDistributedLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + if (!layout) + return ""; + + unsigned threadsPerWarp = getWarpSize(layout); + unsigned numWarpsPerCTA = getNumWarpsPerCTA(layout); + unsigned numBlocks = getNumCTAs(layout); + int numElementsPerThreads = getTotalElemsPerThread(tensorType); + StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register"); + StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane"); + StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp"); + StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block"); + + std::optional ll = + triton::gpu::toLinearLayout(tensorType.getShape(), layout); + if (!ll.has_value()) + llvm::report_fatal_error("Failed to convert layout to linear layout"); + int64_t tensorSize = product(tensorType.getShape()); + std::vector elementMapping(tensorSize); + std::vector threadMapping; + for (int blockId = 0; blockId < numBlocks; ++blockId) { + for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { + for (int tid = 0; tid < threadsPerWarp; ++tid) { + for (int idx = 0; idx < numElementsPerThreads; ++idx) { + SmallVector> inputs = { + {kBlock, blockId}, + {kWarp, warpId}, + {kLane, tid}, + {kRegister, idx}}; + SmallVector> outputs = + ll->apply(inputs); + int32_t linearizedIdx = 0; + int stride = 1; + for (int i = outputs.size() - 1; i >= 0; i--) { + linearizedIdx += outputs[i].second * stride; + stride *= tensorType.getDimSize(i); + } + std::string &value = elementMapping[linearizedIdx]; + if (!value.empty()) + value += "|"; + int padding = numCharacterPadding(blockId, numBlocks) + + numCharacterPadding(tid + warpId * threadsPerWarp, + numWarpsPerCTA * threadsPerWarp) + + numCharacterPadding(idx, numElementsPerThreads); + for (int i = 0; i < padding; i++) + value += " "; + if (numBlocks > 1) + value += "B" + std::to_string(blockId) + ":"; + value += "T" + std::to_string(tid + warpId * threadsPerWarp) + ":" + + std::to_string(idx); + // Now also compute the thread mapping. + std::string threadInfo = "("; + for (int i = 0; i < outputs.size(); i++) { + if (i > 0) + threadInfo += ","; + threadInfo += + paddedString(outputs[i].second, tensorType.getDimSize(i)); + } + threadInfo += ")"; + threadMapping.push_back(threadInfo); + } + } + } + } + std::string layoutStr; + if (!useHWPointOfView) { + // Printing the threads containing each elements of the tensor. + int rank = tensorType.getRank(); + bool newLine = true; + for (int i = 0; i < tensorSize; i++) { + auto indices = delinearizeIndex(i, tensorType.getShape()); + int numOpenBracket = 0; + for (int j = rank - 1; j >= 0; j--) { + if (indices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "["; + numOpenBracket++; + } + if (newLine) { + for (int j = 0; j < rank - numOpenBracket; j++) + layoutStr += " "; + newLine = false; + } + + layoutStr += elementMapping[i]; + auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape()); + for (int j = rank - 1; j >= 0; j--) { + if (nextIndices[j] % tensorType.getDimSize(j) != 0) + break; + layoutStr += "]"; + } + if (nextIndices.back() % tensorType.getShape().back() == 0) { + layoutStr += "\n"; + newLine = true; + } else { + layoutStr += ", "; + } + } + } else { + // Printing the elements in each physical reg/warps/threads. + for (int blockId = 0; blockId < numBlocks; blockId++) { + if (numBlocks > 1) + layoutStr += "Block" + std::to_string(blockId) + ":\n"; + for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { + layoutStr += "Warp" + std::to_string(warpId) + ":\n"; + for (int idx = 0; idx < numElementsPerThreads; ++idx) { + for (int tid = 0; tid < threadsPerWarp; ++tid) { + int linearizedIdx = + blockId * numWarpsPerCTA * threadsPerWarp * + numElementsPerThreads + + warpId * threadsPerWarp * numElementsPerThreads + + tid * numElementsPerThreads + idx; + layoutStr += threadMapping[linearizedIdx]; + if (tid < threadsPerWarp - 1) + layoutStr += ", "; + } + layoutStr += "\n"; + } + } + } + } + return layoutStr; +} + +std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, + bool useHWPointOfView) { + auto layout = tensorType.getEncoding(); + + // tensorType is needed later on (e.g., getDimSize(j)), so we still have to + // pass it as a param + if (auto sharedLayout = mlir::dyn_cast(layout)) { + return getSharedLayoutStr(tensorType, useHWPointOfView); + } else if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return getDistributedLayoutStr(tensorType, useHWPointOfView); + } + + // else unimplemented, return error + llvm::report_fatal_error("Unimplemented usage of getLayoutStr"); + return ""; +} + +void mlir::triton::gpu::dumpLayout(RankedTensorType tensorType) { + llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/false); +} + +void mlir::triton::gpu::dumpHWLayout(RankedTensorType tensorType) { + llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/true); +} + void TritonGPUDialect::initialize() { registerTypes(); @@ -2966,9 +3441,6 @@ void TritonGPUDialect::initialize() { addInterfaces(); } -#define GET_OP_CLASSES -#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" - // verify TritonGPU ops LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index ae34598ae..56af4eaef 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -4,9 +4,11 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" @@ -16,16 +18,17 @@ namespace { // We use the following nomenclature in this file. // -// - ctaLayout: A layout for one block, i.e. input dims (register, lane, warp). -// - cgaLayout: Arrangement of multiple blocks, i.e. input dims (block). +// - ctaLayout: A layout for one block, i.e. input dims [register, lane, warp] +// for register layouts, and input dims [offset] for shared layouts. +// - cgaLayout: Arrangement of multiple blocks, i.e. input dims [block]. // // Note that this is inconsistent with the type name CTALayoutAttr. That type // is equivalent to our cgaLayout. // -// IMO the type name is wrong. If we tried to be consistent anyway, then we'd -// have to rename ctaLayout to "warpLayout". I think that's more confusing than -// being inconsistent about "cgaLayout", especially when we have to consider the -// size of the warpLayout (surely that's not the "warpSize"). +// IMO the name CTALayoutAttr is wrong. If we tried to be consistent anyway, +// then we'd have to rename ctaLayout to "warpLayout". I think that's more +// confusing than being inconsistent about "cgaLayout", especially when we have +// to consider the size of the warpLayout (surely that's not the "warpSize"). #define S(v) StringAttr::get(ctx, (v)) @@ -38,6 +41,25 @@ SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { return ret; } +void assertIsRegisterLayout(const LinearLayout &layout) { + assert(layout.getNumInDims() > 0); + MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + StringAttr kBlock = S("block"); + + const auto &ins = layout.getInDimNames(); + assert(llvm::SmallVector(ins.begin(), ins.end()) == + llvm::SmallVector({kRegister, kLane, kWarp, kBlock})); + + const auto &outs = layout.getOutDimNames(); + const auto &expectedOuts = standardOutDimNames(ctx, layout.getNumOutDims()); + assert(llvm::SmallDenseSet(outs.begin(), outs.end()) == + llvm::SmallDenseSet(expectedOuts.begin(), + expectedOuts.end())); +} + // Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of // size product(shape) and then reshaping to permute(shape, order). LinearLayout identityND(StringAttr inDimName, ArrayRef shape, @@ -84,124 +106,49 @@ LinearLayout makeCgaLayout(CTALayoutAttr layout) { return ret.transposeOuts(outDimNames); } -// Shrinks the output set of a layout function while leaving the input set -// unchanged, by making high-order inputs in inDimName map to the same output. -// Attempts to shrink down to desiredSize, but this is not always possible just -// by modifying one the specified input dimension. +// For each output dimension d, ensure that the layout's output size (i.e., its +// codomain) does not exceed shape[d]. Do this without changing the size of the +// layout's inputs (i.e., leave its domain unchanged). // -// We do this by making the most-major inputs to the layout map to 0. This -// effectively duplicates data along that input dimension. For example, this -// layout has out-dim size 32: +// This function is invariant to the order of the layout's input and output +// dimensions. +// +// We achieve this by setting the largest value in each output dimension d to 0 +// because bases that map to a location larger than shape[d] +// effectively duplicate along that dimension. For example, consider a layout +// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to +// shrink the output dimension size to 8: // // L(register=1) = 8 // L(register=2) = 4 // L(register=4) = 1 // L(lane=1) = 2 -// L(lane=2) = 16. +// L(lane=2) = 16 // -// If we shrink it to size 16 along the `lane` dimension, we set L(lane=2) to 0: +// In the first step, we shrink the output dimension size to 16 by setting +// L(lane=2) to 0: // // L(register=1) = 8 // L(register=2) = 4 // L(register=4) = 1 // L(lane=1) = 2 -// L(lane=2) = 0. +// L(lane=2) = 0 // // This means that lane=2 has the same data as lane=0. // -// If we shrink to size 8 along the lane dimension, we set L(lane=1) = 0 as -// well. But when we do this, we have to remove bit 1 (the value of L(lane=1)) -// from all other bases: -// -// L(register=1) = 4 -// L(register=2) = 2 -// L(register=1) = 1 -// L(lane=1) = 0 -// L(lane=2) = 0. +// Now the output dimension of this layout has a size of 16, which is still +// larger than 8. We find the current largest value in the output dimension, +// which is L(register=1) = 8, and we set L(register=1) to 0: // -// Note this only works because the bases are powers of two. I don't quite know -// what to do when they're not. -LinearLayout shrinkCodomain(const LinearLayout &layout, StringAttr inDimName, - StringAttr outDimName, int desiredSize) { - assert(llvm::isPowerOf2_32(desiredSize)); - int outDimIdx = layout.getOutDimIndex(outDimName); - int desiredZeros = - llvm::Log2_32(layout.getOutDimSize(outDimName) / desiredSize); - if (desiredZeros == 0) { - return layout; - } - - // Find the desiredZeros most-major basis vectors that are not already zero. - // These are the ones we will set to zero. - SmallVector basesToZero; - for (int i = layout.getInDimSizeLog2(inDimName) - 1; - i >= 0 && basesToZero.size() < desiredZeros; i--) { - int basis = layout.getBasis(inDimName, i, outDimName); - if (basis != 0) { - basesToZero.push_back(basis); - } - } - - // Bail if all the bases are already zero; nothing more we can do. - if (basesToZero.empty()) { - return layout; - } - - // The algorithm below only works because the bases are powers of two. I'm - // not sure what to do otherwise. - assert(llvm::all_of(basesToZero, - [&](int basis) { return llvm::isPowerOf2_32(basis); })); - - // We want to zero out the bases in `basesToZero`, and also "shift out" the - // corresponding bits from all other bases. For example if we remove the - // basis with value 8 = 0b100, then if another basis has value 26 = 0b11010, - // the 1 in its 3rd position gets removed and it becomes 10 = 0b1010. - // - // We could manually alter the bases in `layout` to achieve this, but it's - // perhaps simpler to use the linearity of LLs to our advantage. - // - // Consider the function O which is the identity map from out-dims to - // out-dims. We can easily calculate what happens when we remove the relevant - // bases from O. Call this new function O'. - // - // Because of linearity, removing the bases from L is equivalent to composing - // L with O'. So that's what we do below. - - // Construct the out-dims -> out-dims identity layout O. - LinearLayout outputIdentity = LinearLayout::empty(); - for (StringAttr dim : layout.getOutDimNames()) { - outputIdentity *= - LinearLayout::identity1D(layout.getOutDimSize(dim), dim, dim); - } - - // Modify O to remove the relevant bases. - // - // TODO(jlebar): I don't like manually modifying bases here. Perhaps this - // should be a function on LinearLayout. - LinearLayout::BasesT newBases = outputIdentity.getBases(); - llvm::sort(basesToZero); - for (int basis : basesToZero) { - int idx = llvm::Log2_32(basis); - for (int i = newBases[outDimName].size() - 1; i > idx; i--) { - newBases[outDimName][i][outDimIdx] = - newBases[outDimName][i - 1][outDimIdx]; - } - newBases[outDimName][idx][outDimIdx] = 0; - } - - // Construct O'. - LinearLayout transform(std::move(newBases), layout.getOutDimNames()); - - // Compose O' with L. - return layout.compose(transform); -} - -// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no -// larger than shape[d]. Do this without changing the size of the layout's -// inputs (i.e. leave its domain unchanged). +// L(register=1) = 0 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 // -// This function is invariant to the order of the layout's input and output -// dimensions. +// Now the output dimension of this layout has a size of 8, which is the desired +// size. Note that this method works only because the bases are powers of two. +// It is unclear what to do when they are not. LinearLayout ensureLayoutNotLargerThan( const LinearLayout &layout, const llvm::SmallDenseMap &shape) { @@ -211,41 +158,46 @@ LinearLayout ensureLayoutNotLargerThan( } MLIRContext *ctx = shape.begin()->first.getContext(); - // For the purposes of this function, "block" is the "most-minor" dimension. - // This is just a consequence of how legacy layouts work: We only put the same - // tensor element into two different blocks as a last resort, only after all - // the registers in all the lanes in all the warps in a block already have the - // same tensor element. - SmallVector inDimNames = { - S("block"), - S("register"), - S("lane"), - S("warp"), - }; - - LinearLayout ret = layout; - for (auto outDimName : layout.getOutDimNames()) { + auto bases = layout.getBases(); + for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { + auto outDimName = outDim.value(); int32_t actualSize = layout.getOutDimSize(outDimName); int32_t desiredSize = shape.lookup(outDimName); if (actualSize <= desiredSize) { continue; } assert(actualSize % desiredSize == 0); - // TODO: We claim this is invariant to the order of dims, so can we get rid - // of llvm::reverse? - for (StringAttr inDimName : llvm::reverse(inDimNames)) { - if (ret.hasInDim(inDimName)) { - ret = shrinkCodomain(ret, inDimName, outDimName, desiredSize); + // + std::vector> sortedBases; + for (auto [inDimName, basis] : bases) { + for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { + auto outValue = basis[basisIdx][outDim.index()]; + if (outValue == 0) { + continue; + } + assert(llvm::isPowerOf2_32(outValue)); + sortedBases.emplace_back(inDimName, basisIdx, outValue); } } - assert(ret.getOutDimSize(outDimName) == desiredSize); + // From the largest basis to the smallest. + llvm::sort(sortedBases, + [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); + for (auto [inDimName, basisIdx, outValue] : sortedBases) { + if (actualSize <= desiredSize) { + break; + } + bases[inDimName][basisIdx][outDim.index()] = 0; + actualSize >>= 1; + } } - return ret; + return LinearLayout(std::move(bases), + llvm::to_vector(layout.getOutDimNames())); } // For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no // smaller than shape[d]. Do this by increasing the size of the layout's inputs -// along the "register" dimension. +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). // // This function is invariant to the order of the layout's input dimensions, but // it cares about the order of the output dims, which should be minor-to-major. @@ -258,15 +210,15 @@ LinearLayout ensureLayoutNotSmallerThan( } MLIRContext *ctx = shape.begin()->first.getContext(); - StringAttr kRegister = S("register"); + StringAttr kDim = *layout.getInDimNames().begin(); + assert(kDim == "register" || kDim == "offset"); LinearLayout ret = layout; for (StringAttr outDimName : layout.getOutDimNames()) { int32_t actualSize = layout.getOutDimSize(outDimName); int32_t desiredSize = shape.lookup(outDimName); assert(actualSize > desiredSize || desiredSize % actualSize == 0); - ret *= LinearLayout::identity1D(desiredSize / actualSize, kRegister, - outDimName); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); assert(ret.getOutDimSize(outDimName) >= desiredSize); } return ret; @@ -295,12 +247,13 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, LinearLayout cgaLayout = ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape) - .transposeOuts(ctaLayout.getOutDimNames()); + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); // Calculate the shape of the ctaLayout, which is `shape` divided by the // cgaLayout's size. llvm::SmallDenseMap ctaShape; - assert(ctaLayout.getOutDimNames() == cgaLayout.getOutDimNames()); + assert(llvm::to_vector(ctaLayout.getOutDimNames()) == + llvm::to_vector(cgaLayout.getOutDimNames())); for (auto dim : ctaLayout.getOutDimNames()) { ctaShape[dim] = std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim)); @@ -316,24 +269,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, return ret; } -LinearLayout blockedToLinearLayout(ArrayRef shape, - BlockedEncodingAttr blocked) { - assert(shape.size() == blocked.getOrder().size()); - - int rank = shape.size(); - MLIRContext *ctx = blocked.getContext(); - SmallVector outDimNames = standardOutDimNames(ctx, rank); - - const auto &order = blocked.getOrder(); - LinearLayout ctaLayout = - identityND(S("register"), blocked.getSizePerThread(), order, - outDimNames) * - identityND(S("lane"), blocked.getThreadsPerWarp(), order, outDimNames) * - identityND(S("warp"), blocked.getWarpsPerCTA(), order, outDimNames); - - return combineCtaCgaWithShape(ctaLayout, blocked.getCTALayout(), shape); -} - LinearLayout ampereMmaToLinearLayout(ArrayRef shape, NvidiaMmaEncodingAttr mma) { int rank = shape.size(); @@ -374,7 +309,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, int n = mma.getInstrShape()[1]; int k = mma.getInstrShape()[2]; assert(m == 16); - assert(n == 16 || n == 32 || n == 64 || n == 128 || n == 256); + assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256); assert(k == 8 || k == 16 || k == 32); MLIRContext *ctx = mma.getContext(); @@ -393,25 +328,466 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, // this really does seem to be correct. ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}, {S("dim0"), S("dim1")}) - .transposeOuts(ctaLayout.getOutDimNames()); + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); } -std::optional toLinearLayout(ArrayRef shape, - SliceEncodingAttr slice) { - MLIRContext *ctx = slice.getContext(); +LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef shape, + SharedEncodingAttr shared) { + assert(!shared.getHasLeadingOffset()); + + MLIRContext *ctx = shared.getContext(); + int rank = shape.size(); + if (rank == 1) { + return combineCtaCgaWithShape( + LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), + shared.getCTALayout(), shape); + } + + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for the 2 most minor dimensions of the layout. These are + // the dims that get swizzled. + assert(shape.size() >= 2); + int colDim = shared.getOrder()[0]; + int rowDim = shared.getOrder()[1]; + int numCols = shape[colDim]; + int numRows = shape[rowDim]; + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int logCol = 0; logCol < llvm::Log2_32(numCols); logCol++) { + bases2D.push_back({0, 1 << logCol}); + } + for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) { + int row = 1 << logRow; + int vec = shared.getVec(); + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + bases2D.push_back({row, (vec * ((row / perPhase) % maxPhase)) % numCols}); + } + LinearLayout ctaLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int i = 2; i < rank; i++) { + int dim = shared.getOrder()[i]; + ctaLayout *= + LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); + } + + return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape); +} + +LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, + SharedEncodingAttr shared, + int32_t elemBitWidth) { + assert(shared.getHasLeadingOffset()); + + MLIRContext *ctx = shared.getContext(); + int rank = shape.size(); + if (rank == 1) { + // TODO: Not sure if this is correct. + return combineCtaCgaWithShape( + LinearLayout::identity1D(shape[0], S("offset"), S("dim0")), + shared.getCTALayout(), shape); + } + + int tileWidthBytes; + if (shared.getPerPhase() == 4 && shared.getMaxPhase() == 2) { + tileWidthBytes = 32; + } else if (shared.getPerPhase() == 2 && shared.getMaxPhase() == 4) { + tileWidthBytes = 64; + } else if (shared.getPerPhase() == 1 && shared.getMaxPhase() == 8) { + tileWidthBytes = 128; + } else { + llvm::errs() + << "Illegal shared encoding. If hasLeadingOffset is true, " + "then (perPhase, maxPhase) must be either (4,2), (2,4), or (1,8): " + << shared << "\n"; + llvm_unreachable("Illegal shared encoding"); + } + + auto outDimNames = standardOutDimNames(ctx, rank); + + // Construct bases for a the layout's 2-dimensional tile. + assert(shape.size() >= 2); + int colDim = shared.getOrder()[0]; + int rowDim = shared.getOrder()[1]; + + int tileRows = 8; + int tileCols = 8 * tileWidthBytes / elemBitWidth; + + if (shape[colDim] < tileCols || shape[rowDim] < tileRows) { + llvm::errs() << "Illegal shared layout; expected shape to be at least [" + << tileRows << ", " << tileCols << "], shape: [" + << shape[rowDim] << ", " << shape[colDim] << "]\n"; + llvm::report_fatal_error("Illegal shared layout"); + } + + int vec = 8 * 16 / elemBitWidth; + if (vec != shared.getVec()) { + llvm::errs() << "Illegal shared layout; expected `vec` to be " << vec + << ": " << shared << "\n"; + llvm::report_fatal_error("Illegal shared layout"); + } + + StringAttr colDimName = outDimNames[colDim]; + StringAttr rowDimName = outDimNames[rowDim]; + + std::vector> bases2D; + for (int logCol = 0; logCol < llvm::Log2_32(tileCols); logCol++) { + bases2D.push_back({0, 1 << logCol}); + } + for (int logRow = 0; logRow < llvm::Log2_32(tileRows); logRow++) { + int row = 1 << logRow; + int perPhase = shared.getPerPhase(); + int maxPhase = shared.getMaxPhase(); + bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)}); + } + LinearLayout tileLayout = + LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName}); + + // Add the remaining dimensions. + for (int i = 2; i < rank; i++) { + int dim = shared.getOrder()[i]; + tileLayout *= + LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]); + } + + return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape); +} + +} // anonymous namespace + +std::optional +AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + int rank = shape.size(); + assert(rank == getWarpsPerCTA().size()); + + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + (void)mIndex, (void)nIndex; + + assert(((getMDim() == 32 && getNDim() == 32) || + (getMDim() == 16 && getNDim() == 16)) && + "Unsupported mfma type"); + + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + + // https://github.com/ROCm/amd_matrix_instruction_calculator can print the + // register and lane layout for mfma instructions. + + // We use the order from fastest varying to slowest varying. So each base + // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. + SmallVector order = triton::gpu::getOrder(*this); + auto tileLayout = LinearLayout::empty(); + + if (getMDim() == 32) { + // For mfma with 32x32 output, each of the 64 threads holds 16 elements. + // + // For the register (i.e., element) dimension, these 16 elements are along + // the matrix C's M dimension, with 4 consecutive elements spanning 4 rows + // and then the next 4 rows being a gap. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 32 consecutive threads covering a whole + // row and the next 32 threads start after a gap spanning 4 rows. + tileLayout = LinearLayout( + {{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + } else { + assert(getMDim() == 16); + // For mfma with 16x16 output, each of the 64 threads holds 4 elements. + // + // For the register (i.e., element) dimension, these 4 elements are along + // the matrix C's M dimension, with 4 consecutive elements spanning 4 rows. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 16 consecutive threads covering a whole + // row and the next 16 threads start after a gap spanning 4 rows. + tileLayout = LinearLayout( + {{kRegister, {{0, 1}, {0, 2}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + } + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accomodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + // And each warp takes the same register and lane sub-layout. So mulitply with + // an identity layout for the warp. + LinearLayout warpLayout = + identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + LinearLayout ctaLayout = tileLayout * warpLayout; + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +std::optional +dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { + + // Current linear layout conversion for dot operand is only necessary to + // enable LDS bypass for operand B in the MFMA dot path. To achieve + // performance gains from bypassing LDS, the following conditions must be met: + // + // 1) opIdx == 1: Currently, only the B tensor (e.g. weights in moe-like + // kernels) bypasses LDS. This constraint is not strict and support for + // bypassing operand A (e.g. Q tensor in flash attention) will be added in + // the future. + // + // 2) B tensor must be column major: This is required to support vectorized + // global load instructions, as MFMA instructions expect threads to hold B + // operand elements along the K dimension. + // + // 3) kWidth == 8: Ensures maximum global load vectorization for fp16 + // operations. + // TODO: Generalize conversion to handle maximum kWidth for other types + // (i.e. fp8). + // + // 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is + // held by exactly one thread, maintaining the same number of global loads + // as in a blocked layout. + // + // Other use of Linear layout is a support of rare corner cases, + // for example one instruction tile is larger than tensor + auto mfmaLayout = llvm::cast(dotMfmaLayout.getParent()); + + auto rank = shape.size(); + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + + int32_t kWidth = dotMfmaLayout.getKWidth(); + auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2; + int32_t kSize = shape[kDim]; + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + + MLIRContext *ctx = dotMfmaLayout.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + + // register order + // operand A: [1, 0] / [2, 1, 0] + // operand B: [0, 1] / [1, 2, 0] + // for both cases it is [k, nonk]/[k, nonk, batch] + SmallVector order = triton::gpu::getOrder(dotMfmaLayout); + // warp order + // common for both operand A and B: [0, 1] / [0, 1, 2] + // in both cases it is [M dim, N dim]/[batch, M dim, N dim] + SmallVector warpOrder = triton::gpu::getWarpOrder(dotMfmaLayout); + + // Lane holds kWidth consecutive elements along k dimension, so + // base register vectors for one tile are initialized in following way: + // {1, 0}, {2, 0} ... {kWidth/2, 0} + std::vector> registerBase; + for (int32_t elem = 1; elem < kWidth; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + std::vector> laneBase; + int32_t kTileSize = -1; + + if (mfmaLayout.getMDim() == 32) { + // Canonical MFMA linear layout handles 4 consecutive elements along + // the register dimension. Dot operand handles varaible kWidth consecutive + // elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2, + // 32}, this means that mapping of first 5 base (up to thread 16) vectors + // will be an identity along N dim. Thread 32 will be mapped to element + // kWidth in K dimension. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {kWidth, 0}}; + kTileSize = kWidth * 2; + } else { + assert(mfmaLayout.getMDim() == 16); + // For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this + // means that mapping of first 4 base (up to thread 16) vectors will be an + // identity along N dim. Thread 16 will be mapped to element kWisth in K + // dimension. Thread 32 is mapped to element 2*kWidth in K dim. + laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {kWidth, 0}, {kWidth * 2, 0}}; + kTileSize = kWidth * 4; + } + assert(kTileSize != -1); + // Add repeats of registers along K dimension to register base vectors + for (int32_t elem = kTileSize; elem < kSize; elem *= 2) + registerBase.emplace_back(std::vector{elem, 0}); + + // Base vectors above are defined in a fixed order [non-k-dim, k-dim]. + // To assign them to actual matrix dimensions `order` array is used. + // For operand A: non-k-dim -> dim0, k-dim -> dim1 + // For operand B: non-k-dim -> dim1, k-dim -> dim0 + LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accomodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + LinearLayout warpLayout = + identityND(kWarp, warpsPerCTA, warpOrder, outDimNames); + + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); + + return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); +} + +std::optional +AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + int rank = shape.size(); + assert(rank == getWarpsPerCTA().size()); + + bool hasBatchDim = rank == 3; + int mIndex = 0 + hasBatchDim; + int nIndex = 1 + hasBatchDim; + (void)mIndex, (void)nIndex; + + SmallVector mnkDim = getMNKDimPerInstr(); + unsigned mDim = mnkDim[0], nDim = mnkDim[1]; + (void)mDim, (void)nDim; + + assert(((shape[mIndex] == 1 || shape[mIndex] >= mDim) && + (shape[nIndex] == 1 || shape[nIndex] >= nDim)) && + "Unsupported tensor shape for given wmma layout"); + + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + + // https://github.com/ROCm/amd_matrix_instruction_calculator can print the + // register and lane layout for mfma instructions. + + // We use the order from fastest varying to slowest varying. So each base + // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices. + SmallVector order = triton::gpu::getOrder(*this); + + // For wmma with 16x16 output, each of the 32 threads holds 8 elements. + // + // The first version of WMMA layout has following specific: + // for the register (i.e., element) dimension, these 8 elements are + // along the matrix C's M dimension, with 1 consecutive elements + // spanning 1 row and then the next 1 row being a gap. + // + // For the lane (i.e., thread) dimension, these threads are along the + // matrix C's N dimension, with 16 consecutive threads covering a whole + // row and the next 16 threads start at the next row. + // + // The second version of wmma layout is less tricky: + // for the register dimension 8 elements are along the matrix C's M + // dimension. First 16 lanes take 0-8 elems along M, second 16 take 8-15. + // We have 16 pair of threads in each warp, one pair covers the whole + // column. + // + // Please also check explaining comments in TritonGPUAttrDefs.td at the + // AMDWmmaEncodingAttr section. + unsigned ver = getVersion(); + assert(ver == 1 || ver == 2); + LinearLayout tileLayout = + ver == 1 + ? LinearLayout( + {{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}) + : LinearLayout( + {{kRegister, {{0, 1}, {0, 2}, {0, 4}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 8}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); + + if (hasBatchDim) { + assert(order[2] == 0); + // Extend the base vector with one value to accomodate for the batch + // dimension, which appears at the last. + tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]); + tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); + } + + // And each warp takes the same register and lane sub-layout. So mulitply with + // an identity layout for the warp. + LinearLayout warpLayout = + identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + LinearLayout ctaLayout = tileLayout * warpLayout; + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +std::optional +BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { + assert(shape.size() == getOrder().size()); + + int rank = shape.size(); + MLIRContext *ctx = getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + const auto &order = getOrder(); + LinearLayout ctaLayout = + identityND(S("register"), getSizePerThread(), order, outDimNames) * + identityND(S("lane"), getThreadsPerWarp(), order, outDimNames) * + identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + + return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); +} + +std::optional +NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { + if (isAmpere()) { + return ampereMmaToLinearLayout(shape, *this); + } + if (isHopper()) { + return hopperMmaToLinearLayout(shape, *this); + } + return std::nullopt; +} + +std::optional +SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { + MLIRContext *ctx = getContext(); // First compute the linear layout for this layout's parent. SmallVector parentShape(shape); - parentShape.insert(parentShape.begin() + slice.getDim(), 1); + parentShape.insert(parentShape.begin() + getDim(), 1); std::optional parentLL = - triton::gpu::toLinearLayout(parentShape, slice.getParent()); - if (!parentLL) { - return std::nullopt; + triton::gpu::toLinearLayout(parentShape, getParent()); + if (!parentLL.has_value()) { + if (mlir::isa(getParent())) + return std::nullopt; + llvm::report_fatal_error( + "Failed to compute parent layout for slice layout."); } - // Remove dimension slice.getDim() from the parent layout. + // Remove dimension getDim() from the parent layout. // // 1. Construct a layout `transform` from parent-out-dims to slice-out-dims // that removes the relevant out-dim. @@ -421,15 +797,15 @@ std::optional toLinearLayout(ArrayRef shape, auto outDimNames = standardOutDimNames(ctx, shape.size() + 1); LinearLayout transform = LinearLayout::empty(); for (auto [idx, outDim] : llvm::enumerate(parentLL->getOutDimNames())) { - if (idx == slice.getDim()) { + if (idx == getDim()) { // Because we're multiplying by all zeros, we could replace outDimNames[0] // with any other valid out-dim; the layout will be the same. transform *= LinearLayout::zeros1D(parentLL->getOutDimSize(outDim), outDim, outDimNames[0]); } else { - transform *= LinearLayout::identity1D( - parentLL->getOutDimSize(outDim), outDim, - outDimNames[idx - (idx < slice.getDim() ? 0 : 1)]); + transform *= + LinearLayout::identity1D(parentLL->getOutDimSize(outDim), outDim, + outDimNames[idx - (idx < getDim() ? 0 : 1)]); } } LinearLayout sliceLL = parentLL->compose(transform); @@ -444,7 +820,8 @@ std::optional toLinearLayout(ArrayRef shape, } bases[S("register")] = newRegBases; - LinearLayout ret = LinearLayout(std::move(bases), sliceLL.getOutDimNames()); + LinearLayout ret = + LinearLayout(std::move(bases), llvm::to_vector(sliceLL.getOutDimNames())); // Match a hack in the legacy code that ensures that the number of registers // matches getTotalElemsPerThread. Yup: We just removed all the zeros, now @@ -452,8 +829,9 @@ std::optional toLinearLayout(ArrayRef shape, // // TODO(jlebar): Once getTotalElemsPerThread uses LLs instead of the existing // legacy code, I think we can remove this. - int expectedNumRegisters = getTotalElemsPerThread(RankedTensorType::get( - shape, IntegerType::get(ctx, 32) /*dummy type*/, slice)); + int expectedNumRegisters = + triton::gpu::getTotalElemsPerThread(RankedTensorType::get( + shape, IntegerType::get(ctx, 32) /*dummy type*/, *this)); if (ret.getInDimSize(S("register")) != expectedNumRegisters) { int extraZeros = expectedNumRegisters / ret.getInDimSize(S("register")); // Our use of "dim0" here is arbitrary; because we're adding zeros, any @@ -463,27 +841,313 @@ std::optional toLinearLayout(ArrayRef shape, return ret; } -} // anonymous namespace +LinearLayout ampereDotToLinearLayout(ArrayRef shape, + DotOperandEncodingAttr dot) { + // TODO,BE. Implement ampereMMA in terms of this one + int rank = shape.size(); + auto mma = cast(dot.getParent()); + int kWidth = dot.getKWidth(); + bool isA = dot.getOpIdx() == 0; + + assert(mma.isAmpere()); + assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || + (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); + + MLIRContext *ctx = mma.getContext(); + SmallVector dimNames = standardOutDimNames(ctx, rank); -std::optional toLinearLayout(ArrayRef shape, - Attribute layout) { - if (auto blocked = dyn_cast(layout)) { - return blockedToLinearLayout(shape, blocked); + // Implement A. For B transpose in the end + std::vector> registers; + std::vector> lanes; + int32_t i = 1; + // kWidth contiguous elements + while (i < kWidth) { + registers.push_back({i, 0}); + i *= 2; } - if (auto mma = dyn_cast(layout)) { - if (mma.isAmpere()) { - return ampereMmaToLinearLayout(shape, mma); + // 4 threads per chunk + for (int j = 0; j < 2; j++) { + lanes.push_back({i, 0}); + i *= 2; + } + // 8 threads going down + lanes.push_back({0, 1}); + lanes.push_back({0, 2}); + lanes.push_back({0, 4}); + // 2 tiles in column-major order + // Just one if it's the B operand + if (isA) { + registers.push_back({0, 8}); + } + registers.push_back({i, 0}); + + if (!isA) { + for (auto &r : registers) { + std::swap(r[0], r[1]); } - if (mma.isHopper()) { - return hopperMmaToLinearLayout(shape, mma); + for (auto &l : lanes) { + std::swap(l[0], l[1]); } } - if (auto slice = dyn_cast(layout)) { - return toLinearLayout(shape, slice); + + LinearLayout ctaLayout( + {{S("register"), registers}, {S("lane"), lanes}}, + llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); + + auto order = dot.getCTAOrder(); + assert(order[0] == 1 && order[1] == 0); + ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); + + return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); +} + +std::optional +DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { + if (auto mfmaLayout = llvm::dyn_cast(getParent())) { + return dotOperandMfmaToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(getParent())) { + // FIXME [Dot LL] + // Do this unconditionally + auto largeKWidth = getKWidth() == 8; + if (mma.isAmpere() && largeKWidth) { + return ampereDotToLinearLayout(shape, *this); + } + } + return std::nullopt; +} + +std::optional +toLinearLayout(ArrayRef shape, Attribute layout, + std::optional elemBitWidth /*= std::nullopt*/) { + if (auto distributed = dyn_cast(layout)) { + return distributed.toLinearLayout(shape); + } + if (auto shared = dyn_cast(layout)) { + if (shared.getHasLeadingOffset()) { + assert(elemBitWidth.has_value()); + return sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); + } else { + return sharedToLinearLayoutNoLeadingOffset(shape, shared); + } } // TODO(jlebar): Other layouts return std::nullopt; } +LinearLayout getLayoutWithinBlock(const LinearLayout &layout) { + assert(!layout.getInDimNames().empty()); + MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); + + StringAttr kBlock = S("block"); + assert(layout.hasInDim(kBlock)); + auto bases = layout.getBases(); + bases[kBlock] = {}; + return LinearLayout(bases, llvm::to_vector<4>(layout.getOutDimNames())); +} + +LinearLayout chooseShemLayoutForRegToRegConversion( + MLIRContext *ctx, ArrayRef tensorShape, + ArrayRef repShape, ArrayRef order) { + auto outDimNames = standardOutDimNames(ctx, tensorShape.size()); + LinearLayout layout = LinearLayout::empty(); + SmallVector kRepDims; + SmallVector kOffsetDims; + auto totalIters = 1; + auto totalOffsets = 1; + for (int i = 0; i < tensorShape.size(); i++) { + int dim = order[i]; + StringAttr kIteration = S("iteration" + std::to_string(dim)); + StringAttr kOffset = S("offset" + std::to_string(dim)); + kRepDims.push_back(kIteration); + kOffsetDims.push_back(kOffset); + assert(llvm::isPowerOf2_32(repShape[dim])); + assert(llvm::isPowerOf2_32(tensorShape[dim])); + auto numIters = tensorShape[dim] / repShape[dim]; + layout *= + LinearLayout::identity1D(repShape[dim], kOffset, outDimNames[dim]); + layout *= LinearLayout::identity1D(numIters, kIteration, outDimNames[dim]); + totalIters *= numIters; + totalOffsets *= repShape[dim]; + } + StringAttr kOffset = S("offset"); + StringAttr kIteration = S("iteration"); + StringAttr kBlock = S("block"); + SmallVector newDims; + newDims.append(kOffsetDims.begin(), kOffsetDims.end()); + newDims.append(kRepDims.begin(), kRepDims.end()); + // Transpose layout from [offset0, rep0, offset1, rep1, ...] to + // [offset0, offset1, ..., rep0, rep1, ...] + auto ret = layout.transposeIns(newDims); + // Reshape layout from [offset0, offset1, ..., rep0, rep1, ...] to + // [offset, rep, block] + return ret.reshapeIns( + {{kOffset, totalOffsets}, {kIteration, totalIters}, {kBlock, 1}}); +} + +namespace { + +// TODO (Keren): Currently, we have more restrictions than necessary when using +// stmatrix. These restrictions are retained from legacy code, and we could +// relax some of them in the future. +bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, ArrayRef order, + int swizzleByteSize) { + auto mmaLayout = + mlir::dyn_cast(tensorTy.getEncoding()); + if (!mmaLayout || !mmaLayout.isHopper()) + return false; + if (isa(tensorTy.getElementType())) + return false; + if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) + return false; + if (order[0] != 1) + return false; + + auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape()); + if (tensorShapePerCTA.size() != 2) + return false; + auto numIterations = ceil(tensorShapePerCTA[1], repShape[1]) * + ceil(tensorShapePerCTA[0], repShape[0]); + if (numIterations > 1) + return false; + if (paddedRepShape[1] % 8 != 0) + return false; + if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 && + swizzleByteSize != 128) + return false; + return true; +} + +std::optional chooseStMatrixLayoutLeadingOffset( + MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, ArrayRef order, + int swizzleByteSize) { + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + StringAttr kCol = S("dim1"); + StringAttr kRow = S("dim0"); + StringAttr kOffset = S("offset"); + + int perPhase; + int maxPhase; + if (swizzleByteSize == 32) { + perPhase = 4; + maxPhase = 2; + } else if (swizzleByteSize == 64) { + perPhase = 2; + maxPhase = 4; + } else if (swizzleByteSize == 128) { + perPhase = 1; + maxPhase = 8; + } else { + llvm::errs() << "Illegal swizzleByteSize: " << swizzleByteSize << "\n"; + llvm::report_fatal_error("Illegal swizzleByteSize"); + } + + // stmatrix only supports 16-bit elements, and each vector has 8 elements + int elemBitWidth = 16; + int vecSize = 8; + int numRows = 16; + int numCols = 8 * swizzleByteSize / elemBitWidth; + + // Construct a single stmatrix.x4 (16x16) tile + std::vector> basesReg = {{1, 0}, {2, 0}, {4, 0}}; + std::vector> basesLane; + for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) { + int row = 1 << logRow; + basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row}); + } + basesLane.push_back({8, 0}); + + // Expand the tile's register dimension to fit swizzleByteSize, which is a + // "chunk" + for (int logChunk = 0; logChunk < llvm::Log2_32(numCols / 16); logChunk++) { + int chunk = 1 << logChunk; + basesReg.push_back({16 * chunk, 0}); + } + + // Construct the layout for a single chunk + LinearLayout layout = + LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow}); + + // Expand the `warp` dimension according to warpsPerCTA. + auto mma = cast(tensorTy.getEncoding()); + layout *= + identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol}) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + + // Expand the `register` dimension so the size of columns matches `n`. + int n = mma.getInstrShape()[1]; + int numWarpRows = layout.getOutDimSize(kRow); + layout = (layout.reshapeOuts({{kOffset, layout.getTotalOutDimSize()}}) * + LinearLayout::identity1D(n / numCols, kReg, kOffset)) + .reshapeOuts({{kCol, n}, {kRow, numWarpRows}}); + + auto ret = + combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape()); + return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames())) + .reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}}); +} + +std::optional chooseStMatrixLayoutNoLeadingOffset( + MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, ArrayRef order) { + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + StringAttr kCol = S("dim1"); + StringAttr kRow = S("dim0"); + StringAttr kBlock = S("block"); + + std::vector> basesReg = {{1, 0}, {2, 0}, {4, 0}}; + std::vector> basesLane = { + {0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}}; + LinearLayout layout = + LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow}); + + // Expand the `register` dimension so the size of columns matches `n`. + auto mma = cast(tensorTy.getEncoding()); + int n = mma.getInstrShape()[1]; + layout *= + LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol); + + // Expand the `warp` dimension according to warpsPerCTA. + layout *= + identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol}) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + auto ret = + combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape()); + auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape()); + llvm::SmallDenseMap namedTensorShape; + namedTensorShape[kRow] = tensorShapePerCTA[0]; + namedTensorShape[kCol] = tensorShapePerCTA[1]; + ret = ensureLayoutNotSmallerThan(ret, namedTensorShape); + ret = ensureLayoutNotLargerThan(ret, namedTensorShape); + return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames())) + .reshapeOuts({{S("offset"), ret.getTotalOutDimSize()}, + {S("iteration"), 1}}) * + identityND(kBlock, {1, 1}, {0, 1}, {S("offset"), S("iteration")}); +} + +} // anonymous namespace + +std::optional +chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, int swizzleByteSize) { + if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order, + swizzleByteSize)) + return std::nullopt; + + if (swizzleByteSize == 0) + return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape, + paddedRepShape, order); + else + return chooseStMatrixLayoutLeadingOffset( + ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize); +} + } // namespace mlir::triton::gpu diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp new file mode 100644 index 000000000..e61fe096e --- /dev/null +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -0,0 +1,108 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" +#include "llvm/Support/raw_ostream.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" + +namespace mlir::triton::gpu { + +LogicalResult UpcastMXFPOp::verify() { + auto fpType = getFpType(); + + auto xTy = getSrc().getType(); + auto scaleTy = getScale().getType(); + + if (xTy.getElementType() != FloatType::getBF16(getContext()) && + xTy.getElementType() != IntegerType::get(getContext(), 8)) { + return emitOpError("element type of the first operand must be bf16 or i8"); + } + + if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) { + return emitOpError("element type of the second operand must be uint8"); + } + + auto xShape = xTy.getShape(); + auto scaleShape = scaleTy.getShape(); + + if (xShape.size() != scaleShape.size() || xShape.size() < 2) { + return emitOpError( + "operands must have the same number of dimensions, at least 2"); + } + + if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 || + fpType == F8F6F4Type::E5M2)) { + return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2"); + } + + // Change to support fp8 types + const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1; + + if (xShape.back() != (32 / elems_packed) * scaleShape.back()) { + return emitOpError("last dimension of first operand must be 16 times " + "larger than that of the second operand"); + } + + if (!std::equal(xShape.begin(), xShape.end() - 1, scaleShape.begin())) { + return emitOpError( + "all dimensions except the last must match between operands"); + } + + auto layoutX = xTy.getEncoding(); + if (!layoutX || !isa(layoutX)) { + return emitOpError("Expected a DotOperandEncodingAttr for values"); + } + auto layoutScale = scaleTy.getEncoding(); + if (!layoutScale || !isa(layoutScale)) { + return emitOpError("Expected a BlockOperandEncoding for scales"); + } + auto blockedScale = cast(layoutScale); + + // Necessary to keep all of the scales of a given block of values in the same + // warp + auto threadsPerWarp = blockedScale.getThreadsPerWarp(); + if (threadsPerWarp != ArrayRef({16, 2})) { + return emitOpError("Expected threads per warp to be {16, 2}"); + } + + return success(); +} + +LogicalResult UpcastMXFPOp::inferReturnTypes( + MLIRContext *ctx, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties opaqueProperties, + RegionRange regions, SmallVectorImpl &inferredReturnTypes) { + auto xTy = cast(operands[0].getType()); + auto properties = opaqueProperties.as(); + auto typeEncoded = properties->fp_type.getValue(); + auto xShape = xTy.getShape(); + + auto encoding = xTy.getEncoding(); + if (!encoding) { + return emitOptionalError(loc, "expected an encoding"); + } + if (!mlir::isa(encoding)) { + return emitOptionalError(loc, "expected a dotOperand encoding"); + } + + if (typeEncoded == F8F6F4Type::E2M1) { + auto oldEncoding = cast(encoding); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), + oldEncoding.getKWidth() * 2); + auto newShape = SmallVector(xShape); + newShape.back() *= 2; + inferredReturnTypes.push_back( + RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding)); + } else { + inferredReturnTypes.push_back(xTy); + } + + return success(); +} + +} // namespace mlir::triton::gpu diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index df84c4e62..e26118bdf 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -1,14 +1,19 @@ -#include - +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" namespace mlir { @@ -19,23 +24,23 @@ namespace { // Get the highest version supported for the hardware and the dot. static int getMMAVersionSafe(int computeCapability, DotOp op) { - int baseVersion = 0; + // List supported mma version in order of preference. + SmallVector versionsSupported; if (computeCapability < 75) { - baseVersion = 1; + versionsSupported = {1}; } else if (computeCapability < 90) { - baseVersion = 2; + versionsSupported = {2}; } else if (computeCapability < 100) { - baseVersion = 3; + versionsSupported = {3, 2}; } else { assert(false && "computeCapability not supported"); } - - for (; baseVersion >= 1; baseVersion--) { - if (supportMMA(op, baseVersion)) { + for (int baseVersion : versionsSupported) { + if (supportMMA(op, baseVersion)) return baseVersion; - } + if (baseVersion == 3) + op.emitRemark() << "Warning: can't use MMA V3 for the dot op"; } - return 0; } @@ -103,8 +108,12 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, const SmallVector &instrShape) { SetVector slices; mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != - slices.end()) + // Contains a chained dot. We prefer to assign warps to one axis + // to facilitate use cases like flash attention, allowing reductions within + // the same warp. + if (llvm::find_if(slices, [](Operation *op) { + return op->hasTrait(); + }) != slices.end()) return {(unsigned)numWarps, 1}; // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). @@ -122,7 +131,41 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, return ret; } -class BlockedToMMA : public mlir::RewritePattern { +// Returns a shared memory allocation that can be used by a dotMMA op for the +// given value. +static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, + int opIdx, bool allowTranspose) { + OpBuilder::InsertionGuard g(rewriter); + Value arg = v; + if (auto cvtOp = v.getDefiningOp()) + arg = cvtOp.getSrc(); + auto argType = cast(arg.getType()); + assert(argType.getEncoding() && "unexpected tensor type"); + auto newOrder = getOrder(argType.getEncoding()); + + // If the MMA op doesn't support transpose pick the layout expected by the MMA + // op. + if (!allowTranspose) { + if (opIdx == 1) { + newOrder = {0, 1}; + } else { + newOrder = {1, 0}; + } + } + + Attribute SharedMemorySpace = + SharedMemorySpaceAttr::get(argType.getContext()); + auto CTALayout = getCTALayout(argType.getEncoding()); + auto newLayout = + SharedEncodingAttr::get(argType.getContext(), argType.getShape(), + newOrder, CTALayout, argType.getElementType()); + auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), + newLayout, SharedMemorySpace); + rewriter.setInsertionPointAfterValue(arg); + return rewriter.create(arg.getLoc(), newType, arg); +} + +class BlockedToMMA : public mlir::OpRewritePattern { int computeCapability; mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding mutable llvm::DenseMap dotOpInstNs; @@ -170,8 +213,8 @@ class BlockedToMMA : public mlir::RewritePattern { public: BlockedToMMA(mlir::MLIRContext *context, int computeCapability) - : mlir::RewritePattern(DotOp::getOperationName(), 2, context), - computeCapability(computeCapability) {} + : OpRewritePattern(context), computeCapability(computeCapability) { + } static SmallVector getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, @@ -187,44 +230,11 @@ class BlockedToMMA : public mlir::RewritePattern { } } - static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter, - int opIdx) { - OpBuilder::InsertionGuard g(rewriter); - Value arg = v; - if (auto cvtOp = v.getDefiningOp()) - arg = cvtOp.getSrc(); - auto argType = cast(arg.getType()); - auto eltType = argType.getElementType(); - assert(argType.getEncoding() && "unexpected tensor type"); - auto newOrder = getOrder(argType.getEncoding()); - - // MMAv3 with transpose only supports f16 and bf16 data type - // fallback to MMAv3 without transpose for other data types - if (!eltType.isF16() && !eltType.isBF16()) { - if (opIdx == 1) { - newOrder = {0, 1}; - } else { - newOrder = {1, 0}; - } - } - - auto CTALayout = getCTALayout(argType.getEncoding()); - auto newLayout = - SharedEncodingAttr::get(argType.getContext(), argType.getShape(), - newOrder, CTALayout, argType.getElementType()); - auto newType = MemDescType::get(argType.getShape(), - argType.getElementType(), newLayout); - rewriter.setInsertionPointAfterValue(arg); - return rewriter.create(arg.getLoc(), newType, arg); - } - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(triton::DotOp dotOp, mlir::PatternRewriter &rewriter) const override { if (computeCapability < 70) return failure(); - auto dotOp = cast(op); - auto ctx = op->getContext(); // TODO: Check data-types and SM compatibility RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || @@ -233,16 +243,17 @@ class BlockedToMMA : public mlir::RewritePattern { // get MMA encoding for the given number of warps auto retShapePerCTA = getShapePerCTA(oldRetType); - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = TritonGPUDialect::getNumWarps(mod); auto CTALayout = getCTALayout(oldRetType.getEncoding()); int versionMajor = getMMAVersionSafe(computeCapability, dotOp); - if (!versionMajor) + if (!(versionMajor >= 1 && versionMajor <= 3)) return failure(); - auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, - dotOp.getA().getType(), numWarps); + auto instrShape = mmaVersionToInstrShape( + versionMajor, retShapePerCTA, dotOp.getA().getType().getElementType(), + numWarps); // operands Value a = dotOp.getA(); Value b = dotOp.getB(); @@ -281,7 +292,8 @@ class BlockedToMMA : public mlir::RewritePattern { oldRetType.getContext(), versionMajor, numWarps, CTALayout, instrShape, oldAType.getShape(), oldBType.getShape(), retShapePerCTA, isARow, isBRow, mmaV1Counter++); - } else if (versionMajor == 2 || versionMajor == 3) { + } else { + assert(versionMajor == 2 || versionMajor == 3); int versionMinor = computeCapability == 75 ? 1 : 0; auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, numWarps, instrShape); @@ -289,6 +301,7 @@ class BlockedToMMA : public mlir::RewritePattern { versionMinor, warpsPerTile, CTALayout, instrShape); } + PatternRewriterWithAsyncTaskIds taskIdRewriter(rewriter, dotOp); auto newRetType = RankedTensorType::get( oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); // convert accumulator @@ -296,15 +309,21 @@ class BlockedToMMA : public mlir::RewritePattern { auto newAcc = rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); + Operation *newDot = nullptr; if (versionMajor == 3) { - a = getMMAv3Operand(a, rewriter, 0); - b = getMMAv3Operand(b, rewriter, 1); + auto eltType = dotOp.getA().getType().getElementType(); + // In MMAV3 tranpose is only supported for f16 and bf16. + bool allowTranspose = eltType.isF16() || eltType.isBF16(); + a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); + b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); + newDot = taskIdRewriter.create( + dotOp.getLoc(), newRetType, a, b, newAcc, nullptr, + dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false); } else { - // convert operands int minBitwidth = std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); - Type minType = IntegerType::get(ctx, minBitwidth); + Type minType = rewriter.getIntegerType(minBitwidth); // convert A operand auto newAEncoding = DotOperandEncodingAttr::get( oldAType.getContext(), 0, newRetType.getEncoding(), @@ -319,14 +338,13 @@ class BlockedToMMA : public mlir::RewritePattern { auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), newBEncoding); b = rewriter.create(b.getLoc(), newBType, b); + newDot = taskIdRewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); } // convert dot instruction - auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, - newAcc, dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc()); - - rewriter.replaceOpWithNewOp(op, oldRetType, - newDot.getResult()); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot->getResult(0)); return success(); } }; @@ -350,7 +368,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { NvidiaMmaEncodingAttr mmaLayout = dyn_cast(D.getType().getEncoding()); if (mmaLayout) { - bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); + bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN(); // promote operands for sm < 89 since fp8 mma is not natively supported // promote operands for sm >= 90 when mma is not v3 if (!isNativeFP8 || @@ -373,6 +391,154 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { }); } +class ScaledBlockedToMMAv2 + : public mlir::OpRewritePattern { + int computeCapability; + +public: + ScaledBlockedToMMAv2(mlir::MLIRContext *context, int computeCapability) + : mlir::OpRewritePattern(context), + computeCapability(computeCapability) {} + + mlir::LogicalResult + matchAndRewrite(triton::DotScaledOp dotOp, + mlir::PatternRewriter &rewriter) const override { + if (computeCapability >= 100) + return failure(); + + auto oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + mlir::isa(oldRetType.getEncoding())) + return failure(); + auto ctx = dotOp.getContext(); + + // Check that rhs scale is null + assert(dotOp.getRhsScale() == nullptr && "rhs scale NYI"); + + // operands + auto a = dotOp.getLhs(); + auto b = dotOp.getRhs(); + auto scale = dotOp.getLhsScale(); + auto aType = dotOp.getLhsType(); + auto bType = dotOp.getRhsType(); + + auto enumToType = [&rewriter](F8F6F4Type type) { + switch (type) { + case F8F6F4Type::E4M3: + return rewriter.getFloat8E4M3FNType(); + case F8F6F4Type::E5M2: + return rewriter.getFloat8E5M2Type(); + default: + llvm_unreachable("unexpected type"); + } + }; + + assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 || + aType == F8F6F4Type::E2M1) && + "NYI: lhs supports fp4 or fp8"); + assert(bType == F8F6F4Type::E4M3 || + bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8"); + + // TODO run accelerate matmul on A and B first to choose their layouts + // Set return type + auto versionMajor = 2; + auto retShapePerCTA = getShapePerCTA(oldRetType); + auto mod = dotOp->getParentOfType(); + unsigned numWarps = TritonGPUDialect::getNumWarps(mod); + auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, + rewriter.getBF16Type(), numWarps); + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + // TODO Use warpsPerTileV2 + SmallVector warpsPerCTA = {numWarps, 1}; + auto mmaEnc = NvidiaMmaEncodingAttr::get(ctx, /*versionMajor=*/versionMajor, + /*versionMinor=*/0, warpsPerCTA, + CTALayout, instrShape); + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); + + // convert accumulator + auto oldAcc = dotOp.getOperand(2); + auto newAcc = + rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); + + auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType]( + TypedValue v, int idx, + F8F6F4Type type) -> TypedValue { + auto vType = v.getType(); + if (type == F8F6F4Type::E2M1) { + // A bit too dynamically typed... + // perhaps return ints in both cases? + + auto retEnc = dyn_cast(newRetType.getEncoding()); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), /*kWidth=*/4); + auto newVType = RankedTensorType::get( + vType.getShape(), vType.getElementType(), newVEncoding); + return rewriter.create(v.getLoc(), newVType, v); + } else { + assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), /*kWidth=*/8); + auto newVType = RankedTensorType::get( + vType.getShape(), vType.getElementType(), newVEncoding); + v = rewriter.create(v.getLoc(), newVType, v); + + // Bitcast + auto vTypeFp8 = RankedTensorType::get(vType.getShape(), + enumToType(type), newVEncoding); + v = cast>( + rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); + + // Convert to bf16 + auto vTypeBf16 = RankedTensorType::get( + vType.getShape(), rewriter.getBF16Type(), newVEncoding); + return rewriter.create(v.getLoc(), vTypeBf16, v); + } + }; + a = toMMABf16(a, 0, aType); + b = toMMABf16(b, 1, bType); + + // [Note: A trick to avoid warp shuffles in the lowering] + // FIXME: Implement this when we can set general layouts on a tensor + + // For bf16, we have 4 threads per row + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-a-f16 + // and each of them needs to get every scale in that row. + // It turns out that the layout for the output of type bf16 gives us exactly + // this layout when the number of mxfp vectors is equal to two (K = 64) + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16816-c + // This can be generalised to other K with linear layouts, but the general + // layout cannot cannot be represented with the predefined layouts :( + // With this trick, we could do the full lowering here and remove the + // UpcastMXFPOp altogether + + assert(instrShape == ArrayRef({16, 8}) || + instrShape == ArrayRef({1, 16, 8})); + auto shapeTileA = std::array{instrShape[0], instrShape[0]}; + // Necessary choice to leave all the scales of the tile in that given warp + auto threadsPerWarp = + SmallVector{shapeTileA[0], 32 / shapeTileA[0]}; + + auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, CTALayout); + + auto newScaleType = RankedTensorType::get(scale.getType().getShape(), + scale.getType().getElementType(), + newScaleEncoding); + scale = + rewriter.create(scale.getLoc(), newScaleType, scale); + + auto scaledA = rewriter.create( + dotOp.getLoc(), a, scale, dotOp.getLhsType()); + + // convert dot instruction + auto newDot = + rewriter.create(dotOp.getLoc(), newRetType, scaledA, b, newAcc); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, newDot); + return success(); + } +}; + #define GEN_PASS_DEF_TRITONGPUACCELERATEMATMUL #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -390,7 +556,8 @@ class TritonGPUAccelerateMatmulPass auto computeCapability = getNVIDIAComputeCapability(m); mlir::RewritePatternSet patterns(context); - patterns.add(context, computeCapability); + patterns.add(context, + computeCapability); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 7b2ab63e8..b1094963f 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(TritonGPUTransforms F32DotTC.cpp CombineTensorSelectAndIf.cpp ReduceDataDuplication.cpp + OptimizeAccumulatorInit.cpp OptimizeDotOperands.cpp OptimizeThreadLocality.cpp Pipeliner/MatmulLoopPipeline.cpp @@ -12,10 +13,16 @@ add_triton_library(TritonGPUTransforms Pipeliner/SoftwarePipeliner.cpp Pipeliner/TMAStoresPipeline.cpp Pipeliner/PipeliningUtility.cpp + Pipeliner/Schedule.cpp Prefetch.cpp RemoveLayoutConversions.cpp ReorderInstructions.cpp Utility.cpp + TaskIdPropagate.cpp + WSTaskPartition.cpp + WSDataPartition.cpp + WSCodePartition.cpp + WSLowering.cpp DEPENDS TritonGPUTransformsIncGen diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 06a7d963d..b3814329a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -92,7 +92,6 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase { // in the memory write at the warp level, resulting in worse performance. // For loads, we can expect that the gaps won't matter due to the L1 // cache. - unsigned elemNumBits = getElementBitWidth(refTensorType); perThread = std::min( perThread, getNumElementsPerThread(op, order, axisInfoAnalysis)); } diff --git a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp index f701634d4..d9fb1d7e1 100644 --- a/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp +++ b/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -45,6 +45,15 @@ class TF32x3 : public OpRewritePattern { ArrayRef{value}) .getResult()[0]; }; + auto zeroLike = [&](Value c) -> Value { + return rewriter.create( + dotOp->getLoc(), c.getType(), + rewriter.create(dotOp->getLoc(), + rewriter.getF32FloatAttr(0))); + }; + auto add = [&](Value a, Value b) -> Value { + return rewriter.create(dotOp.getLoc(), a, b); + }; auto sub = [&](Value a, Value b) -> Value { return rewriter.create(dotOp.getLoc(), a, b); }; @@ -60,11 +69,15 @@ class TF32x3 : public OpRewritePattern { auto bBig = f32ToTF32(dotOp.getB()); auto bSmall = sub(dotOp.getB(), bBig); - auto dot1 = dot(aSmall, bBig, dotOp.getC()); + auto zero = zeroLike(dotOp.getC()); + + auto dot1 = dot(aSmall, bBig, zero); auto dot2 = dot(aBig, bSmall, dot1); auto dot3 = dot(aBig, bBig, dot2); - rewriter.replaceOp(dotOp, dot3); + auto sum = add(dot3, dotOp.getC()); + + rewriter.replaceOp(dotOp, sum); return success(); } }; diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp new file mode 100644 index 000000000..f73efe61a --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -0,0 +1,204 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEACCUMULATORINIT +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +bool dotSupportsAccInitFlag(Operation *op) { + assert(op->hasTrait() && "Expected a dot-like operation"); + if (auto wgDotOp = dyn_cast(op)) { + // Partial accumulation would require a select op to handle the + // initialization that would degrade the performance. + return !wgDotOp.needsPartialAccumulator(); + } + return false; +} + +std::pair getAccumulatorUseAndDef(Operation *op) { + assert(op->hasTrait() && "Expected a dot-like operation"); + if (auto wgDotOp = dyn_cast(op)) { + return std::make_pair(wgDotOp.getC(), wgDotOp); + } + return std::make_pair(nullptr, nullptr); +} + +void setUseAccFlag(Operation *op, Value useAcc) { + assert(op->hasTrait() && "Expected a dot-like operation"); + if (auto wgDotOp = dyn_cast(op)) { + wgDotOp.getUseCMutable().assign(useAcc); + } +} + +bool isConstantZeroTensor(Value v) { + return (matchPattern(v, m_Zero()) || matchPattern(v, m_AnyZeroFloat())); +} + +std::optional> findZeroInitOp(Value accUse, + Operation *accDef, + scf::ForOp forOp, + bool &loopArgIsZero) { + Value v = accUse; + if (auto arg = dyn_cast(v)) { + assert(arg.getOwner() == forOp.getBody()); + if (isConstantZeroTensor(forOp.getInitArgs()[arg.getArgNumber() - 1])) { + loopArgIsZero = true; + } + v = forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + + auto defOp = v.getDefiningOp(); + if (!defOp) { + return std::nullopt; + } + if (auto selOp = dyn_cast(defOp)) { + if (isConstantZeroTensor(selOp.getTrueValue()) || + isConstantZeroTensor(selOp.getFalseValue())) { + return std::make_pair(selOp, 0); + } + } + if (auto ifOp = dyn_cast(defOp)) { + unsigned resultIndex = 0; + for (; resultIndex < ifOp.getNumResults(); ++resultIndex) { + if (ifOp.getResult(resultIndex) == v) + break; + } + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + if (isConstantZeroTensor(thenVal) || isConstantZeroTensor(elseVal)) { + // Make sure that the other value is not defined in the if itself, but + // passed from outside + if (thenVal.getParentBlock()->getParentOp() == ifOp || + elseVal.getParentBlock()->getParentOp() == ifOp) { + return std::nullopt; + } + return std::make_pair(ifOp, resultIndex); + } + } + return std::nullopt; +} + +} // namespace + +class OptimizeAccumulatorInitPass + : public impl::TritonGPUOptimizeAccumulatorInitBase< + OptimizeAccumulatorInitPass> { +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + SmallVector mmaOps; + m.walk([&](Operation *op) { + if (op->hasTrait() && dotSupportsAccInitFlag(op)) { + mmaOps.push_back(op); + } + }); + + // for each mma op, find where the accumulator is initialized with zero + // It can be: + // 1. A constant zero + // 2. Initialized with zero as the loop argument + // 3. Initialized with zero in the if op or with a select op in current + // or any of the previous loop iterations + for (Operation *mmaOp : mmaOps) { + Location loc = mmaOp->getLoc(); + + scf::ForOp forOp = dyn_cast(mmaOp->getParentOp()); + if (!forOp) { + continue; + } + + IRRewriter rewriter(forOp); + rewriter.setInsertionPoint(forOp); + + Value vTrue = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value vFalse = + rewriter.create(loc, rewriter.getBoolAttr(false)); + + // Find the accumulator + auto [accUse, accDef] = getAccumulatorUseAndDef(mmaOp); + if (!accUse || !accDef) { + continue; + } + if (isConstantZeroTensor(accUse)) { + setUseAccFlag(mmaOp, vFalse); + continue; + } + + bool loopArgIsZero = false; + std::optional> zeroInitOp = + findZeroInitOp(accUse, accDef, forOp, loopArgIsZero); + if (!zeroInitOp) { + continue; + } + + Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue; + scf::ForOp newForOp = + replaceForOpWithNewSignature(rewriter, forOp, {loopArgFlagValue}); + forOp.erase(); + forOp = newForOp; + loopArgFlagValue = + forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1); + + Value condition = nullptr; + Value oldValue = nullptr; + Value zeroValue = nullptr; + bool thenInitsToZero = false; + if (auto selOp = dyn_cast(zeroInitOp->first)) { + condition = selOp.getCondition(); + oldValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getFalseValue() + : selOp.getTrueValue(); + zeroValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getTrueValue() + : selOp.getFalseValue(); + thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue()); + } else { + assert(isa(*zeroInitOp->first) && "Expected an if op"); + auto ifOp = cast(zeroInitOp->first); + unsigned resultIndex = zeroInitOp->second; + condition = ifOp.getCondition(); + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal; + zeroValue = isConstantZeroTensor(thenVal) ? thenVal : elseVal; + thenInitsToZero = isConstantZeroTensor(thenVal); + } + + // Create a select op that updates the flag + rewriter.setInsertionPoint(zeroInitOp->first); + bool zeroingBeforeMMA = zeroInitOp->first->isBeforeInBlock(mmaOp); + Value prevFlagValue = zeroingBeforeMMA ? loopArgFlagValue : vTrue; + auto selectFlagOp = rewriter.create( + loc, condition, thenInitsToZero ? vFalse : prevFlagValue, + thenInitsToZero ? prevFlagValue : vFalse); + setUseAccFlag(mmaOp, zeroingBeforeMMA ? selectFlagOp : loopArgFlagValue); + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield->insertOperands(forYield->getNumOperands(), + {zeroingBeforeMMA ? vTrue : selectFlagOp}); + + // Stop clearing out the accumulator with zero + if (auto selOp = dyn_cast(zeroInitOp->first)) { + rewriter.setInsertionPoint(selOp); + rewriter.replaceOp(selOp, oldValue); + } else { + auto ifOp = cast(zeroInitOp->first); + int resultIndex = zeroInitOp->second; + auto zeroingYield = + thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield(); + zeroingYield.setOperand(resultIndex, oldValue); + } + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 4a30bf9f3..6d8279795 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -59,12 +59,12 @@ class SwizzleShmemConvert : public OpRewritePattern { srcTy.getElementType(), /*needTrans=*/true); if (newInnerCvtEnc == cvtEncoding) return failure(); - rewriter.setInsertionPoint(trans); + auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext()); auto alloc = rewriter.create( trans.getLoc(), MemDescType::get(srcTy.getShape(), srcTy.getElementType(), - newInnerCvtEnc), + newInnerCvtEnc, sharedMemorySpace), trans.getSrc()); auto newTrans = rewriter.create(trans.getLoc(), alloc, ArrayRef({1, 0})); @@ -210,15 +210,10 @@ class FuseTransHopper : public OpRewritePattern { LogicalResult matchAndRewrite(LocalAllocOp allocOp, PatternRewriter &rewriter) const override { if (!allocOp->hasOneUse() || - !isa(*allocOp->getUsers().begin())) + !allocOp->getUsers().begin()->hasTrait()) return failure(); auto dot = *allocOp->getUsers().begin(); - auto dotEnc = dyn_cast( - cast(dot->getResult(0).getType()).getEncoding()); - if (!dotEnc || dotEnc.getVersionMajor() != 3) - return failure(); - if (!allocOp.getSrc()) return failure(); @@ -254,7 +249,8 @@ class FuseTransHopper : public OpRewritePattern { allocEncoding.getCTALayout(), srcTy.getElementType()); MemDescType innerTy = - MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc); + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc, + allocType.getMemorySpace()); auto newAlloc = rewriter.create(allocOp.getLoc(), innerTy, trans.getSrc()); rewriter.replaceOpWithNewOp(allocOp, newAlloc, @@ -267,10 +263,11 @@ class FuseTransHopper : public OpRewritePattern { // dot(convert(lhs #mma) #shared, rhs) #mma -> // dot(convert(lhs #mma) #dot_operand, rhs) #mma, // for fp16 or bf16 MMAv3 dots. -struct MMAV3UseRegOperand : public OpRewritePattern { +struct MMAV3UseRegOperand + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DotOp dotOp, + LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp, PatternRewriter &rewriter) const override { auto alloc = dotOp.getOperand(0).getDefiningOp(); if (!alloc || !alloc.getSrc()) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp index 30211da08..b0e5095ac 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -120,9 +120,12 @@ class TritonGPUOptimizeThreadLocalityPass // TODO: relax this restriction if (!(isa(srcEncoding) && rank > 1)) return; + // The code currently assumes that the reduction is happening on the most + // inner dim. + if (reduce.getAxis() != rank - 1) + return; for (auto operand : reduce->getOperands()) { - auto def = operand.getDefiningOp(); - if (!isa(def)) + if (!operand.getDefiningOp()) return; } auto elemsPerThread = @@ -148,7 +151,7 @@ class TritonGPUOptimizeThreadLocalityPass return; auto argNum = yieldOpOperand.getOperandNumber(); auto oldAccum = forOp.getInitArgs()[argNum]; - auto cstOp = dyn_cast(oldAccum.getDefiningOp()); + auto cstOp = oldAccum.getDefiningOp(); if (!cstOp) return; reduceOps.insert(reduce); @@ -311,8 +314,8 @@ class TritonGPUOptimizeThreadLocalityPass IRMapping mapping; for (auto operand : reduce.getOperands()) { auto viewOp = builder.create( - reduce.getLoc(), viewOpTensorType, operand, /*allowReorder=*/true); - viewOp.setEfficientLayout(true); + reduce.getLoc(), viewOpTensorType, operand, + /*allowReorder=*/true, /*efficientLayout=*/true); mapping.map(operand, viewOp); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index f3d5aa00e..e946735e2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1,6 +1,3 @@ -#include "PipelineExpander.h" -#include "PipeliningUtility.h" -#include "Schedule.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" @@ -14,6 +11,9 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "llvm/ADT/MapVector.h" @@ -51,168 +51,10 @@ struct LoadInfo { } // namespace -class CoarseSchedule { -public: - class ClusterList { - std::list orderClusters; - - public: - using iterator = decltype(orderClusters)::iterator; - ClusterList() = default; - iterator begin() { return orderClusters.begin(); } - iterator end() { return orderClusters.end(); } - size_t size() { return orderClusters.size(); } - iterator newAtBack() { - orderClusters.push_back(orderClusters.size()); - return std::prev(orderClusters.end()); - } - iterator newAtFront() { - orderClusters.push_front(-1); - for (auto &clusterId : orderClusters) { - clusterId++; - } - return orderClusters.begin(); - } - iterator newBefore(iterator cluster) { - auto ret = orderClusters.insert(cluster, *cluster); - for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) { - clusterId++; - } - return ret; - } - }; - - CoarseSchedule(int numStages) : numStages(numStages) {} - int numStages; - ClusterList clusters; - using Cluster = decltype(clusters)::iterator; - - DenseMap> opToStageAndCluster; - - void insert(Operation *op, int stage, Cluster cluster) { - opToStageAndCluster[op] = {stage, cluster}; - } - - bool insertIfAbsent(Operation *op, int stage, Cluster cluster) { - if (opToStageAndCluster.count(op)) - return false; - insert(op, stage, cluster); - return true; - } - - void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, - bool includeArg) { - for (Value operand : op->getOperands()) { - Value v = operand; - llvm::SmallDenseSet seen; - while (auto arg = dyn_cast(v)) { - if (!includeArg) - break; - if (!seen.insert(v).second) - break; - if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { - auto yieldOp = op->getBlock()->getTerminator(); - v = yieldOp->getOperand(arg.getArgNumber() - 1); - continue; - } - break; - } - Operation *defOp = v.getDefiningOp(); - if (defOp && defOp->getBlock() == op->getBlock()) { - if (insertIfAbsent(defOp, stage, cluster)) { - insertDepsOfOp(defOp, stage, cluster, includeArg); - } - } - } - } - - void erase(Operation *op) { opToStageAndCluster.erase(op); } - - int count(Operation *op) { return opToStageAndCluster.count(op); } - - std::pair operator[](Operation *op) { - return opToStageAndCluster[op]; - } - - SmallVector> - getOpsInOrder(scf::ForOp forOp) { - SmallVector>, 8> - orderClusters(clusters.size()); - for (auto &op : forOp.getBody()->without_terminator()) { - if (opToStageAndCluster.count(&op) == 0) { - continue; - } - assert(opToStageAndCluster[&op].first < numStages && - "Op with invalid stage!"); - int clusterId = *opToStageAndCluster[&op].second; - assert(clusterId == std::distance(clusters.begin(), - opToStageAndCluster[&op].second) && - "Cluster ID mismatch!"); - orderClusters[clusterId].push_back( - make_tuple(&op, opToStageAndCluster[&op].first, - opToStageAndCluster[&op].second)); - } - SmallVector> opsInOrder; - for (int i = 0; i < orderClusters.size(); i++) { - for (auto [op, stage, cluster] : orderClusters[i]) { - opsInOrder.push_back({op, stage, cluster}); - } - } - - return opsInOrder; - } - - std::vector> - createFinalSchedule(scf::ForOp forOp) { - SmallVector> opsInOrder = - getOpsInOrder(forOp); - std::vector> schedule; - for (auto [op, stage, cluster] : opsInOrder) { - LDBG("Adding op to schedule at stage " << stage << " cluster " << *cluster - << ":" << *op); - schedule.push_back({op, stage}); - } - return schedule; - } - - void dump() { - for (int i = 0; i < numStages; i++) { - LDBG("- Ops in stage " << i); - for (auto &[op, stageAndCluster] : opToStageAndCluster) { - if (i == stageAndCluster.first) { - llvm::outs() << " cluster: " << *stageAndCluster.second << " "; - op->dump(); - } - } - } - } -}; - -static bool isMMAv3Dot(Operation *op) { - auto dot = dyn_cast(op); - if (!dot) - return false; - auto enc = - mlir::dyn_cast(dot.getType().getEncoding()); - return enc && enc.isHopper(); -} - -// Replace the ForOp's yield with a new one with the given operands appended. -static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { - // Fix up the yield op. - Operation *yieldOp = forOp.getBody()->getTerminator(); - SmallVector operands(yieldOp->getOperands()); - operands.append(newOperands.begin(), newOperands.end()); - - OpBuilder builder(yieldOp); - builder.create(yieldOp->getLoc(), operands); - yieldOp->erase(); -} - static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, - CoarseSchedule &schedule, - CoarseSchedule::Cluster prefetchCluster, + tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster prefetchCluster, llvm::MapVector &loadToInfo, int numStages) { OpBuilder builder(forOp); @@ -245,9 +87,11 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, tt::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); copyOffsets[0] = insertIdx; + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); tt::MemDescType subviewTy = tt::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), /*mutableMemory=*/true); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto view = builder.create(loc, subviewTy, alloc, copyOffsets); Operation *copy = builder.create( @@ -271,13 +115,13 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, builder.create(loc, subviewTy, alloc, loadOffsets); if (isMMV3Load) { auto alloc = cast((*loadOp->getUsers().begin())); - alloc.replaceAllUsesWith(viewLoad.getResult()); + replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); alloc.erase(); } else { SmallVector allocsToErase; for (Operation *user : loadOp->getUsers()) { if (auto alloc = dyn_cast(user)) { - alloc.replaceAllUsesWith(viewLoad.getResult()); + replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); allocsToErase.push_back(alloc); } } @@ -312,10 +156,12 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, static void createTMAAsyncCopy( scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, - Value phase, CoarseSchedule &schedule, + Value phase, tt::CoarseSchedule &schedule, llvm::MapVector &loadToInfo, int numStages) { assert(phase && "Phase value is required for TMA async copy."); OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); Value zero = builder.create(forOp.getLoc(), 0, 32); builder.setInsertionPoint(loadOp); Location loc = loadOp.getLoc(); @@ -324,7 +170,7 @@ static void createTMAAsyncCopy( copyOffsets[0] = insertIdx; tt::MemDescType subviewTy = tt::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), /*mutableMemory=*/true); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto view = builder.create(loc, subviewTy, alloc, copyOffsets); @@ -345,13 +191,13 @@ static void createTMAAsyncCopy( builder.create(loc, subviewTy, alloc, loadOffsets); if (isMMV3Load) { auto alloc = cast((*loadOp->getUsers().begin())); - alloc.replaceAllUsesWith(viewLoad.getResult()); + replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); alloc.erase(); } else { SmallVector allocsToErase; for (Operation *user : loadOp->getUsers()) { if (auto alloc = dyn_cast(user)) { - alloc.replaceAllUsesWith(viewLoad.getResult()); + replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); allocsToErase.push_back(alloc); } } @@ -368,11 +214,14 @@ static void createTMAAsyncCopy( } // If all the transitive uses of the given value have are used by a convert to -// the same dot operand encoding, return true and get the shared encoding that -// needs to be used to be compatible with users' layouts. +// the same dot operand encoding, return the shared encoding that needs to be +// used to be compatible with users' layouts. If there are imcompatible shared +// encodings, raise assertion, since incompatible shared encoding has been +// handled in splitLoadsForIncompatible. static std::optional -getSharedEncIfAllUsersAreDotEnc(Value val) { +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { ttg::SharedEncodingAttr attr; + incompatible = false; for (Operation *user : val.getUsers()) { ttg::SharedEncodingAttr tempAttr; if (user->getNumResults() != 1) @@ -382,7 +231,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { // First time we find a shared encoding in the chain, save it and try to // use it if it is compatible with the other users. tempAttr = cast(memDesc.getEncoding()); - if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) + .has_value()) return std::nullopt; } else { if (!isa(user)) @@ -396,14 +246,14 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { auto order = ttg::getOrder(srcTy.getEncoding()); unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); tempAttr = ttg::SharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), - ttg::getOrder(srcTy.getEncoding()), - ttg::getCTALayout(srcTy.getEncoding()), - srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); + val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, + bitWidth, /*needTrans=*/false); } // Check that the shared encodings needed by the users are compatible. - if (!tempAttr || (attr != nullptr && attr != tempAttr)) + if (attr != nullptr && attr != tempAttr) { + incompatible = true; return std::nullopt; + } attr = tempAttr; } return attr; @@ -506,7 +356,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { }; for (Operation &op : forOp.getBody()->without_terminator()) { - if (!isa(op)) + if (!op.hasTrait()) continue; seen.clear(); dfs(&op, 0, &op); @@ -583,7 +433,7 @@ assignMemoryLayouts(llvm::SmallVector> continue; } - if (auto dot = dyn_cast(use)) { + if (use->hasTrait()) { loadInfo.usedByDot = true; if (loadIsMMAv3(op)) { loadInfo.loadIsMMAV3 = true; @@ -592,9 +442,14 @@ assignMemoryLayouts(llvm::SmallVector> } else if (isa(op)) { loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); - } else { + } else if (auto dot = dyn_cast(use)) { + bool incompatible = false; loadInfo.sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) + .value_or(nullptr); + // If we can't agree on a shared encoding skip pipelinig the load. + if (incompatible) + continue; // HACK: Triton LLVM codegen has a bug where local_loads from #shared to // #mma layout can lead to invalid code if the loaded shape is smaller @@ -642,7 +497,7 @@ assignMemoryLayouts(llvm::SmallVector> // If we still don't have a shared encoding, try a "generic" shared // encoding. - if (!loadInfo.sharedEncoding && !isMMAv3Dot(use)) { + if (!loadInfo.sharedEncoding && !isa(use)) { loadInfo.sharedEncoding = getSharedEncoding(op, /*isMMAV3=*/loadInfo.loadIsMMAV3) .value_or(nullptr); @@ -662,8 +517,9 @@ assignMemoryLayouts(llvm::SmallVector> } static llvm::MapVector -scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, +scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, DenseSet &rootUsers, int numStages) { + ModuleOp moduleOp = forOp->getParentOfType(); tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); @@ -682,6 +538,15 @@ scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, if (loadOpToIndLevelAndUse.empty()) return {}; + // We assume loads with different dist are assigned to different stages. + // If numStages is 2, we will have no stage available for indirect loads + // with dist >= 1. In general, when dist is equal to numStages - 1, we + // should not pipeline it. + auto it = llvm::remove_if(loadOpToIndLevelAndUse, [=](auto op) { + return std::get<1>(op) >= numStages - 1; + }); + loadOpToIndLevelAndUse.erase(it, loadOpToIndLevelAndUse.end()); + // Check which loads are good for pipelining, and assign them // memory layouts. llvm::MapVector loadToInfo = @@ -700,7 +565,7 @@ scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, unsigned stagesBetweenLoads = ceil(numStages - 2, maxIndirectionLevel + 1); - CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); // Put the root uses of the loads in the last stage. for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { if (loadToInfo.count(loadOp) == 0) @@ -713,7 +578,7 @@ scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, } } - SmallVector loadsClusters; + SmallVector loadsClusters; for (int i = 0; i < maxIndirectionLevel + 1; i++) { loadsClusters.push_back(schedule.clusters.newAtBack()); } @@ -738,10 +603,10 @@ scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, // Schedule the prologue and epilogue `if` ops in the loop, pushing them as // close to the loop boundaries as possible. Return the cluster after the // prologue (or the beginning of the loop if there is no prologue). -static CoarseSchedule::Cluster -schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule, +static tt::CoarseSchedule::Cluster +schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, DenseSet &rootUsers, int numStages) { - CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); // Look for the IfOp that is in the backward slice any of the currently // scheduled ops and put it at the beginning of the loop. @@ -763,14 +628,14 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule, } } } - CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + tt::CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); for (auto [ifOp, stage] : ifsToStage) { schedule.insert(ifOp, stage, prologueCluster); } // Look for the IfOp that is in the forward slice of the root users and put it // at the end of the loop. - CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + tt::CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); for (auto rootUser : rootUsers) { SetVector forwardSlice; getForwardSlice(rootUser, &forwardSlice); @@ -797,9 +662,9 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule, // Add dependencies of anchor ops to the coarse schedule. Schedule them to // the same stage and ordering cluster as the anchor op. -static void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule, +static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, int numStages) { - SmallVector> + SmallVector> opsInOrder = schedule.getOpsInOrder(forOp); // Schedule dependencies stage by stage. for (int stage = 0; stage < numStages; stage++) { @@ -814,7 +679,7 @@ static void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule, // Find dependencies with distance of 1. They will go to the next stage, // but in the cluster before the current op. static void scheduleDistanceOneDependencies(scf::ForOp forOp, - CoarseSchedule &schedule, + tt::CoarseSchedule &schedule, int numStages) { auto getNestedOperands = [](Operation *op) -> SmallVector { SmallVector operands; @@ -828,7 +693,8 @@ static void scheduleDistanceOneDependencies(scf::ForOp forOp, }; // Mapping from the cluster to the cluster before it. - DenseMap dist1Cluster; + DenseMap + dist1Cluster; for (auto &op : forOp.getBody()->without_terminator()) { if (schedule.count(&op) == 0) continue; @@ -863,14 +729,14 @@ static void scheduleDistanceOneDependencies(scf::ForOp forOp, } } -static void scheduleRemainingToLastStage(scf::ForOp forOp, - CoarseSchedule &schedule, - CoarseSchedule::Cluster afterPrologue, - int numStages) { +static void +scheduleRemainingToLastStage(scf::ForOp forOp, tt::CoarseSchedule &schedule, + tt::CoarseSchedule::Cluster afterPrologue, + int numStages) { // Assign the rest of the ops to the last stage. // Take care of the ordering of the ops - uses cannot be scheduled to the // cluster before the definition. - DenseMap opToCluster; + DenseMap opToCluster; for (auto &op : forOp.getBody()->without_terminator()) { if (schedule.count(&op) == 0) { opToCluster[&op] = afterPrologue; @@ -888,8 +754,8 @@ static void scheduleRemainingToLastStage(scf::ForOp forOp, Operation *op = queue.pop_back_val(); for (auto user : op->getUsers()) { if (opToCluster.count(user)) { - CoarseSchedule::Cluster userCluster = opToCluster[user]; - CoarseSchedule::Cluster opCluster; + tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; + tt::CoarseSchedule::Cluster opCluster; if (schedule.count(op)) opCluster = schedule[op].second; else @@ -910,11 +776,14 @@ static void scheduleRemainingToLastStage(scf::ForOp forOp, static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, ttg::SharedEncodingAttr sharedEnc, unsigned distance) { OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); auto ty = cast(loadOp->getResultTypes()[0]); SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); bufferShape.insert(bufferShape.begin(), distance); Type memdescType = mlir::triton::MemDescType::get( - bufferShape, ty.getElementType(), sharedEnc, /*mutableMemory*/ true); + bufferShape, ty.getElementType(), sharedEnc, sharedMemorySpace, + /*mutableMemory*/ true); Value alloc = builder.create( loadOp->getLoc(), memdescType, Value()); return alloc; @@ -923,6 +792,8 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, // Create an allocation to hold the mbarriers. static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); Location loc = forOp.getLoc(); auto context = forOp.getContext(); auto barrierCTALayout = @@ -930,11 +801,12 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); auto barrierEncoding = ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); - Type barrierMemDescType = - tt::MemDescType::get({distance}, builder.getI64Type(), barrierEncoding, - /*mutableMemory=*/true); - Type singleBarrierMemDescType = tt::MemDescType::get( - {1}, builder.getI64Type(), barrierEncoding, /*mutableMemory=*/true); + Type barrierMemDescType = tt::MemDescType::get( + {distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); Value barrierAlloc = builder.create( loc, barrierMemDescType, Value()); for (unsigned i = 0; i < distance; i++) { @@ -959,7 +831,7 @@ struct AsyncLoad { // multiple loads is the schedule allows it. static void createTMABarrierAndWait( scf::ForOp &forOp, SmallVector &asyncLoads, Value insertIdx, - Value extractIdx, Value phase, int numBuffers, CoarseSchedule &schedule, + Value extractIdx, Value phase, int numBuffers, tt::CoarseSchedule &schedule, SmallVector &barriers, const llvm::MapVector &loadToInfo) { llvm::SmallDenseMap loadToAsyncLoad; @@ -1030,9 +902,12 @@ static void createTMABarrierAndWait( barriers.push_back(barrierAlloc); Location loc = forOp.getLoc(); OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(builder.getContext()); tt::MemDescType barrierTy = tt::MemDescType::get( {1}, builder.getI64Type(), cast(barrierAlloc.getType()).getEncoding(), + sharedMemorySpace, /*mutableMemory=*/true); builder.setInsertionPoint(group[0]->loadOp); Value barrier = builder.create( @@ -1059,7 +934,7 @@ static void createTMABarrierAndWait( // Convert load ops into their asyn version and apply multi-buffering based on // the required number of buffers. static SmallVector -createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, +createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, llvm::MapVector &loadToInfo, SmallVector &barriers, int numStages) { // Calculate the number of buffers needed for each load. @@ -1147,7 +1022,7 @@ createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, // Create a cluster for the prefetches. It may end up being empty, but this // is OK. - CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); for (AsyncLoad &asyncLoad : asyncLoads) { if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { @@ -1164,13 +1039,15 @@ createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, if (phase) newYieldOperands.push_back(phase); // Patch the yield with the updated counters. - appendToYield(forOp, newYieldOperands); + appendToForOpYield(forOp, newYieldOperands); return allocs; } static void invalidateBarriers(OpBuilder &builder, SmallVector &barriers) { + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(builder.getContext()); for (Value barrier : barriers) { int numBarriers = cast(barrier.getType()).getShape()[0]; for (int i = 0; i < numBarriers; i++) { @@ -1178,6 +1055,7 @@ static void invalidateBarriers(OpBuilder &builder, tt::MemDescType barrierTy = tt::MemDescType::get( {1}, builder.getI64Type(), cast(barrier.getType()).getEncoding(), + sharedMemorySpace, /*mutableMemory=*/true); Value barrierView = builder.create( barrier.getLoc(), barrierTy, barrier, idx); @@ -1191,7 +1069,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule( // Schedule the loads and root ops (dot ops) in the loop. This will give us // a scaffold for the final schedule. DenseSet rootUsers; - CoarseSchedule coarseSchedule(numStages); + tt::CoarseSchedule coarseSchedule(numStages); llvm::MapVector loadToInfo = scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); if (loadToInfo.empty()) @@ -1202,6 +1080,13 @@ bool mlir::triton::preProcessLoopAndGetSchedule( coarseSchedule.dump(); }); + tt::CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with prologue and epilogue:"); + coarseSchedule.dump(); + }); + SmallVector barriers; // Convert the loads into async loads and create the allocs. SmallVector allocs = @@ -1212,13 +1097,6 @@ bool mlir::triton::preProcessLoopAndGetSchedule( coarseSchedule.dump(); }); - CoarseSchedule::Cluster afterPrologue = - schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); - LLVM_DEBUG({ - LDBG("Coarse schedule with prologue and epilogue:"); - coarseSchedule.dump(); - }); - scheduleDependencies(forOp, coarseSchedule, numStages); LLVM_DEBUG({ LDBG("Coarse schedule with dependencies:"); @@ -1397,16 +1275,17 @@ void mlir::triton::updateWaits(ModuleOp module) { // also adds some MemDesc's to the wait. The idea is that if you have // // %alloc = ttg.local_alloc ... -// %a = ttng.dot_async %alloc -// %a1 = ttng.dot_wait %a +// %a = ttng.warp_group_dot %alloc +// %a1 = ttng.warp_group_dot_wait %a // // then we want the wait to depend on %alloc as well as %a. This extends the // live range of %alloc, so that it won't be destroyed until after the dot is // waited on. // -// Specifically, this function finds all dot_async ops that elements of `values` -// depend on. Then it adds the MemDesc operands of those dots to the wait. -static void threadValuesThroughWait(ttng::DotWaitOp wait, +// Specifically, this function finds all warp_group_dot ops that elements of +// `values` depend on. Then it adds the MemDesc operands of those dots to the +// wait. +static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, MutableArrayRef values) { IRRewriter builder(wait.getContext()); builder.setInsertionPoint(wait); @@ -1423,12 +1302,12 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, newOperands.insert(values.begin(), values.end()); // Find memdefs depended on by `values` through async dot ops. - SmallVector asyncDots; + SmallVector asyncDots; for (Value v : values) { BackwardSliceOptions options; options.omitBlockArguments = true; options.filter = [&](Operation *op) { - if (auto dot = dyn_cast(op)) { + if (auto dot = dyn_cast(op)) { asyncDots.push_back(dot); return false; } @@ -1438,7 +1317,7 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, getBackwardSlice(v, &slice, options); } - for (ttng::DotAsyncOp dot : asyncDots) { + for (ttng::WarpGroupDotOp dot : asyncDots) { for (Value operand : dot.getOperands()) { if (isa(operand.getType())) { newOperands.insert(operand); @@ -1448,7 +1327,7 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, // We can't use replaceWithNewOp because we're changing the number of return // values in the operation. - auto newWait = builder.create( + auto newWait = builder.create( wait.getLoc(), llvm::to_vector(newOperands), wait.getPendings()); auto dominatedByNewWait = [&](OpOperand &operand) { @@ -1469,13 +1348,14 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, wait->erase(); } -// Determines whether a given MMAv3 dot op, represented as ttng.dot_async, needs -// a wait immediately after it. +// Determines whether a given MMAv3 dot op, represented as ttng.warp_group_dot, +// needs a wait immediately after it. // // In PTX, MMAv3 exists only as an asynchronous op. In Triton, we can represent -// MMAv3 ops as either tt.dot (synchronous) or ttng.dot_async. But even if we -// use ttng.dot_async, the conservative thing is to make a dot "effectively -// synchronous" by inserting a `ttng.dot_wait {pendings=0}` right after it. +// MMAv3 ops as either ttng.warp_group_dot {isAsync=True} or ttng.warp_group_dot +// {isAsync=False}. But even if we use ttng.warp_group_dot {isAsync=True}, the +// conservative thing is to make a dot "effectively synchronous" by inserting a +// `ttng.warp_group_dot_wait {pendings=0}` right after it. // // We can omit the wait and create a "properly async" dot if all of the // following are true. @@ -1487,28 +1367,29 @@ static void threadValuesThroughWait(ttng::DotWaitOp wait, // and will be synced with a `wait 0` at the beginning of the `if` block. // // 3. During iteration i, between the start of the loop up until the first -// `ttng.dot_wait {pendings=0}` op, the result of the dot from iteration i-1 -// is consumed only by other MMAv3 dots as the `c` operand. +// `ttng.warp_group_dot_wait {pendings=0}` op, the result of the dot from +// iteration i-1 is consumed only by other MMAv3 dots as the `c` operand. // // This is safe because the following pseudo-PTX is valid: // -// %accum = dot_async %a1, %b1, %c1 -// %accum = dot_async %a2, %b2, %accum +// %accum = warp_group_dot %a1, %b1, %c1 +// %accum = warp_group_dot %a2, %b2, %accum // // That is, the second async dot can use the result of the first one without // an intervening wait. However, the only operation that can legally read -// %accum before the wait is another dot_async, and this only works for the -// `c` operand, not `a` or `b`. See +// %accum before the wait is another warp_group_dot, and this only works for +// the `c` operand, not `a` or `b`. See // https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence -// (ttng::DotAsyncOp corresponds to wgmma.fence followed by one or more -// wgmma.async ops, so our understanding is that the two ttng::DotAsyncOps -// don't have to correspond to wgmma.async ops with the same shapes as -// specified in the docs, because there's an intervening fence.) +// (ttng::WarpGroupDotOp corresponds to wgmma.fence followed by one or more +// wgmma.async ops, so our understanding is that the two +// ttng::WarpGroupDotOps don't have to correspond to wgmma.async ops with +// the same shapes as specified in the docs, because there's an intervening +// fence.) // // If the op can be properly async, this function returns the index of the dot // in the loop's iter_args. (Rule (2) above ensures this is well-defined.) // -static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, +static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, scf::ForOp forOp) { LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp); @@ -1524,11 +1405,19 @@ static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, // allowed in between. Value transitiveOperand = operand; while (isa_and_nonnull( - transitiveOperand.getDefiningOp())) { - transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0); + transitiveOperand.getDefiningOp()) || + isa(transitiveOperand)) { + auto blockArg = dyn_cast(transitiveOperand); + if (blockArg && blockArg.getOwner() == forOp.getBody()) { + transitiveOperand = + cast(blockArg.getOwner()->getTerminator()) + .getOperand(blockArg.getArgNumber() - 1); + } else if (Operation *def = transitiveOperand.getDefiningOp()) { + transitiveOperand = def->getOperand(0); + } } return forOp.isDefinedOutsideOfLoop(transitiveOperand) || - isa(transitiveOperand.getDefiningOp()); + transitiveOperand.getDefiningOp(); }; // We don't have to call checkOperand on getC() because it's always in @@ -1582,16 +1471,17 @@ static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, // Rule 3a: Are the only users of the dot's result from iteration i-1 other // MMAv3 dots? If so, we're done, this dot can be properly async. if (llvm::all_of(iterArg.getUses(), [&](OpOperand &use) { - return isa(use.getOwner()) && + return isa(use.getOwner()) && use.getOperandNumber() == 2; })) { return iterArgIdx; } // Rule 3b: Are all users of the dot's result from iteration i-1 after the - // first `dot_wait {pendings=0}` op? If so, the dot can be properly async, - // but we have to thread its result from iteration i-1 through the wait. - auto waitOps = forOp.getBody()->getOps(); + // first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be + // properly async, but we have to thread its result from iteration i-1 through + // the wait. + auto waitOps = forOp.getBody()->getOps(); auto firstWaitOpIter = llvm::find_if( waitOps, [&](auto waitOp) { return waitOp.getPendings() == 0; }); if (firstWaitOpIter != waitOps.end() && @@ -1602,7 +1492,8 @@ static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, } return (*firstWaitOpIter)->isBeforeInBlock(user); })) { - LDBG("MMAv3 dot can be properly async because it follows a dot_wait " + LDBG("MMAv3 dot can be properly async because it follows a " + "warp_group_dot_wait " "{pendings=0}.\n" << " wait: " << *firstWaitOpIter << "\n" << " dot: " << dotOp); @@ -1617,16 +1508,16 @@ static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, // If necessary, insert a dot-wait inside the loop, waiting for the results of // the properly-async dots from iteration i-1 to complete. (We pipeline to -// depth 2, so there are at most 2 copies of each dot_async in flight at a +// depth 2, so there are at most 2 copies of each warp_group_dot in flight at a // time.) // -// We can skip inserting the wait if we have a `dot_wait {pendings=0}` somewhere -// in the loop. To see why, consider: +// We can skip inserting the wait if we have a `warp_group_dot_wait +// {pendings=0}` somewhere in the loop. To see why, consider: // -// dot_async -// dot_async; wait 0 // synchronous dot -// dot_async -// dot_async +// warp_group_dot +// warp_group_dot; wait 0 // synchronous dot +// warp_group_dot +// warp_group_dot // // In this example, there are three properly-async dots, so we'd normally put // `wait 3` at the end of the loop, meaning "wait until there are 3 or fewer @@ -1634,13 +1525,13 @@ static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, // completes, there are only *two* pending async dots from this iteration, so // this wait would do nothing. This is true in general, no matter where the // `wait 0` appears. -static void insertAsyncDotWaitInLoop( +static void insertAsyncWarpGroupDotWaitInLoop( scf::ForOp forOp, const llvm::MapVector &properlyAsyncDots) { if (properlyAsyncDots.empty()) return; - if (llvm::any_of(forOp.getBody()->getOps(), + if (llvm::any_of(forOp.getBody()->getOps(), [](auto wait) { return wait.getPendings() == 0; })) { return; } @@ -1664,8 +1555,8 @@ static void insertAsyncDotWaitInLoop( for (auto [block, users] : blockToUsers) { OpBuilder builder(block, block->begin()); - auto newWait = builder.create(asyncDot->getLoc(), - ArrayRef{}, 0); + auto newWait = builder.create( + asyncDot->getLoc(), ArrayRef{}, 0); threadValuesThroughWait(newWait, users); } @@ -1682,9 +1573,9 @@ static void insertAsyncDotWaitInLoop( IRRewriter builder(forOp.getContext()); auto lastAsyncDot = properlyAsyncDots.back().first; builder.setInsertionPointAfter(lastAsyncDot); - auto wait = builder.create(lastAsyncDot->getLoc(), - /*inputs=*/ArrayRef{}, - properlyAsyncDots.size()); + auto wait = builder.create( + lastAsyncDot->getLoc(), + /*inputs=*/ArrayRef{}, properlyAsyncDots.size()); // Thread the results of the async dots through the wait. SmallVector addlWaitOperands; @@ -1694,49 +1585,40 @@ static void insertAsyncDotWaitInLoop( threadValuesThroughWait(wait, addlWaitOperands); } -// Convert MMAv3 tt::DotOps (i.e. Hopper wgmma) into ttng::DotAsyncOps and -// insert ttng::DotWaitOps as necessary. +// Convert MMAv3 ttng::WarpGroupDotOps {isAsync = False} (i.e. Hopper wgmma) +// into ttng::WarpGroupDotOps {isAsync = True} and insert +// ttng::WarpGroupDotWaitOps as necessary. // // We assume we have space for each dot to be pipelined to depth 2, i.e. each -// dot op in the loop can have at most 2 dot_async ops in flight at once. (Each -// dot_async op usually corresponds to a series of wgmma.async ops.) +// dot op in the loop can have at most 2 warp_group_dot ops in flight at once. +// (Each warp_group_dot op usually corresponds to a series of wgmma.async ops.) void triton::asyncLaunchDots(scf::ForOp forOp) { LDBG("Original loop:\n" << *forOp); - // First, change every MMAv3 tt.dot into ttng.dot_async. The rest of this - // function is concerned with inserting ttng.dot_wait ops in the appropriate - // places. - // - // It's not strictly necessary to convert every dot into dot_async: - // Synchronous MMAv3 dots can be represented equally well as `tt.dot` or - // `ttng.dot_async; wait 0`. But this makes things easier elsewhere. + // First, change every MMAv3 ttng.warp_group_dot {isAsync=false} + // into ttng.warp_group_dot {isAsync=true}. + // The rest of this function is concerned with inserting + // ttng.warp_group_dot_wait ops in the appropriate places. // // We call those dots that don't need to be followed immediately by a `wait 0` // "properly async", or sometimes just "async". - IRRewriter builder(forOp.getContext()); - for (auto dotOp : llvm::to_vector(forOp.getBody()->getOps())) { - if (isMMAv3Dot(dotOp)) { - builder.setInsertionPoint(dotOp); - builder.replaceOpWithNewOp( - dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), - dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); - } - } - + // // For each dot, determine whether it can be properly async, or if it needs a // sync immediately after. If it can be properly async, we know its only use // is in the loop's `yield` statement; asyncDots maps the op to its index in // the yield op. + IRRewriter builder(forOp.getContext()); llvm::MapVector properlyAsyncDots; - for (auto dotOp : forOp.getBody()->getOps()) { - if (auto iterArgIdx = dotCanBeProperlyAsync(dotOp, forOp)) { - properlyAsyncDots[dotOp] = *iterArgIdx; + for (auto WarpGroupDotOp : forOp.getBody()->getOps()) { + WarpGroupDotOp.setIsAsync(true); + if (auto iterArgIdx = dotCanBeProperlyAsync(WarpGroupDotOp, forOp)) { + properlyAsyncDots[WarpGroupDotOp] = *iterArgIdx; } else { - builder.setInsertionPointAfter(dotOp); - auto wait = - builder.create(dotOp.getLoc(), ArrayRef{}, - /*pendings=*/0); - SmallVector waitOperands = {dotOp.getResult()}; + builder.setInsertionPointAfter(WarpGroupDotOp); + auto wait = builder.create( + WarpGroupDotOp.getLoc(), ArrayRef{}, + /*pendings=*/0); + SmallVector waitOperands = {WarpGroupDotOp.getResult()}; threadValuesThroughWait(wait, waitOperands); } } @@ -1750,7 +1632,7 @@ void triton::asyncLaunchDots(scf::ForOp forOp) { // iteration's set of asynchronous dots (and their corresponding async copies // from global to shmem) can't start until the first iteration's set has // completed. - insertAsyncDotWaitInLoop(forOp, properlyAsyncDots); + insertAsyncWarpGroupDotWaitInLoop(forOp, properlyAsyncDots); // Finally, insert a wait after the loop, waiting for dots from the final // iteration of the loop. @@ -1760,7 +1642,7 @@ void triton::asyncLaunchDots(scf::ForOp forOp) { } // Wait until there are 0 outstanding async dot ops. builder.setInsertionPointAfter(forOp); - auto dotWaitAfterLoop = - builder.create(forOp.getLoc(), ArrayRef{}, 0); - threadValuesThroughWait(dotWaitAfterLoop, waitOperands); + auto WarpGroupDotWaitAfterLoop = builder.create( + forOp.getLoc(), ArrayRef{}, 0); + threadValuesThroughWait(WarpGroupDotWaitAfterLoop, waitOperands); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp index 8b3f55bb8..d8a34f694 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp @@ -1,7 +1,7 @@ -#include "PipelineExpander.h" -#include "PipeliningUtility.h" -#include "Schedule.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" using namespace mlir; namespace tt = mlir::triton; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index 6dfd0e344..20fcba4d7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -26,12 +26,12 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" -#include "PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" #define DEBUG_TYPE "triton-loop-pipelining" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -89,7 +89,7 @@ struct LoopPipelinerInternal { bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options); /// Emits the prologue, this creates `maxStage - 1` part which will contain /// operations from stages [0; i], where i is the part index. - void emitPrologue(RewriterBase &rewriter); + LogicalResult emitPrologue(RewriterBase &rewriter); /// Gather liverange information for Values that are used in a different stage /// than its definition. llvm::MapVector analyzeCrossStageValues(); @@ -106,8 +106,8 @@ struct LoopPipelinerInternal { RewriterBase &rewriter); /// Emits the epilogue, this creates `maxStage - 1` part which will contain /// operations from stages [i; maxStage], where i is the part index. - void emitEpilogue(RewriterBase &rewriter, - llvm::SmallVector &returnValues); + LogicalResult emitEpilogue(RewriterBase &rewriter, + llvm::SmallVector &returnValues); }; bool LoopPipelinerInternal::initializeLoopInfo( @@ -131,7 +131,7 @@ bool LoopPipelinerInternal::initializeLoopInfo( int64_t ubImm = upperBoundCst.value(); int64_t lbImm = lowerBoundCst.value(); int64_t stepImm = stepCst.value(); - int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm); + int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm); if (numIteration > maxStage) { dynamicLoop = false; } else if (!options.supportDynamicLoops) { @@ -145,10 +145,6 @@ bool LoopPipelinerInternal::initializeLoopInfo( LDBG("--no epilogue or predicate set -> BAIL"); return false; } - if (dynamicLoop && peelEpilogue) { - LDBG("--dynamic loop doesn't support epilogue yet -> BAIL"); - return false; - } std::vector> schedule; options.getScheduleFn(forOp, schedule); if (schedule.empty()) { @@ -279,7 +275,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, return clone; } -void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { +LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { // Initialize the iteration argument to the loop initiale values. for (auto [arg, operand] : llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { @@ -289,19 +285,6 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { Location loc = forOp.getLoc(); SmallVector predicates(maxStage); for (int64_t i = 0; i < maxStage; i++) { - if (dynamicLoop) { - Type t = ub.getType(); - // pred = ub > lb + (i * step) - Value iv = rewriter.create( - loc, lb, - rewriter.create( - loc, step, - rewriter.create( - loc, rewriter.getIntegerAttr(t, i)))); - predicates[i] = rewriter.create( - loc, arith::CmpIPredicate::slt, iv, ub); - } - // special handling for induction variable as the increment is implicit. // iv = lb + i * step Type t = lb.getType(); @@ -312,6 +295,13 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { rewriter.create(loc, rewriter.getIntegerAttr(t, i)))); setValueMapping(forOp.getInductionVar(), iv, i); + + if (dynamicLoop) { + // pred = ub > lb + (i * step) + predicates[i] = rewriter.create( + loc, arith::CmpIPredicate::slt, iv, ub); + } + for (Operation *op : opOrder) { if (stages[op] > i) continue; @@ -325,26 +315,38 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { }); int predicateIdx = i - stages[op]; if (predicates[predicateIdx]) { + OpBuilder::InsertionGuard insertGuard(rewriter); newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); - assert(newOp && "failed to predicate op."); + if (newOp == nullptr) + return failure(); } - rewriter.setInsertionPointAfter(newOp); if (annotateFn) annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { - setValueMapping(op->getResult(destId), newOp->getResult(destId), - i - stages[op]); + Value source = newOp->getResult(destId); // If the value is a loop carried dependency update the loop argument - // mapping. for (OpOperand &operand : yield->getOpOperands()) { if (operand.get() != op->getResult(destId)) continue; + if (predicates[predicateIdx] && + !forOp.getResult(operand.getOperandNumber()).use_empty()) { + // If the value is used outside the loop, we need to make sure we + // return the correct version of it. + Value prevValue = valueMapping + [forOp.getRegionIterArgs()[operand.getOperandNumber()]] + [i - stages[op]]; + source = rewriter.create( + loc, predicates[predicateIdx], source, prevValue); + } setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], - newOp->getResult(destId), i - stages[op] + 1); + source, i - stages[op] + 1); } + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); } } } + return success(); } llvm::MapVector @@ -453,6 +455,7 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( auto newForOp = rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, forOp.getStep(), newLoopArg); + newForOp->setAttrs(forOp->getAttrs()); // When there are no iter args, the loop body terminator will be created. // Since we always create it below, remove the terminator if it was created. if (!newForOp.getBody()->empty()) @@ -563,6 +566,7 @@ LogicalResult LoopPipelinerInternal::createKernel( } if (predicates[useStage]) { + OpBuilder::InsertionGuard insertGuard(rewriter); newOp = predicateFn(rewriter, newOp, predicates[useStage]); if (!newOp) return failure(); @@ -570,7 +574,6 @@ LogicalResult LoopPipelinerInternal::createKernel( for (auto values : llvm::zip(op->getResults(), newOp->getResults())) mapping.map(std::get<0>(values), std::get<1>(values)); } - rewriter.setInsertionPointAfter(newOp); if (annotateFn) annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0); } @@ -642,71 +645,131 @@ LogicalResult LoopPipelinerInternal::createKernel( return success(); } -void LoopPipelinerInternal::emitEpilogue( - RewriterBase &rewriter, llvm::SmallVector &returnValues) { +LogicalResult +LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, + llvm::SmallVector &returnValues) { + Location loc = forOp.getLoc(); + Type t = lb.getType(); // Emit different versions of the induction variable. They will be // removed by dead code if not used. - for (int64_t i = 0; i < maxStage; i++) { - Location loc = forOp.getLoc(); - Type t = lb.getType(); - Value minusOne = - rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); - // number of iterations = ((ub - 1) - lb) / step - Value totalNumIteration = rewriter.create( - loc, - rewriter.create( - loc, rewriter.create(loc, ub, minusOne), lb), - step); - // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) - Value minusI = - rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); + + auto createConst = [&](int v) { + return rewriter.create(loc, + rewriter.getIntegerAttr(t, v)); + }; + + // total_iterations = cdiv(range_diff, step); + // - range_diff = ub - lb + // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step + Value zero = createConst(0); + Value one = createConst(1); + Value stepLessZero = rewriter.create( + loc, arith::CmpIPredicate::slt, step, zero); + Value stepDecr = + rewriter.create(loc, stepLessZero, one, createConst(-1)); + + Value rangeDiff = rewriter.create(loc, ub, lb); + Value rangeIncrStep = rewriter.create(loc, rangeDiff, step); + Value rangeDecr = + rewriter.create(loc, rangeIncrStep, stepDecr); + Value totalIterations = rewriter.create(loc, rangeDecr, step); + + // If total_iters < max_stage, start the epilogue at zero to match the + // ramp-up in the prologue. + // start_iter = max(0, total_iters - max_stage) + Value iterI = rewriter.create(loc, totalIterations, + createConst(maxStage)); + iterI = rewriter.create(loc, zero, iterI); + + // Capture predicates for dynamic loops. + SmallVector predicates(maxStage + 1); + + for (int64_t i = 1; i <= maxStage; i++) { + // newLastIter = lb + step * iterI Value newlastIter = rewriter.create( - loc, lb, - rewriter.create( - loc, step, - rewriter.create(loc, totalNumIteration, minusI))); - setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); + loc, lb, rewriter.create(loc, step, iterI)); + + setValueMapping(forOp.getInductionVar(), newlastIter, i); + + // increment to next iterI + iterI = rewriter.create(loc, iterI, one); + + if (dynamicLoop) { + // Disable stages when `i` is greater than total_iters. + // pred = total_iters >= i + predicates[i] = rewriter.create( + loc, arith::CmpIPredicate::sge, totalIterations, createConst(i)); + } } + // Emit `maxStage - 1` epilogue part that includes operations from stages // [i; maxStage]. for (int64_t i = 1; i <= maxStage; i++) { + SmallVector> returnMap(returnValues.size()); for (Operation *op : opOrder) { if (stages[op] < i) continue; + unsigned currentVersion = maxStage - stages[op] + i; + unsigned nextVersion = currentVersion + 1; Operation *newOp = cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { auto it = valueMapping.find(newOperand->get()); if (it != valueMapping.end()) { - Value replacement = it->second[maxStage - stages[op] + i]; + Value replacement = it->second[currentVersion]; newOperand->set(replacement); } }); + if (dynamicLoop) { + OpBuilder::InsertionGuard insertGuard(rewriter); + newOp = predicateFn(rewriter, newOp, predicates[currentVersion]); + if (!newOp) + return failure(); + } if (annotateFn) annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue, i - 1); - for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { - setValueMapping(op->getResult(destId), newOp->getResult(destId), - maxStage - stages[op] + i); + for (auto [opRes, newRes] : + llvm::zip(op->getResults(), newOp->getResults())) { + setValueMapping(opRes, newRes, currentVersion); // If the value is a loop carried dependency update the loop argument // mapping and keep track of the last version to replace the original // forOp uses. for (OpOperand &operand : forOp.getBody()->getTerminator()->getOpOperands()) { - if (operand.get() != op->getResult(destId)) + if (operand.get() != opRes) continue; - unsigned version = maxStage - stages[op] + i + 1; // If the version is greater than maxStage it means it maps to the // original forOp returned value. - if (version > maxStage) { - returnValues[operand.getOperandNumber()] = newOp->getResult(destId); - continue; - } - setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], - newOp->getResult(destId), version); + unsigned ri = operand.getOperandNumber(); + returnValues[ri] = newRes; + Value mapVal = forOp.getRegionIterArgs()[ri]; + returnMap[ri] = std::make_pair(mapVal, currentVersion); + if (nextVersion <= maxStage) + setValueMapping(mapVal, newRes, nextVersion); + } + } + } + if (dynamicLoop) { + // Select return values from this stage (live outs) based on predication. + // If the stage is valid select the peeled value, else use previous stage + // value. + for (auto pair : llvm::enumerate(returnValues)) { + unsigned ri = pair.index(); + auto [mapVal, currentVersion] = returnMap[ri]; + if (mapVal) { + unsigned nextVersion = currentVersion + 1; + Value pred = predicates[currentVersion]; + Value prevValue = valueMapping[mapVal][currentVersion]; + auto selOp = rewriter.create(loc, pred, pair.value(), + prevValue); + returnValues[ri] = selOp; + if (nextVersion <= maxStage) + setValueMapping(mapVal, selOp, nextVersion); } } } } + return success(); } void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { @@ -737,7 +800,8 @@ mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, *modifiedIR = true; // 1. Emit prologue. - pipeliner.emitPrologue(rewriter); + if (failed(pipeliner.emitPrologue(rewriter))) + return failure(); // 2. Track values used across stages. When a value cross stages it will // need to be passed as loop iteration arguments. @@ -764,7 +828,8 @@ mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, if (options.peelEpilogue) { // 4. Emit the epilogue after the new forOp. rewriter.setInsertionPointAfter(newForOp); - pipeliner.emitEpilogue(rewriter, returnValues); + if (failed(pipeliner.emitEpilogue(rewriter, returnValues))) + return failure(); } // 5. Erase the original loop and replace the uses with the epilogue output. if (forOp->getNumResults() > 0) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index c773d808c..a9465652f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -1,4 +1,4 @@ -#include "PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/TypeUtilities.h" @@ -34,11 +34,9 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, OpBuilder::InsertionGuard guard(rewriter); if (mlir::isMemoryEffectFree(op)) return op; - if (isa(op)) + if (isa(op)) return op; - if (isa(op)) - return op; - if (isa(op)) + if (isa(op)) return op; if (auto ifOp = dyn_cast(op)) { rewriter.setInsertionPoint(op); @@ -75,6 +73,13 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, expectOp.getPredMutable().assign(mask); return op; } + if (auto storeOp = dyn_cast(op)) { + rewriter.setInsertionPoint(storeOp); + Value mask = getPredMask(rewriter, storeOp.getPtr().getType(), + storeOp.getMask(), pred); + storeOp.getMaskMutable().assign(mask); + return op; + } assert("don't know how to predicate this op" && false); return op; @@ -121,3 +126,51 @@ void mlir::triton::addOps( schedule.emplace_back(&op, stage); } } + +void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, + Operation *oldUse, Value val) { + SmallVector opsToDelete; + SmallVector operandsToReplace; + + // Save the operand to replace / delete later (avoid iterator invalidation). + // TODO: can we use an early_inc iterator? + for (OpOperand &use : oldUse->getUses()) { + // Non-subview/trans ops will be replaced by `val`. + if (!isa(use.getOwner())) { + operandsToReplace.push_back(&use); + continue; + } + Operation *user = use.getOwner(); + // `subview(old_op)` is replaced by a new `subview(val)`. + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(user); + Value newVal; + if (auto subview = dyn_cast(user)) { + triton::MemDescType oldType = subview.getType(); + bool isMutable = + cast(val.getType()).getMutableMemory(); + Type newDstType = triton::MemDescType::get( + oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), + oldType.getMemorySpace(), isMutable); + newVal = builder.create( + subview.getLoc(), newDstType, val, subview.getOffsets()); + } else if (auto trans = dyn_cast(user)) { + newVal = builder.create(trans.getLoc(), val, + trans.getOrderAttr()); + } + assert(newVal); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); + replaceUsesAndPropagateType(builder, user, newVal); + opsToDelete.push_back(use.getOwner()); + } + + // Perform late replacement. + for (OpOperand *operand : operandsToReplace) { + Operation *op = operand->getOwner(); + operand->set(val); + } + + // Perform late op erasure. + for (Operation *op : opsToDelete) + op->erase(); +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp new file mode 100644 index 000000000..1116b70a0 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -0,0 +1,92 @@ +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, + tt::CoarseSchedule::Cluster cluster, + bool includeArg) { + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + if (insertIfAbsent(defOp, stage, cluster)) { + insertDepsOfOp(defOp, stage, cluster, includeArg); + } + } + } +} + +SmallVector> +tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) { + SmallVector>, 8> + orderClusters(clusters.size()); + for (auto &op : forOp.getBody()->without_terminator()) { + if (opToStageAndCluster.count(&op) == 0) { + continue; + } + assert(opToStageAndCluster[&op].first < numStages && + "Op with invalid stage!"); + int clusterId = *opToStageAndCluster[&op].second; + assert(clusterId == std::distance(clusters.begin(), + opToStageAndCluster[&op].second) && + "Cluster ID mismatch!"); + orderClusters[clusterId].push_back(make_tuple( + &op, opToStageAndCluster[&op].first, opToStageAndCluster[&op].second)); + } + SmallVector> opsInOrder; + for (int i = 0; i < orderClusters.size(); i++) { + for (auto [op, stage, cluster] : orderClusters[i]) { + opsInOrder.push_back({op, stage, cluster}); + } + } + + return opsInOrder; +} + +std::vector> +tt::CoarseSchedule::createFinalSchedule(scf::ForOp forOp) { + SmallVector> + opsInOrder = getOpsInOrder(forOp); + std::vector> schedule; + for (auto [op, stage, cluster] : opsInOrder) + schedule.push_back({op, stage}); + return schedule; +} + +void tt::CoarseSchedule::dump() { + for (int i = 0; i < numStages; i++) { + llvm::dbgs() << "\n---- Ops in stage " << i << "\n"; + for (auto &[op, stageAndCluster] : opToStageAndCluster) { + if (i == stageAndCluster.first) { + llvm::dbgs() << " cluster: " << *stageAndCluster.second + << ":\n\t" << *op << "\n"; + } + } + } +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index e5ed6ed37..8766e82b9 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -1,6 +1,3 @@ -#include "PipelineExpander.h" -#include "PipeliningUtility.h" -#include "Schedule.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/TypeUtilities.h" @@ -11,6 +8,9 @@ #include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index 6318b178d..d7c9422c5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -1,5 +1,6 @@ -#include "Schedule.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" using namespace mlir; @@ -29,7 +30,7 @@ getTMAStores(scf::ForOp forOp) { static Value createAlloc(scf::ForOp &forOp, tt::ExperimentalDescriptorStoreOp storeOp) { - OpBuilder builder(forOp); + OpBuilderWithAsyncTaskIds builder(forOp); auto ty = cast(storeOp.getSrc().getType()); auto order = ttg::getOrder(ty.getEncoding()); auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); @@ -39,18 +40,20 @@ static Value createAlloc(scf::ForOp &forOp, encoding = ttg::SharedEncodingAttr::get( ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType()); } - - Type memdescType = tt::MemDescType::get(ty.getShape(), ty.getElementType(), - encoding, /*mutableMemory*/ true); - Value alloc = builder.create(storeOp->getLoc(), - memdescType, Value()); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()); + Type memdescType = + tt::MemDescType::get(ty.getShape(), ty.getElementType(), encoding, + sharedMemorySpace, /*mutableMemory*/ true); + Value alloc = builder.createWithAsyncTaskIds( + storeOp->getLoc(), memdescType, Value()); return alloc; } static void createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorStoreOp storeOp, Value alloc) { - OpBuilder builder(storeOp); + OpBuilderWithAsyncTaskIds builder(storeOp); auto loc = storeOp.getLoc(); auto ty = cast(storeOp.getSrc().getType()); auto order = ttg::getOrder(ty.getEncoding()); @@ -58,10 +61,11 @@ static void createTMAAsyncCopy(scf::ForOp &forOp, // Put wait before the local_store make the store truly async. We know // that we are the only user of the CopyLocalToGlobal. - builder.create(loc, 0); - builder.create(loc, storeOp.getSrc(), alloc); - builder.create(loc, false); - builder.create( + builder.createWithAsyncTaskIds(loc, 0); + builder.createWithAsyncTaskIds(loc, storeOp.getSrc(), + alloc); + builder.createWithAsyncTaskIds(loc, false); + builder.createWithAsyncTaskIds( loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc); storeOp->erase(); @@ -74,8 +78,21 @@ bool mlir::triton::pipelineTMAStores(scf::ForOp forOp) { return false; DenseMap storeToAlloc; + DenseMap, Type>, Value> allocs; for (tt::ExperimentalDescriptorStoreOp op : tmaStores) { + // Reuse allocations for stores of the same shape and types. This allows + // saving shared memory usage. It is valid since we have a wait 0 before + // every local_store. We could pipeline more aggressively if we didn't + // re-use but there is a tradeoff with shared memory usage. + auto key = std::make_pair(op.getSrc().getType().getShape(), + op.getSrc().getType().getElementType()); + auto it = allocs.find(key); + if (it != allocs.end()) { + storeToAlloc[op] = it->second; + continue; + } storeToAlloc[op] = createAlloc(forOp, op); + allocs[key] = storeToAlloc[op]; } for (tt::ExperimentalDescriptorStoreOp op : tmaStores) { diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 85a95aaa7..2cbc00142 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -136,8 +136,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, builder.create(v.getLoc(), off, 32)); Value newSmem = builder.create( v.getLoc(), - triton::MemDescType::get(shape, elementType, type.getEncoding()), v, - offsetsVal); + triton::MemDescType::get(shape, elementType, type.getEncoding(), + type.getMemorySpace()), + v, offsetsVal); auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); @@ -158,10 +159,13 @@ LogicalResult Prefetcher::initialize() { SmallVector dotsInFor; for (Operation &op : *loop) if (auto dotOp = dyn_cast(op)) { - // bail out if there exist non v2 dots. - auto dstEnc = + // Only accepts dotOps encoded as Nvidia MMA v2 or AMD MFMA + auto dstMmaEnc = dyn_cast(getEncoding(dotOp.getResult())); - if (!dstEnc || dstEnc.getVersionMajor() != 2) + auto dstMfmaEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!dstMfmaEnc && (!dstMmaEnc || dstMmaEnc.getVersionMajor() != 2)) + // Don't rewrite if any other type is found. return failure(); dotsInFor.push_back(dotOp); } @@ -174,8 +178,6 @@ LogicalResult Prefetcher::initialize() { if (dotsInFor.size() > 1) return failure(); - // returns source of cvt - // returns source of cvt auto getPrefetchSrc = [](Value v) -> SmallVector { // walk back to conversion @@ -209,7 +211,7 @@ LogicalResult Prefetcher::initialize() { return Value(); }; - auto getYieldOp = [this](Value v) -> Value { + auto getYieldOperand = [this](Value v) -> Value { auto arg = mlir::cast(v); unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars(); return yieldOp.getOperand(yieldIdx); @@ -255,8 +257,8 @@ LogicalResult Prefetcher::initialize() { dot2bHeaderDef[dot] = bHeaderDef; dot2aLoopArg[dot] = aSmem; dot2bLoopArg[dot] = bSmem; - dot2aYield[dot] = getYieldOp(aSmem); - dot2bYield[dot] = getYieldOp(bSmem); + dot2aYield[dot] = getYieldOperand(aSmem); + dot2bYield[dot] = getYieldOperand(bSmem); } } } @@ -302,7 +304,31 @@ scf::ForOp Prefetcher::createNewForOp() { mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + // The insertion point should be placed before the yield op + auto setInsertionPointBeforeYield = [](OpBuilder &builder, + scf::ForOp newForOp) { + if (newForOp.getBody()->mightHaveTerminator()) { + builder.setInsertionPoint(newForOp.getBody()->getTerminator()); + } else { + builder.setInsertionPointToEnd(newForOp.getBody()); + } + }; + for (Operation &op : forOp.getBody()->without_terminator()) { + // If we're currently trying to sink a prefetched dot, we need to stop + // sinking it (by resetting the insertion point to the end) if we find + // control flow, or anything that depends on the dot op. + if (op.getNumRegions() > 0) { + setInsertionPointBeforeYield(builder, newForOp); + } + for (auto operand : op.getOperands()) { + if (auto def = operand.getDefiningOp()) { + auto dot = dyn_cast(def); + if (dot && dots.contains(dot)) { + setInsertionPointBeforeYield(builder, newForOp); + } + } + } Operation *newOp = builder.clone(op, mapping); auto dot = dyn_cast(&op); if (dot && dots.contains(dot)) { @@ -320,6 +346,14 @@ scf::ForOp Prefetcher::createNewForOp() { int64_t kOff = prefetchWidth; int64_t kRem = dot.getA().getType().getShape()[1] - prefetchWidth; Operation *prevDot = firstDot; + if (kRem == 0) { + // There is only one dot while prefetchWidth == kSize so delay issuing + // it. Meanwhile, newOp should be set to firstDot to make sure the dot + // result is updated to yield. + builder.setInsertionPoint(prevDot); + newOp = firstDot; + } + while (kRem != 0) { // int64_t kShape = largestPow2(kRem); int64_t kShape = prefetchWidth; @@ -341,6 +375,13 @@ scf::ForOp Prefetcher::createNewForOp() { prevDot = newOp; kOff += kShape; kRem -= kShape; + if (kRem == 0) { + // We want to delay issuing the last dot as long as possible, ideally + // until after the prefetch. To accomplish this, set the insertion + // point above the dot. If we find anything dependent on the dot (at + // the top of this loop), we resume inserting after it. + builder.setInsertionPoint(prevDot); + } } } // update mapping of results @@ -365,6 +406,7 @@ scf::ForOp Prefetcher::createNewForOp() { yieldValues.push_back(bToYield); } // Update ops of yield + builder.setInsertionPointToEnd(newForOp.getBody()); if (!yieldValues.empty()) builder.create(yieldOp.getLoc(), yieldValues); return newForOp; diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index c0b586d60..b1e296c1b 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -42,22 +42,8 @@ class TritonGPUReduceDataDuplicationPass dyn_cast(dstType.getEncoding()); if (!dstDotOp) return; - if (auto srcMmaEncoding = - dyn_cast(srcEncoding)) { - - if (srcMmaEncoding.getVersionMajor() != 2 || - (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && - dstDotOp.getParent() == srcMmaEncoding)) - return; - } - if (auto srcMfmaEncoding = - dyn_cast(srcEncoding)) { - - if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 && - srcMfmaEncoding.getIsTransposed() && - dstDotOp.getParent() == srcMfmaEncoding) - return; - } + if (!cvtNeedsSharedMemory(srcType, dstType)) + return; auto srcOrder = triton::gpu::getOrder(srcEncoding); auto rank = srcOrder.size(); SmallVector sharedOrder; @@ -70,12 +56,14 @@ class TritonGPUReduceDataDuplicationPass } else { sharedOrder = srcOrder; } + auto sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = triton::MemDescType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( mod.getContext(), dstDotOp, srcType.getShape(), sharedOrder, - triton::gpu::getCTALayout(srcEncoding), - srcType.getElementType())); + triton::gpu::getCTALayout(srcEncoding), srcType.getElementType()), + sharedMemorySpace); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getSrc()); auto newConvert = builder.create(cvtOp.getLoc(), diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 967d34c8f..7b3b60451 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -20,6 +20,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include #include namespace mlir { @@ -39,45 +40,6 @@ namespace { // // ----------------------------------------------------------------------------- -// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0)) -class ConvertDotConvert : public RewritePattern { -public: - ConvertDotConvert(MLIRContext *context) - : RewritePattern(ConvertLayoutOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto dstOp = cast(op); - auto dotOp = dstOp.getSrc().getDefiningOp(); - if (!dotOp) - return failure(); - if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || - std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) - return failure(); - auto cvtOp = dotOp.getOperand(2).getDefiningOp(); - if (!cvtOp) - return failure(); - if (!cvtOp.getSrc().getDefiningOp()) - return failure(); - RankedTensorType dstTy = dstOp.getType(); - RankedTensorType srcTy = cvtOp.getSrc().getType(); - if (dstTy != srcTy) - return failure(); - - auto _0f = rewriter.create( - op->getLoc(), dstTy.getElementType(), - rewriter.getZeroAttr(dstTy.getElementType())); - auto _0 = rewriter.create(op->getLoc(), dotOp.getType(), _0f); - auto newDot = rewriter.create( - op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1), - _0, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); - auto newCvt = rewriter.create(op->getLoc(), dstTy, - newDot.getResult()); - rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getSrc()); - return success(); - } -}; - // The current algorithm works by analyzing the IR and doing a one-shot rewrite // based on the analysis. The algorithm is as follows. // @@ -201,91 +163,12 @@ void LayoutRematerialization::cleanup() { op->erase(); } -// Look ahead to at the transitive uses and see if there is a convert to mma -// operations. -bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { - SmallVector queue = {op->getResult(0)}; - SetVector forwardSlice; - llvm::SmallDenseSet seen; - while (!queue.empty()) { - Value currentValue = queue.back(); - queue.pop_back(); - getForwardSlice(currentValue, &forwardSlice); - for (Operation *op : forwardSlice) { - // HACK: Stop propagation if the ReduceOp is using mma layout but is - // producing tensor smaller than the layout we would like to propagate. - // This is to avoid stepping into the known bug. - if (isa(op)) { - auto tensorType = - dyn_cast(op->getOperand(0).getType()); - if (tensorType && - isa(tensorType.getEncoding())) { - auto mmaInstrShape = - cast(encoding).getInstrShape(); - if (tensorType.getShape()[tensorType.getRank() - 2] < - mmaInstrShape[0] || - tensorType.getShape()[tensorType.getRank() - 1] < - mmaInstrShape[1]) { - return false; - } - } - } - - if (auto convertOp = dyn_cast(op)) { - Attribute dstEncoding = convertOp.getType().getEncoding(); - if (auto mmaLayout = dyn_cast(dstEncoding)) - return (mmaLayout.getVersionMajor() > 1) ? true - : mmaLayout == encoding; - if (isa(dstEncoding)) - return true; - if (isa(dstEncoding)) { - if (auto mmaLayout = dyn_cast(encoding)) { - return mmaLayout.getVersionMajor() > 1; - } else { - assert((mlir::isa(encoding))); - return true; - } - } - } - bool isMMAV3 = - isa(encoding) && - cast(encoding).getVersionMajor() == 3; - if (isMMAV3 && (isa(op) || isa(op))) - return true; - auto yield = dyn_cast(op); - if (!yield) - continue; - if (auto ifOp = dyn_cast(yield->getParentOp())) { - for (OpOperand &operand : yield->getOpOperands()) { - Operation *def = operand.get().getDefiningOp(); - if (def && - (forwardSlice.count(def) || operand.get() == currentValue) && - (seen.insert(operand.get()).second == true)) - queue.push_back(ifOp.getResult(operand.getOperandNumber())); - } - } - auto forOp = dyn_cast(yield.getOperation()->getParentOp()); - if (!forOp) - continue; - for (OpOperand &operand : yield->getOpOperands()) { - Operation *def = operand.get().getDefiningOp(); - if (def && (forwardSlice.count(def) || operand.get() == currentValue) && - (seen.insert(operand.get()).second == true)) - queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); - } - } - } - return false; -} - // Return true if the op is an op with a layout we don't want to change. We will // propagate the layout starting from anchor ops. bool isLayoutAnchor(Operation *op) { if (isa(op)) return isExpensiveLoadOrStore(op); - if (isa(op)) + if (isa(op)) return true; // Heuristic: Mark permuting reshape as a layout anchor. Its dst can be @@ -300,18 +183,8 @@ bool isLayoutAnchor(Operation *op) { } void LayoutPropagation::initAnchorLayout() { - auto maybeAddAnchor = [&](Value v) { + auto addAnchor = [&](Value v) { if (auto tensorType = dyn_cast(v.getType())) { - // Workaround, don't popagate MMA layout unless there is a convert - // back to mma further down to avoid generating reduction with MMA - // layout that may have lower performance. - // This can be improved with more aggressive backward propagation. - if (isa(tensorType.getEncoding()) && - v.getDefiningOp() && - !hasConvertToMMATransisitiveUse(v.getDefiningOp(), - tensorType.getEncoding())) { - return; - } layouts.insert({v, LayoutInfo(tensorType.getEncoding())}); } }; @@ -320,13 +193,13 @@ void LayoutPropagation::initAnchorLayout() { // you can pass a tensor with an encoding as an arg, instead of explicitly // calling tt.load. for (auto arg : funcOp.getArguments()) { - maybeAddAnchor(arg); + addAnchor(arg); } funcOp.walk([&](Operation *op) { if (isLayoutAnchor(op)) { for (auto result : op->getResults()) { - maybeAddAnchor(result); + addAnchor(result); } } }); @@ -375,17 +248,14 @@ SmallVector LayoutPropagation::propagateToUsers(Value value, if (auto yieldOp = dyn_cast(user)) { auto parent = yieldOp->getParentOp(); SmallVector valuesToPropagate; - if (isa(parent)) + if (isa(parent)) valuesToPropagate.push_back(parent->getResult(use.getOperandNumber())); if (auto forOp = dyn_cast(parent)) valuesToPropagate.push_back( forOp.getRegionIterArg(use.getOperandNumber())); - if (auto whileOp = dyn_cast(parent)) { + if (auto whileOp = dyn_cast(parent)) valuesToPropagate.push_back( whileOp.getBeforeArguments()[use.getOperandNumber()]); - valuesToPropagate.push_back( - whileOp->getOperand(use.getOperandNumber())); - } if (isa(parent)) setEncoding(valuesToPropagate, info, changed, user); continue; @@ -399,10 +269,16 @@ SmallVector LayoutPropagation::propagateToUsers(Value value, setEncoding({afterArg, result}, info, changed, user); continue; } + if (auto dotWaitOp = dyn_cast(user)) { + unsigned opIndex = use.getOperandNumber(); + Value result = dotWaitOp->getResult(opIndex); + setEncoding(result, info, changed, user); + continue; + } if (user->hasTrait() || user->hasTrait() || isa(user)) { + ConvertLayoutOp>(user)) { setEncoding(user->getResults(), info, changed, user); continue; } @@ -479,10 +355,10 @@ bool reduceToScalar(Operation *op) { } void LayoutPropagation::rewriteRegion(Region ®ion) { - SmallVector queue = {®ion}; + std::deque queue = {®ion}; while (!queue.empty()) { - Region *currentRegion = queue.back(); - queue.pop_back(); + Region *currentRegion = queue.front(); + queue.pop_front(); for (Operation &op : currentRegion->getOps()) { bool needRewrite = false; SmallVector results = op.getResults(); @@ -563,6 +439,8 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { tensorType.getElementType(), encoding); Value converted = rewriter.create(value.getLoc(), tmpType, rewrittenValue); + if (value.getDefiningOp()) + converted.getDefiningOp()->setAttrs(value.getDefiningOp()->getAttrs()); // TODO: we could cache the conversion. return converted; } @@ -805,6 +683,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), encoding); auto cvt = rewriter.create(op->getLoc(), newType, src); + cvt->setAttrs(op->getAttrs()); map(op->getResult(0), cvt.getResult()); return cvt.getOperation(); } @@ -821,7 +700,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { if (op->hasTrait() || op->hasTrait() || isa(op)) { + ConvertLayoutOp, nvidia_gpu::WarpGroupDotWaitOp>(op)) { Operation *newOp = cloneElementwise(rewriter, op, encoding); for (auto [oldResult, newResult] : llvm::zip(op->getResults(), newOp->getResults())) { @@ -1206,6 +1085,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( tensorType.getShape(), tensorType.getElementType(), *srcEncoding); auto newConvertOp = builder.create( convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); + newConvertOp->setAttrs(convertOp->getAttrs()); Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp); newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); auto oldExtOrBroadcastType = @@ -1288,17 +1168,6 @@ class TritonGPURemoveLayoutConversionsPass m.dump(); }); - RewritePatternSet decomposePatterns(context); - decomposePatterns.add(context); - if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) - .failed()) { - signalPassFailure(); - } - LLVM_DEBUG({ - DBGS() << "Module after decomposing dot-converts:\n"; - m.dump(); - }); - // 4. Apply clean up patterns to remove remove dead convert and dead code // generated by the previous transformations. RewritePatternSet cleanUpPatterns2(context); diff --git a/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp b/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp new file mode 100644 index 000000000..dd39a2338 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp @@ -0,0 +1,483 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "triton-gpu-taskid-propagate" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = ::mlir::triton; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTASKIDPROPAGATE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Return all Ops that are marked with target task +void getAsyncTaskOps(triton::FuncOp funcOp, DenseSet &asyncTaskOps, + int asyncTaskId) { + funcOp.walk([&](Operation *op) -> void { + if (auto attr = + op->getAttrOfType("async_task_id")) { + for (auto val : attr.getValues()) { + if (val == asyncTaskId) { + asyncTaskOps.insert(op); + break; + } + } + } + }); +} + +void getAllParentOps(DenseSet &parentOps, Operation *targetOp) { + auto op = targetOp; + while (auto parent = op->getParentOp()) { + if (!isa(parent) && !isa(parent)) { + parentOps.insert(parent); + op = parent; + } else { + break; + } + } +} + +void getAllParentOps(triton::FuncOp funcOp, DenseSet &parentOps, + int asyncTaskId) { + DenseSet targetOps; + getAsyncTaskOps(funcOp, targetOps, asyncTaskId); + for (auto op : targetOps) { + getAllParentOps(parentOps, op); + } +} + +void labelByUsers(Operation *op, ArrayRef allAsyncTasks) { + for (Value result : op->getResults()) { + for (Operation *userOp : result.getUsers()) { + if (!userOp->hasAttr("async_task_id")) { + labelByUsers(userOp, allAsyncTasks); + } + addAsyncTaskIds(op, getAsyncTaskIds(userOp)); + } + } + if (!op->hasAttr("async_task_id")) { + addAsyncTaskIds(op, allAsyncTasks); + } +} + +/// Because we set some special filter rules in populateAsyncTaskRegion, +/// there may be unlabeled Ops, e.g. YieldOps, some definingOps of ForOps. +/// or Ops without relations to asyncTaskOps +void populateUnlabledOpsAtLast(triton::FuncOp funcOp, + ArrayRef allAsyncTasks) { + // Label asyncTasks' parentOps + for (int i : allAsyncTasks) { + DenseSet asyncTaskParentOps; + getAllParentOps(funcOp, asyncTaskParentOps, i); + for (auto op : asyncTaskParentOps) { + addAsyncTaskIds(op, {i}); + } + } + + // Get unlabeled Ops + DenseSet unlabeledOps; + funcOp.walk([&](Operation *op) -> void { + if (isa(op) || isa(op) || + isa(op)) { + return; + } + if (!op->hasAttr("async_task_id")) { + unlabeledOps.insert(op); + } + }); + + // Label Ops using its parentOp + for (auto op : unlabeledOps) { + if (auto parent = op->getParentOp()) { + if (!isa(parent)) { + if (!parent->hasAttr("async_task_id")) { + LLVM_DEBUG({ + LDBG("op and parent: "); + op->dump(); + parent->dump(); + }); + continue; + } + assert(parent->hasAttr("async_task_id")); + auto asyncTasks = getAsyncTaskIds(parent); + setAsyncTaskIds(op, asyncTasks); + unlabeledOps.erase(op); + } + } + } + + // Label Ops using dependency + for (auto op : unlabeledOps) { + labelByUsers(op, allAsyncTasks); + unlabeledOps.erase(op); + } + assert(unlabeledOps.size() == 0); +} + +#ifndef NDEBUG +static bool oneVecCoversTheOther(SmallVector &one, + SmallVector &other) { + // Every element of other appears in one. + for (AsyncTaskId t : other) { + // If t doesn't appear in one, return false. + bool found = false; + for (AsyncTaskId t2 : one) { + if (t2 == t) { + found = true; + break; + } + } + if (!found) + return false; + } + return true; +} + +struct AsyncTaskIdsCompare { + static SmallVector getEmptyKey() { + SmallVector V; + V.push_back(reinterpret_cast(-1)); + return V; + } + + static SmallVector getTombstoneKey() { + SmallVector V; + V.push_back(reinterpret_cast(-2)); + return V; + } + + static unsigned getHashValue(const SmallVector &V) { + return static_cast(llvm::hash_combine_range(V.begin(), V.end())); + } + + static bool isEqual(const SmallVector &LHS, + const SmallVector &RHS) { + return LHS == RHS; + } +}; + +// Make sure the def chain contains the right taskId. +bool verifyTaskId(triton::FuncOp &funcOp, + const llvm::DenseSet &anchorOps) { + bool retCode = true; + DenseSet, AsyncTaskIdsCompare> anchorAsyncTasks; + for (auto anchorOp : anchorOps) { + anchorAsyncTasks.insert(getAsyncTaskIds(anchorOp)); + } + + funcOp.walk([&](Operation *op) { + // Skip control ops + if (llvm::isa(op)) + return; + + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.empty()) { + LLVM_DEBUG({ + LDBG("Op does not have task id"); + op->dump(); + }); + llvm_unreachable("Op does not have task id"); + } + + auto partitionShouldBeUsedSpecified = [](Operation *op) { + if (isa(op)) + return true; + if (isa(op)) + return true; + if (op->hasTrait()) + return true; + return false; + }; + + if (!anchorAsyncTasks.contains(asyncTaskIds)) { + if (partitionShouldBeUsedSpecified(op)) { + LLVM_DEBUG({ + LDBG("async tasks not specified by user"); + op->dump(); + }); + llvm_unreachable("async tasks not specified by user"); + } + } + + assert(!asyncTaskIds.empty() && "Op does not have task id"); + + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + if (llvm::isa(defOp)) + continue; + auto defTaskIds = getAsyncTaskIds(defOp); + // Make sure defTaskIds cover asyncTaskIds. Call addAsyncTaskIds if + // necessary. + LLVM_DEBUG({ + if (!oneVecCoversTheOther(defTaskIds, asyncTaskIds)) { + // print defOp and op + LDBG("Def op does not cover op"); + LDBG("Def op"); + defOp->dump(); + LDBG("op"); + op->dump(); + } + }); + assert(oneVecCoversTheOther(defTaskIds, asyncTaskIds) && + "defTaskIds should cover asyncTaskIds"); + } + }); + return retCode; +} +#endif + +void backwardPropagateTaskIds(Operation *op, + const llvm::DenseSet &anchors) { + SmallVector queue; + auto asyncTasks = getAsyncTaskIds(op); + for (Value operand : op->getOperands()) { + queue.push_back(operand); + } + + DenseSet seen; + for (auto anchor : anchors) { + if (anchor != op) + for (auto result : anchor->getResults()) + seen.insert(result); + } + + while (!queue.empty()) { + auto value = queue.pop_back_val(); + if (!seen.insert(value).second) { + continue; + } + + // Handle BlockArguments of for loops (i.e. loop carried dependences). + if (auto blockArg = dyn_cast(value)) { + auto parent = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(parent)) { + // Propagate to the control operands. + auto control = + forOp.getOperands().take_front(forOp.getNumControlOperands()); + queue.insert(queue.end(), control.begin(), control.end()); + // Propagate to the initializer. + if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { + queue.push_back(forOp.getTiedLoopInit(blockArg)->get()); + // Propagate to the yield. + auto idx = blockArg.getArgNumber() - forOp.getNumInductionVars(); + queue.push_back(forOp.getBody()->getTerminator()->getOperand(idx)); + addAsyncTaskIds(forOp, asyncTasks); + } + } + continue; + } + + auto op = value.getDefiningOp(); + addAsyncTaskIds(op, asyncTasks); + + // Handle for loops. + if (auto forOp = dyn_cast(op)) { + // Propagate to control operands. + auto control = + forOp.getOperands().take_front(forOp.getNumControlOperands()); + queue.insert(queue.end(), control.begin(), control.end()); + // Propagate to arguments. + unsigned idx = cast(value).getResultNumber(); + queue.push_back(forOp.getOperand(idx + forOp.getNumControlOperands())); + // Propagate to yield. + queue.push_back(forOp.getBody()->getTerminator()->getOperand(idx)); + continue; + } + + // Handle conditionals. + if (auto ifOp = dyn_cast(op)) { + queue.push_back(ifOp.getCondition()); + unsigned idx = cast(value).getResultNumber(); + if (ifOp.elseBlock()) { + queue.push_back(ifOp.elseYield()->getOperand(idx)); + } + queue.push_back(ifOp.thenYield()->getOperand(idx)); + continue; + } + + // Handle normal ops. + for (Value operand : op->getOperands()) { + queue.push_back(operand); + } + } +} + +void backwardPropagateTaskIds(llvm::DenseSet &rootOps, + llvm::DenseSet &anchorOps) { + for (Operation *op : rootOps) { + backwardPropagateTaskIds(op, anchorOps); + } +} + +void forwardPropagateTaskIds(Operation *root, + const llvm::DenseSet &anchors) { + auto asyncTasks = getAsyncTaskIds(root); + SmallVector queue; + for (Value result : root->getResults()) + queue.push_back(result); + + DenseSet seen; + for (auto anchor : anchors) { + if (anchor != root) + for (auto result : anchor->getResults()) + seen.insert(result); + } + + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!seen.insert(v).second) + continue; + + for (Operation *depOp : v.getUsers()) { + auto depAsyncTasks = getAsyncTaskIds(depOp); + // Skip depOp that already has task ids. Those could be either anchorOps + // or propagated backward from anchor ops. + if (!depAsyncTasks.empty() && depAsyncTasks != asyncTasks) + continue; + setAsyncTaskIds(depOp, asyncTasks); + // Go through yieldOp to propagate task ids to the result of parentOp. + if (auto yieldOp = dyn_cast(depOp)) { + auto parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (operand.get() == v) { + queue.push_back(parentOp->getResult(operand.getOperandNumber())); + break; + } + } + } else { + for (Value result : depOp->getResults()) + queue.push_back(result); + } + } + } +} + +void forwardPropagateTaskIds(llvm::DenseSet &anchorOps) { + for (Operation *op : anchorOps) { + forwardPropagateTaskIds(op, anchorOps); + } +} + +void populateTaskIdsForControlDependencies( + llvm::DenseSet &anchorOps) { + for (auto op : anchorOps) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (!asyncTaskIds.empty()) { + while (auto parent = op->getParentOp()) { + if (!isa(parent) && !isa(parent)) { + setAsyncTaskIds(parent, asyncTaskIds); + backwardPropagateTaskIds(parent, anchorOps); + op = parent; + } else { + break; + } + } + } + } +} + +class TritonGPUTaskIdPropagatePass + : public impl::TritonGPUTaskIdPropagateBase { +public: + using impl::TritonGPUTaskIdPropagateBase< + TritonGPUTaskIdPropagatePass>::TritonGPUTaskIdPropagateBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + llvm::DenseSet anchorOps; + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (asyncTasks.empty()) + return; + std::sort(asyncTasks.begin(), asyncTasks.end()); + setAsyncTaskIds(op, asyncTasks); + if (!isa(op)) + anchorOps.insert(op); + }); + + populateTaskIdsForControlDependencies(anchorOps); + + LLVM_DEBUG({ + LDBG("after populateTaskIdsForControlDependencies "); + funcOp->dump(); + }); + + backwardPropagateTaskIds(anchorOps, anchorOps); + + LLVM_DEBUG({ + LDBG("after backwardPropagateTaskIds "); + funcOp->dump(); + }); + + forwardPropagateTaskIds(anchorOps); + + LLVM_DEBUG({ + LDBG("after forwardPropagateTaskIds "); + funcOp->dump(); + }); + + llvm::DenseSet rootOps; + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (!asyncTasks.empty() && + !isa(op)) + rootOps.insert(op); + }); + backwardPropagateTaskIds(rootOps, anchorOps); + LLVM_DEBUG({ + LDBG("after final backwardPropagateTaskIds "); + funcOp->dump(); + }); + + DenseSet allAsyncTasks; + funcOp->walk([&](Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + allAsyncTasks.insert(asyncTasks.begin(), asyncTasks.end()); + }); + SmallVector allAsyncTasksVec(allAsyncTasks.begin(), + allAsyncTasks.end()); + populateUnlabledOpsAtLast(funcOp, allAsyncTasksVec); + + LLVM_DEBUG({ + LDBG("after populateUnlabledOpsAtLast "); + funcOp->dump(); + }); + +#ifndef NDEBUG + verifyTaskId(funcOp, anchorOps); +#endif + } + + void runOnOperation() override { + if (numConsumerGroups == 0) { + getOperation()->walk([&](triton::FuncOp funcOp) { + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (!asyncTasks.empty()) + op->removeAttr("async_task_id"); + }); + }); + return; + } + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 9bf61f01e..4ef9d1cd1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -25,8 +25,7 @@ using namespace triton; SmallVector mmaVersionToInstrShape(int version, const ArrayRef &shape, - TensorOrMemDesc type, - int numWarps) { + Type eltType, int numWarps) { if (version == 1) return {16, 16}; else if (version == 2) { @@ -36,17 +35,17 @@ SmallVector mmaVersionToInstrShape(int version, ret[rank - 2] = 16; return ret; } else if (version == 3) { - unsigned k = 256 / type.getElementTypeBitWidth(); + unsigned k = 256 / eltType.getIntOrFloatBitWidth(); if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { assert(false && "type not supported"); return {0, 0, 0}; } - auto eltType = type.getElementType(); SmallVector validN; // MMAv3 with larger instruction shape is preferred. - if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() || - eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() || + eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() || + eltType.isF32()) { validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); @@ -433,6 +432,19 @@ static std::optional inferSrcEncoding(triton::ReshapeOp op, op.getAllowReorder()); } +static bool isSingleValue(Value value) { + // Don't consider load as expensive if it is loading a scalar. + if (auto tensorTy = dyn_cast(value.getType())) + return tensorTy.getNumElements() == 1; + // TODO: Handle other cases. + // For example, when ptr is a tensor of single value. + // It means that ptr is a resultant of broadcast or generated through + // a chain of broadcast and other operations. + // Rematerialize it without considering contiguous memory access pattern is + // fine. + return true; +} + std::optional inferSrcEncoding(Operation *op, Attribute encoding) { if (isa(op)) { // Scan only supports blocked encoding at the moment. @@ -442,8 +454,8 @@ std::optional inferSrcEncoding(Operation *op, Attribute encoding) { if (op->hasTrait() || op->hasTrait() || op->hasTrait() || - isa( - op)) { + isa(op)) { return encoding; } @@ -472,7 +484,7 @@ std::optional inferDstEncoding(Operation *op, Attribute encoding) { op->hasTrait() || op->hasTrait() || isa(op)) + nvidia_gpu::WarpGroupDotWaitOp>(op)) return encoding; if (auto reduceOp = dyn_cast(op)) return inferDstEncoding(reduceOp, encoding); @@ -490,19 +502,6 @@ std::optional inferDstEncoding(Operation *op, Attribute encoding) { return std::nullopt; } -bool isSingleValue(Value value) { - // Don't consider load as expensive if it is loading a scalar. - if (auto tensorTy = dyn_cast(value.getType())) - return tensorTy.getNumElements() == 1; - // TODO: Handle other cases. - // For example, when ptr is a tensor of single value. - // It means that ptr is a resultant of broadcast or generated through - // a chain of broadcast and other operations. - // Rematerialize it without considering contiguous memory access pattern is - // fine. - return true; -} - bool isExpensiveLoadOrStore(Operation *op) { // Case 1: Pointer of tensor is always expensive auto operandType = op->getOperand(0).getType(); @@ -557,8 +556,7 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { RankedTensorType newDstType = RankedTensorType::get(reshapeDstType.getShape(), reshapeDstType.getElementType(), targetEncoding); - return reshape.getAllowReorder() && - !reshape.getEfficientLayout().has_value() && + return reshape.getAllowReorder() && !reshape.getEfficientLayout() && !triton::gpu::isExpensiveView(reshape.getSrc().getType(), newDstType); } @@ -603,6 +601,73 @@ scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, return newForOp; } +scf::WhileOp replaceWhileOpWithNewSignature( + RewriterBase &rewriter, scf::WhileOp loop, ValueRange newIterOperands, + TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInits()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + + // Result and operand types + SmallVector resultTypes; + SmallVector argsTypesBefore; + for (auto res : loop.getResults()) + resultTypes.push_back(res.getType()); + for (auto type : newResultTypes) + resultTypes.push_back(type); + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + scf::WhileOp newLoop = + rewriter.create(loop.getLoc(), resultTypes, operands); + newLoop->setAttrs(loop->getAttrs()); + + SmallVector bbArgLocsBefore(argsTypesBefore.size(), loop.getLoc()); + SmallVector bbArgLocsAfter(resultTypes.size(), loop.getLoc()); + rewriter.createBlock(&newLoop.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newLoop.getAfter(), {}, resultTypes, bbArgLocsAfter); + + // Copy regions + for (int i = 0; i < loop.getNumRegions(); ++i) + newLoop->getRegion(i).front().getOperations().splice( + newLoop->getRegion(i).front().getOperations().begin(), + loop->getRegion(i).front().getOperations()); + + // Remap arguments + for (auto [oldArg, newArg] : llvm::zip( + loop.getBeforeArguments(), newLoop.getBeforeArguments().take_front( + loop.getBeforeArguments().size()))) + rewriter.replaceAllUsesWith(oldArg, newArg); + for (auto [oldArg, newArg] : llvm::zip(loop.getAfterArguments(), + newLoop.getAfterArguments().take_front( + loop.getAfterArguments().size()))) + rewriter.replaceAllUsesWith(oldArg, newArg); + + // Stack the new results + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + + return newLoop; +} + +scf::WhileOp replaceWhileOpWithNewSignature(RewriterBase &rewriter, + scf::WhileOp loop, + ValueRange newIterOperands, + TypeRange newResultTypes) { + SmallVector> replacements; + auto newWhileOp = replaceWhileOpWithNewSignature( + rewriter, loop, newIterOperands, newResultTypes, replacements); + for (auto &kv : replacements) { + rewriter.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + return newWhileOp; +} + scf::IfOp replaceIfOpWithNewSignature( RewriterBase &rewriter, scf::IfOp ifOp, TypeRange newResultTypes, SmallVectorImpl> &replacements) { @@ -627,6 +692,26 @@ scf::IfOp replaceIfOpWithNewSignature( return newIf; } +void appendToForOpYield(scf::ForOp forOp, ArrayRef newOperands) { + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); + + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +scf::IfOp replaceIfOpWithNewSignature(RewriterBase &rewriter, scf::IfOp ifOp, + TypeRange newResultTypes) { + SmallVector> replacements; + auto newIfOp = + replaceIfOpWithNewSignature(rewriter, ifOp, newResultTypes, replacements); + for (auto &kv : replacements) + rewriter.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + return newIfOp; +} + Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping) { Operation *newOp = rewriter.clone(*op, mapping); @@ -663,12 +748,13 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, return newOp; } -// Check if the convert will be a no-op in codegen. +// Check if the convert will be performed by reordering registers. static bool isFreeConvert(Operation *op) { auto convertOp = dyn_cast(op); if (!convertOp) return false; - return isMmaToMmaShortcut(convertOp.getSrc().getType(), convertOp.getType()); + return cvtReordersRegisters(convertOp.getSrc().getType(), + convertOp.getType()); } LogicalResult @@ -723,8 +809,11 @@ getConvertBackwardSlice(Value root, SetVector &slice, continue; enqueue(result, encoding); } - if (!isFreeConvert(definingOp) && - canFoldIntoConversion(definingOp, encoding)) + if (isFreeConvert(definingOp)) { + enqueue(definingOp->getOperand(0), encoding); + continue; + } + if (canFoldIntoConversion(definingOp, encoding)) continue; if (stopPropagation && stopPropagation(definingOp)) continue; @@ -896,6 +985,8 @@ struct ForOpDeadArgElimination : public OpRewritePattern { } if (auto nestedIf = value.getDefiningOp()) { auto result = mlir::cast(value); + // mark condition as live. + markLive(nestedIf.getCondition()); for (scf::YieldOp nestedYieldOp : {nestedIf.thenYield(), nestedIf.elseYield()}) { Value nestedYieldOperand = diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp new file mode 100644 index 000000000..44809814d --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -0,0 +1,2274 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include +#include + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUWSCODEPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-warp-spec-code-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +std::pair scanRegUsage(Block *block, AsyncTaskId asyncTaskId, + int regDecProducer, int regIncConsumer) { + // TODO: scan ops to estimate register usage + if (asyncTaskId == 0) { + // deallocate registers + return {regDecProducer == 0 ? 40 : regDecProducer, false}; + } else { + // allocate registers + return {regIncConsumer == 0 ? 232 : regIncConsumer, true}; + } +} + +unsigned getNumBuffersOrDefault(scf::ForOp forOp, unsigned numBuffers) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return numBuffers; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); +} + +// Collect argument indices that are used by the specific taskId. +static SmallVector collectBlockArgsForTask(scf::ForOp forOp, + int asyncTaskId) { + + // Collect argument indices that can be reached along the definition chain. + SetVector argIndices; + std::function dfs = [&](Value arg, unsigned argIdx) { + for (auto user : arg.getUsers()) { + // Skip ops that are not in the same async task + if (!hasAsyncTaskId(user, asyncTaskId)) + continue; + + // Skip control flow ops that are shared by all async tasks + if (isa(user)) + continue; + + // Found a real user, the arg is needed + if (user->getNumRegions() == 0) { + argIndices.insert(argIdx); + return; + } + + // Iterate through all regions of the user operation + for (auto ®ion : user->getRegions()) { + for (auto regionArg : region.getArguments()) { + if (arg == regionArg) + dfs(regionArg, argIdx); + } + } + } + }; + + // check dependency with DFS traversal for loop args and results. + mlir::Block &block = forOp.getRegion().front(); + for (unsigned i = forOp.getNumInductionVars(); i < block.getNumArguments(); + ++i) { + auto arg = block.getArgument(i); + dfs(arg, i - forOp.getNumInductionVars()); + } + for (unsigned i = 0; i < forOp.getNumResults(); ++i) { + auto result = forOp->getResult(i); + dfs(result, i); + } + + SmallVector args(argIndices.begin(), argIndices.end()); + llvm::sort(args); + return args; +} + +Operation *SpecializeOp(Operation *op, IRMapping &mapping, + OpBuilderWithAsyncTaskIds &builder, + AsyncTaskId asyncTaskId); + +// Return the argument that tracks accumLoopCount if there is an outer +// ForOp. +Value getAccumLoopCountArg(scf::ForOp parentForOp) { + assert(parentForOp); + auto tSize = parentForOp.getBody()->getArguments().size(); + assert(tSize >= 3); // accum, bufferIdx, phase + Value tmpAccumLoopCount = parentForOp.getBody()->getArgument(tSize - 3); + return tmpAccumLoopCount; +} + +// Return true if the IfOp contains a ForOp that is in loopWithBufferReuse. +static bool +needAccumulatedLoopCnt(scf::IfOp ifOp, + SmallVector &loopWithBufferReuse) { + bool needAccum = false; + ifOp.walk([&](Operation *subOp) { + if (auto forOp = dyn_cast(subOp)) + for (auto tLoop : loopWithBufferReuse) + if (forOp.getOperation() == tLoop) { + needAccum = true; + break; + } + }); + return needAccum; +} + +Value updateAccumLoopCount(SmallVector &opList, + unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &loopWithBufferReuse, + Value prevAccum); + +scf::ForOp createNewLoopWrapper(scf::ForOp origForOp, unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &loopWithBufferReuse, + Value prevAccum); + +// For certain cases, we need to add an additional output for +// IfOp to track the accumulatedLoopCount, we may need to add +// a corresponding elseBlock with yieldOp. +scf::IfOp rewriteIfOp(scf::IfOp ifOp, unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &loopWithBufferReuse, + Value prevAccum) { + LLVM_DEBUG({ + LDBG("rewrite ifOp for smem sharing "); + ifOp.dump(); + }); + + OpBuilderWithAsyncTaskIds ifBuilder(ifOp.getContext()); + ifBuilder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(ifOp)); + ifBuilder.setInsertionPoint(ifOp); + + SmallVector newResultTypes(ifOp->getResultTypes()); + // Add an output for the IfOp for accumulated loop count. + newResultTypes.push_back(ifBuilder.getI64Type()); + // Create else block if we need to generate accumulated loop count. + auto newIfOp = ifBuilder.createWithAsyncTaskIds( + ifOp.getLoc(), newResultTypes, ifOp.getCondition(), true, true); + + // Move the existing blocks to the new if. + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + + ifBuilder.setInsertionPointToEnd(newIfOp.thenBlock()); + SmallVector opList; + for (Operation &op : newIfOp.thenBlock()->getOperations()) { + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + } + + // Update yields + auto loc = ifOp.getLoc(); + auto updateYield = [&](scf::YieldOp yield, SmallVector &operands) { + ifBuilder.setInsertionPoint(yield); + ifBuilder.createWithAsyncTaskIds(loc, operands); + yield.erase(); + }; + + // Add one more operand to then Yield. + Value endAccum = + updateAccumLoopCount(opList, numBuffers, taskTopOps, commonOuterLoop, + loopWithBufferReuse, prevAccum); + + SmallVector ifYieldOperands = newIfOp.thenYield().getOperands(); + ifYieldOperands.push_back(endAccum); + updateYield(newIfOp.thenYield(), ifYieldOperands); + + // Handle elseRegion of the IfOp. + if (ifOp.elseBlock()) { + ifBuilder.setInsertionPointToEnd(newIfOp.elseBlock()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + opList.clear(); + for (Operation &op : newIfOp.elseBlock()->getOperations()) { + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + } + endAccum = + updateAccumLoopCount(opList, numBuffers, taskTopOps, commonOuterLoop, + loopWithBufferReuse, prevAccum); + } else { + // Create an empty yield + auto yieldOp = + newIfOp.getElseBodyBuilder().create(ifOp.getLoc()); + endAccum = prevAccum; + } + // Add one more operand to else Yield. + SmallVector elseYieldOperands = newIfOp.elseYield().getOperands(); + elseYieldOperands.push_back(endAccum); + updateYield(newIfOp.elseYield(), elseYieldOperands); + int resultIdx = 0; + // Replace old if with the new one. + for (auto result : ifOp.getResults()) { + result.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + } + ifOp.erase(); + return newIfOp; +} + +Operation *SpecializeIfOp(scf::IfOp ifOp, IRMapping &mapping, + OpBuilderWithAsyncTaskIds &builder, + AsyncTaskId asyncTaskId) { + LLVM_DEBUG({ + LDBG("specialize ifOp "); + ifOp.dump(); + }); + + // It is possible that we need to reduce the results. One example + // is that the defining op for the yield operation is not for this + // taskId and the defining op is not specialized, thus we should + // remove the result. + // We need to update the result types correctly here. + unsigned resultIdx = 0; + SmallVector keptResultVec; + if (!ifOp->getResultTypes().empty()) { + for (Value yieldV : ifOp.thenYield().getOperands()) { + // Check the defining op for the corresponding result. + if (Operation *def = yieldV.getDefiningOp()) { + bool hasTaskId = hasAsyncTaskId(def, asyncTaskId); + if (hasTaskId) { + keptResultVec.push_back(resultIdx); + } + } else { + keptResultVec.push_back(resultIdx); + } + ++resultIdx; + } + } + + SmallVector newResultTypes; + for (auto idx : keptResultVec) { + newResultTypes.push_back(ifOp->getResultTypes()[idx]); + } + auto newIfOp = builder.createWithAsyncTaskIds( + ifOp.getLoc(), newResultTypes, mapping.lookup(ifOp.getCondition()), true, + ifOp.elseBlock()); + + OpBuilderWithAsyncTaskIds ifBuilder(ifOp.getContext()); + ifBuilder.setAsynTaskIdsFromArray({asyncTaskId}); + + // Handle thenRegion of this IfOp. + ifBuilder.setInsertionPointToEnd(newIfOp.thenBlock()); + for (Operation &thenOp : ifOp.thenBlock()->getOperations()) { + SpecializeOp(&thenOp, mapping, ifBuilder, asyncTaskId); + } + + // Update yields + auto updateYield = [&](scf::YieldOp yield, SmallVector &operands) { + ifBuilder.setInsertionPoint(yield); + ifBuilder.createWithAsyncTaskIds(yield.getLoc(), operands); + yield.erase(); + }; + if (keptResultVec.size() < ifOp->getResultTypes().size()) { + SmallVector ifYieldOperands; + for (auto idx : keptResultVec) { + ifYieldOperands.push_back(newIfOp.thenYield().getOperand(idx)); + } + updateYield(newIfOp.thenYield(), ifYieldOperands); + } + + // Handle elseRegion of the IfOp. + if (ifOp.elseBlock()) { + ifBuilder.setInsertionPointToEnd(newIfOp.elseBlock()); + for (Operation &elseOp : ifOp.elseBlock()->getOperations()) { + SpecializeOp(&elseOp, mapping, ifBuilder, asyncTaskId); + } + if (keptResultVec.size() < ifOp->getResultTypes().size()) { + SmallVector elseYieldOperands; + for (auto idx : keptResultVec) { + elseYieldOperands.push_back(newIfOp.elseYield().getOperand(idx)); + } + updateYield(newIfOp.elseYield(), elseYieldOperands); + } + } + + unsigned newResIdx = 0; + for (auto idx : keptResultVec) { + mapping.map(ifOp.getResult(idx), newIfOp.getResult(newResIdx)); + ++newResIdx; + } + return newIfOp; +} + +Operation *SpecializeForOp(scf::ForOp forOp, IRMapping &mapping, + OpBuilderWithAsyncTaskIds &builder, + AsyncTaskId asyncTaskId) { + // Create newForOp for each task Id. + auto usedArgs = collectBlockArgsForTask(forOp, asyncTaskId); + + // Prepare newLoopArgs. + SmallVector newLoopArgs; + for (unsigned argNumber : usedArgs) { + auto arg = forOp.getInitArgs()[argNumber]; + auto newArg = mapping.lookupOrDefault(arg); + assert(newArg && "Unexpected missing mapping"); + newLoopArgs.push_back(newArg); + } + + // Prepare loop bounds. + auto newLowerBound = mapping.lookupOrDefault(forOp.getLowerBound()); + auto newUpperBound = mapping.lookupOrDefault(forOp.getUpperBound()); + auto newStep = mapping.lookupOrDefault(forOp.getStep()); + + // Create newForOp. + auto newForOp = builder.createWithAsyncTaskIds( + forOp.getLoc(), newLowerBound, newUpperBound, newStep, newLoopArgs); + if (forOp->getAttr("tt.loop_schedule")) + newForOp->setAttr("tt.loop_schedule", forOp->getAttr("tt.loop_schedule")); + + // Initialize Value mapping from forOp to newForOp + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldArg = forOp.getRegionIterArgs()[usedArgs[i]]; + auto newArg = newForOp.getRegionIterArgs()[i]; + mapping.map(oldArg, newArg); + } + + // Recursively clone all operations with this asyncTaskId to newForOp. + OpBuilderWithAsyncTaskIds forBuilder(forOp.getContext()); + forBuilder.setAsynTaskIdsFromArray({asyncTaskId}); + forBuilder.setInsertionPointToStart(newForOp.getBody()); + for (Operation &op : forOp.getBody()->without_terminator()) { + SpecializeOp(&op, mapping, forBuilder, asyncTaskId); + } + + // Create YieldOp for newForOp. + auto yieldOp = llvm::cast(forOp.getBody()->getTerminator()); + SmallVector newYieldOperands; + for (unsigned i : usedArgs) + newYieldOperands.push_back(mapping.lookup(yieldOp.getOperand(i))); + + bool createNewYield = true; + if (newForOp.getBody()->mightHaveTerminator()) { + auto initialYield = + llvm::cast(newForOp.getBody()->getTerminator()); + if (newYieldOperands.size() == 0) { + setAsyncTaskIds(initialYield, {asyncTaskId}); + createNewYield = false; + } + } + if (createNewYield) { + auto newYieldOp = + forBuilder.create(yieldOp.getLoc(), newYieldOperands); + setAsyncTaskIds(newYieldOp, {asyncTaskId}); + } + + // Replace results of forOp with results of newForOp. + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldResult = forOp.getResult(usedArgs[i]); + auto newResult = newForOp.getResult(i); + mapping.map(oldResult, newResult); + } + + return newForOp; +} + +Operation *SpecializeOp(Operation *op, IRMapping &mapping, + OpBuilderWithAsyncTaskIds &builder, + AsyncTaskId asyncTaskId) { + auto taskIds = getAsyncTaskIds(op); + // yieldOp are sometimes implict, meaning they do not necessarily have a task + // id, but they should be shared by all async tasks. + if (!hasAsyncTaskId(op, asyncTaskId) && !isa(op)) + return nullptr; + + if (op->getNumRegions() == 0) { + Operation *newOp = builder.clone(*op, mapping); + setAsyncTaskIds(newOp, asyncTaskId); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + return newOp; + } else { + if (auto ifOp = dyn_cast(op)) { + return SpecializeIfOp(ifOp, mapping, builder, asyncTaskId); + } else if (auto forOp = dyn_cast(op)) { + return SpecializeForOp(forOp, mapping, builder, asyncTaskId); + } else if (auto reduceOp = dyn_cast(op)) { + Operation *newOp = builder.clone(*op, mapping); + // recursively set async task ids for child ops + newOp->walk( + [&](Operation *childOp) { setAsyncTaskIds(childOp, asyncTaskId); }); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + return newOp; + } else { + llvm_unreachable("Unexpected Op with regions"); + } + } + + return nullptr; +} + +// Create IfOp for each ayncTaskId. +DenseMap SpecializeRegion(triton::FuncOp funcOp, + int regDecProducer, + int regIncConsumer) { + + LLVM_DEBUG({ + LDBG("\n\n"); + LDBG("Start specializing region"); + }); + + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(context); + auto loc = funcOp.getLoc(); + + // Collect original operations + SmallVector opList; + for (auto &block : funcOp.getBody().getBlocks()) { + for (Operation &op : block.getOperations()) { + auto taskIds = getAsyncTaskIds(&op); + if (!taskIds.empty()) + opList.push_back(&op); + } + } + + LLVM_DEBUG({ + LDBG("ops to be specialized: "); + for (Operation *op : opList) { + op->dump(); + } + }); + + // Create GetAsyncTaskIdOp. + Block *lastBlock = &funcOp.getBody().back(); + auto returnOp = llvm::cast(lastBlock->getTerminator()); + builder.setInsertionPoint(returnOp); + Value curAsyncTaskId = builder.create(loc); + + DenseMap tasksToIfOp; + + // Clone all operations into the corresponding if blocks. If the operation + // has multiple taskIds, it will be cloned for multiple if blocks. + // If the original code has an IfOp, we should only clone its + // body with the right asyncTaskId, instead of cloning the IfOp. + for (AsyncTaskId asyncTaskId : getNestedAsyncTaskIds(funcOp)) { + // Create IfOp for each asyncTaskId. + Value cond = builder.create( + loc, arith::CmpIPredicate::eq, curAsyncTaskId, + builder.create(loc, asyncTaskId, 32)); + + auto ifOp = builder.create(loc, cond); + tasksToIfOp[asyncTaskId] = ifOp; + setAsyncTaskIds(ifOp, {asyncTaskId}); + + OpBuilderWithAsyncTaskIds taskBuilder(context); + taskBuilder.setAsynTaskIdsFromArray({asyncTaskId}); + + // Set insertion point before yieldOp. + auto yieldOp = ifOp.thenYield(); + setAsyncTaskIds(yieldOp, {asyncTaskId}); + taskBuilder.setInsertionPoint(yieldOp); + + IRMapping mapping; + for (Operation *op : opList) { + SpecializeOp(op, mapping, taskBuilder, asyncTaskId); + } + } + + // Decide if this taskId is a producer or a consumer, and create either + // RegAllocOp or RegDeallocOp accordingly. + for (auto ifOps : tasksToIfOp) { + AsyncTaskId asyncTaskId = ifOps.first; + auto ifOp = ifOps.second; + OpBuilderWithAsyncTaskIds taskBuilder(ifOp.getContext()); + taskBuilder.setAsynTaskIdsFromArray({asyncTaskId}); + auto regAlloc = scanRegUsage(ifOp.thenBlock(), asyncTaskId, regDecProducer, + regIncConsumer); + taskBuilder.setInsertionPointToStart(&(ifOp.getThenRegion().front())); + if (regAlloc.second) + taskBuilder.create( + loc, taskBuilder.getI32IntegerAttr(regAlloc.first)); + else + taskBuilder.create( + loc, taskBuilder.getI32IntegerAttr(regAlloc.first)); + } + + LLVM_DEBUG({ + LDBG("\n\nWith task Id checks"); + funcOp.dump(); + }); + + // Remove original operations that have been cloned in reverse order. + for (auto it = opList.rbegin(); it != opList.rend(); ++it) { + Operation *op = *it; + LLVM_DEBUG({ + LDBG("erasing op "); + op->dump(); + }); + // For debugging purposes, check to see if the original op is still in use. + bool hasUse = false; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + for (Operation *user : op->getResult(i).getUsers()) { + hasUse = true; + LLVM_DEBUG({ + LDBG("op has use "); + user->dump(); + }); + } + } + op->erase(); + } + return tasksToIfOp; +} + +struct Channel { +public: + using Relation = std::pair>; + + Channel(int producer, SmallVector &consumers, Operation *op, + unsigned operandIdx, unsigned numBuffers) + : relation(producer, consumers), op(op), operandIdx(operandIdx), + numBuffers(numBuffers) {} + + bool operator==(const Channel &c) { + return relation == c.relation && operandIdx == c.operandIdx && op == c.op; + } + + Operation *getDstOp() { return op; } + unsigned getDstOperandIdx() { return operandIdx; } + Value getSrcOperand() { return op->getOperand(operandIdx); } + Operation *getSrcOp() { return getSrcOperand().getDefiningOp(); } + + Relation relation; // producer task Id, a list of consumer task Ids + Operation *op; + unsigned operandIdx; + unsigned numBuffers; +}; + +// Find transitive users of the root op. Track through control flow ops (such as +// yield) to get to the real users. +void getTransitiveUsers(Value root, + SetVector> &users) { + for (Operation *userOp : root.getUsers()) { + if (auto yieldOp = dyn_cast(userOp)) { + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (operand.get() == root) { + auto result = + yieldOp->getParentOp()->getResult(operand.getOperandNumber()); + getTransitiveUsers(result, users); + } + } + } else { + // find operand index of root + unsigned operandIndex = 0; + for (OpOperand &operand : userOp->getOpOperands()) { + if (operand.get() == root) { + break; + } + operandIndex++; + } + assert(operandIndex < userOp->getNumOperands() && + "root is not an operand of userOp"); + users.insert({userOp, operandIndex}); + } + } +} + +// Loads will be in producer warp groups. For now, we only allow a single +// warp group/task for a producer. For each LoadOp, create a channel from it +// to any direct user which belongs to a different taskId. +void collectAsyncChannels(SmallVector> &channels, + triton::FuncOp &funcOp, unsigned numBuffers) { + funcOp.walk([&](Operation *op) { + if (isa(op) || + op->hasTrait()) { + auto producerTaskIds = getAsyncTaskIds(op); + if (producerTaskIds.empty() || producerTaskIds.size() > 1) { + LLVM_DEBUG({ + LDBG(" ignoring load ops without async task id or with multiple task " + "ids: "); + op->dump(); + }); + return; + } + auto producerTaskId = producerTaskIds.front(); + unsigned producerNumBuffers = numBuffers; + if (auto forOp = op->getParentOfType()) { + producerNumBuffers = getNumBuffersOrDefault(forOp, numBuffers); + } + + for (auto result : op->getResults()) { + if (result.use_empty()) { + continue; + } + + SetVector> users; + getTransitiveUsers(result, users); + for (auto user : users) { + auto userOp = user.first; + auto consumerTaskIds = getAsyncTaskIds(userOp); + if (consumerTaskIds.empty()) + continue; + // Remove producer task id from consumerTaskIds. + auto iter = std::remove(consumerTaskIds.begin(), + consumerTaskIds.end(), producerTaskId); + consumerTaskIds.erase(iter, consumerTaskIds.end()); + // Add a channel from the single producer task to consumerTaskIds. + if (consumerTaskIds.size() > 0) { + channels.push_back(std::make_unique( + producerTaskId, consumerTaskIds, userOp, user.second, + producerNumBuffers)); + } + } + } + } + }); + + LLVM_DEBUG({ + LDBG("Async channels:"); + for (auto &channel : channels) { + LDBG("producer op: " << channel->relation.first); + channel->getSrcOp()->dump(); + for (auto &asyncTaskId : channel->relation.second) + LDBG("consumer: " << asyncTaskId); + channel->getDstOp()->dump(); + LDBG("numBuffers: " << channel->numBuffers); + } + }); +} + +// Group channels in two ways: +// - by producer ops. One producer corresponds to multiple channels. This +// grouping will be used to create buffers per shared producer. +// - by consumer ops. One consumer corresponds to multiple channels. This +// grouping will be used to create barriers per shared consumer. +// Also compute orderedChannels, which will be keyed by getDstOp() of channels, +// to enforce deterministic order for map. +void groupChannels( + SmallVector &channels, + DenseMap> &channelsGroupedByProducers, + DenseMap> &channelsGroupedByConsumers, + SmallVector &orderedChannels) { + + // Group channels by producer op. + DenseMap> producerChannels; + for (auto channel : channels) { + producerChannels[channel->getSrcOp()].push_back(channel); + } + +#ifndef NDEBUG + // Some sanity checks. + for (auto &item : producerChannels) { + auto &channels = item.second; + unsigned numBuffers = channels.front()->numBuffers; + for (auto c : channels) { + assert(c->numBuffers == numBuffers && "Unmatched number of buffers"); + } + } +#endif + + // Group channels by consumer op. + DenseMap> consumerChannels; + + // Two channels can be combined if + // src1 and src2 are in the same block and + // (dst1 == dst2 or + // (dst1 and dst2 are in the same block, both have a single user, and + // dst1User == dst2User and dst1User is in the same block as dst1)) + auto channelCanBeMerged = [](Channel *c1, Channel *c2) -> bool { + if (c1->getSrcOp()->getBlock() != c2->getSrcOp()->getBlock()) + return false; + Operation *dst1 = c1->getDstOp(), *dst2 = c2->getDstOp(); + if (dst1 == dst2) + return true; + if (dst1->getBlock() != dst2->getBlock() || !dst1->hasOneUse() || + !dst2->hasOneUse()) + return false; + Operation *dst1User = *(dst1->getUsers().begin()); + Operation *dst2User = *(dst2->getUsers().begin()); + return dst1User == dst2User && dst1User->getBlock() == dst1->getBlock(); + }; + assert(channels.size() > 0 && "channel size is zero"); + // Compare with existing channels in the consumerChannels to see if + // it can be combined. + for (auto *c0 : channels) { + bool merged = false; + for (auto &kv : consumerChannels) { + if (kv.second.size() > 0 && channelCanBeMerged(c0, kv.second.front())) { + kv.second.push_back(c0); + merged = true; + break; + } + } + if (!merged) { // Create a new entry. + auto *keyOp = c0->getDstOp(); + if (!consumerChannels.count(keyOp)) + orderedChannels.push_back(c0); + consumerChannels[keyOp].push_back(c0); + } + } + + // Reorder channels associated with one entry based on program order of the + // producers. + for (auto &kv : consumerChannels) { + if (kv.second.size() > 1) { + auto &allOps = kv.second.front()->getSrcOp()->getBlock()->getOperations(); + std::sort( + kv.second.begin(), kv.second.end(), [&](Channel *a, Channel *b) { + auto itrA = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == a->getSrcOp(); + }); + auto itrB = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == b->getSrcOp(); + }); + assert(itrA != allOps.end() && itrB != allOps.end()); + return std::distance(itrA, itrB) < 0; + }); + } + } + + // Switch to using channel as the key instead of ops as ops can be volatile. + for (auto &kv : producerChannels) { + channelsGroupedByProducers[kv.second.front()] = kv.second; + } + for (auto &kv : consumerChannels) { + channelsGroupedByConsumers[kv.second.front()] = kv.second; + } + + LLVM_DEBUG({ + DBGS() << "\n\n"; + LDBG("Grouped channels by producer:"); + unsigned i = 0; + for (auto &kv : channelsGroupedByProducers) { + DBGS() << "Channel " << ++i << ":\n"; + DBGS() << "producer: "; + kv.getFirst()->getSrcOp()->dump(); + for (auto &channel : kv.second) { + DBGS() << "consumer: "; + channel->getDstOp()->dump(); + DBGS() << "] "; + LDBG("numBuffers: " << channel->numBuffers); + DBGS() << "\n"; + } + } + + DBGS() << "\n\n"; + LDBG("Grouped channels by consumer:"); + i = 0; + for (auto &kv : channelsGroupedByConsumers) { + DBGS() << "Channel " << ++i << ":\n"; + DBGS() << "consumer: "; + kv.getFirst()->getDstOp()->dump(); + for (auto &channel : kv.second) { + DBGS() << "producer: "; + channel->getSrcOp()->dump(); + for (auto &asyncTaskId : channel->relation.second) + DBGS() << asyncTaskId << ", "; + DBGS() << "] "; + LDBG("numBuffers: " << channel->numBuffers); + DBGS() << "\n"; + } + DBGS() << "\n"; + } + }); +} + +// Reorder producer ops to unblock consumers interleavingly. +void reorderProducerOps(SmallVector &channels) { + if (channels.size() <= 1) + return; + + // Bail out if channels are not in the same block + auto block = channels.front()->getSrcOp()->getBlock(); + for (auto &channel : channels) { + if (channel->getSrcOp()->getBlock() != block) { + return; + } + } + + // Group channels by the first consumer taskId of each channel. Smaller taskId + // has higher priority. + // TODO: consider consumer priority + std::map> groupedProducerOps; + for (auto &channel : channels) { + auto asyncTaskId = channel->relation.second.front(); + groupedProducerOps[asyncTaskId].push_back(channel); + } + + // No need to reorder if all channels are in the same group. + if (groupedProducerOps.size() <= 1) + return; + + // Sort each group by number of consumers. + for (auto &group : groupedProducerOps) { + std::sort(group.second.begin(), group.second.end(), + [&](Channel *a, Channel *b) { + return a->relation.second.size() < b->relation.second.size(); + }); + } + + // Start from the first producer in channels. Iterate through the groups + // which are ordered by the first consumer taskId. Within each group, channels + // are ordered by number of consumers. + Operation *currOp = channels.front()->getSrcOp(); + for (auto &group : groupedProducerOps) { + for (auto &channel : group.second) { + channel->getSrcOp()->moveAfter(currOp); + currOp = channel->getSrcOp(); + } + } + + // Move backward dependency slice close to producer ops. + // Start from the last producer op backwards and move backward slice to + // before each op. This guarantees that the backward slice of each op is + // scheduled as late as possible. + for (auto &group : reverse(groupedProducerOps)) { + for (auto &channel : reverse(group.second)) { + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + SetVector backwardSlice; + getBackwardSlice(channel->getSrcOp(), &backwardSlice, opt); + for (auto &op : backwardSlice) { + if (op->getBlock() == block) + op->moveBefore(channel->getSrcOp()); + } + } + } + + LLVM_DEBUG({ + LDBG("\n"); + LDBG("after reordering producer ops"); + currOp->getParentOfType().dump(); + LDBG("\n"); + }); +} + +bool isInnermostLoop(scf::ForOp forOp) { + bool isInner = true; + forOp.walk([&](Operation *subOp) { + if (subOp != forOp.getOperation()) + if (auto forOp = dyn_cast(subOp)) + isInner = false; + }); + return isInner; +} + +// Generate code +// numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep +Value getNumSteps(scf::ForOp forOp, OpBuilderWithAsyncTaskIds &builder) { + auto loc = forOp.getLoc(); + // numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep + Value numSteps = builder.createWithAsyncTaskIds( + loc, forOp.getUpperBound(), forOp.getLowerBound()); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, + forOp.getStep()); + if (forOp.getStep().getType() != builder.getI64Type()) + numSteps = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), numSteps); + + Value one = builder.createWithAsyncTaskIds(loc, 1, 64); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, one); + Value innerForStep = forOp.getStep(); + if (forOp.getStep().getType() != builder.getI64Type()) + innerForStep = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), forOp.getStep()); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, + innerForStep); + return numSteps; +} + +// Add phase and bufferIndex to be used when lowering the producer. +// When hasParallelReuse is true (i.e this is the innermost loop), we pass in +// accumulatedLoopCount, which is used to initialize initBufferIdx. +// When isOuterOfReuse is true, we add an additional arg for accumLoopCount. +scf::ForOp createNewLoop(scf::ForOp forOp, int numBuffers, + scf::ForOp &parentForOp, Value accumulatedLoopCount, + SmallVector &loopWithBufferReuse, + bool hasParallelReuse, bool isOuterOfReuse) { + auto loc = forOp.getLoc(); + Block *body = forOp.getBody(); + + OpBuilderWithAsyncTaskIds builder(forOp.getContext()); + builder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(forOp)); + builder.setInsertionPoint(forOp); + if (hasParallelReuse) { + LLVM_DEBUG({ + LDBG("createNewLoop hasParallelReuse: "); + accumulatedLoopCount.dump(); + }); + } + + Value numBuffersVal = + builder.createWithAsyncTaskIds(loc, numBuffers, 32); + + // Step 1: Append bufferIdx and phase as forOp arguments. + Value tmpAccumLoopCount; + if (isOuterOfReuse) { + tmpAccumLoopCount = body->insertArgument(body->getNumArguments(), + builder.getI64Type(), loc); + } + Value phase = + body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc); + Value bufferIdx = + body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc); + + // Step 2: Generate bufferIdx and phase for next iteration: + // nextBufferIdx = bufferIdx + 1 + // nextPhase = ((nextBufferIdx < numBuffers && curPhase) || + // (nextBufferIdx >= numBuffers && curPhase^1)) + // nextBufferIdx = nextBufferIdx >= numBuffers ? 0 : nextBufferIdx + auto yieldOp = llvm::cast(body->getTerminator()); + builder.setInsertionPoint(yieldOp); + Value one = builder.createWithAsyncTaskIds(loc, 1, 32); + Value _1_1b = builder.createWithAsyncTaskIds(loc, 1, 1); + // nextBufferIdx = bufferIdx + 1 + Value nextBufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + Value bufferGECond = builder.createWithAsyncTaskIds( + loc, arith::CmpIPredicate::uge, nextBufferIdx, numBuffersVal); + Value bufferLTCond = builder.createWithAsyncTaskIds( + loc, arith::CmpIPredicate::ult, nextBufferIdx, numBuffersVal); + // nextBufferIdx >= numBuffers ? nextBufferIdx - numBuffers : nextBufferIdx + Value moduloBufferIdx = builder.createWithAsyncTaskIds( + loc, nextBufferIdx, numBuffersVal); + nextBufferIdx = builder.createWithAsyncTaskIds( + loc, bufferGECond, moduloBufferIdx, nextBufferIdx); + + // nextPhase = ((nextBufferIdx < numBuffers && curPhase) || + // (nextBufferIdx >= numBuffers && curPhase^1)) + Value flipPhase = + builder.createWithAsyncTaskIds(loc, phase, _1_1b); + Value cond0 = builder.createWithAsyncTaskIds( + loc, bufferGECond, flipPhase); + Value cond1 = builder.createWithAsyncTaskIds( + loc, bufferLTCond, phase); + Value nextPhase = + builder.createWithAsyncTaskIds(loc, cond0, cond1); + + // Step 3: Add nextBufferIdx and nextPhase to yieldOp. + if (isOuterOfReuse) { + // We have not iterated through the body yet, so do not have the right value + // for nextTmpIdx. This will be fixed in the caller. + Value nextTmpIdx = tmpAccumLoopCount; + yieldOp->insertOperands(yieldOp.getNumOperands(), + {nextTmpIdx, nextPhase, nextBufferIdx}); + } else + yieldOp->insertOperands(yieldOp.getNumOperands(), + {nextPhase, nextBufferIdx}); + + // Step 4: Create loop arguments for the new ForOp. + SmallVector newLoopArgs; + for (auto operand : forOp.getInitArgs()) + newLoopArgs.push_back(operand); + + builder.setInsertionPoint(forOp); + Value initBufferIdx, initPhase; + // Set initial values for bufferIdx and phase. + if (parentForOp) { + if (hasParallelReuse) { + // Handling ForOp with an outer loop, use the passed-in value as initial + // value. + initBufferIdx = accumulatedLoopCount; + } else { + // It is possible that parent loop induction variable has different type. + // Here we promote to 64 bit. + // numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep + Value numSteps = getNumSteps(forOp, builder); + + // TODO: use a global flattened iteration space index for multi-dim loops. + // initBufferIdx = (parentInductionVar - parentLowBound) / parentStep * + // numSteps + Value parentIterIdx = builder.createWithAsyncTaskIds( + loc, parentForOp.getInductionVar(), parentForOp.getLowerBound()); + parentIterIdx = builder.createWithAsyncTaskIds( + loc, parentIterIdx, parentForOp.getStep()); + if (parentForOp.getStep().getType() != builder.getI64Type()) + parentIterIdx = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), parentIterIdx); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, parentIterIdx, numSteps); + } + + numBuffersVal = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), numBuffersVal); + // Calculate tmpIdx / numBuffers + // initBufferIdx = tmpIdx - tmpIdx / numBuffers * numBuffers + // initPhase = (tmpIdx / numBuffers) & 1 + Value bufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, numBuffersVal); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, + builder.createWithAsyncTaskIds(loc, bufferIdx, + numBuffersVal)); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, builder.getI32Type(), initBufferIdx); + + Value one = + builder.createWithAsyncTaskIds(loc, 1, 64); + bufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + initPhase = builder.createWithAsyncTaskIds( + loc, builder.getI1Type(), bufferIdx); + } else { + if (hasParallelReuse) { + // Handling ForOp without outer loop. + // tmpIdx = accumulatedLoopCount + initBufferIdx = accumulatedLoopCount; + numBuffersVal = builder.createWithAsyncTaskIds( + loc, builder.getI64Type(), numBuffersVal); + // bufferIdx = tmpIdx / numBuffers + Value bufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, numBuffersVal); + // initBufferIdx = tmpIdx - tmpIdx/numBuffers * numBuffers (modulo) + initBufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, + builder.createWithAsyncTaskIds(loc, bufferIdx, + numBuffersVal)); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, builder.getI32Type(), initBufferIdx); + + Value one = + builder.createWithAsyncTaskIds(loc, 1, 64); + // initPhase = (tmpIdx / numBuffers) & 1 + bufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + initPhase = builder.createWithAsyncTaskIds( + loc, builder.getI1Type(), bufferIdx); + } else { + // Set initial phase to false, and initial bufferIdx to 0. + initBufferIdx = + builder.createWithAsyncTaskIds(loc, 0, 32); + initPhase = + builder.createWithAsyncTaskIds(loc, 0, 1); + } + } + if (isOuterOfReuse) { + assert(!hasParallelReuse); + Value initTmpIdx = + builder.createWithAsyncTaskIds(loc, 0, 64); + newLoopArgs.append({initTmpIdx, initPhase, initBufferIdx}); + } else + newLoopArgs.append({initPhase, initBufferIdx}); + + // Step 5: Create newForOp and take the region of the original forOp. + auto newForOp = builder.createWithAsyncTaskIds( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newLoopArgs); + if (forOp->getAttr("tt.loop_schedule")) + newForOp->setAttr("tt.loop_schedule", forOp->getAttr("tt.loop_schedule")); + newForOp.getRegion().takeBody(forOp.getRegion()); + + // Step 6: Replace forOp with newForOp. + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + forOp.erase(); + + return newForOp; +} + +// Find top-level ops which contain at least one channel. If a channel's +// getSrcOp() and getDstOp() belong to the inner loop, the outer loop will be +// part of asyncTaskOps. +SmallVector +getTaskTopRegion(triton::FuncOp funcOp, + const SmallVector &channels) { + SmallVector asyncTaskOps; + auto isAsyncTaskTopOp = [&](Operation *taskTopOp) -> bool { + for (auto c : channels) { + Operation *producer = c->getSrcOp(), *consumer = c->getDstOp(); + while (producer && !isa(producer->getParentOp())) { + producer = producer->getParentOp(); + } + while (consumer && !isa(consumer->getParentOp())) { + consumer = consumer->getParentOp(); + } + if (producer == taskTopOp && consumer == taskTopOp) + return true; + } + return false; + }; + for (auto &block : funcOp.getBody().getBlocks()) { + for (Operation &bodyOp : block.getOperations()) { + Operation *op = &bodyOp; + if (op->getNumRegions() <= 0) + continue; + // If this op does not contain both a producer taskId and a consumer + // taskId, continue. + if (getAsyncTaskIds(op).size() == 1) + continue; + if (isAsyncTaskTopOp(op)) + asyncTaskOps.push_back(op); + } + } + + LLVM_DEBUG({ + LDBG("\nTop Task Bodies"); + for (auto op : asyncTaskOps) { + LDBG("\nTask Body:"); + op->dump(); + } + }); + return asyncTaskOps; +} + +static unsigned getNumChannelsInLoop(scf::ForOp forOp, + const SmallVector &channels, + SmallVector &channelsInLoop) { + unsigned num = 0; + for (auto *ch : channels) { + scf::ForOp srcParent = ch->getSrcOp()->getParentOfType(); + scf::ForOp dstParent = ch->getSrcOp()->getParentOfType(); + if (srcParent == forOp && dstParent == forOp) { + channelsInLoop.push_back(ch); + } + } + return channelsInLoop.size(); +} + +bool reuseBuffers(SmallVector &taskTopOps, + const SmallVector &channels, + DenseMap &mapToRepresenting, + SmallVector &loopWithBufferReuse) { + // For the case of multiple parallel ForOps with same number of channels, + // we can try reusing the buffers across the parallel ForOps. + // One case is + // ForOp + // ForOp + // The other case is + // ForOp (persistent) + // ForOp + // ForOp + // For the case of + // ForOp (persistent) + // ForOp + // We update loopWithBufferReuse with the inner loop and use accumLoopCount + // via the outer loop to update bufferIdx. This is needed when the loop count + // for the inner loop varies with each outer loop iteration. + // Assume we handle outer ForOp first, then inner ForOp in program order. + SmallVector orderedForOps; + SmallVector innerForOps; + for (auto &op : taskTopOps) { + op->walk([&](Operation *subOp) { + if (auto forOp = dyn_cast(subOp)) { + if (isInnermostLoop(forOp)) + innerForOps.push_back(forOp.getOperation()); + orderedForOps.push_back(forOp); + } + }); + } + LDBG("reuseBuffers number of inner loops: " << innerForOps.size()); + if (innerForOps.empty()) + return false; + if (innerForOps.size() == 1) { + // Persistent with a single inner loop. + scf::ForOp parentForOp = innerForOps[0]->getParentOfType(); + if (parentForOp) { + loopWithBufferReuse = innerForOps; + LDBG("-- loopWithBufferReuse with size 1"); + } + return false; + } + // Check to see if the innermost loops are under one parent. And there are no + // other loops. Make sure the inner loops have same number of channels. + bool firstLoop = true; + Operation *outerLoop = nullptr; + unsigned numChannels = 0, numBuffers = 0; + SmallVector channelsInLoopOne; + for (auto &innerLoop : innerForOps) { + scf::ForOp parentForOp = innerLoop->getParentOfType(); + SmallVector channelsInLoop; + getNumChannelsInLoop(cast(innerLoop), channels, channelsInLoop); + if (firstLoop) { + outerLoop = parentForOp.getOperation(); + numChannels = channelsInLoop.size(); + channelsInLoopOne = channelsInLoop; + numBuffers = channelsInLoop[0]->numBuffers; + if (numChannels == 0) + return false; + } else { + if (outerLoop != parentForOp.getOperation()) + return false; + if (numChannels != channelsInLoop.size()) + return false; + if (numBuffers != channelsInLoop[0]->numBuffers) + return false; + unsigned idx = 0; + for (auto *ch : channelsInLoop) { + // TODO: sort the channels in the loop according to buffer size. + mapToRepresenting[ch] = channelsInLoopOne[idx++]; + } + } + firstLoop = false; + } + LLVM_DEBUG({ + LDBG("reuseBuffers: "); + for (auto &kv : mapToRepresenting) { + llvm::dbgs() << "---- from "; + kv.first->getDstOp()->dump(); + llvm::dbgs() << "---- to "; + kv.second->getDstOp()->dump(); + } + }); + loopWithBufferReuse = innerForOps; + return true; +} + +// Go through a list of operations under one scope. +// prevAccum can be null if there is an outer loop for the reuse loops. +Value updateAccumLoopCount(SmallVector &opList, + unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &loopWithBufferReuse, + Value prevAccum) { + for (Operation *op : opList) { + if (auto forOp = dyn_cast(op)) { + auto newForOp = + createNewLoopWrapper(forOp, numBuffers, taskTopOps, commonOuterLoop, + loopWithBufferReuse, prevAccum); + // Update prevAccum to be after the loop. + // If the loop is in loopWithBufferReuse, generate prevAccum + numSteps. + bool hasReuse = false; + for (auto tLoop : loopWithBufferReuse) + if (newForOp.getOperation() == tLoop) { + hasReuse = true; + break; + } + if (hasReuse) { + // Update accumLoopCount = prevAccum + numSteps. + OpBuilderWithAsyncTaskIds builder(newForOp.getContext()); + builder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(newForOp)); + builder.setInsertionPointAfter(newForOp); + + Value numSteps = getNumSteps(newForOp, builder); + prevAccum = builder.createWithAsyncTaskIds( + newForOp.getLoc(), prevAccum, numSteps); + } + // If the loop is the outer loop for a reuse loop, we are done. + // At this point, op is no longer valid. + } else if (auto ifOp = dyn_cast(op)) { + if (needAccumulatedLoopCnt(ifOp, loopWithBufferReuse)) { + auto newIfOp = + rewriteIfOp(ifOp, numBuffers, taskTopOps, commonOuterLoop, + loopWithBufferReuse, prevAccum); + // update prevAccum to be result of the new IfOp. + assert(newIfOp.getNumResults() >= 1); + auto numRes = newIfOp.getNumResults(); + LDBG("update prevAccum with result from IfOp"); + prevAccum = newIfOp.getResult(numRes - 1); // last result + } else { + // Still need to process ForOps in pre-order. + ifOp->walk([&](Operation *subOp) { + if (auto forOp = dyn_cast(subOp)) { + // Handle forOp. + createNewLoopWrapper(forOp, numBuffers, taskTopOps, commonOuterLoop, + loopWithBufferReuse, prevAccum); + } + }); + } + } + } + return prevAccum; +} + +scf::ForOp createNewLoopWrapper(scf::ForOp origForOp, unsigned numBuffers, + SmallVector &taskTopOps, + Operation *commonOuterLoop, + SmallVector &loopWithBufferReuse, + Value prevAccum) { + LLVM_DEBUG({ + LDBG("call createNewLoop on"); + origForOp.dump(); + }); + + scf::ForOp parentForOp = origForOp->getParentOfType(); + scf::ForOp newForOp; + // for(...) -> for(..., phase, bufferIdx) + unsigned loopNumBuffers = getNumBuffersOrDefault(origForOp, numBuffers); + + bool isOuterOfReuse = + commonOuterLoop && commonOuterLoop == origForOp.getOperation(); + bool hasReuse = false; + for (auto tLoop : loopWithBufferReuse) + if (origForOp.getOperation() == tLoop) { + hasReuse = true; + break; + } + // Set accumulatedLoopCount when this is a loop in loopWithBufferReuse. If + // this loop has an outer loop, an extra arg for accumLoopCount should have + // been added to the outer loop. + Value accumulatedLoopCount = prevAccum; // Value(); + newForOp = createNewLoop(origForOp, loopNumBuffers, parentForOp, + accumulatedLoopCount, loopWithBufferReuse, hasReuse, + isOuterOfReuse); + LLVM_DEBUG({ + LDBG("after createNewLoop "); + newForOp.dump(); + }); + // origForOp is erased in createNewLoop. If origForOp is a top operation + // (i.e in taskTopOps), make sure taskTopOps is updated with the newForOp. + auto asyncTaskLoopForItr = + std::find(taskTopOps.begin(), taskTopOps.end(), origForOp.getOperation()); + if (asyncTaskLoopForItr != taskTopOps.end()) { + // Update taskTopOps. + *asyncTaskLoopForItr = newForOp.getOperation(); + } + + // origForOp is erased in createNewLoop. If origForOp is in + // loopWithBufferReuse, replace. + auto tmpIter = std::find(loopWithBufferReuse.begin(), + loopWithBufferReuse.end(), origForOp.getOperation()); + if (tmpIter != loopWithBufferReuse.end()) { + *tmpIter = newForOp.getOperation(); + } + + // Handle ops in loop body, only IfOps and ForOps. + SmallVector opList; + for (Operation &op : newForOp.getBody()->without_terminator()) { + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + if (auto tOp = dyn_cast(&op)) + opList.push_back(&op); + } + Value endAccum = updateAccumLoopCount( + opList, numBuffers, taskTopOps, commonOuterLoop, loopWithBufferReuse, + isOuterOfReuse ? getAccumLoopCountArg(newForOp) : prevAccum); + + // Update yieldOp. + if (isOuterOfReuse) { + Value arg = getAccumLoopCountArg(newForOp); + Operation *yieldOp = newForOp.getBody()->getTerminator(); + yieldOp->replaceUsesOfWith(arg, endAccum); + } + return newForOp; +} + +// This function takes a list of channels, a mapping from a channel +// to its representing channel if the key shares smem space with the +// representing channel, and a list of loops that are sharing smem spaces. Note +// that every loop in loopWithBufferReuse either has the same outer loop or has +// no outer loop. +// For ForOps in taskTopOps, create new ForOp for each by adding phase, +// bufferIdx to the arguments. In the case of sharing smem, we need to traverse +// and update IfOps when necessary. We call updateAccumLoopCount on the list +// of top level Ops that are ForOps or IfOps enclosing a loop with buffer reuse. +// updateAccumLoopCount calls createNewLoopWrapper on ForOps, and rewriteIfOp on +// IfOps. Both will call updateAccumLoopCount on the list of Ops in the ForOp +// body or the thenBlock, elseBlock for IfOp. +Value appendBufferIdxArgs( + SmallVector &taskTopOps, unsigned numBuffers, + const SmallVector &channels, + const DenseMap &mapToRepresenting, + SmallVector &loopWithBufferReuse) { + // In order to handle sharing smem for a list of loops, we have two cases, + // one is the top-level op containing all loops in loopWithBufferReuse is + // a ForOp. + bool genAccumLoopCount = !loopWithBufferReuse.empty(); + Operation *commonOuterLoop = nullptr; + if (genAccumLoopCount) { + auto oneFor = loopWithBufferReuse[0]; + scf::ForOp parentForOp = oneFor->getParentOfType(); + if (parentForOp) + commonOuterLoop = parentForOp.getOperation(); + } + + // When there is no outer loop, we need to create a place holder for + // tmpAccumLoopCount. Every forOp in loopWithBufferReuse either has the same + // outer loop or has no outer loop. + Value tmpAccumLoopCount; + if (loopWithBufferReuse.size() > 1 && !commonOuterLoop) { + auto oneFor = loopWithBufferReuse[0]; + // Initialize tmpAccumLoopCount to be 0. + OpBuilderWithAsyncTaskIds builder(taskTopOps[0]->getContext()); + builder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(oneFor)); + builder.setInsertionPoint(taskTopOps[0]); + tmpAccumLoopCount = builder.createWithAsyncTaskIds( + oneFor->getLoc(), 0, 64); + } + + SmallVector opList; + for (auto &op : taskTopOps) { + if (auto origIfOp = dyn_cast(op)) { + opList.push_back(op); + } + if (auto origForOp = dyn_cast(op)) + opList.push_back(op); + } + updateAccumLoopCount(opList, numBuffers, taskTopOps, commonOuterLoop, + loopWithBufferReuse, tmpAccumLoopCount); + + return tmpAccumLoopCount; +} + +// Create an allocation to hold the mbarriers. +static Value createBarrierAlloc(triton::FuncOp funcOp, unsigned distance) { + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(funcOp.getContext()); + Location loc = funcOp.getLoc(); + auto context = funcOp.getContext(); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = tt::MemDescType::get( + {distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = builder.create( + loc, barrierMemDescType, Value()); + for (unsigned i = 0; i < distance; i++) { + Value idx = builder.create(loc, i, 32); + Value barrierView = builder.create( + loc, singleBarrierMemDescType, barrierAlloc, idx); + builder.create(funcOp->getLoc(), barrierView, 1); + } + return barrierAlloc; +} + +// channelsGroupedByConsumers: channels are grouped together. +// Go through each group, check the first channel in the group, create a token +// for each consumer taskId. Return a map that maps each channel + consumer +// taskId to a token. Also update barrierAllocMap that maps each channel + +// consumer taskId to a BarrierAlloc. +DenseMap> +createToken(const DenseMap> + &channelsGroupedByConsumers, + const SmallVector &orderedChannels, + triton::FuncOp funcOp, int numConsumerGroups, + DenseMap> &channelReuse, + DenseMap> &barrierAllocMap) { + DenseMap> ret; + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + for (auto *key : orderedChannels) { + auto it = channelsGroupedByConsumers.find(key); + Channel *channel = it->second.front(); + if (!channelReuse.count(channel)) + continue; + for (auto consumerAsyncTaskId : channel->relation.second) { + Value v; + if (it->second.front()->getSrcOp()->getParentOfType()) { + v = builder.create(funcOp.getLoc(), + channel->numBuffers); + } else { + v = builder.create(funcOp.getLoc(), 1); + } + // Channels in the group share the same set of tokens. + for (auto &c : it->second) { + ret[c][consumerAsyncTaskId] = v; + } + for (auto *reuse : channelReuse[channel]) { + ret[reuse][consumerAsyncTaskId] = v; + } + + auto producerOp = it->second.front()->getSrcOp(); + if (isa(producerOp)) { + Value bAlloc = createBarrierAlloc(funcOp, channel->numBuffers); + // Channels in the group share the same set of tokens. + for (auto &c : it->second) { + ret[c][consumerAsyncTaskId] = v; + barrierAllocMap[c][consumerAsyncTaskId] = bAlloc; + } + for (auto *reuse : channelReuse[channel]) { + ret[reuse][consumerAsyncTaskId] = v; + barrierAllocMap[reuse][consumerAsyncTaskId] = bAlloc; + } + } + } + } + return ret; +} + +// Create a buffer array for each producer op, if the producer is in a ForOp, +// the buffer array will contain numBuffers. +DenseMap createBuffer( + DenseMap> &channelsGroupedByProducers, + triton::FuncOp funcOp, int numConsumerGroups, + DenseMap &mapToRepresenting, + DenseMap> &channelReuse) { + + DenseMap bufferMap; + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + DenseSet visited; + for (auto &item : channelsGroupedByProducers) { + auto &channels = item.second; + for (auto c : channels) { + assert(!visited.count(c)); + visited.insert(c); + if (mapToRepresenting.count(c)) { + channelReuse[mapToRepresenting[c]].push_back(c); + LDBG("update channelReuse key " << mapToRepresenting[c] << " " << c); + } else { + channelReuse[c].push_back(c); + LDBG("update channelReuse key " << c << " " << c); + } + } + } + for (auto &item : channelsGroupedByProducers) { + auto &channels = item.second; + auto srcValue = item.first->getSrcOperand(); + auto srcOp = item.first->getSrcOp(); + unsigned numBuffers = channels.front()->numBuffers; + + if (auto tensorType = dyn_cast(srcValue.getType())) { + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get( + context, sliceShape, order, CTALayout, elemType); + auto sliceType = + RankedTensorType::get(sliceShape, elemType, sharedLayout); + + // Get shape, layout and type of the complete buffer + SmallVector bufferShape(sliceShape.begin(), sliceShape.end()); + if (srcOp->getParentOfType()) + bufferShape.insert(bufferShape.begin(), numBuffers); + else + bufferShape.insert(bufferShape.begin(), 1); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + auto bufferType = + RankedTensorType::get(bufferShape, elemType, sharedLayout); + Type memdescType = + tt::MemDescType::get(bufferShape, elemType, sharedLayout, + sharedMemorySpace, /*mutableMemory*/ true); + Value buffer = + builder.create(funcOp.getLoc(), memdescType); + + // Channels in the group share the same buffer. + for (auto c : channels) + bufferMap[c] = buffer; + } else { + llvm_unreachable("Unexpected result type"); + } + } + unsigned groupId = 0; + for (auto &kv : channelReuse) { + if (kv.second.size() <= 1) + continue; + bufferMap[kv.first].getDefiningOp()->setAttr( + "allocation.shareGroup", + IntegerAttr::get(IntegerType::get(context, 32), groupId)); + for (auto *c : kv.second) + bufferMap[c].getDefiningOp()->setAttr( + "allocation.shareGroup", + IntegerAttr::get(IntegerType::get(context, 32), groupId)); + ++groupId; + } + return bufferMap; +} + +static Operation *createAsyncCopy(const DenseMap &bufferMap, + Channel *c, Operation *op, + SmallVector &asyncTasksPC, + Value bufferIdx, Value bufferIdxExtract) { + auto loadOp = cast(op); + auto buffer = bufferMap.find(c)->second; + MLIRContext *context = loadOp->getContext(); + OpBuilderWithAsyncTaskIds builder(context); + builder.setInsertionPoint(loadOp->getParentOp()); + builder.setAsynTaskIdsFromArray(asyncTasksPC); + + builder.setInsertionPoint(loadOp); + Value loadResult = loadOp.getResult(); + auto tensorType = dyn_cast(loadResult.getType()); + if (!tensorType) + return nullptr; + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get(context, sliceShape, order, + CTALayout, elemType); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType subviewTy = + tt::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true); + Value zero = builder.createWithAsyncTaskIds( + loadOp.getLoc(), 0, 32); + SmallVector copyOffsets(sliceType.getRank() + 1, zero); + copyOffsets[0] = bufferIdx; + builder.setAsyncTaskIdsFromOp(loadOp); + builder.setInsertionPointAfter(loadOp); + auto view = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, copyOffsets); + // Create cp.async + Operation *copy = + builder.createWithAsyncTaskIds( + loadOp.getLoc(), loadOp.getPtr(), view, loadOp.getMask(), + loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile()); + + // Extract part. + builder.setAsyncTaskIdsFromValueUsers(loadResult); + builder.setInsertionPoint(c->getDstOp()); + SmallVector loadOffsets(sliceType.getRank() + 1, zero); + loadOffsets[0] = bufferIdxExtract; + auto viewLoad = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, loadOffsets); + auto sharedLoad = builder.createWithAsyncTaskIds( + loadOp.getLoc(), loadOp.getType(), viewLoad /*,wait->getResult(0)*/); + // Replace all uses of loadResult + loadResult.replaceAllUsesWith(sharedLoad.getResult()); + loadOp.erase(); + return copy; +} + +// Create a local copy for a channel that is populated by the producer and +// accessed by the consumer. +static void createLocalCopy(const DenseMap &bufferMap, + Channel *channel, Value srcBufferIdx, + Value dstBufferIdx) { + Operation *srcOp = channel->getSrcOp(); + Operation *dstOp = channel->getDstOp(); + MLIRContext *context = srcOp->getContext(); + auto buffer = bufferMap.find(channel)->second; + + Value srcValue = channel->getSrcOperand(); + auto tensorType = dyn_cast(srcValue.getType()); + if (!tensorType) + return; + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get(context, sliceShape, order, + CTALayout, elemType); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType subviewTy = + tt::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true); + + // Consumer part. + OpBuilderWithAsyncTaskIds builder(dstOp); + builder.setAsyncTaskIdsFromOp(dstOp); + builder.setInsertionPoint(dstOp); + Value zero = builder.createWithAsyncTaskIds( + dstOp->getLoc(), 0, 32); + SmallVector loadOffsets(sliceType.getRank() + 1, zero); + loadOffsets[0] = dstBufferIdx; + auto dstView = builder.createWithAsyncTaskIds( + dstOp->getLoc(), subviewTy, buffer, loadOffsets); + auto sharedLoad = builder.createWithAsyncTaskIds( + dstOp->getLoc(), srcValue.getType(), dstView); + srcValue.replaceAllUsesWith(sharedLoad.getResult()); + + // Producer part. Create local_store for new producers. + builder.setAsynTaskIdsFromArray(channel->relation.first); + builder.setInsertionPoint(srcOp->getParentOp()); + zero = builder.createWithAsyncTaskIds(srcOp->getLoc(), + 0, 32); + SmallVector storeOffsets(sliceType.getRank() + 1, zero); + storeOffsets[0] = srcBufferIdx; + builder.setInsertionPointAfter(srcOp); + auto srcView = builder.createWithAsyncTaskIds( + srcOp->getLoc(), subviewTy, buffer, storeOffsets); + // Create local_alloc + Operation *copy = builder.createWithAsyncTaskIds( + srcOp->getLoc(), srcValue, srcView); +} + +static int getTMALoadSize(tt::ExperimentalDescriptorLoadOp &tmaLoad) { + auto tensorTy = cast(tmaLoad->getResult(0).getType()); + int loadSize = product(tensorTy.getShape()); + return loadSize * tensorTy.getElementType().getIntOrFloatBitWidth() / 8; +} + +Value getBarrierForPipelineStage(OpBuilderWithAsyncTaskIds &builder, + Value barrierAlloc, Value bufferIdx) { + auto context = barrierAlloc.getContext(); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType barrierTy = tt::MemDescType::get( + {1}, builder.getI64Type(), + cast(barrierAlloc.getType()).getEncoding(), + sharedMemorySpace, + /*mutableMemory=*/true); + + // Create barrierForTMA from barrierAlloc. + return builder.createWithAsyncTaskIds( + barrierAlloc.getLoc(), barrierTy, barrierAlloc, + ArrayRef({bufferIdx})); +} + +Value getBufferForPipelineStage(OpBuilderWithAsyncTaskIds &builder, + Type loadType, Value buffer, Value bufferIdx, + bool mutableMem) { + auto context = buffer.getContext(); + auto tensorType = dyn_cast(loadType); + assert(tensorType); + + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get(context, sliceShape, order, + CTALayout, elemType); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType subviewTy = + tt::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemOry=*/mutableMem); + + Value zero = builder.createWithAsyncTaskIds( + buffer.getLoc(), 0, 32); + SmallVector copyOffsets(sliceType.getRank() + 1, zero); + copyOffsets[0] = bufferIdx; + + return builder.createWithAsyncTaskIds( + buffer.getLoc(), subviewTy, buffer, copyOffsets); +} + +Operation * +optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, + SmallVector &tmaLoads, + SmallVector &buffers, Value barrierAlloc, + Value bufferIdx, Value bufferIdxExtract, Value phase, + Operation *headProducer, Operation *headConsumer) { + auto loc = barrierAlloc.getLoc(); + + // Compute the total size of the loads. + int sizeInBytes = 0; + for (auto &tmaLoad : tmaLoads) { + sizeInBytes += getTMALoadSize(tmaLoad); + } + + // For each of the following ops, we will operate on a subview of each value + // according to the pipeline stage. + + // Create a barrier_expect with the appropriate size and insert it before the + // first load. + builder.setInsertionPoint(headProducer); + builder.setAsyncTaskIdsFromOp(headProducer); + auto prodBarrier = + getBarrierForPipelineStage(builder, barrierAlloc, bufferIdx); + auto pred = builder.createWithAsyncTaskIds(loc, 1, 1); + auto expect = builder.createWithAsyncTaskIds( + loc, prodBarrier, sizeInBytes, pred); + + // Convert all the producers to async_tma_copy_global_to_local + Operation *copy = nullptr; + for (auto [tmaLoad, buffer] : zip(tmaLoads, buffers)) { + auto pipelineBuffer = getBufferForPipelineStage(builder, tmaLoad.getType(), + buffer, bufferIdx, true); + copy = builder.createWithAsyncTaskIds( + loc, tmaLoad.getDescPtr(), tmaLoad.getIndices(), prodBarrier, + pipelineBuffer, pred); + } + + // Create a wait_barrier before the first consumer. + builder.setInsertionPoint(headConsumer); + builder.setAsyncTaskIdsFromOp(headConsumer); + auto consBarrier = + getBarrierForPipelineStage(builder, barrierAlloc, bufferIdxExtract); + phase = builder.createWithAsyncTaskIds( + loc, builder.getI32Type(), phase); + auto wait = builder.createWithAsyncTaskIds( + loc, consBarrier, phase); + + // Convert all the consumers to local_load + for (auto [tmaLoad, buffer] : zip(tmaLoads, buffers)) { + auto pipelineBuffer = getBufferForPipelineStage( + builder, tmaLoad.getType(), buffer, bufferIdxExtract, false); + auto sharedLoad = builder.createWithAsyncTaskIds( + loc, tmaLoad.getType(), pipelineBuffer); + + Value loadResult = tmaLoad.getResult(); + tmaLoad.getResult().replaceAllUsesWith(sharedLoad.getResult()); + tmaLoad.erase(); + } + return copy; +} + +// Lower producers for channels. Here channels are grouped in +// "channelsGroupedByConsumers". tokenMap tracks the set of tokens for each +// channel. +void insertAsyncComm( + triton::FuncOp funcOp, + const DenseMap> + &channelsGroupedByConsumers, + const DenseMap> &tokenMap, + const DenseMap> &barrierAllocMap, + const DenseMap &bufferMap, int numConsumerGroups) { + + // Find the operation that is along producer's parent chain, and its parent + // is the same op as producer's parent. Here p is producer, and c is consumer. + auto getSameLevelOp = [](Operation *p, Operation *c) -> Operation * { + while (!isa(c)) { + if (c->getParentOp() == p->getParentOp()) { + return c; + } + c = c->getParentOp(); + } + llvm_unreachable("Failed to find consumer's same level Op with producer"); + }; + + auto consumerReleaseHeutistic = [&](Operation *p, Operation *c, + int consumerAsyncTaskId) -> Operation * { + if (c->getBlock() != p->getBlock()) + return getSameLevelOp(p, c); + + // Find a common place for all users of the consumer, which would be the + // common post dominator. + mlir::PostDominanceInfo dom(funcOp); + std::unordered_set mutuallyNonDominatingUsers; + SmallVector users; + for (auto user : c->getUsers()) { + if (isa(user)) { + // TransOp is not a real consumer. It caculates the shared memory + // address for the real consumer. Continue to find its transitive users + // recursively. + DenseSet visited; + SmallVector transUsers; + transUsers.push_back(user); + while (!transUsers.empty()) { + auto transUser = transUsers.pop_back_val(); + visited.insert(transUser); + if (isa(transUser)) { + for (auto transitiveUser : transUser->getUsers()) { + if (!visited.count(transitiveUser)) + transUsers.push_back(transitiveUser); + } + } else { + users.push_back(transUser); + } + } + } else { + users.push_back(user); + } + } + + for (auto user : users) { + auto it = mutuallyNonDominatingUsers.begin(); + while (it != mutuallyNonDominatingUsers.end()) { + if (dom.properlyPostDominates(user, *it)) { + it = mutuallyNonDominatingUsers.erase(it); + } else if (dom.properlyPostDominates(*it, user)) { + break; + } else { + ++it; + } + } + if (it == mutuallyNonDominatingUsers.end()) + mutuallyNonDominatingUsers.insert(user); + } + + if (mutuallyNonDominatingUsers.size() == 1) { + // Find the common parent of this user and c + auto user = *mutuallyNonDominatingUsers.begin(); + while (user && user->getParentOp() != c->getParentOp()) + user = user->getParentOp(); + assert(user && "Failed to find common parent of this user and c"); + return user; + } + + for (auto &op : reverse(c->getBlock()->getOperations())) { + auto asyncTasks = getAsyncTaskIds(&op); + if (asyncTasks.size() == 1 && asyncTasks[0] == consumerAsyncTaskId) + return &op; + } + + return nullptr; + }; + + // Go through each channel group. + for (auto kv : channelsGroupedByConsumers) { + // Find head and tail ops. + DenseSet producerOps; + DenseSet consumerOps; + for (auto &c : kv.second) { + producerOps.insert(c->getSrcOp()); + consumerOps.insert(c->getDstOp()); + } + + // Find head producer + auto producerBlock = kv.second.front()->getSrcOp()->getBlock(); + Operation *headProducer = nullptr; + for (auto &op : producerBlock->getOperations()) { + if (producerOps.count(&op)) { + headProducer = &op; + break; + } + } + // Find tail producer + Operation *tailProducer = nullptr; + for (auto &op : reverse(producerBlock->getOperations())) { + if (producerOps.count(&op)) { + tailProducer = &op; + break; + } + } + + // Find head consumer and tail consumer + auto consumerBlock = kv.second.front()->getDstOp()->getBlock(); + Operation *headConsumer = nullptr; + for (auto &op : consumerBlock->getOperations()) { + if (consumerOps.count(&op)) { + headConsumer = &op; + break; + } + } + Operation *tailConsumer = nullptr; + for (auto &op : reverse(consumerBlock->getOperations())) { + if (consumerOps.count(&op)) { + tailConsumer = &op; + break; + } + } + + // We have one set of tokens for each channel group. + auto tokens = tokenMap.find(kv.second.front())->second; + auto masterChannel = kv.getFirst(); + + SmallVector asyncTaskP; + asyncTaskP.push_back(masterChannel->relation.first); + SmallVector &asyncTaskC = masterChannel->relation.second; + SmallVector asyncTasksPC = asyncTaskP; + asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskC.begin(), + asyncTaskC.end()); + + OpBuilderWithAsyncTaskIds builder(headProducer->getContext()); + if (auto funcOp = dyn_cast(headProducer->getParentOp())) { + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + } else { + builder.setInsertionPoint(headProducer->getParentOp()); + } + builder.setAsynTaskIdsFromArray(asyncTasksPC); + + Value bufferIdx; + Value phase = Value(); + if (auto forOp = headProducer->getParentOfType()) { + // We already added phase, bufferIdx to the ForOp. + auto tSize = forOp.getBody()->getArguments().size(); + assert(tSize >= 2); + bufferIdx = forOp.getBody()->getArguments().back(); + phase = forOp.getBody()->getArgument(tSize - 2); // next to last argument + } else { + // Producer is not in a ForOp, create phase and bufferIdx here. + bufferIdx = builder.createWithAsyncTaskIds( + headProducer->getLoc(), 0, 32); + phase = builder.createWithAsyncTaskIds( + headProducer->getLoc(), 0, 1); + } + + builder.setAsynTaskIdsFromArray(masterChannel->relation.first); + for (auto token : tokens) { + // Insert ProducerAcquireOp before the producer. + builder.setInsertionPoint(headProducer); + builder.createWithAsyncTaskIds( + headProducer->getLoc(), token.second, bufferIdx, phase); + + // Insert ProducerCommitOp if producer is LoadOp. For TMA, TMA lowering + // will handle the ProducerCommit. + if (!isa(headProducer)) { + builder.setInsertionPointAfter(tailProducer); + builder.createWithAsyncTaskIds( + tailProducer->getLoc(), token.second, bufferIdx); + } + } + + for (auto token : tokens) { + builder.setAsynTaskIdsFromArray(token.first); + // Insert ConsumerWaitOp + if (!isa(headProducer)) { + auto consumerWaitPoint = getSameLevelOp(headProducer, headConsumer); + builder.setInsertionPoint(consumerWaitPoint); + builder.createWithAsyncTaskIds( + headConsumer->getLoc(), token.second, bufferIdx, phase); + } + + // Insert ConsumerReleaseOp. + auto consumerReleasePoint = + consumerReleaseHeutistic(tailProducer, tailConsumer, token.first); + builder.setInsertionPointAfter(consumerReleasePoint); + builder.createWithAsyncTaskIds( + consumerReleasePoint->getLoc(), token.second, bufferIdx); + } + + SmallVector tmaLoads; + SmallVector buffers; + DenseMap producerCopyMap; + // Go through all channels in this channel group. + for (auto &c : kv.second) { + if (auto tmaLoad = + dyn_cast(c->getSrcOp())) { + tmaLoads.push_back(tmaLoad); + buffers.push_back(bufferMap.find(c)->second); + } + } + + // Optimize TMA loads. + if (tmaLoads.size() > 0) { + auto barrierAllocs = barrierAllocMap.find(kv.second.front())->second; + // TODO: we created one Alloc for each consumer taskId, but here, we + // only use the first Alloc. + auto barrierAlloc = barrierAllocs.begin()->second; + optimizeTMALoads(builder, tmaLoads, buffers, barrierAlloc, bufferIdx, + bufferIdx, phase, headProducer, headConsumer); + } + } +} + +// Lower producers for channels. Here channels are grouped in +// "channelsGroupedByProducers" +void insertAsyncCopy(triton::FuncOp funcOp, + const DenseMap> + &channelsGroupedByProducers, + const DenseMap &bufferMap) { + // For each producer op, create a async_copy or local_store from the producer + // to the buffer. Create a local_load from the buffer at the dominating + // consumer. + mlir::DominanceInfo dom(funcOp); + + for (auto kv : channelsGroupedByProducers) { + // Finding the dominating channel if possible. + std::unordered_set mutuallyNonDominatingChannels; + for (auto &c : kv.second) { + // check if c is dominating all other previous channels. + auto it = mutuallyNonDominatingChannels.begin(); + while (it != mutuallyNonDominatingChannels.end()) { + auto channel = *it; + if (dom.properlyDominates(c->getDstOp(), channel->getDstOp())) { + it = mutuallyNonDominatingChannels.erase(it); + } else if (dom.properlyDominates(channel->getDstOp(), c->getDstOp())) { + break; + } else { + ++it; + } + } + if (it == mutuallyNonDominatingChannels.end()) + mutuallyNonDominatingChannels.insert(c); + } + + auto srcOp = kv.getFirst()->getSrcOp(); + Value bufferIdx; + Value phase = Value(); + if (auto forOp = srcOp->getParentOfType()) { + // We already added phase, bufferIdx to the ForOp. + auto tSize = forOp.getBody()->getArguments().size(); + assert(tSize >= 2); + bufferIdx = forOp.getBody()->getArguments().back(); + } else { + // Producer is not in a ForOp, create phase and bufferIdx here which will + // be used by both producer and consumers. + OpBuilderWithAsyncTaskIds builder(srcOp); + SmallVector asyncTasksPC = getAsyncTaskIds(srcOp); + for (auto channel : mutuallyNonDominatingChannels) + asyncTasksPC.append(getAsyncTaskIds(channel->getDstOp())); + builder.setAsynTaskIdsFromArray(asyncTasksPC); + bufferIdx = builder.createWithAsyncTaskIds( + srcOp->getLoc(), 0, 32); + } + + for (auto channel : mutuallyNonDominatingChannels) { + // No need to create async copy for TMA load which is handled in + // insertAsyncComm. + if (isa(srcOp)) { + continue; + } + if (isa(srcOp)) { + SmallVector asyncTasksPC = getAsyncTaskIds(srcOp); + asyncTasksPC.append(getAsyncTaskIds(channel->getDstOp())); + // After createAsyncCopy, c->getSrcOp()/headProducer are no longer + // valid. + createAsyncCopy(bufferMap, channel, channel->getSrcOp(), asyncTasksPC, + bufferIdx, bufferIdx); + } else { + createLocalCopy(bufferMap, channel, bufferIdx, bufferIdx); + } + } + } +} + +void foldLocalLoads(triton::FuncOp funcOp) { + // If loadResult has a single use which is LocalAlloc, we can get rid of + // sharedLoad and replace all uses of LocalAlloc with viewLoad. + DenseMap opsToReplace; + funcOp.walk([&](ttg::LocalAllocOp localAlloc) { + if (auto src = localAlloc.getSrc()) { + if (auto localLoad = dyn_cast(src.getDefiningOp())) { + // Only fold within the same tasks + if (getAsyncTaskIds(localLoad) == getAsyncTaskIds(localAlloc)) { + opsToReplace[localAlloc] = localLoad.getSrc(); + } + } + } + }); + OpBuilderWithAsyncTaskIds builder(funcOp.getContext()); + for (auto kv : opsToReplace) + replaceUsesAndPropagateType(builder, kv.getFirst(), kv.getSecond()); +} + +class TritonGPUWSCodePartitionPass + : public impl::TritonGPUWSCodePartitionBase { +public: + using impl::TritonGPUWSCodePartitionBase< + TritonGPUWSCodePartitionPass>::TritonGPUWSCodePartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + // Disable code partitioning when numBuffers is 0. + if (numBuffers == 0) + return; + + // Step 1: collect all communications between producers and consumers. + SmallVector> channelsOrigin; + collectAsyncChannels(channelsOrigin, funcOp, numBuffers); + SmallVector channels; + for (const auto &c : channelsOrigin) { + channels.push_back(c.get()); + } + if (channels.empty()) { + return; + } + + // Step 2: group channels + // - each entry of the channelsGroupedByProducers is keyed by the srcOp. + // - each entry of the channelsGroupedByConsumers is keyed by the dstOp. + DenseMap> channelsGroupedByProducers; + DenseMap> channelsGroupedByConsumers; + SmallVector orderedChannels; + groupChannels(channels, channelsGroupedByProducers, + channelsGroupedByConsumers, orderedChannels); + + // Step 3: reorder producer ops and the backward slices of the producer ops. + reorderProducerOps(channels); + + // Step 4: find top-level ops that contain a channel, also create new ForOps + // by adding phase and bufferIdx to the original ForOps, erase the original + // ForOps. + SmallVector asyncTaskTopOps = + getTaskTopRegion(funcOp, channels); + // Update mapToRepresenting that maps a channel to the representing channel + // in the sharing group. + DenseMap mapToRepresenting; + SmallVector loopWithBufferReuse; + reuseBuffers(asyncTaskTopOps, channels, mapToRepresenting, + loopWithBufferReuse); + // Use and update loopWithBufferReuse. + appendBufferIdxArgs(asyncTaskTopOps, numBuffers, channels, + mapToRepresenting, loopWithBufferReuse); + LLVM_DEBUG({ + LDBG("\n\nafter appendBufferIdxArgs"); + funcOp.dump(); + }); + + // Step 5: Create tokens, and buffers. A set of tokens for each group of + // channels and an array of buffers for each channel. + // Update channelReuse that maps from a representing channel to the group of + // channels that share buffers. + DenseMap> channelReuse; + DenseMap bufferMap = + createBuffer(channelsGroupedByProducers, funcOp, numConsumerGroups, + mapToRepresenting, channelReuse); + DenseMap> barrierAllocMap; + DenseMap> tokenMap = + createToken(channelsGroupedByConsumers, orderedChannels, funcOp, + numConsumerGroups, channelReuse, barrierAllocMap); + LLVM_DEBUG({ + LDBG("\n\nafter createBuffer"); + funcOp.dump(); + }); + + // Step 6: add async communication ops (ProducerAcquire etc). Also lower the + // loads. + insertAsyncComm(funcOp, channelsGroupedByConsumers, tokenMap, + barrierAllocMap, bufferMap, numConsumerGroups); + LLVM_DEBUG({ + LDBG("\n\nwith SyncOps"); + funcOp.dump(); + }); + + // Step 7: Lower the loads. Also add local copy ops for non-load producers. + insertAsyncCopy(funcOp, channelsGroupedByProducers, bufferMap); + LLVM_DEBUG({ + LDBG("\n\nwith async copy"); + funcOp.dump(); + }); + + // If loadResult has a single use which is LocalAlloc, we can get rid of + // sharedLoad and replace all uses of LocalAlloc with viewLoad. + foldLocalLoads(funcOp); + LLVM_DEBUG({ + LDBG("\n\nsimplify localLoad + localAlloc"); + funcOp.dump(); + }); + + // Assuming there are no changes to loops in loopWithBufferReuse. + auto ret = SpecializeRegion(funcOp, regDecProducer, regIncConsumer); + LLVM_DEBUG({ + LDBG("\n\nwith SpecializeRegion"); + funcOp.dump(); + }); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + LLVM_DEBUG({ + LDBG("post pass"); + getOperation()->dump(); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp new file mode 100644 index 000000000..0275e1b18 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp @@ -0,0 +1,756 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-spec-data-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +static bool oneVecCoversTheOther(SmallVector &one, + SmallVector &other) { + // Every element of other appears in one. + for (AsyncTaskId t : other) { + // If t doesn't appear in one, return false. + bool found = false; + for (AsyncTaskId t2 : one) { + if (t2 == t) { + found = true; + break; + } + } + if (!found) + return false; + } + return true; +} + +// Make sure the def chain contains the right taskId. +void fixTaskId(triton::FuncOp &funcOp) { + funcOp.walk([&](Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + // Do not update loads. + if (isa(defOp)) + continue; + auto defTaskIds = getAsyncTaskIds(defOp); + // Make sure defTaskIds cover asyncTaskIds. Call addAsyncTaskIds if + // necessary. + if (!oneVecCoversTheOther(defTaskIds, asyncTaskIds)) { + // Const ops with same value but different task ids can be folded. + if (isa(defOp)) { + LLVM_DEBUG({ + LDBG("backward fixing taskId for"); + defOp->dump(); + }); + addAsyncTaskIds(defOp, asyncTaskIds); + LLVM_DEBUG({ + LDBG("resulting"); + defOp->dump(); + }); + } + } + if (operand.hasOneUse() && + !oneVecCoversTheOther(asyncTaskIds, defTaskIds)) { + // YieldOp may lose task attribute during MLIR canonicalization. + if (isa(op)) { + LLVM_DEBUG({ + LDBG("forward fixing taskId for"); + defOp->dump(); + }); + addAsyncTaskIds(op, defTaskIds); + LLVM_DEBUG({ + LDBG("resulting"); + defOp->dump(); + }); + } + } + } + }); +} + +static SmallVector getShape(Value v) { + auto type = v.getType(); + if (auto type = dyn_cast(v.getType())) { + return {type.getShape().begin(), type.getShape().end()}; + } else if (auto type = dyn_cast(v.getType())) { + return {type.getShape().begin(), type.getShape().end()}; + } + return {}; +} + +bool needToSlice(Value v, int dim, int size) { + auto shape = getShape(v); + return shape.size() > dim && shape[dim] > size; +} + +void getBackwardSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &backwardSlice) { + SmallVector queue = {root}; + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!needToSlice(v, dim, sliceSize)) + continue; + if (auto op = v.getDefiningOp()) { + if (backwardSlice.insert(op)) { + if (op->hasTrait() || + isa(op)) { + for (Value operand : op->getOperands()) + queue.push_back(operand); + } else if (auto dotOp = dyn_cast(op)) { + queue.push_back(dim == 0 ? dotOp.getA() : dotOp.getB()); + queue.push_back(dotOp.getC()); + } else { + llvm_unreachable("Unexpected op"); + } + } + } else { + assert(isa(v) && "value is not an operation or block "); + auto bbArg = cast(v); + Operation *bbAargOwner = bbArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(bbAargOwner)) { + // track initial value + auto initArg = forOp.getInitArgs()[bbArg.getArgNumber() - 1]; + queue.push_back(initArg); + // track yield value + auto yieldArg = forOp.getYieldedValues()[bbArg.getArgNumber() - 1]; + queue.push_back(yieldArg); + } + } + } +}; + +void getForwardSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &forwardSlice) { + SmallVector queue = {root}; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!seen.insert(v).second) + continue; + if (!needToSlice(v, dim, sliceSize)) + continue; + getForwardSlice(v, &forwardSlice); + for (Operation *op : forwardSlice) { + if (op->getNumResults() > 0) + seen.insert(op->getResult(0)); + if (auto yieldOp = dyn_cast(op)) { + auto parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (seen.count(operand.get())) { + queue.push_back(parentOp->getResult(operand.getOperandNumber())); + forwardSlice.insert(parentOp); + } + } + } + } + } +}; + +// Compute a closure of all ops originated from or being dependent on by the +// root op. +void getSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &slice) { + getBackwardSliceToPartition(root, dim, sliceSize, slice); + SetVector forwardSlice; + getForwardSliceToPartition(root, dim, sliceSize, forwardSlice); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + for (auto op : forwardSlice) { + if (op->hasTrait() || + isa(op)) { + for (OpOperand &operand : op->getOpOperands()) { + getBackwardSliceToPartition(operand.get(), dim, sliceSize, slice); + } + } else if (auto dotOp = dyn_cast(op)) { + getBackwardSliceToPartition(dim == 0 ? dotOp.getA() : dotOp.getB(), dim, + sliceSize, slice); + getBackwardSliceToPartition(dotOp.getC(), dim, sliceSize, slice); + } + } +} + +struct DataPartitionScheme { + // Which dimension to partition. For dot, dim 0 means along M dimension, 1 + // means along N dimensiont. + unsigned partitionDim = 0; + unsigned numPartitions = 0; + SetVector ops; +}; + +bool computePartitionScheme(triton::FuncOp &funcOp, + DataPartitionScheme &partitionScheme) { + // Do not partition producer tasks + + // Use dot to drive the partition + SetVector dots; + + // check all dot ops that have more than one async task id + funcOp.walk([&](Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() > 1) { + if (auto dotWaitOp = dyn_cast(op)) { + dots.insert(dotWaitOp); + } + } + }); + + // Checking if all dots can be partitioned in the same way + int numWarps = + TritonGPUDialect::getNumWarps(funcOp->getParentOfType()); + for (auto dotOp : dots) { + // partition along M first, otherwise along N + RankedTensorType dotType = dotOp.getType(); + LLVM_DEBUG({ + LDBG("Computing partition scheme for"); + dotOp.dump(); + LDBG("\n"); + }); + auto shapePerCTA = getShapePerCTA(dotType); + if (shapePerCTA.size() != 2) { + LDBG("partition not possible: shapePerCTA " << shapePerCTA.size()); + return false; + } + auto CTALayout = getCTALayout(dotType.getEncoding()); + auto asyncTaskIds = getAsyncTaskIds(dotOp); + int sliceSizeM = shapePerCTA[0] / asyncTaskIds.size(); + int sliceSizeN = shapePerCTA[1] / asyncTaskIds.size(); + int partitionDim, partitionSize; + Value partitionOperand; + + if (sliceSizeM >= 64) { + LLVM_DEBUG({ LDBG("partition along M\n"); }); + partitionDim = 0; + partitionSize = sliceSizeM; + partitionOperand = dotOp.getA(); + } else if (sliceSizeN >= 256) { + LLVM_DEBUG({ LDBG("partition along N\n"); }); + partitionDim = 1; + partitionSize = sliceSizeN; + partitionOperand = dotOp.getB(); + } else { + LDBG("partition not possible: " << sliceSizeM << " " << sliceSizeN); + return false; + } + + if (partitionScheme.numPartitions == 0) { + partitionScheme.partitionDim = partitionDim; + partitionScheme.numPartitions = asyncTaskIds.size(); + } else { + if (partitionScheme.partitionDim != partitionDim || + partitionScheme.numPartitions != asyncTaskIds.size()) { + LDBG("partition not possible, in conflict with previous partition\n"); + return false; + } + } + + // Partition the slice closure + SetVector &slice = partitionScheme.ops; + getSliceToPartition(dotOp.getD(), partitionDim, partitionSize, slice); + + LLVM_DEBUG({ + partitionOperand.dump(); + LDBG("\n"); + LDBG(" slice:"); + for (auto &op : slice) { + op->dump(); + } + LDBG("\n"); + }); + + for (auto op : partitionScheme.ops) { + auto opTaskIds = getAsyncTaskIds(op); + // skip check for control flow ops + if (isa(op)) + continue; +#if 0 + if (opTaskIds.size() > partitionScheme.numPartitions) { + LLVM_DEBUG({ + LDBG("partition not possible: numPartitions" << opTaskIds.size() << " " << partitionScheme.numPartitions); + op->dump(); + }); + return false; + } +#endif + } + } + + return !partitionScheme.ops.empty(); +} + +Operation *sliceOp(Value v, int offset, OpBuilderWithAsyncTaskIds &builder, + IRMapping &mappings, IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme); + +Operation *sliceOp(Operation *op, int offset, + OpBuilderWithAsyncTaskIds &builder, IRMapping &mappings, + IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme) { + if (!partitionScheme.ops.contains(op)) + return op; + if (mappings.contains(op)) + return mappings.lookupOrNull(op); + if (reverseMappings.contains(op)) + return op; + + LLVM_DEBUG({ + LDBG("slicing:"); + op->dump(); + LDBG("\n"); + }); + + int dim = partitionScheme.partitionDim; + int numOfPartitions = partitionScheme.numPartitions; + + auto asyncTaskIds = getAsyncTaskIds(op); + SmallVector sliceTaskIds; + if (asyncTaskIds.size() == numOfPartitions) { + // We are slicing the op for consumer only + sliceTaskIds.push_back(asyncTaskIds[offset]); + } else if (asyncTaskIds.size() == 1) { + // We are slicing the op for producer only + sliceTaskIds.push_back(asyncTaskIds.front()); + } else if (asyncTaskIds.size() > numOfPartitions) { + // We are slicing the op for both producer and consumer + sliceTaskIds.push_back(asyncTaskIds.front()); + sliceTaskIds.push_back(asyncTaskIds[offset + 1]); + } else { + llvm_unreachable("Unexpected asyncTaskIds.size()"); + } + + builder.setAsynTaskIdsFromArray(sliceTaskIds); + auto cloneAndSetResultType = [&](Operation *op) { + builder.setInsertionPoint(op); + auto newOp = builder.clone(*op, mappings); + setAsyncTaskIds(newOp, sliceTaskIds); + mappings.map(op, newOp); + reverseMappings.map(newOp, op); + // set result shape + if (!op->getResults().empty()) { + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + if (auto type = dyn_cast(v.getType())) { + SmallVector shape{type.getShape().begin(), + type.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newType = + MemDescType::get(shape, type.getElementType(), type.getEncoding(), + type.getMemorySpace(), type.getMutableMemory()); + newV.setType(newType); + } else if (auto type = dyn_cast(v.getType())) { + SmallVector shape{type.getShape().begin(), + type.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newType = RankedTensorType::get(shape, type.getElementType(), + type.getEncoding()); + newV.setType(newType); + } + + mappings.map(v, newV); + reverseMappings.map(newV, v); + } + return newOp; + }; + + // slice operands first + Operation *newOp; + if (op->hasTrait() || + isa( + op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + newOp = cloneAndSetResultType(op); + } else if (auto constOp = dyn_cast(op)) { + builder.setInsertionPoint(op); + auto valAttr = cast(constOp.getValueAttr()); + auto valType = cast(valAttr.getType()); + SmallVector shape{valType.getShape().begin(), + valType.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newValType = valType.clone(shape); + auto newValAttr = valAttr.resizeSplat(newValType); + newOp = builder.createWithAsyncTaskIds(op->getLoc(), + newValAttr); + // Do not drop original task id as constant folding may lose one constant. + setAsyncTaskIds(newOp, getAsyncTaskIds(op)); + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } else if (auto makeRangeOp = dyn_cast(op)) { + builder.setInsertionPoint(op); + int newRangeStart = makeRangeOp.getStart(); + int newRangeEnd = makeRangeOp.getEnd(); + int sliceSize = (newRangeEnd - newRangeStart) / numOfPartitions; + newRangeStart += offset * sliceSize; + newRangeEnd = newRangeStart + sliceSize; + auto v = op->getResult(0); + auto type = cast(v.getType()); + auto newType = RankedTensorType::get({sliceSize}, builder.getI32Type(), + type.getEncoding()); + newOp = builder.createWithAsyncTaskIds( + op->getLoc(), newType, newRangeStart, newRangeEnd); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } else if (isa(op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + // TODO: slice store base ptr + newOp = cloneAndSetResultType(op); + } else if (isa( + op)) { + SmallVector shape; + Value coordVal; + if (auto loadOp = dyn_cast(op)) { + coordVal = loadOp.getIndices()[dim]; + shape = getShape(loadOp.getResult()); + } else if (auto storeOp = dyn_cast(op)) { + coordVal = storeOp.getIndices()[dim]; + shape = getShape(storeOp.getSrc()); + } + auto newCoordVal = coordVal; + if (offset) { + builder.setInsertionPointAfter(coordVal.getDefiningOp()); + Value offsetVal = builder.createWithAsyncTaskIds( + op->getLoc(), offset * shape[dim] / numOfPartitions, 32); + newCoordVal = builder.createWithAsyncTaskIds( + op->getLoc(), coordVal, offsetVal); + mappings.map(coordVal, newCoordVal); + reverseMappings.map(newCoordVal, coordVal); + } + + newOp = cloneAndSetResultType(op); + if (isa(op)) { + // map load result + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } + } else if (auto dotOp = dyn_cast(op)) { + // Only hanlde A and accumulator + sliceOp(dim == 0 ? dotOp.getA() : dotOp.getB(), offset, builder, mappings, + reverseMappings, partitionScheme); + sliceOp(dotOp.getC(), offset, builder, mappings, reverseMappings, + partitionScheme); + newOp = cloneAndSetResultType(op); + } else if (auto forOp = dyn_cast(op)) { + // Add new loop arguments + SmallVector newLoopArgs; + for (auto initArg : forOp.getInitArgs()) + newLoopArgs.push_back(initArg); + DenseMap newArgIdices; + for (unsigned i = 0; i < forOp.getInitArgs().size(); i++) { + auto initArg = forOp.getInitArgs()[i]; + Value newInitArg; + auto newInitArgOp = sliceOp(initArg, offset, builder, mappings, + reverseMappings, partitionScheme); + if (auto bbArg = dyn_cast(initArg)) { + // find the corresponding new block argument + Block *parentBlock = bbArg.getOwner(); + unsigned argIndex = parentBlock->getNumArguments(); + for (unsigned i = 0; i < parentBlock->getNumArguments(); ++i) { + if (parentBlock->getArgument(i) == bbArg) { + argIndex = i; + break; + } + } + assert(argIndex < parentBlock->getNumArguments() && + "new init argment not found"); + Region *parentRegion = parentBlock->getParent(); + Region &newParentRegion = + newInitArgOp->getRegion(parentRegion->getRegionNumber()); + newInitArg = parentRegion->getArgument(argIndex); + } else { + auto initArgOp = initArg.getDefiningOp(); + unsigned resultIndex = cast(initArg).getResultNumber(); + newInitArg = newInitArgOp->getResult(resultIndex); + } + + if (newInitArg != initArg) { + newLoopArgs.append({newInitArg}); + forOp.getBody()->insertArgument(forOp.getBody()->getNumArguments(), + newInitArg.getType(), forOp.getLoc()); + newArgIdices[i] = newLoopArgs.size() - 1; + } + } + + // Create newForOp and take the region of forOp + builder.setInsertionPoint(op); + auto newForOp = builder.createWithAsyncTaskIds( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newLoopArgs); + assert(newForOp.getRegionIterArgs().size() == + newForOp.getInitArgs().size()); + newForOp->setAttrs(forOp->getAttrs()); + partitionScheme.ops.insert(newForOp); + newOp = newForOp; + + // Replace forOp with newForOp + newForOp.getRegion().takeBody(forOp.getRegion()); + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + op->setAttr("to_be_removed", builder.getUnitAttr()); + + // Map new loop arguments + for (auto argIndex : newArgIdices) { + Value v = newForOp.getResult(argIndex.first); + Value newV = newForOp.getResult(argIndex.second); + mappings.map(v, newV); + reverseMappings.map(newV, v); + + auto regionArg = newForOp.getRegionIterArg(argIndex.first); + auto newRegionArg = newForOp.getRegionIterArg(argIndex.second); + mappings.map(regionArg, newRegionArg); + reverseMappings.map(newRegionArg, regionArg); + } + } else if (auto ifOp = dyn_cast(op)) { + // Slice the yield op and update if results + auto thenYieldOp = ifOp.thenYield(); + auto elseYieldOp = ifOp.elseYield(); + auto newThenYieldOp = sliceOp(thenYieldOp, offset, builder, mappings, + reverseMappings, partitionScheme); + sliceOp(elseYieldOp, offset, builder, mappings, reverseMappings, + partitionScheme); + assert(newThenYieldOp->getNumOperands() > ifOp->getNumResults() && + "no need to slice if op"); + // Clone ifOp with updated results but re-use the original regions. + builder.setInsertionPoint(op); + SmallVector newResultTypes; + for (auto thenResult : thenYieldOp.getResults()) { + newResultTypes.push_back(thenResult.getType()); + } + auto newIfOp = builder.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition()); + // Move the original regions to the cloned operation. + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + newOp = newIfOp; + newIfOp->setAttrs(ifOp->getAttrs()); + partitionScheme.ops.insert(newIfOp); + ifOp->setAttr("to_be_removed", builder.getUnitAttr()); + + // Replace ifOp with newIfOp + for (unsigned i = 0; i < ifOp.getNumResults(); ++i) + ifOp.getResult(i).replaceAllUsesWith(newIfOp.getResult(i)); + + // Map if results based on the mapping for yield + for (auto &v : thenYieldOp->getOpOperands()) { + auto newV = mappings.lookupOrNull(v.get()); + if (newV) { + int operandIndex = v.getOperandNumber(); + // find the corresponding operand index of newV in newYieldOp + int newOperandIndex = -1; + for (int i = 0; i < newThenYieldOp->getNumOperands(); ++i) { + if (newThenYieldOp->getOperand(i) == newV) { + newOperandIndex = i; + break; + } + } + assert(newOperandIndex >= 0 && "newV not found in newYieldOp"); + auto newResult = newIfOp.getResult(operandIndex); + auto newSlicedResult = newIfOp.getResult(newOperandIndex); + mappings.map(newResult, newSlicedResult); + reverseMappings.map(newSlicedResult, newResult); + } + } + } else if (auto yieldOp = dyn_cast(op)) { + int num = yieldOp.getNumOperands(); + for (int i = 0; i < num; i++) { + auto operand = yieldOp.getOperand(i); + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + if (auto newV = mappings.lookupOrNull(operand)) + yieldOp->insertOperands(op->getNumOperands(), newV); + } + newOp = op; + } else if (auto reduceOp = dyn_cast(op)) { + assert(reduceOp.getAxis() != partitionScheme.partitionDim && + "reduce should not happen on the partitioned dimension"); + for (Value operand : op->getOperands()) + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + newOp = cloneAndSetResultType(op); + // recursively set async task ids for child ops + newOp->walk( + [&](Operation *childOp) { setAsyncTaskIds(childOp, sliceTaskIds); }); + } else { + llvm_unreachable("unsupported op type"); + } + + LLVM_DEBUG({ + LDBG("resulting"); + newOp->dump(); + LDBG("\n"); + }); + mappings.map(op, newOp); + reverseMappings.map(newOp, op); + return newOp; +} + +Operation *sliceOp(Value v, int offset, OpBuilderWithAsyncTaskIds &builder, + IRMapping &mappings, IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme) { + if (auto op = v.getDefiningOp()) { + return sliceOp(op, offset, builder, mappings, reverseMappings, + partitionScheme); + } else { + assert(isa(v) && "value is not an operation or block "); + auto bbArg = cast(v); + Operation *bbAargOwner = bbArg.getOwner()->getParentOp(); + return sliceOp(bbAargOwner, offset, builder, mappings, reverseMappings, + partitionScheme); + } +} + +void partitionTasks(triton::FuncOp &funcOp) { + + // op -> (partition dim, num of partitions) + DataPartitionScheme partitionScheme; + if (!computePartitionScheme(funcOp, partitionScheme)) + return; + + for (int i = 0; i < partitionScheme.numPartitions; i++) { + OpBuilderWithAsyncTaskIds builder(funcOp.getContext()); + IRMapping mappings, reverseMappings; + + LLVM_DEBUG({ LDBG("partitioning op for task " << i << ":\n"); }); + + // TODO: compute a topological order for partitionScheme.ops and + // slice in that order. + int numOps = partitionScheme.ops.size(); + for (int j = 0; j < numOps; j++) { + auto op = partitionScheme.ops[j]; + sliceOp(op, i, builder, mappings, reverseMappings, partitionScheme); + } + + // clean up + LLVM_DEBUG({ + LDBG("prior to clean up:"); + funcOp.dump(); + }); + SmallVector opsToDelete; + for (auto op : partitionScheme.ops) { + if (op->hasAttr("to_be_removed")) + opsToDelete.push_back(op); + } + for (auto op : opsToDelete) { + partitionScheme.ops.remove(op); + op->erase(); + } + } + + // clean up + + SmallVector opsToDelete; + for (auto op : partitionScheme.ops) { + if (isa(op)) + continue; + bool notUsed = true; + for (auto result : op->getResults()) { + if (!result.getUsers().empty()) { + notUsed = false; + break; + } + } + if (notUsed) + opsToDelete.push_back(op); + } + + LLVM_DEBUG({ + LDBG("opsToDelete:\n"); + for (auto op : opsToDelete) { + LDBG("op: "); + op->dump(); + } + LDBG("\n"); + }); + for (auto op : opsToDelete) { + partitionScheme.ops.remove(op); + op->erase(); + } + LLVM_DEBUG({ + LDBG("prior to clean up:"); + funcOp.dump(); + }); + + // delete block arguments + RewritePatternSet cleanUpPatterns(funcOp.getContext()); + populateForOpDeadArgumentElimination(cleanUpPatterns); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns, funcOp.getContext()); + scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns, funcOp.getContext()); + if (applyPatternsAndFoldGreedily(funcOp, std::move(cleanUpPatterns)) + .failed()) { + llvm_unreachable("failed to clean up"); + // signalPassFailure(); + } + + // Make sure original ops are not used + LLVM_DEBUG({ + LDBG("after partition"); + funcOp.dump(); + LDBG("\n"); + }); + fixTaskId(funcOp); +} + +#define GEN_PASS_DEF_TRITONGPUWSDATAPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUWSDataPartitionPass + : public impl::TritonGPUWSDataPartitionBase { +public: + using impl::TritonGPUWSDataPartitionBase< + TritonGPUWSDataPartitionPass>::TritonGPUWSDataPartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + if (numConsumerGroups == 0) + return; + partitionTasks(funcOp); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp new file mode 100644 index 000000000..d25c21778 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp @@ -0,0 +1,297 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include + +#include "mlir/IR/OperationSupport.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-spec-lowering" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +enum class LoadType { + LoadAsyncOp, + LoadTMAOp, +}; + +static Value createThreadIdOp(OpBuilder &builder, Location loc) { + Value threadId = builder.create<::mlir::gpu::ThreadIdOp>( + loc, builder.getIndexType(), ::mlir::gpu::Dimension::x); + auto cast = builder.create( + loc, TypeRange{builder.getIntegerType(32)}, ValueRange{threadId}); + return cast.getResult(0); +} + +// Lower to use GetCanonicalWarpIdOp. +// In Hopper, each task is a warpgroup consisting of 4 warps. +static const int WARPS_PER_TASK = 4; +static const int THREADS_PER_TASK = 128; +void lowerGetAsyncTaskIdOp(Operation *parentOp, int numConsumerGroups) { + DenseSet eraseOps; + parentOp->walk([&](ttng::GetAsyncTaskIdOp op) { + auto loc = op.getLoc(); + OpBuilder builder(op); + Value _4 = builder.create(loc, WARPS_PER_TASK, 32); + Value warpId = builder.create(loc); + Value asyncTaskId = builder.create(loc, warpId, _4); + op.getResult().replaceAllUsesWith(asyncTaskId); + + LLVM_DEBUG({ + LDBG("erasing GetAsyncTask"); + op->dump(); + }); + eraseOps.insert(op); + }); + for (Operation *op : eraseOps) + op->erase(); +} + +//===----------------------------------------------------------------------===// +// Lower token operations +//===----------------------------------------------------------------------===// + +LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) { + std::set loadTypes; + createTokenOp->getBlock()->walk([&](Operation *op) { + if (auto asyncCopy = dyn_cast(op)) { + loadTypes.insert(LoadType::LoadAsyncOp); + } else if (auto asyncCopy = + dyn_cast(op)) { + loadTypes.insert(LoadType::LoadTMAOp); + } + }); + assert(loadTypes.size() > 0 && "no async copy in the block"); + assert(loadTypes.size() == 1 && "block contains both async copy and tma"); + return *loadTypes.begin(); +} + +Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op, + bool emptyBarrier) { + auto loc = op->getLoc(); + assert(isa(op) || isa(op)); + Value curPhase; + if (auto acq = dyn_cast(op)) + curPhase = acq.getPhase(); + else if (auto wait = dyn_cast(op)) + curPhase = wait.getPhase(); + if (emptyBarrier) { + // curPhase = curPhase xor True for emptyBarrier. + Value _1_1b = builder.create(loc, 1, 1); + curPhase = builder.create(loc, curPhase, _1_1b); + } + LLVM_DEBUG(curPhase.dump()); + return curPhase; +} + +void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op, + Value bufferEmpty) { + auto loc = op.getLoc(); + Value phase = getMBarrierPhaseBit(builder, op, true); + auto i32Ty = builder.getIntegerType(32); + phase = builder.create(loc, i32Ty, phase); + auto waitOp = builder.create(loc, bufferEmpty, phase); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(waitOp, getAsyncTaskIds(op.getOperation())); +} + +void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op, + Value bufferFull, LoadType loadType) { + auto loc = op.getLoc(); + int txCnt = 0; + ttng::MBarrierArriveOp arriveOp; + + if (loadType == LoadType::LoadAsyncOp) { + // Each thread arrives. + Value pred = builder.create(loc, 1, 1); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ true, + txCnt); + } else { + // Only thread 0 arrives for TMA load. + Value _0 = builder.create(loc, 0, 32); + Value threadId = createThreadIdOp(builder, loc); + Value pred = builder.create(loc, arith::CmpIPredicate::eq, + threadId, _0); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ false, + txCnt); + } + + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); +} + +void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op, + Value bufferFull) { + auto loc = op.getLoc(); + Value phase = getMBarrierPhaseBit(builder, op, false); + auto i32Ty = builder.getIntegerType(32); + phase = builder.create(loc, i32Ty, phase); + auto waitOp = builder.create(loc, bufferFull, phase); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(waitOp, getAsyncTaskIds(op.getOperation())); +} + +void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op, + Value bufferEmpty, int numCTAs) { + auto loc = op.getLoc(); + auto arriveOp = builder.create( + loc, bufferEmpty, nullptr, nullptr, false, 0); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); +} + +void lowerTokenOperations(Operation *parentOp, int numCTAs, + int numConsumerGroups) { + SmallVector deprecatedOps; + parentOp->walk([&](ttng::CreateTokenOp createTokenOp) { + LoadType loadType = scanLoadTypes(createTokenOp); + MLIRContext *context = createTokenOp.getContext(); + OpBuilder builder(createTokenOp); + Location loc = createTokenOp.getLoc(); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = + tt::MemDescType::get({createTokenOp.getNum()}, builder.getI64Type(), + barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value bufferFullArray = builder.create( + loc, barrierMemDescType, Value()); + Value bufferEmptyArray = builder.create( + loc, barrierMemDescType, Value()); + + for (unsigned i = 0; i < createTokenOp.getNum(); i++) { + Value idx = builder.create(loc, i, 32); + Value barrierFullView = builder.create( + loc, singleBarrierMemDescType, bufferFullArray, idx); + unsigned bufferFullCount = + loadType == LoadType::LoadTMAOp ? 1 : THREADS_PER_TASK; + builder.create(loc, barrierFullView, + bufferFullCount); + + Value barrierEmptyView = builder.create( + loc, singleBarrierMemDescType, bufferEmptyArray, idx); + builder.create(loc, barrierEmptyView, + THREADS_PER_TASK); + } + + assert(numCTAs == 1 && "remote CTA is not supported yet"); + builder.create(loc); + + // Helper function for extracting one index from bufferFullArray. + auto extractBufferFull = [&](Location loc, Value idx) -> Value { + return builder.create( + loc, singleBarrierMemDescType, bufferFullArray, idx); + }; + + // Helper function for extracting one index from bufferEmptyArray. + auto extractBufferEmpty = [&](Location loc, Value idx) -> Value { + return builder.create( + loc, singleBarrierMemDescType, bufferEmptyArray, idx); + }; + + // Process token users: ProducerAcquireOp, ProducerCommitOp, ConsumerWaitOp, + // and ConsumerReleaseOp. + for (Operation *user : createTokenOp.getResult().getUsers()) { + auto loc = user->getLoc(); + builder.setInsertionPoint(user); + if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferEmpty.getDefiningOp(), getAsyncTaskIds(user)); + processProducerAcquireOp(builder, op, bufferEmpty); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferFull.getDefiningOp(), getAsyncTaskIds(user)); + processProducerCommitOp(builder, op, bufferFull, loadType); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferFull.getDefiningOp(), getAsyncTaskIds(user)); + processConsumerWaitOp(builder, op, bufferFull); + } else if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferEmpty.getDefiningOp(), getAsyncTaskIds(user)); + processConsumerReleaseOp(builder, op, bufferEmpty, numCTAs); + } else { + llvm_unreachable("Unexpected user of token"); + } + deprecatedOps.push_back(user); + } + + deprecatedOps.push_back(createTokenOp); + }); + for (auto op : deprecatedOps) { + op->erase(); + } + + assert(numCTAs == 1 && "remote CTA is not supported yet"); +} + +#define GEN_PASS_DEF_TRITONGPUWSLOWERING +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// This pass lowers WS-specific operations. +class TritonGPUWSLowering + : public impl::TritonGPUWSLoweringBase { +public: + using impl::TritonGPUWSLoweringBase< + TritonGPUWSLowering>::TritonGPUWSLoweringBase; + + void runOnOperation() override { + // Disable WarpSpec if numConsumerGroups is zero. + if (numConsumerGroups == 0) + return; + ModuleOp mod = getOperation(); + int numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + + lowerGetAsyncTaskIdOp(mod, numConsumerGroups); + lowerTokenOperations(mod, numCTAs, numConsumerGroups); + + // We assume number of warps per warp group is 4. + // With Warp Spec, the effective warps per CTA is + // number of warp groups * 4, but within each warp group, layout will use + // num_warps of 4, since tensors are not distributed between the groups. + // + // Loads usually happen in one producer warp groups. num_warps of 4 makes + // sense because only the 4 warps from the producer warp group are + // participating in the load. + // + // But at some point (at least when we launch the kernel!) we really do need + // to know that the CTA has 8 or 12 warps in it. Attribute + // "num-warp-groups-per-cta" can be used to calculate the total number of + // warps. + auto builder = OpBuilder::atBlockBegin(mod.getBody()); + mod->setAttr("triton_gpu.num-warp-groups-per-cta", + builder.getI32IntegerAttr(1 + numConsumerGroups)); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/WSTaskPartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSTaskPartition.cpp new file mode 100644 index 000000000..816ec0917 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSTaskPartition.cpp @@ -0,0 +1,167 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-task-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define GEN_PASS_DEF_TRITONGPUWSTASKPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct TaskSchedule { + unsigned numTasks = 0; + DenseMap opToTaskId; +}; + +// Compute a partition schedule for later passes to actually partition the +// program into async tasks. +void doPartition(triton::FuncOp &funcOp, unsigned numConsumerGroups) { + + // Bail out in the presence of user annotations. + DenseSet allAsyncTasks; + funcOp->walk([&](Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + allAsyncTasks.insert(asyncTasks.begin(), asyncTasks.end()); + }); + + if (!allAsyncTasks.empty()) + return; + + SmallVector loops; + SmallVector loads; + SmallVector dots; + + funcOp.walk([&](Operation *op) { + if (scf::ForOp forOp = dyn_cast(op)) + loops.push_back(forOp); + else if (isa(op)) + dots.push_back(op); + else if (isa(op)) + loads.push_back(op); + }); + + if (loops.empty() || loads.empty() || dots.empty()) + return; + + auto getLoopLevel = [&](Operation *op) { + // Compute loop depth + unsigned depth = 0; + Operation *parent = op->getParentOp(); + while (parent) { + if (isa(parent)) { + ++depth; + } + parent = parent->getParentOp(); + } + return depth; + }; + + // Step 1. Select loads into the first task, which is the producer task by + // default. Place dots into the second task, which is the consumer. + // Only consider loads that are connected to a dot op in a loop. + SmallVector producerOps; + SmallVector consumerOps; + for (auto op : dots) { + if (getLoopLevel(op) == 0) + continue; + consumerOps.push_back(op); + auto dotOp = dyn_cast(op); + if (!dotOp) + continue; + SetVector backwardSlice; + getBackwardSlice(dotOp.getA(), &backwardSlice); + getBackwardSlice(dotOp.getB(), &backwardSlice); + + for (auto depOp : backwardSlice) { + if (isa(depOp)) { + producerOps.push_back(depOp); + } + } + } + + LLVM_DEBUG({ + LDBG("Producer ops:\n"); + for (auto op : producerOps) { + op->dump(); + } + + LDBG("\n"); + LDBG("Consumer ops:\n"); + for (auto op : consumerOps) { + op->dump(); + } + + LDBG("\n"); + }); + + if (consumerOps.empty() || producerOps.empty()) + return; + + // Annoate the program with task ids + SmallVector producerTaskIds{0}; + SmallVector consumerTaskIds; + for (unsigned i = 0; i < numConsumerGroups; ++i) { + consumerTaskIds.push_back(i + producerTaskIds.size()); + } + + for (auto op : producerOps) { + setAsyncTaskIds(op, producerTaskIds); + } + + for (auto op : consumerOps) { + setAsyncTaskIds(op, consumerTaskIds); + } + + LLVM_DEBUG({ + LDBG("After task partition"); + funcOp.dump(); + LDBG("\n"); + }); +} + +class TritonGPUWSTaskPartitionPass + : public impl::TritonGPUWSTaskPartitionBase { +public: + using impl::TritonGPUWSTaskPartitionBase< + TritonGPUWSTaskPartitionPass>::TritonGPUWSTaskPartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + if (numConsumerGroups == 0) + return; + doPartition(funcOp, numConsumerGroups); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 0b06ee643..088ca2663 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -32,8 +32,8 @@ namespace mlir { namespace triton { namespace nvidia_gpu { -// -- DotAsyncOp -- -mlir::LogicalResult DotAsyncOp::inferReturnTypes( +// -- WarpGroupDotOp -- +mlir::LogicalResult WarpGroupDotOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { @@ -57,21 +57,33 @@ mlir::LogicalResult DotAsyncOp::inferReturnTypes( return mlir::success(); } -void DotAsyncOp::getEffects( +void WarpGroupDotOp::getEffects( SmallVectorImpl> &effects) { - auto a = getA(); - auto b = getB(); - if (isa(a.getType())) - effects.emplace_back(MemoryEffects::Read::get(), a, + auto &a = getAMutable(); + auto &b = getBMutable(); + if (isa(a.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &a, mlir::triton::gpu::SharedMemory::get()); - if (isa(b.getType())) - effects.emplace_back(MemoryEffects::Read::get(), b, + if (isa(b.get().getType())) + effects.emplace_back(MemoryEffects::Read::get(), &b, mlir::triton::gpu::SharedMemory::get()); } -// -- DotWaitOp -- -LogicalResult DotWaitOp::inferReturnTypes( +bool WarpGroupDotOp::needsPartialAccumulator() { + const auto &a = getA(); + const auto &d = getD(); + auto aTensorTy = cast(a.getType()); + auto aElTy = cast(a.getType()).getElementType(); + bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() || + aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ(); + bool accFP32 = cast(d.getType()).getElementType().isF32(); + uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc(); + return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1]; +} + +// -- WarpGroupDotWaitOp -- +LogicalResult WarpGroupDotWaitOp::inferReturnTypes( ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, @@ -81,6 +93,19 @@ LogicalResult DotWaitOp::inferReturnTypes( return mlir::success(); } +///--- Async related ops --- +void GetAsyncTaskIdOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state) { + build(builder, state, builder.getI32Type()); +} + +void CreateTokenOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, uint32_t num) { + auto tokenType = TokenType::get(builder.getContext()); + auto resultType = RankedTensorType::get({num}, tokenType); + build(builder, state, resultType, num); +} + static LogicalResult verifyBarrierType(Operation *op, MemDescType barrierType) { if (!barrierType.getElementType().isInteger(64) || barrierType.getShape() != ArrayRef({1})) @@ -99,7 +124,7 @@ LogicalResult InitBarrierOp::verify() { void InitBarrierOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), + effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), mlir::triton::gpu::SharedMemory::get()); } @@ -113,7 +138,7 @@ LogicalResult InvalBarrierOp::verify() { void InvalBarrierOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), + effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), mlir::triton::gpu::SharedMemory::get()); } @@ -127,7 +152,7 @@ LogicalResult BarrierExpectOp::verify() { void BarrierExpectOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), + effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), mlir::triton::gpu::SharedMemory::get()); } @@ -141,12 +166,11 @@ LogicalResult WaitBarrierOp::verify() { void WaitBarrierOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getAlloc(), + // The wait will flip the phase therefore it reads and writes the barrier. + effects.emplace_back(MemoryEffects::Read::get(), &getAllocMutable(), + mlir::triton::gpu::SharedMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getAllocMutable(), mlir::triton::gpu::SharedMemory::get()); - // Need a side effect to prevent compiler from reordering and removing - // the wait operation. - effects.emplace_back(MemoryEffects::Write::get(), - mlir::SideEffects::DefaultResource::get()); } // -- AsyncTMACopyGlobalToLocalOp -- @@ -155,17 +179,19 @@ LogicalResult AsyncTMACopyGlobalToLocalOp::verify() { return failure(); if (getCoord().size() < 1 || getCoord().size() > 5) return emitOpError("TMA copies must have between 1 and 5 coordinates"); + if (!getResult().getType().getMutableMemory()) + return emitOpError("Cannot store into immutable memory"); return success(); } void AsyncTMACopyGlobalToLocalOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getDescPtr(), + effects.emplace_back(MemoryEffects::Read::get(), &getDescPtrMutable(), mlir::triton::GlobalMemory::get()); - effects.emplace_back(MemoryEffects::Write::get(), getBarrier(), + effects.emplace_back(MemoryEffects::Write::get(), &getBarrierMutable(), mlir::triton::gpu::SharedMemory::get()); - effects.emplace_back(MemoryEffects::Write::get(), getResult(), + effects.emplace_back(MemoryEffects::Write::get(), &getResultMutable(), mlir::triton::gpu::SharedMemory::get()); } @@ -173,9 +199,9 @@ void AsyncTMACopyGlobalToLocalOp::getEffects( void AsyncTMACopyLocalToGlobalOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Write::get(), getDescPtr(), + effects.emplace_back(MemoryEffects::Write::get(), &getDescPtrMutable(), mlir::triton::GlobalMemory::get()); - effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + effects.emplace_back(MemoryEffects::Read::get(), &getSrcMutable(), mlir::triton::gpu::SharedMemory::get()); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt index 5adebc352..001d96214 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonNvidiaGPUTransforms FenceInsertion.cpp PlanCTA.cpp TMALowering.cpp + Utility.cpp DEPENDS TritonNvidiaGPUTransformsIncGen diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index c7dd8d595..fb0e7f6fd 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -44,7 +44,7 @@ struct FenceInsertionPass return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { - if (!isa(op)) + if (!isa(op)) return WalkResult::advance(); OpBuilder builder(op); auto a = op->getOperand(0); @@ -79,7 +79,7 @@ struct FenceInsertionPass static DenseSet> trace; auto op = operand.getDefiningOp(); // avoid redundant insertion - if (op && isa(op)) + if (op && op->hasTrait()) return false; // reach convertlayout if (op && isa(op) && diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 58e2888b7..64a7f3e30 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -1,7 +1,11 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" @@ -22,7 +26,9 @@ class TMALoadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op, - PatternRewriter &rewriter) const override { + PatternRewriter &baseRewriter) const override { + MLIRContext *ctx = op.getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); auto loc = op.getLoc(); auto tensorType = op.getResult().getType(); auto order = getOrder(tensorType.getEncoding()); @@ -36,17 +42,18 @@ class TMALoadLowering : public OpRewritePattern { } MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), - encoding, /*mutableMemory=*/true); - Value alloc = rewriter.create(loc, memDescType, Value()); + encoding, sharedMemorySpace, /*mutableMemory=*/true); + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); + Value alloc = rewriter.create(loc, memDescType); auto barrierCTALayout = CTALayoutAttr::get( /*context=*/tensorType.getContext(), /*CTAsPerCGA=*/{1}, /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); auto barrierEncoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, {0}, barrierCTALayout); - MemDescType barrierMemDescType = MemDescType::get( - {1}, rewriter.getI64Type(), barrierEncoding, /*mutableMemory=*/true); - Value barrierAlloc = - rewriter.create(loc, barrierMemDescType, Value()); + MemDescType barrierMemDescType = + MemDescType::get({1}, baseRewriter.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = rewriter.create(loc, barrierMemDescType); rewriter.create(loc, barrierAlloc, 1); int sizeInBytes = product(tensorType.getShape()) * tensorType.getElementType().getIntOrFloatBitWidth() / 8; @@ -70,6 +77,8 @@ class TMAStoreLowering LogicalResult matchAndRewrite(ExperimentalDescriptorStoreOp op, PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); auto loc = op.getLoc(); auto tensorType = op.getSrc().getType(); auto order = getOrder(tensorType.getEncoding()); @@ -83,12 +92,19 @@ class TMAStoreLowering } MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), - encoding, /*mutableMemory=*/true); - Value alloc = rewriter.create(loc, memDescType, op.getSrc()); - rewriter.create(loc, false); - rewriter.create( - loc, op.getDescPtr(), op.getIndices(), alloc); - rewriter.create(loc, 0); + encoding, sharedMemorySpace, /*mutableMemory=*/true); + auto alloc = rewriter.create(loc, memDescType, op.getSrc()); + auto attrs = op->getAttrs(); + alloc->setAttrs(attrs); + auto fence = + rewriter.create(loc, false); + fence->setAttrs(attrs); + auto asyncCopy = + rewriter.create( + loc, op.getDescPtr(), op.getIndices(), alloc); + asyncCopy->setAttrs(attrs); + auto tma_wait = rewriter.create(loc, 0); + tma_wait->setAttrs(attrs); rewriter.eraseOp(op); return success(); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp new file mode 100644 index 000000000..ba46d52eb --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp @@ -0,0 +1,88 @@ + +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +#include + +namespace mlir { + +namespace ttg = triton::gpu; + +//===----------------------------------------------------------------------===// +// Helper functions for async task +//===----------------------------------------------------------------------===// + +SmallVector getAsyncTaskIds(Operation *op) { + SmallVector asyncTaskIds; + if (auto attr = op->getAttrOfType("async_task_id")) + for (AsyncTaskId asyncTaskId : attr.getValues()) + asyncTaskIds.push_back(asyncTaskId); + return asyncTaskIds; +} + +bool hasAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId) { + for (AsyncTaskId candidate : getAsyncTaskIds(op)) + if (candidate == asyncTaskId) + return true; + return false; +} + +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds) { + SmallVector sortedAsyncTaskIds(asyncTaskIds.begin(), + asyncTaskIds.end()); + sort(sortedAsyncTaskIds); + auto i32Ty = IntegerType::get(op->getContext(), 32); + auto size = static_cast(sortedAsyncTaskIds.size()); + auto vecTy = VectorType::get(size, i32Ty); + op->setAttr("async_task_id", + DenseIntElementsAttr::get(vecTy, sortedAsyncTaskIds)); +} + +SmallVector getNestedAsyncTaskIds(Operation *op) { + SetVector asyncTaskIds; + op->walk([&](Operation *curOp) { + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(curOp)) + asyncTaskIds.insert(asyncTaskId); + }); + SmallVector res(asyncTaskIds.begin(), asyncTaskIds.end()); + llvm::sort(res); + return res; +} + +void addAsyncTaskIds(Operation *op, ArrayRef asyncTasks) { + auto asyncTasksVec = getAsyncTaskIds(op); + DenseSet asyncTasksSet(asyncTasksVec.begin(), asyncTasksVec.end()); + for (int a : asyncTasks) { + if (!asyncTasksSet.contains(a)) { + asyncTasksVec.push_back(a); + } + } + if (asyncTasksVec.size() > 0) { + setAsyncTaskIds(op, asyncTasksVec); + } +} + +void removeAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId) { + auto origAsyncTaskIds = getAsyncTaskIds(op); + auto end = std::remove(origAsyncTaskIds.begin(), origAsyncTaskIds.end(), + asyncTaskId); + origAsyncTaskIds.erase(end, origAsyncTaskIds.end()); + if (origAsyncTaskIds.empty()) + op->removeAttr("async_task_id"); + else + setAsyncTaskIds(op, origAsyncTaskIds); +} + +void removeAsyncTaskIds(Operation *op) { op->removeAttr("async_task_id"); } +//===----------------------------------------------------------------------===// +// Implementations for general auto WS +//===----------------------------------------------------------------------===// + +} // namespace mlir diff --git a/lib/Target/LLVMIR/LLVMDIScope.cpp b/lib/Target/LLVMIR/LLVMDIScope.cpp index af7079060..4aa9828cd 100644 --- a/lib/Target/LLVMIR/LLVMDIScope.cpp +++ b/lib/Target/LLVMIR/LLVMDIScope.cpp @@ -103,9 +103,9 @@ struct LLVMDIScopePass : public LLVMDIScopeBase { // the column offset auto subprogramAttr = LLVM::DISubprogramAttr::get( context, distinctId, compileUnitAttr, fileAttr, funcNameAttr, - funcNameAttr, fileAttr, - /*line=*/line, - /*scopeline=*/line, subprogramFlags, subroutineTypeAttr); + funcNameAttr, fileAttr, /*line=*/line, /*scopeline=*/line, + subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{}, + /*annotations=*/{}); funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); } diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 75e530db5..bf017f8c6 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -1,18 +1,44 @@ #include "triton/Tools/LinearLayout.h" #include +#include #include #include "mlir/IR/BuiltinAttributes.h" #include "third_party/f2reduce/f2reduce.h" #include "triton/Tools/StrUtil.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" +#define DEBUG_TYPE "linear_layout" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#if defined(_MSC_VER) && !defined(__clang__) +// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 +#include + +static int __builtin_ctz(unsigned x) { + unsigned long r; + _BitScanForward(&r, x); + return static_cast(r); +} + +static int __builtin_ctzll(unsigned long long x) { + unsigned long r; + _BitScanForward64(&r, x); + return static_cast(r); +} + +#endif + namespace mlir::triton { namespace { using BasesT = LinearLayout::BasesT; +using llvm::SmallDenseSet; using llvm::Twine; BasesT makeBasesMap( @@ -24,115 +50,104 @@ BasesT makeBasesMap( return ret; } -std::string stringifyBases(const BasesT &bases, - ArrayRef outDimNames) { - std::string ret; - - if (bases.empty()) - return "(empty layout)\n"; - - // TODO: Add spaces for alignment. - for (const auto &[inDim, inDimBases] : bases) { - if (inDimBases.empty()) { - ret += " - " + inDim.str() + " is a size 1 dimension\n"; - continue; +// Dump the matrix to stderr in a human-readable format for debugging. +void dumpMatrix(uint64_t *m, int numRows, int numCols) { + assert(numCols <= 64); + for (int r = 0; r < numRows; r++) { + llvm::errs() << "0b"; + for (int c = 0; c < numCols; c++) { + llvm::errs() << ((m[r] & (1 << c)) != 0 ? "1" : "0"); } - - ret += " - " + - join(llvm::seq(inDimBases.size()), "\n ", - [&, &inDim = inDim, &inDimBases = inDimBases](int i) { - return inDim.str() + "=" + std::to_string(1 << i) + " -> (" + - join(inDimBases[i], ", ") + ")"; - }) + - "\n"; + llvm::errs() << "\n"; } - ret += "where out dims are: [" + - join(outDimNames, ", ", [](StringAttr s) { return s.str(); }) + "]\n"; - return ret; } -BasesT validateBases(BasesT bases, ArrayRef outDimNames) { - if (bases.empty()) - return bases; +// Build a matrix of size sum(outDimSizeLog2) x sum(inDimSizeLog2) representing +// the bases of the given layout. This can then be used by f2reduce. +// +// This function is called from the constructor of LinearLayout, so be careful +// not to use any functions that create LLs in here. +std::unique_ptr getMatrix(const LinearLayout &layout) { + int numRows = layout.getTotalOutDimSizeLog2(); + int numCols = layout.getTotalInDimSizeLog2(); - for (const auto &[inDim, inDimBases] : bases) { - for (const auto &basis : inDimBases) { - if (llvm::any_of(basis, [](int32_t b) { return b < 0; })) { - llvm::report_fatal_error( - "Invalid bases passed to LinearLayout. Expected all basis " - "values to be non-negative, but found a negative value for " - "in dimension '" + - Twine(inDim) + "'. Full list of bases:\n" + - stringifyBases(bases, outDimNames)); + // Don't handle giant LLs. This makes some things easier; for example, each + // row can be a single uint64_t. + assert(numCols <= 64 && "LinearLayout too large"); + assert(numRows <= 64 && "LinearLayout too large"); + + // Suppose we have a layout specified by the following values. + // + // L(0,1) = (0b01, 0b1) + // L(0,2) = (0b10, 0b0) + // L(1,0) = (0b10, 0b0) + // L(2,0) = (0b11, 0b0) + // + // We will create one column per entry above. The max bit width of the + // codomain is (2,1), so our matrix will have 2+1=3 rows. The final matrix + // will be + // + // | L(0,1)[0] L(0,2)[0] L(1,0)[0] L(2,0)[0] | | 0b1001 | + // | ↓ ↓ ↓ ↓ | | 0b0111 | + // | L(0,1)[1] L(0,2)[1] L(1,0)[1] L(2,0)[1] | = | 0b1000 | + // | ↓ ↓ ↓ ↓ | + // + // Note `new uint64_t[n]()` is zero-initialized, but `new uint64_t[n]` is not. + std::unique_ptr m(new uint64_t[numRows]()); + int r = 0; + for (StringAttr outDim : layout.getOutDimNames()) { + int c = 0; + for (StringAttr inDim : layout.getInDimNames()) { + for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) { + uint64_t basis = layout.getBasis(inDim, i, outDim); + for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) { + m[r + j] |= ((basis >> j) & 1) << c; + } + c++; } } + r += layout.getOutDimSizeLog2(outDim); } - // Check that the bases all have length equal to outDimNames.size(). - for (const auto &[inDim, inDimBases] : bases) { - for (const auto &basis : inDimBases) { - if (basis.size() != outDimNames.size()) { - llvm::report_fatal_error( - "Invalid bases passed to LinearLayout. Expect all bases to have " - "the same size, equal to outDimNames.size() (" + - Twine(outDimNames.size()) + - "). But this failed for in dimension '" + Twine(inDim) + - "'. Full list of bases:\n" + stringifyBases(bases, outDimNames)); - } + return m; +} + +// Get a matrix for `layout` with its codomain expanded so it's injective, i.e. +// each input element maps to a unique output element. We do this by finding +// columns that are equal to 0 and adding a new row with a 1 in that column. +std::tuple, int /*numRows*/, int /*numCols*/> +getInjectiveMat(const LinearLayout &layout) { + int numRows = layout.getTotalOutDimSizeLog2(); + int numCols = layout.getTotalInDimSizeLog2(); + std::unique_ptr mat = getMatrix(layout); + + // Bits of mat or-reduced along the columns (so there's just one row). + uint64_t colBits = 0; + for (int r = 0; r < numRows; r++) { + colBits |= mat[r]; + } + auto expanded = std::unique_ptr(new uint64_t[numRows + numCols]); + std::memcpy(expanded.get(), mat.get(), numRows * sizeof(uint64_t)); + for (int c = 0; c < numCols; c++) { + if ((colBits & (1 << c)) == 0) { + expanded[numRows++] = (1 << c); } } - - return bases; + return std::make_tuple(std::move(expanded), numRows, numCols); } // Compute the rank of the matrix formed by taking the bases for the given // outDim as columns. In other words, finds the number of linearly-independent // bases for this output dimension. -int getMatrixRank(const LinearLayout &layout, StringAttr outDim) { - // Suppose we have a layout specified by the following key values. - // - // L(0,1) = 0b01 - // L(0,2) = 0b10 - // L(1,0) = 0b10 - // L(2,0) = 0b11 - // - // We will create one column per key value. The max bit width of these values - // is 2, so our matrix will have 2 rows. The final matrix will be - // - // | ↑ ↑ ↑ ↑ | | 0b0111 | - // | L(0,1) L(0,2) L(1,0) L(2,0) | = | 0b1001 | - // | ↓ ↓ ↓ ↓ | - int numRows = layout.getOutDimSizeLog2(outDim); - - int numCols = 0; - for (StringAttr inDim : layout.getInDimNames()) { - numCols += layout.getInDimSizeLog2(inDim); - } - - if (numCols == 0 || numRows == 0) +int getMatrixRank(std::unique_ptr m, int numRows, int numCols) { + // f2reduce underflows if the number of cols is 0, return the rank early in + // this case. + if (numCols == 0) { return 0; - - // Don't handle giant LLs. This makes some things easier; for example, each - // row can be a single uint64_t. - assert(numCols <= 64 && "LinearLayout too large"); - assert(numRows <= 64 && "LinearLayout too large"); - - // Note that `new int[n]()` is zero-initialized, whereas `new int[n]` is not. - std::unique_ptr m(new uint64_t[numRows]()); - - // Fill in the matrix. - int c = 0; - for (StringAttr inDim : layout.getInDimNames()) { - for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) { - uint64_t basis = layout.getBasis(inDim, i, outDim); - for (int j = 0; j < numRows; j++) { - m[j] |= ((basis >> j) & 1) << c; - } - c++; - } } - - // stride is specified in number of 64-bit words per row. + // stride is specified in number of 64-bit words per row, and we pack our + // matrix so that there's only one uint64_t per row. + assert(numCols <= 64); f2reduce::inplace_rref_strided(m.get(), numRows, numCols, /*stride=*/1); // The rank of the reduced matrix is simply the number of nonzero rows. @@ -144,40 +159,10 @@ int getMatrixRank(const LinearLayout &layout, StringAttr outDim) { return rank; } -// Check that the given layout is surjective, i.e. that every `out` coordinate -// can be reached by some `in` coordinate. -// -// It's sufficient to check each output dimension indepedently. Still, -// it's prohibitively slow to calculate this naively. -// -// Thankfully, this is equivalent to checking that the number of -// linearly-independent bases for outDim d is equal to getOutDimSizeLog2(d). -// This can be computed by finding the rank of the matrix whose columns are -// those bases. We can compute the rank of our matrix using Gaussian -// elimination, which runs in O(n^3) for an n x n matrix. Our matrix size is -// log(product(inDimSize)) x log(outDimSize), and we do this numOutDims times, -// so this should be plenty fast overall. -void validateSurjectivity(const LinearLayout &layout) { - for (const auto &outDim : layout.getOutDimNames()) { - unsigned rank = getMatrixRank(layout, outDim); - unsigned expectedRank = layout.getOutDimSizeLog2(outDim); - if (rank != expectedRank) { - llvm::report_fatal_error( - "Invalid bases passed to LinearLayout. Expected bases to be " - "surjective, i.e. all possible output coordinates can be reached " - "by some input coordinates. But this failed for output dimension " + - Twine(outDim) + ", where we got rank " + Twine(rank) + - " instead of expected rank " + Twine(expectedRank) + - ". Full list of bases:\n" + - Twine(stringifyBases(layout.getBases(), layout.getOutDimNames()))); - } - } -} - template void assertDimsEqualIgnoringOrder(T &&a, U &&b) { - llvm::DenseSet as(a.begin(), a.end()); - llvm::DenseSet bs(b.begin(), b.end()); + SmallDenseSet as(a.begin(), a.end()); + SmallDenseSet bs(b.begin(), b.end()); if (as != bs) { llvm::report_fatal_error("Dimensions must match, ignoring order, but they " "don't. Got dims: [" + @@ -186,12 +171,209 @@ void assertDimsEqualIgnoringOrder(T &&a, U &&b) { } } +template +void assertDimsSubsetIgnoringOrder(T &&small, U &&big) { + SmallDenseSet smallSet(small.begin(), small.end()); + SmallDenseSet bigSet(big.begin(), big.end()); + if (!llvm::set_is_subset(smallSet, bigSet)) { + llvm::report_fatal_error("Dimensions must be a subset, ignoring order, but " + "they aren't. Got dims: [" + + Twine(triton::join(small, ", ")) + "] and [" + + triton::join(big, ", ") + "]"); + } +} + +// Check that elements common to both aDims and bDims +// appear in the same relative order. +template +void assertCommonDimsSameOrder(T &&aDims, U &&bDims) { + SmallDenseSet aDimsSet(aDims.begin(), aDims.end()); + SmallDenseSet bDimsSet(bDims.begin(), bDims.end()); + + std::vector aCommonDims; + for (StringAttr dim : aDims) { + if (bDimsSet.contains(dim)) { + aCommonDims.push_back(dim); + } + } + + std::vector bCommonDims; + for (StringAttr dim : bDims) { + if (aDimsSet.contains(dim)) { + bCommonDims.push_back(dim); + } + } + + if (aCommonDims != bCommonDims) { + llvm::report_fatal_error("All a/b dimensions common to both layouts " + "must appear in the same relative order, but they " + "don't.\na:" + + Twine(triton::join(aDims, ", ")) + + "\nb: " + triton::join(bDims, ", ")); + } +} + +void eraseEmptyInOutDims(BasesT &bases, + llvm::MapVector &outDims) { + // Erase empty out-dims. + SmallVector emptyOutDims; + for (auto [i, outDim] : llvm::enumerate( + llvm::to_vector_of(llvm::make_first_range(outDims)))) { + if (outDims[outDim] == 1) { + emptyOutDims.push_back(i); + outDims.erase(outDim); + } + } + if (outDims.empty()) { + bases.clear(); + return; + } + + for (auto &[inDim, inDimBases] : bases) { + for (auto &basis : inDimBases) { + // Erase the basis elements corresponding to the empty out-dims. + for (int i : llvm::reverse(emptyOutDims)) { + basis.erase(basis.begin() + i); + } + } + } + + // Erase empty in-dims. + // TODO: This needs a test-case. + for (StringAttr inDim : + llvm::to_vector_of(llvm::make_first_range(bases))) { + if (bases[inDim].empty()) { + bases.erase(inDim); + } + } +} + } // anonymous namespace +/*static*/ std::optional +LinearLayout::tryCreate(BasesT bases, + ArrayRef> outDims, + bool requireSurjective) { + LinearLayout ll(std::move(bases), std::move(outDims), NoCheckInvariants{}); + std::optional error = ll.checkInvariants(requireSurjective); + if (error) { + return std::nullopt; + } + return ll; +} + +LinearLayout::LinearLayout(BasesT bases, + ArrayRef> outDims, + NoCheckInvariants) + : bases(std::move(bases)) { + for (auto [outDim, size] : outDims) { + this->outDims[outDim] = size; + } +} + LinearLayout::LinearLayout(BasesT bases, ArrayRef outDimNames) - : bases(validateBases(std::move(bases), outDimNames)), - outDimNames(outDimNames.begin(), outDimNames.end()) { - validateSurjectivity(*this); + : bases(std::move(bases)) { + // Infer out-dim sizes. + for (StringAttr outDim : outDimNames) { + outDims[outDim] = 1; + } + for (const auto &[inDim, inDimBases] : this->bases) { + for (const auto &basis : inDimBases) { + for (int i = 0; i < basis.size(); i++) { + int32_t &size = outDims[outDimNames[i]]; + size = std::max(size, llvm::NextPowerOf2(basis[i])); + } + } + } + + std::optional error = + checkInvariants(/*requireSurjective=*/true); + if (error.has_value()) { + llvm::report_fatal_error(StringRef(*error)); + } +} + +LinearLayout::LinearLayout(BasesT bases, + ArrayRef> outDims, + bool requireSurjective) + : LinearLayout(std::move(bases), std::move(outDims), NoCheckInvariants{}) { + std::optional error = checkInvariants(requireSurjective); + if (error.has_value()) { + llvm::report_fatal_error(StringRef(*error)); + } +} + +std::optional +LinearLayout::checkInvariants(bool requireSurjective) { + LDBG("checkInvariants: " << toString()); + // Check that basis values are non-negative. + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t b) { return b < 0; })) { + return "Invalid bases passed to LinearLayout. Expected all basis " + "values to be non-negative, but found a negative value for " + "in dimension '" + + inDim.str() + "'. Full list of bases:" + toString() + "\n"; + } + } + } + + // Check that the bases all have length equal to outDimNames.size(). + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (basis.size() != outDims.size()) { + return "Invalid bases passed to LinearLayout. Expect all bases to " + "have the same size, equal to outDimNames.size() (" + + std::to_string(outDims.size()) + + "). But this failed for in dimension '" + inDim.str() + + "'. Full list of bases:" + toString() + "\n"; + } + } + } + + // Check that the out-dim sizes are powers of 2. + for (const auto &[outDim, size] : outDims) { + if (!llvm::isPowerOf2_32(size)) { + return "Invalid out-dim size " + std::to_string(size) + " for out-dim '" + + outDim.str() + "'. Out-dim sizes must be powers of 2.\n"; + } + } + + // Check that the bases are smaller than the out-dim sizes. + SmallVector outDimNames = llvm::to_vector(getOutDimNames()); + for (const auto &[inDim, inDimBases] : this->bases) { + for (const auto &basis : inDimBases) { + for (int i = 0; i < basis.size(); i++) { + if (basis[i] >= outDims[outDimNames[i]]) { + return "Invalid basis " + std::to_string(basis[i]) + " for in-dim '" + + inDim.str() + "' and out-dim '" + outDimNames[i].str() + + "'. Basis must be less than the out-dim size.\n"; + } + } + } + } + + // Determine whether the this layout is surjective, i.e. that every `out` + // coordinate can be reached by some `in` coordinate. + // + // It's prohibitively slow to calculate this naively, but thankfully, this + // is equivalent to checking that the number of linearly-independent bases + // is equal to sum(getOutDimSizeLog2). This can be computed by finding + // the rank of the matrix whose columns are those bases. We can compute + // the rank of our matrix using Gaussian elimination, which runs in O(n^3) + // for an n x n matrix. Our matrix size is sum(inDimSizeLog2) x + // sum(outDimSizeLog2), so this should be plenty fast. + this->surjective = + getMatrixRank(getMatrix(*this), /*numRows=*/getTotalOutDimSizeLog2(), + /*numCols=*/getTotalInDimSizeLog2()) == + getTotalOutDimSizeLog2(); + + if (requireSurjective && !surjective) { + return "Layout is expected to be surjective, i.e. every `out` coordinate " + "can be reached by some `in` coordinate, but was not:" + + toString(); + } + return std::nullopt; } LinearLayout::LinearLayout( @@ -199,6 +381,11 @@ LinearLayout::LinearLayout( ArrayRef outDimNames) : LinearLayout(makeBasesMap(bases), outDimNames) {} +LinearLayout::LinearLayout( + ArrayRef>>> bases, + ArrayRef> outDims, bool requireSurjective) + : LinearLayout(makeBasesMap(bases), outDims, requireSurjective) {} + /*static*/ LinearLayout LinearLayout::identity1D(int32_t size, StringAttr inDimName, StringAttr outDimName) { @@ -228,13 +415,14 @@ LinearLayout::LinearLayout( } int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const { - // Sadly SetVector doesn't provide an O(1) way to do this. - for (int i = 0; i < outDimNames.size(); ++i) { - if (outDimNames[i] == outDim) { + int i = 0; + for (auto [name, _] : outDims) { + if (name == outDim) { return i; } + i++; } - llvm::report_fatal_error("outDim " + Twine(outDim) + " is not in layout\n" + + llvm::report_fatal_error("outDim " + Twine(outDim) + " is not in layout" + toString()); } @@ -244,16 +432,55 @@ int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const { return it->second.size(); } +int32_t LinearLayout::getTotalInDimSizeLog2() const { + return std::accumulate(getInDimNames().begin(), getInDimNames().end(), 0, + [&](int32_t acc, StringAttr inDim) { + return acc + getInDimSizeLog2(inDim); + }); +} + int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const { - // TODO(jlebar): Cache this? - int32_t outDimIdx = getOutDimIndex(outDim); - int32_t max = 0; + auto it = outDims.find(outDim); + assert(it != outDims.end()); + return llvm::Log2_32(it->second); +} + +int32_t LinearLayout::getTotalOutDimSizeLog2() const { + return std::accumulate(getOutDimNames().begin(), getOutDimNames().end(), 0, + [&](int32_t acc, StringAttr outDim) { + return acc + getOutDimSizeLog2(outDim); + }); +} + +int32_t LinearLayout::getNumConsecutiveInOut() const { + if (bases.empty() || getNumOutDims() == 0) + return 1; + + // Count how many of the initial bases for the first in-dim are + // (2^i, 0, ..., 0). + const auto &firstInDimBases = bases.begin()->second; + int consec = 0; + for (; consec < firstInDimBases.size(); consec++) { + const auto &basis = firstInDimBases[consec]; + if (basis[0] != (1 << consec) || + !std::all_of(basis.begin() + 1, basis.end(), + [](int32_t x) { return x == 0; })) { + break; + } + } + + // `or` together all other bases' first out-dim. + int32_t otherBits = 0; for (const auto &[inDim, inDimBases] : bases) { - for (const auto &basis : inDimBases) { - max = std::max(max, basis[outDimIdx]); + for (int i = 0; i < inDimBases.size(); i++) { + if (inDim != bases.begin()->first || i >= consec) { + otherBits |= inDimBases[i][0]; + } } } - return max == 0 ? 0 : llvm::Log2_32(max) + 1; + int32_t trailingZeros = otherBits != 0 ? __builtin_ctz(otherBits) : 31; + + return 1 << std::min(consec, trailingZeros); } LinearLayout LinearLayout::transposeIns(ArrayRef newInDims) const { @@ -263,7 +490,8 @@ LinearLayout LinearLayout::transposeIns(ArrayRef newInDims) const { for (const auto &inDim : newInDims) { newBases[inDim] = bases.find(inDim)->second; } - return LinearLayout(std::move(newBases), outDimNames.getArrayRef()); + return LinearLayout(std::move(newBases), llvm::to_vector(outDims), + surjective); } LinearLayout @@ -286,66 +514,122 @@ LinearLayout::transposeOuts(ArrayRef newOutDims) const { newInDimBases.push_back(std::move(newBasis)); } } - return LinearLayout(std::move(newBases), newOutDims); + + SmallVector> newOutDimSizes; + for (auto outDim : newOutDims) { + newOutDimSizes.push_back({outDim, getOutDimSize(outDim)}); + } + return LinearLayout(std::move(newBases), newOutDimSizes, surjective); } -LinearLayout operator*(LinearLayout inner, LinearLayout outer) { - // Check that elements common to both outerDimsRange and innerDimsRange appear - // in the same relative order. - auto checkCommonDims = [&](auto outerDimsRange, auto innerDimsRange) { - llvm::DenseSet outerDims(outerDimsRange.begin(), - outerDimsRange.end()); - llvm::DenseSet innerDims(innerDimsRange.begin(), - innerDimsRange.end()); - - std::vector outerCommonDims; - for (StringAttr dim : outerDimsRange) { - if (innerDims.contains(dim)) { - outerCommonDims.push_back(dim); - } +LinearLayout LinearLayout::reshapeIns( + ArrayRef> newInDims) const { + assert(llvm::all_of(newInDims, [&](auto &inDim) { + return llvm::isPowerOf2_32(inDim.second); + })); + assert(getTotalInDimSize() == std::accumulate(newInDims.begin(), + newInDims.end(), 1, + [&](int32_t acc, auto &inDim) { + return acc * inDim.second; + })); + + // First flatten into a single in-dimension. Then split it up according + // to `newInDims`. + SmallVector> flatBases; + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + flatBases.push_back(basis); + } + } + + BasesT newBases; + int i = 0; + for (const auto &[inDim, inDimSize] : newInDims) { + auto &newInDimBases = newBases[inDim]; + for (int j = 0; j < llvm::Log2_32(inDimSize); j++) { + newInDimBases.push_back(flatBases[i++]); } + } + return LinearLayout(std::move(newBases), llvm::to_vector(outDims), + surjective); +} - std::vector innerCommonDims; - for (StringAttr dim : innerDimsRange) { - if (outerDims.contains(dim)) { - innerCommonDims.push_back(dim); +LinearLayout LinearLayout::reshapeOuts( + ArrayRef> newOutDims) const { + assert(llvm::all_of(newOutDims, [&](auto &outDim) { + return llvm::isPowerOf2_32(outDim.second); + })); + assert(getTotalOutDimSize() == + std::accumulate( + newOutDims.begin(), newOutDims.end(), 1, + [&](int32_t acc, auto &outDim) { return acc * outDim.second; })); + + SmallVector shifts; + shifts.push_back(0); + for (StringAttr outDim : getOutDimNames()) { + shifts.push_back(shifts.back() + getOutDimSizeLog2(outDim)); + } + + // Flatten into a single out-dimension. Then split it up according to + // `newOutDims`. + llvm::MapVector> flatBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &flatInBases = flatBases[inDim]; + for (const auto &basis : inDimBases) { + int b = 0; + for (int i = 0; i < basis.size(); i++) { + b += basis[i] << shifts[i]; } + flatInBases.push_back(b); } + } - if (outerCommonDims != innerCommonDims) { - llvm::report_fatal_error( - "Cannot multiply layouts. All in/out dimensions common to both " - "layouts must appear in the same relative order, but they " - "don't.\nOuter:\n" + - Twine(outer.toString()) + "\nInner:\n" + inner.toString()); + BasesT newBases; + for (const auto &[inDim, flatInBases] : flatBases) { + std::vector> &newInDimBases = newBases[inDim]; + for (int32_t b : flatInBases) { + std::vector multiDimBasis; + for (int32_t newSize : llvm::make_second_range(newOutDims)) { + multiDimBasis.push_back(b % newSize); + b /= newSize; + } + newInDimBases.push_back(std::move(multiDimBasis)); } - }; + } + + return LinearLayout(std::move(newBases), newOutDims, surjective); +} +LinearLayout operator*(LinearLayout inner, LinearLayout outer) { // Check that dims common to outer and inner have the same relative order. - checkCommonDims(outer.getInDimNames(), inner.getInDimNames()); - checkCommonDims(outer.getOutDimNames(), inner.getOutDimNames()); + assertCommonDimsSameOrder(inner.getOutDimNames(), outer.getOutDimNames()); + assertCommonDimsSameOrder(inner.getInDimNames(), outer.getInDimNames()); // Get the sizeLog2 of all input and output dimensions we're going to - // consider, in order. `inner` is more minor, so its dimensions come first. - llvm::MapVector inDimSizes; - llvm::SetVector outDimNames; + // consider, in order. `inner` is more minor, so its dimensions come + // first. + llvm::MapVector inDimSizesLog2; + llvm::MapVector outDimSizesLog2; for (const auto &layout : {inner, outer}) { for (StringAttr inDim : layout.getInDimNames()) { - inDimSizes[inDim] += layout.getInDimSizeLog2(inDim); + inDimSizesLog2[inDim] += layout.getInDimSizeLog2(inDim); } for (StringAttr outDim : layout.getOutDimNames()) { - outDimNames.insert(outDim); + outDimSizesLog2[outDim] += layout.getOutDimSizeLog2(outDim); } } + BasesT allBases; - for (auto [inDimName, inDimSize] : inDimSizes) { + for (auto [inDimName, inDimSizeLog2] : inDimSizesLog2) { std::vector> &inDimBases = allBases[inDimName]; // Fill with zeros. inDimBases = std::vector>( - inDimSize, std::vector(outDimNames.size(), 0)); + inDimSizeLog2, std::vector(outDimSizesLog2.size(), 0)); - for (auto [outDimIdx, outDimName] : llvm::enumerate(outDimNames)) { + for (auto [outDimIdx, outDimNameAndSize] : + llvm::enumerate(outDimSizesLog2)) { + auto [outDimName, outDimSize] = outDimNameAndSize; if (inner.hasInDim(inDimName) && inner.hasOutDim(outDimName)) { for (int i = 0; i < inner.getInDimSizeLog2(inDimName); i++) { inDimBases[i][outDimIdx] = inner.getBasis(inDimName, i, outDimName); @@ -365,7 +649,147 @@ LinearLayout operator*(LinearLayout inner, LinearLayout outer) { } } - return LinearLayout(std::move(allBases), outDimNames.getArrayRef()); + llvm::SmallVector> outDimSizes; + for (auto [outDim, sizeLog2] : outDimSizesLog2) { + outDimSizes.push_back({outDim, 1 << sizeLog2}); + } + return LinearLayout(std::move(allBases), outDimSizes, + inner.isSurjective() && outer.isSurjective()); +} + +bool LinearLayout::isTrivialOver(ArrayRef dimNames) const { + for (StringAttr dim : dimNames) { + if (!llvm::is_contained(getInDimNames(), dim) && + !llvm::is_contained(getOutDimNames(), dim)) { + return false; + } + } + + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); + } + } + return remainingDimNames; + }; + SmallVector remainingInDimNames = + getRemainingDimNames(getInDimNames()); + SmallVector remainingOutDimNames = + getRemainingDimNames(getOutDimNames()); + + // Think of this as a block-matrix multiplying a vector: + // [[A, B], * [v_1, + // [C, D]] v_2] + // where v_2 is the dimNames and v_1 is the remainingInDimNames + // We can quotient out dimNames iff they don't affect the remainingInDimNames + // in the result. In other words, we want to check that B is zero, and C is + // zero, and D is the identity + return squareSublayoutIsIdentity(dimNames) && + sublayoutIsZero(remainingInDimNames, dimNames) && + sublayoutIsZero(dimNames, remainingOutDimNames); +} + +std::optional +LinearLayout::quotient(ArrayRef dimNames) const { + if (!isTrivialOver(dimNames)) { + return std::nullopt; + } + + // This should probably be even less general, where we ask inDimNames == + // outDimNames + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); + } + } + return remainingDimNames; + }; + + SmallVector inDimNames = getRemainingDimNames(getInDimNames()); + SmallVector outDimNames = getRemainingDimNames(getOutDimNames()); + + return sublayout(inDimNames, outDimNames); +} + +LinearLayout LinearLayout::sublayout(ArrayRef inDimNames, + ArrayRef outDimNames) const { + assertDimsSubsetIgnoringOrder(inDimNames, getInDimNames()); + assertDimsSubsetIgnoringOrder(outDimNames, getOutDimNames()); + SmallDenseSet inDimSet(inDimNames.begin(), inDimNames.end()); + SmallDenseSet outDimSet(outDimNames.begin(), outDimNames.end()); + + SmallVector outDimIndicesToKeep; + for (auto [i, outDim] : llvm::enumerate(getOutDimNames())) { + if (outDimSet.contains(outDim)) { + outDimIndicesToKeep.push_back(i); + } + } + BasesT newBases; + for (auto [inDim, inDimBases] : bases) { + if (!inDimSet.contains(inDim)) { + continue; + } + auto &newInDimBases = newBases[inDim]; + for (auto &basis : inDimBases) { + auto &newBasis = newInDimBases.emplace_back(); + for (int i : outDimIndicesToKeep) { + newBasis.push_back(basis[i]); + } + } + } + + SmallVector> newOutDims; + for (auto [outDim, outDimSize] : outDims) { + if (outDimSet.contains(outDim)) { + newOutDims.push_back({outDim, outDimSize}); + } + } + return LinearLayout(std::move(newBases), std::move(newOutDims), + /*requireSurjective=*/false); +} + +bool LinearLayout::sublayoutIsZero(ArrayRef inDimNames, + ArrayRef outDimNames) const { + LinearLayout ss = sublayout(inDimNames, outDimNames); + for (auto [inDim, inDimBases] : ss.bases) { + for (auto basis : inDimBases) { + if (!llvm::all_of(basis, [](int32_t b) { return b == 0; })) { + return false; + } + } + } + return true; +} + +bool LinearLayout::squareSublayoutIsIdentity( + ArrayRef dimNames) const { + // The empty layout is the identity + if (dimNames.size() == 0) { + return true; + } + // Check that the input-output sizes are the same + LinearLayout sl = sublayout(dimNames, dimNames); + for (StringAttr dim : dimNames) { + if (getInDimSize(dim) != getOutDimSize(dim)) { + return false; + } + } + // Once the inputs and output dimensions are the same, we can just check + // that the basis for the single remaining dimension is the identity. + sl = sl.flattenIns().flattenOuts(); + int b = 0; + const auto &inDimBases = sl.bases.begin()->second; + for (auto basis : inDimBases) { + if (basis[0] != (1 << b)) { + return false; + } + b++; + } + return true; } SmallVector> @@ -388,6 +812,9 @@ LinearLayout::apply(ArrayRef> ins) const { LinearLayout LinearLayout::compose(const LinearLayout &outer) const { assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getInDimNames()); + for (StringAttr outDim : getOutDimNames()) { + assert(getOutDimSize(outDim) <= outer.getInDimSize(outDim)); + } BasesT newBases; for (const auto &[inDim, inDimBases] : bases) { @@ -403,17 +830,186 @@ LinearLayout LinearLayout::compose(const LinearLayout &outer) const { std::vector(newBasesRange.begin(), newBasesRange.end())); } } - return LinearLayout(std::move(newBases), outer.getOutDimNames()); + + bool compositionIsSurjective = + isSurjective() && outer.isSurjective() && + llvm::all_of(getOutDimNames(), [&](StringAttr outDim) { + return getOutDimSize(outDim) == outer.getInDimSize(outDim); + }); + return LinearLayout(std::move(newBases), llvm::to_vector(outer.outDims), + compositionIsSurjective); +} + +LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { + assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getOutDimNames()); + for (StringAttr outDim : getOutDimNames()) { + assert(getOutDimSize(outDim) <= outer.getOutDimSize(outDim)); + } + assert(outer.isSurjective()); + + // Make both `this` and `outer` injective. We need to do this on the + // `outer` layout because we can't invert a non-injective function. We + // choose to do so on the `this` layout as well. The rest of the comment + // explains why we make that choice. + // + // Recall from the header that C = A.invertAndCompose(B) just means that + // A(x) = B(C(x)). + // + // Sometimes we may have a choice of multiple values for a particular + // C(x). For example, if A(1) = B(0) = B(1) = 0, then C(1) can be either 0 + // or 1. + // + // We want to choose C such that C(x) != 0 where possible. For example, + // suppose we are transferring from registers to registers and we have the + // following layouts. + // + // A(thread=1, block=0) = 1 + // A(thread=2, block=0) = 2 + // A(thread=0, block=1) = 0 + // + // B(thread=1, block=0) = 2 + // B(thread=2, block=0) = 1 + // B(thread=0, block=1) = 0 + // + // Notice that A and B both have the same data in each of their two + // blocks. So if we want to transfer from A to B, we don't need to cross + // blocks, which is expensive. We want A.invertAndCompose(B) to reflect + // that choice. + // + // Let A' be A with the last line changed to "=4", and similarly for B'. + // When transferring from A' to B', we can't cross blocks even if we wanted + // to, because the two blocks now have different data. But also, any + // mapping of thread+block from A' to B' is also valid for mapping from A + // to B. + // + // Thus making A and B injective encodes our desire not to cross blocks, + // or more generally our desire that C(x) != 0 where possible. + auto [matThis, numRowsThis, numColsThis] = getInjectiveMat(*this); + auto [matOuter, numRowsOuter, numColsOuter] = getInjectiveMat( + outer.transposeOuts(llvm::to_vector(this->getOutDimNames()))); + + // Concatenate `matOuter` and `matThis` horizontally (i.e. `matThis` + // is to the right of `matOuter`). + int combinedNumRows = std::max(numRowsThis, numRowsOuter); + int combinedNumCols = numColsThis + numColsOuter; + assert(combinedNumCols <= 64 && "Can't handle huge layouts"); + + std::unique_ptr m(new uint64_t[combinedNumRows]()); + for (int r = 0; r < numRowsOuter; r++) { + m[r] = matOuter[r]; + } + for (int r = 0; r < numRowsThis; r++) { + m[r] |= matThis[r] << numColsOuter; + } + + // Perform Gaussian elimination on `m`. Because `outer` was modified to + // be bijective, the first half of the matrix should be the identity + // matrix. The remaining half are the bases for the combined + // transformation. + // + // `stride` is specified in number of 64-bit words per row, and we pack + // our matrix so that there's only one uint64_t per row. + f2reduce::inplace_rref_strided(m.get(), combinedNumRows, combinedNumCols, + /*stride=*/1); + + // Check that the first half of the matrix is indeed the identity. + for (int r = 0; r < std::min(numRowsOuter, numColsOuter); r++) { + for (int c = 0; c < std::min(numColsOuter, numRowsOuter); c++) { + if (((m[r] >> c) & 1) != (r == c ? 1 : 0)) { + llvm::report_fatal_error("First half of the matrix was not the " + "identity, bug in invertAndCompose"); + } + } + } + + // We need names for the in/out dim of the flattened layout we're going to + // read off from `m`. These could be anything, doesn't matter. + StringAttr inDim1D = *getInDimNames().begin(); + StringAttr outDim1D = *getOutDimNames().begin(); + + // Read off the new bases. These are for a flattened 1D -> 1D + // transformation from `this`'s in-dims to `outer`'s in-dims. + BasesT newBases; + auto &bs = newBases[inDim1D]; + for (int c = 0; c < numColsThis; c++) { + int32_t basis = 0; + for (int r = 0; r < numRowsOuter; r++) { + basis |= (m[r] >> (numColsOuter + c) & 1) << r; + } + bs.push_back({basis}); + } + + LinearLayout flatComposed(std::move(newBases), + {{outDim1D, outer.getTotalInDimSize()}}, + /*requireSurjective=*/false); + + SmallVector> retInDims; + SmallVector> retOutDims; + for (StringAttr dim : getInDimNames()) { + retInDims.push_back({dim, getInDimSize(dim)}); + } + for (StringAttr dim : outer.getInDimNames()) { + retOutDims.push_back({dim, outer.getInDimSize(dim)}); + } + return flatComposed.reshapeIns(retInDims).reshapeOuts(retOutDims); +} + +llvm::MapVector +LinearLayout::getFreeVariableMasks() const { + std::unique_ptr mat = getMatrix(*this); + int numRows = getTotalOutDimSizeLog2(); + int numCols = getTotalInDimSizeLog2(); + + // stride is specified in number of 64-bit words per row, and we pack our + // matrix so that there's only one uint64_t per row. + assert(numCols <= 64); + f2reduce::inplace_rref_strided(mat.get(), numRows, numCols, /*stride=*/1); + + // For each row in the RREF matrix, identify the column with the first "1". + // These columns correspond to the basic (i.e. non-free) variables. + std::set basicVars; + for (int r = 0; r < numRows; r++) { + if (mat[r] == 0) { + continue; + } + basicVars.insert(__builtin_ctzll(mat[r])); + } + + llvm::MapVector ret; + int c = 0; + for (StringAttr dim : getInDimNames()) { + int32_t mask = 0; + for (int i = 0; i < getInDimSizeLog2(dim); i++, c++) { + if (basicVars.count(c) == 0) { + mask |= (1 << i); + } + } + ret[dim] = mask; + } + return ret; } bool operator==(LinearLayout lhs, LinearLayout rhs) { + if (!lhs.equalIgnoringOutDimSizes(rhs)) + return false; + + for (const auto &[lhsOutDimAndSize, rhsOutDimAndSize] : + llvm::zip(lhs.outDims, rhs.outDims)) { + if (lhsOutDimAndSize.second != rhsOutDimAndSize.second) + return false; + } + return true; +} + +bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { // llvm::MapVector doesn't have an operator== :(. - if (lhs.getOutDimNames() != rhs.getOutDimNames()) + if (llvm::to_vector(this->getOutDimNames()) != + llvm::to_vector(other.getOutDimNames())) return false; - if (lhs.bases.size() != rhs.bases.size()) + if (this->bases.size() != other.bases.size()) return false; - for (auto it1 = lhs.bases.begin(), it2 = rhs.bases.begin(); - it1 != lhs.bases.end(); ++it1, ++it2) { + for (auto it1 = this->bases.begin(), it2 = other.bases.begin(); + it1 != this->bases.end(); ++it1, ++it2) { if (*it1 != *it2) return false; } @@ -421,7 +1017,44 @@ bool operator==(LinearLayout lhs, LinearLayout rhs) { } std::string LinearLayout::toString() const { - return stringifyBases(bases, getOutDimNames()); + // Start with a newline because we print out a bulleted list; it doesn't + // make sense for the first line of this list to be on the same line as + // any previous text. + std::string ret = "\n"; + std::string outDimsStr = + "[" + + join(outDims, ", ", + [](auto dimAndSize) { + auto [outDim, size] = dimAndSize; + return outDim.str() + " (size " + std::to_string(size) + ")"; + }) + + "]"; + + if (bases.empty()) { + if (outDims.empty()) { + return "\n(empty layout)"; + } else { + return "\n(empty layout with out-dims " + outDimsStr + ")"; + } + } + + // TODO: Add spaces for alignment. + for (const auto &[inDim, inDimBases] : bases) { + if (inDimBases.empty()) { + ret += " - " + inDim.str() + " is a size 1 dimension\n"; + continue; + } + + ret += " - " + + join(llvm::seq(inDimBases.size()), "\n ", + [&, &inDim = inDim, &inDimBases = inDimBases](int i) { + return inDim.str() + "=" + std::to_string(1 << i) + " -> (" + + join(inDimBases[i], ", ") + ")"; + }) + + "\n"; + } + ret += "where out dims are: " + outDimsStr; + return ret; } } // namespace mlir::triton diff --git a/python/pyproject.toml b/python/pyproject.toml index 315aa7da9..d96af50a5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"] +requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1", "pybind11>=2.13.1"] # We're incrementally switching from autopep8 to ruff. [tool.autopep8] diff --git a/python/requirements.txt b/python/requirements.txt new file mode 100644 index 000000000..4f7fffe43 --- /dev/null +++ b/python/requirements.txt @@ -0,0 +1,7 @@ +ninja +cmake +wheel +GitPython +pytest +scipy +pybind11 diff --git a/python/setup.py b/python/setup.py index c60dc6158..c4711357d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -2,6 +2,7 @@ import platform import re import contextlib +import shlex import shutil import subprocess import sys @@ -9,10 +10,11 @@ import tarfile import zipfile import urllib.request +import json from io import BytesIO from distutils.command.clean import clean from pathlib import Path -from typing import NamedTuple +from typing import List, NamedTuple, Optional from setuptools import Extension, setup from setuptools.command.build_ext import build_ext @@ -24,13 +26,19 @@ from setuptools.command.egg_info import egg_info from wheel.bdist_wheel import bdist_wheel +import pybind11 + +from setup_tools import setup_helper as helper + @dataclass class Backend: name: str - package_data: dict + package_data: List[str] + language_package_data: List[str] src_dir: str backend_dir: str + language_dir: Optional[str] install_dir: str is_external: bool @@ -58,12 +66,22 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = backend_path = os.path.abspath(os.path.join(backend_src_dir, "backend")) assert os.path.exists(backend_path), f"{backend_path} does not exist!" + language_dir = os.path.abspath(os.path.join(backend_src_dir, "language")) + if not os.path.exists(language_dir): + language_dir = None + for file in ["compiler.py", "driver.py"]: assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}" install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name) package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)] - return Backend(name=backend_name, package_data=package_data, src_dir=backend_src_dir, backend_dir=backend_path, + + language_package_data = [] + if language_dir is not None: + language_package_data = [f"{os.path.relpath(p, language_dir)}/*" for p, _, _, in os.walk(language_dir)] + + return Backend(name=backend_name, package_data=package_data, language_package_data=language_package_data, + src_dir=backend_src_dir, backend_dir=backend_path, language_dir=language_dir, install_dir=install_dir, is_external=is_external) # Copy all in-tree backends under triton/third_party. @@ -99,6 +117,8 @@ def get_build_type(): return "RelWithDebInfo" elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"): return "TritonRelBuildWithAsserts" + elif check_env_flag("TRITON_BUILD_WITH_O1"): + return "TritonBuildWithO1" else: # TODO: change to release when stable enough return "TritonRelBuildWithAsserts" @@ -111,6 +131,22 @@ def get_env_with_keys(key: list): return "" +def is_offline_build() -> bool: + """ + Downstream projects and distributions which bootstrap their own dependencies from scratch + and run builds in offline sandboxes + may set `TRITON_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading + pinned dependencies from the internet or at using dependencies vendored in-tree. + + Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`). + Missing dependencies lead to an early abortion. + Dependencies' compatibility is not verified. + + Note that this flag isn't tested by the CI and does not provide any guarantees. + """ + return check_env_flag("TRITON_OFFLINE_BUILD", "") + + # --- third party packages ----- @@ -123,16 +159,6 @@ class Package(NamedTuple): syspath_var_name: str -# pybind11 -def get_pybind11_package_info(): - pybind11_version_path = os.path.join(get_base_dir(), "cmake", "pybind11-version.txt") - with open(pybind11_version_path, "r") as pybind11_version_file: - version = pybind11_version_file.read().strip() - name = f"pybind11-{version}" - url = f"https://github.com/pybind/pybind11/archive/refs/tags/v{version}.tar.gz" - return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH") - - # json def get_json_package_info(): url = "https://github.com/nlohmann/json/releases/download/v3.11.3/include.zip" @@ -201,7 +227,9 @@ def open_url(url): def get_triton_cache_path(): - user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None + user_home = os.getenv("TRITON_HOME") + if not user_home: + user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None if not user_home: raise RuntimeError("Could not find user home directory") return os.path.join(user_home, ".triton") @@ -216,8 +244,14 @@ def get_thirdparty_packages(packages: list): if os.environ.get(p.syspath_var_name): package_dir = os.environ[p.syspath_var_name] version_file_path = os.path.join(package_dir, "version.txt") - if p.syspath_var_name not in os.environ and\ - (not os.path.exists(version_file_path) or Path(version_file_path).read_text() != p.url): + + input_defined = p.syspath_var_name in os.environ + input_exists = os.path.exists(version_file_path) + input_compatible = input_exists and Path(version_file_path).read_text() == p.url + + if is_offline_build() and not input_defined: + raise RuntimeError(f"Requested an offline build but {p.syspath_var_name} is not set") + if not is_offline_build() and not input_defined and not input_compatible: with contextlib.suppress(Exception): shutil.rmtree(package_root_dir) os.makedirs(package_root_dir, exist_ok=True) @@ -240,7 +274,9 @@ def get_thirdparty_packages(packages: list): return thirdparty_cmake_args -def download_and_copy(name, src_path, variable, version, url_func): +def download_and_copy(name, src_path, dst_path, variable, version, url_func): + if is_offline_build(): + return triton_cache_path = get_triton_cache_path() if variable in os.environ: return @@ -250,9 +286,12 @@ def download_and_copy(name, src_path, variable, version, url_func): arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] except KeyError: arch = platform.machine() - url = url_func(arch, version) + supported = {"Linux": "linux", "Darwin": "linux"} + url = url_func(supported[system], arch, version) tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download - dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", src_path) # final binary path + dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path + platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux" + src_path = src_path(platform_name, version) if callable(src_path) else src_path src_path = os.path.join(tmp_path, src_path) download = not os.path.exists(src_path) if os.path.exists(dst_path) and system == "Linux" and shutil.which(dst_path) is not None: @@ -336,12 +375,25 @@ def run(self): for ext in self.extensions: self.build_extension(ext) + def get_pybind11_cmake_args(self): + pybind11_sys_path = get_env_with_keys(["PYBIND11_SYSPATH"]) + if pybind11_sys_path: + pybind11_include_dir = os.path.join(pybind11_sys_path, "include") + else: + pybind11_include_dir = pybind11.get_include() + return [f"-DPYBIND11_INCLUDE_DIR={pybind11_include_dir}"] + def get_proton_cmake_args(self): - cmake_args = get_thirdparty_packages([get_json_package_info(), get_pybind11_package_info()]) - cupti_include_dir = get_env_with_keys(["CUPTI_INCLUDE_PATH"]) + cmake_args = get_thirdparty_packages([get_json_package_info()]) + cmake_args += self.get_pybind11_cmake_args() + cupti_include_dir = get_env_with_keys(["TRITON_CUPTI_INCLUDE_PATH"]) if cupti_include_dir == "": cupti_include_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "include") cmake_args += ["-DCUPTI_INCLUDE_DIR=" + cupti_include_dir] + cupti_lib_dir = get_env_with_keys(["TRITON_CUPTI_LIB_PATH"]) + if cupti_lib_dir == "": + cupti_lib_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "lib", "cupti") + cmake_args += ["-DCUPTI_LIB_DIR=" + cupti_lib_dir] roctracer_include_dir = get_env_with_keys(["ROCTRACER_INCLUDE_PATH"]) if roctracer_include_dir == "": roctracer_include_dir = os.path.join(get_base_dir(), "third_party", "amd", "backend", "include") @@ -352,7 +404,8 @@ def build_extension(self, ext): lit_dir = shutil.which('lit') ninja_dir = shutil.which('ninja') # lit is used by the test suite - thirdparty_cmake_args = get_thirdparty_packages([get_pybind11_package_info(), get_llvm_package_info()]) + thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()]) + thirdparty_cmake_args += self.get_pybind11_cmake_args() extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # create build directories if not os.path.exists(self.build_temp): @@ -421,6 +474,14 @@ def build_extension(self, ext): else: cmake_args += ["-DTRITON_BUILD_PROTON=OFF"] + if is_offline_build(): + # unit test builds fetch googletests from GitHub + cmake_args += ["-DTRITON_BUILD_UT=OFF"] + + cmake_args_append = os.getenv("TRITON_APPEND_CMAKE_ARGS") + if cmake_args_append is not None: + cmake_args += shlex.split(cmake_args_append) + env = os.environ.copy() cmake_dir = get_cmake_dir() subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) @@ -428,63 +489,86 @@ def build_extension(self, ext): subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) -nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.txt") +nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json") with open(nvidia_version_path, "r") as nvidia_version_file: - NVIDIA_TOOLCHAIN_VERSION = nvidia_version_file.read().strip() + # parse this json file to get the version of the nvidia toolchain + NVIDIA_TOOLCHAIN_VERSION = json.load(nvidia_version_file) + + +def get_platform_dependent_src_path(subdir): + return lambda platform, version: ( + (lambda version_major, version_minor1, version_minor2, : f"targets/{platform}/{subdir}" + if int(version_major) >= 12 and int(version_minor1) >= 5 else subdir)(*version.split('.'))) + download_and_copy( - name="ptxas", - src_path="bin/ptxas", - variable="TRITON_PTXAS_PATH", - version=NVIDIA_TOOLCHAIN_VERSION, - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2", -) + name="ptxas", src_path="bin/ptxas", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version: + ((lambda version_major, version_minor1, version_minor2: + f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/{system}-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2" + if int(version_major) >= 12 and int(version_minor1) >= 5 else + f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") + (*version.split('.')))) download_and_copy( name="cuobjdump", src_path="bin/cuobjdump", + dst_path="bin/cuobjdump", variable="TRITON_CUOBJDUMP_PATH", - version=NVIDIA_TOOLCHAIN_VERSION, - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", + version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"], + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", ) download_and_copy( name="nvdisasm", src_path="bin/nvdisasm", + dst_path="bin/nvdisasm", variable="TRITON_NVDISASM_PATH", - version=NVIDIA_TOOLCHAIN_VERSION, - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", + version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"], + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/{system}-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", ) download_and_copy( - name="cudacrt", - src_path="include", - variable="TRITON_CUDACRT_PATH", - version=NVIDIA_TOOLCHAIN_VERSION, - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2", -) + name="cudacrt", src_path=get_platform_dependent_src_path("include"), dst_path="include", + variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda system, arch, version: + ((lambda version_major, version_minor1, version_minor2: + f"https://anaconda.org/nvidia/cuda-crt-dev_{system}-{arch}/{version}/download/noarch/cuda-crt-dev_{system}-{arch}-{version}-0.tar.bz2" + if int(version_major) >= 12 and int(version_minor1) >= 5 else + f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") + (*version.split('.')))) download_and_copy( - name="cudart", - src_path="include", - variable="TRITON_CUDART_PATH", - version=NVIDIA_TOOLCHAIN_VERSION, - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/linux-{arch}/cuda-cudart-dev-{version}-0.tar.bz2", -) + name="cudart", src_path=get_platform_dependent_src_path("include"), dst_path="include", + variable="TRITON_CUDART_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudart"], url_func=lambda system, arch, version: + ((lambda version_major, version_minor1, version_minor2: + f"https://anaconda.org/nvidia/cuda-cudart-dev_{system}-{arch}/{version}/download/noarch/cuda-cudart-dev_{system}-{arch}-{version}-0.tar.bz2" + if int(version_major) >= 12 and int(version_minor1) >= 5 else + f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/{system}-{arch}/cuda-cudart-dev-{version}-0.tar.bz2" + )(*version.split('.')))) download_and_copy( - name="cupti", - src_path="include", - variable="TRITON_CUPTI_PATH", - version=NVIDIA_TOOLCHAIN_VERSION, - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/linux-{arch}/cuda-cupti-{version}-0.tar.bz2", -) + name="cupti", src_path=get_platform_dependent_src_path("include"), dst_path="include", + variable="TRITON_CUPTI_INCLUDE_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], + url_func=lambda system, arch, version: + ((lambda version_major, version_minor1, version_minor2: + f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2" + if int(version_major) >= 12 and int(version_minor1) >= 5 else + f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2") + (*version.split('.')))) +download_and_copy( + name="cupti", src_path=get_platform_dependent_src_path("lib"), dst_path="lib/cupti", + variable="TRITON_CUPTI_LIB_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda system, arch, version: + ((lambda version_major, version_minor1, version_minor2: + f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2" + if int(version_major) >= 12 and int(version_minor1) >= 5 else + f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2") + (*version.split('.')))) -backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()] +if helper.flagtree_backend: + backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()] +else: + backends = [*BackendInstaller.copy(helper.default_backends), *BackendInstaller.copy_externals()] def add_link_to_backends(): + helper.CommonUtils.unlink() for backend in backends: if os.path.islink(backend.install_dir): os.unlink(backend.install_dir) @@ -492,6 +576,19 @@ def add_link_to_backends(): shutil.rmtree(backend.install_dir) os.symlink(backend.backend_dir, backend.install_dir) + if backend.language_dir: + # Link the contents of each backend's `language` directory into + # `triton.language.extra`. + extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "triton", "language", "extra")) + for x in os.listdir(backend.language_dir): + src_dir = os.path.join(backend.language_dir, x) + install_dir = os.path.join(extra_dir, x) + if os.path.islink(install_dir): + os.unlink(install_dir) + if os.path.exists(install_dir): + shutil.rmtree(install_dir) + os.symlink(src_dir, install_dir) + def add_link_to_proton(): proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton")) @@ -514,6 +611,7 @@ class plugin_install(install): def run(self): add_links() install.run(self) + helper.post_install() class plugin_develop(develop): @@ -521,6 +619,7 @@ class plugin_develop(develop): def run(self): add_links() develop.run(self) + helper.post_install() class plugin_bdist_wheel(bdist_wheel): @@ -528,6 +627,7 @@ class plugin_bdist_wheel(bdist_wheel): def run(self): add_links() bdist_wheel.run(self) + helper.post_install() class plugin_egginfo(egg_info): @@ -535,15 +635,38 @@ class plugin_egginfo(egg_info): def run(self): add_links() egg_info.run(self) + helper.post_install() +package_data_tools = helper.get_package_data_tools() package_data = { - "triton/tools": ["compile.h", "compile.c"], - **{f"triton/backends/{b.name}": b.package_data - for b in backends}, + "triton/tools": package_data_tools, **{f"triton/backends/{b.name}": b.package_data + for b in backends}, "triton/language/extra": sum( + (b.language_package_data for b in backends), []) } +def get_language_extra_packages(): + packages = [] + for backend in backends: + if backend.language_dir is None: + continue + + # Walk the `language` directory of each backend to enumerate + # any subpackages, which will be added to `triton.language.extra`. + for dir, dirs, files in os.walk(backend.language_dir, followlinks=True): + if not any(f for f in files if f.endswith(".py")) or dir == backend.language_dir: + # Ignore directories with no python files. + # Also ignore the root directory which corresponds to + # "triton/language/extra". + continue + subpackage = os.path.relpath(dir, backend.language_dir) + package = os.path.join("triton/language/extra", subpackage) + packages.append(package) + + return list(packages) + + def get_packages(): packages = [ "triton", @@ -551,16 +674,19 @@ def get_packages(): "triton/compiler", "triton/language", "triton/language/extra", - "triton/language/extra/cuda", - "triton/language/extra/hip", - "triton/ops", - "triton/ops/blocksparse", "triton/runtime", "triton/backends", "triton/tools", ] + if helper.flagtree_backend: + packages.append(f"triton/language/extra/{helper.get_device_name()}") + packages += helper.get_extra_packages() + packages += [f'triton/backends/{backend.name}' for backend in backends] - packages += ["triton/profiler"] + packages += get_language_extra_packages() + if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON + packages += ["triton/profiler"] + return packages @@ -574,24 +700,27 @@ def get_entry_points(): return entry_points -def get_install_requires(): - install_requires = ["filelock"] - return install_requires +def get_git_commit_hash(length=8): + try: + cmd = ['git', 'rev-parse', f'--short={length}', 'HEAD'] + return "+git{}".format(subprocess.check_output(cmd).strip().decode('utf-8')) + except Exception: + return "" setup( name=os.environ.get("TRITON_WHEEL_NAME", "triton"), - version="3.1.0" + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""), + version="3.2.0" + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""), author="Philippe Tillet", author_email="phil@openai.com", description="A language and compiler for custom Deep Learning operations", long_description="", packages=get_packages(), + package_dir=helper.CommonUtils.get_package_dir(get_packages()), entry_points=get_entry_points(), - install_requires=get_install_requires(), package_data=package_data, include_package_data=True, - ext_modules=[CMakeExtension("triton", "triton/_C/")], + ext_modules=[CMakeExtension("triton", helper.ext_sourcedir)], cmdclass={ "build_ext": CMakeBuild, "build_py": CMakeBuildPy, @@ -610,16 +739,17 @@ def get_install_requires(): "Intended Audience :: Developers", "Topic :: Software Development :: Build Tools", "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], test_suite="tests", extras_require={ "build": [ "cmake>=3.20", + "GitPython", "lit", ], "tests": [ diff --git a/python/setup_tools/__init__.py b/python/setup_tools/__init__.py new file mode 100644 index 000000000..c3411313f --- /dev/null +++ b/python/setup_tools/__init__.py @@ -0,0 +1,4 @@ +from . import setup_helper +from . import utils + +__all__ = ["setup_helper", "utils"] diff --git a/python/setup_tools/setup_helper.py b/python/setup_tools/setup_helper.py new file mode 100644 index 000000000..4974ab51b --- /dev/null +++ b/python/setup_tools/setup_helper.py @@ -0,0 +1,401 @@ +import os +import shutil +import sys +import functools +import tarfile +import zipfile +from io import BytesIO +import urllib.request +from pathlib import Path +import hashlib +from distutils.sysconfig import get_python_lib +from . import utils + +extend_backends = [] +default_backends = ["nvidia", "amd"] +plugin_backends = ["cambricon", "ascend"] +ext_sourcedir = "triton/_C/" +flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower() +flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower() +device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"} +flagtree_backends = utils.flagtree_backends +backend_utils = utils.activate(flagtree_backend) + +set_llvm_env = lambda path: set_env({ + 'LLVM_INCLUDE_DIRS': Path(path) / "include", + 'LLVM_LIBRARY_DIR': Path(path) / "lib", + 'LLVM_SYSPATH': path, +}) + + +def get_device_name(): + return device_mapping[flagtree_backend] + + +def get_extra_packages(): + packages = [] + try: + packages = backend_utils.get_extra_install_packages() + except Exception: + packages = [] + return packages + + +def get_package_data_tools(): + package_data = ["compile.h", "compile.c"] + try: + package_data += backend_utils.get_package_data_tools() + except Exception: + package_data + return package_data + + +def git_clone(lib, lib_path): + import git + MAX_RETRY = 4 + print(f"Clone {lib.name} into {lib_path} ...") + retry_count = MAX_RETRY + while (retry_count): + try: + repo = git.Repo.clone_from(lib.url, lib_path) + if lib.tag is not None: + repo.git.checkout(lib.tag) + sub_triton_path = Path(lib_path) / "triton" + if os.path.exists(sub_triton_path): + shutil.rmtree(sub_triton_path) + print(f"successfully clone {lib.name} into {lib_path} ...") + return True + except Exception: + retry_count -= 1 + print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}") + return False + + +def download_flagtree_third_party(name, condition, required=False, hock=None): + if not condition: + return + backend = None + for _backend in flagtree_backends: + if _backend.name is name: + backend = _backend + break + if backend is None: + return backend + third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party" + lib_path = Path(third_party_base_dir) / backend.name + if not os.path.exists(lib_path): + succ = git_clone(lib=backend, lib_path=lib_path) + if not succ and required: + raise RuntimeError("Bad network ! ") + else: + print(f'Found third_party {backend.name} at {lib_path}\n') + if callable(hock): + hock(third_party_base_dir=third_party_base_dir, backend=backend) + + +def post_install(): + + backend_utils.post_install() + + +class FlagTreeCache: + + def __init__(self): + self.flagtree_dir = os.path.dirname(os.getcwd()) + self.dir_name = ".flagtree" + self.sub_dirs = {} + self.cache_files = {} + self.dir_path = self._get_cache_dir_path() + self._create_cache_dir() + if flagtree_backend: + self._create_subdir(subdir_name=flagtree_backend) + + @functools.lru_cache(maxsize=None) + def _get_cache_dir_path(self) -> Path: + _cache_dir = os.environ.get("FLAGTREE_CACHE_DIR") + if _cache_dir is None: + _cache_dir = Path.home() / self.dir_name + else: + _cache_dir = Path(_cache_dir) + return _cache_dir + + def _create_cache_dir(self) -> Path: + if not os.path.exists(self.dir_path): + os.makedirs(self.dir_path, exist_ok=True) + + def _create_subdir(self, subdir_name, path=None): + if path is None: + subdir_path = Path(self.dir_path) / subdir_name + else: + subdir_path = Path(path) / subdir_name + + if not os.path.exists(subdir_path): + os.makedirs(subdir_path, exist_ok=True) + self.sub_dirs[subdir_name] = subdir_path + + def _md5(self, file_path): + md5_hash = hashlib.md5() + with open(file_path, "rb") as file: + while chunk := file.read(4096): + md5_hash.update(chunk) + return md5_hash.hexdigest() + + def _download(self, url, path, file_name): + MAX_RETRY_COUNT = 4 + user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0' + headers = { + 'User-Agent': user_agent, + } + request = urllib.request.Request(url, None, headers) + retry_count = MAX_RETRY_COUNT + content = None + print(f'downloading {url} ...') + while (retry_count): + try: + with urllib.request.urlopen(request, timeout=300) as response: + content = response.read() + break + except Exception: + retry_count -= 1 + print(f"\n[{MAX_RETRY_COUNT - retry_count}] retry to downloading and extracting {url}") + + if retry_count == 0: + raise RuntimeError("The download failed, probably due to network problems") + + print(f'extracting {url} ...') + file_bytes = BytesIO(content) + file_names = [] + if url.endswith(".zip"): + with zipfile.ZipFile(file_bytes, "r") as file: + file.extractall(path=path) + file_names = file.namelist() + else: + with tarfile.open(fileobj=file_bytes, mode="r|*") as file: + file.extractall(path=path) + file_names = file.getnames() + os.rename(Path(path) / file_names[0], Path(path) / file_name) + + def check_file(self, file_name=None, url=None, path=None, md5_digest=None): + origin_file_path = None + if url is not None: + origin_file_name = url.split("/")[-1].split('.')[0] + origin_file_path = self.cache_files.get(origin_file_name, "") + if path is not None: + _path = path + else: + _path = self.cache_files.get(file_name, "") + empty = (not os.path.exists(_path)) or (origin_file_path and not os.path.exists(origin_file_path)) + if empty: + return False + if md5_digest is None: + return True + else: + cur_md5 = self._md5(_path) + return cur_md5[:8] == md5_digest + + def clear(self): + shutil.rmtree(self.dir_path) + + def reverse_copy(self, src_path, cache_file_path, md5_digest): + if src_path is None or not os.path.exists(src_path): + return False + if os.path.exists(cache_file_path): + return False + copy_needed = True + if md5_digest is None or self._md5(src_path) == md5_digest: + copy_needed = False + if copy_needed: + print(f"copying {src_path} to {cache_file_path}") + if os.path.isdir(src_path): + shutil.copytree(src_path, cache_file_path, dirs_exist_ok=True) + else: + shutil.copy(src_path, cache_file_path) + return True + return False + + def store(self, file=None, condition=None, url=None, copy_src_path=None, copy_dst_path=None, files=None, + md5_digest=None, pre_hock=None, post_hock=None): + + if not condition or (pre_hock and pre_hock()): + return + is_url = False if url is None else True + path = self.sub_dirs[flagtree_backend] if flagtree_backend else self.dir_path + + if files is not None: + for single_files in files: + self.cache_files[single_files] = Path(path) / single_files + else: + self.cache_files[file] = Path(path) / file + if url is not None: + origin_file_name = url.split("/")[-1].split('.')[0] + self.cache_files[origin_file_name] = Path(path) / file + if copy_dst_path is not None: + dst_path_root = Path(self.flagtree_dir) / copy_dst_path + dst_path = Path(dst_path_root) / file + if self.reverse_copy(dst_path, self.cache_files[file], md5_digest): + return + + if is_url and not self.check_file(file_name=file, url=url, md5_digest=md5_digest): + self._download(url, path, file_name=file) + + if copy_dst_path is not None: + file_lists = [file] if files is None else list(files) + for single_file in file_lists: + dst_path_root = Path(self.flagtree_dir) / copy_dst_path + os.makedirs(dst_path_root, exist_ok=True) + dst_path = Path(dst_path_root) / single_file + if not self.check_file(path=dst_path, md5_digest=md5_digest): + if copy_src_path: + src_path = Path(copy_src_path) / single_file + else: + src_path = self.cache_files[single_file] + print(f"copying {src_path} to {dst_path}") + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path, dirs_exist_ok=True) + else: + shutil.copy(src_path, dst_path) + post_hock(self.cache_files[file]) if post_hock else False + + def get(self, file_name) -> Path: + return self.cache_files[file_name] + + +class CommonUtils: + + @staticmethod + def unlink(): + cur_path = os.path.dirname(__file__) + if "editable_wheel" in sys.argv: + installation_dir = cur_path + else: + installation_dir = get_python_lib() + backends_dir_path = Path(installation_dir) / "triton" / "backends" + if not os.path.exists(backends_dir_path): + return + for name in os.listdir(backends_dir_path): + exist_backend_path = os.path.join(backends_dir_path, name) + if not os.path.isdir(exist_backend_path): + continue + if name.startswith('__'): + continue + if os.path.islink(exist_backend_path): + os.unlink(exist_backend_path) + if os.path.exists(exist_backend_path): + shutil.rmtree(exist_backend_path) + + @staticmethod + def skip_package_dir(package): + if 'backends' in package or 'profiler' in package: + return True + if flagtree_backend in ['cambricon']: + if package not in ['triton', 'triton/_C']: + return True + return False + + @staticmethod + def get_package_dir(packages): + package_dict = {} + if flagtree_backend and flagtree_backend not in plugin_backends: + connection = [] + backend_triton_path = f"../third_party/{flagtree_backend}/python/" + for package in packages: + if CommonUtils.skip_package_dir(package): + continue + pair = (package, f"{backend_triton_path}{package}") + connection.append(pair) + package_dict.update(connection) + if flagtree_backend == "ascend": + triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch" + package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dict + + +def handle_flagtree_backend(): + global ext_sourcedir + if flagtree_backend: + print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m") + extend_backends.append(flagtree_backend) + if "editable_wheel" in sys.argv and flagtree_backend != "ascend": + ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/" + + +def set_env(env_dict: dict): + for env_k, env_v in env_dict.items(): + os.environ[env_k] = str(env_v) + + +def check_env(env_val): + return os.environ.get(env_val, '') != '' + + +download_flagtree_third_party("triton_shared", condition=(not flagtree_backend)) + +download_flagtree_third_party("triton_ascend", condition=(flagtree_backend == "ascend"), + hock=utils.ascend.precompile_hock, required=True) + +download_flagtree_third_party("cambricon", condition=(flagtree_backend == "cambricon"), required=True) + +handle_flagtree_backend() + +cache = FlagTreeCache() + +# iluvatar +cache.store( + file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (flagtree_plugin == ''), url= + "https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatarTritonPlugin-cpython3.10-glibc2.30-glibcxx3.4.28-cxxabi1.3.12-ubuntu-x86_64.tar.gz", + copy_dst_path="third_party/iluvatar", md5_digest="7d4e136c") + +cache.store( + file="iluvatar-llvm18-x86_64", + condition=("iluvatar" == flagtree_backend), + url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/iluvatar-llvm18-x86_64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +# xpu(kunlunxin) +cache.store( + file="XTDK-llvm18-ubuntu2004_x86_64", + condition=("xpu" == flagtree_backend), + url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/XTDK-llvm18-ubuntu2004_x86_64.tar", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +cache.store(file="xre-Linux-x86_64", condition=("xpu" == flagtree_backend), + url="https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/xre-Linux-x86_64.tar.gz", + copy_dst_path='python/_deps/xre3') + +cache.store( + files=("clang", "xpu-xxd", "xpu3-crt.xpu", "xpu-kernel.t", "ld.lld", "llvm-readelf", "llvm-objdump", + "llvm-objcopy"), condition=("xpu" == flagtree_backend), + copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/bin", copy_dst_path="third_party/xpu/backend/xpu3/bin") + +cache.store(files=("libclang_rt.builtins-xpu3.a", "libclang_rt.builtins-xpu3s.a"), + condition=("xpu" == flagtree_backend), copy_src_path=f"{os.environ.get('LLVM_SYSPATH','')}/lib/linux", + copy_dst_path="third_party/xpu/backend/xpu3/lib/linux") + +cache.store(files=("include", "so"), condition=("xpu" == flagtree_backend), + copy_src_path=f"{cache.dir_path}/xpu/xre-Linux-x86_64", copy_dst_path="third_party/xpu/backend/xpu3") + +# mthreads +cache.store( + file="mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64", + condition=("mthreads" == flagtree_backend), + url= + "https://github.com/FlagTree/flagtree/releases/download/v0.1.0-build-deps/mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) + +# ascend +cache.store( + file="ascend-llvm-b5cc222d-ubuntu-arm64", + condition=("ascend" == flagtree_backend), + url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz", + pre_hock=lambda: check_env('LLVM_SYSPATH'), + post_hock=set_llvm_env, +) diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py new file mode 100644 index 000000000..ddee3dc76 --- /dev/null +++ b/python/setup_tools/utils/__init__.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from pathlib import Path +import importlib.util +import os +from . import ascend + + +@dataclass +class FlagTreeBackend: + name: str + url: str + tag: str = None + + +flagtree_backends = ( + FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git", + tag="380b87122c88af131530903a702d5318ec59bb33"), + FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git", + tag="00f51c2e48a943922f86f03d58e29f514def646d"), + FlagTreeBackend( + name="triton_ascend", + url="https://gitee.com/ascend/triton-ascend.git", + ), +) + + +def activate(backend, suffix=".py"): + module_path = Path(os.path.dirname(__file__)) / backend + module_path = str(module_path) + suffix + spec = importlib.util.spec_from_file_location("module", module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +__all__ = ["ascend"] diff --git a/python/setup_tools/utils/ascend.py b/python/setup_tools/utils/ascend.py new file mode 100644 index 000000000..bb57fe3fc --- /dev/null +++ b/python/setup_tools/utils/ascend.py @@ -0,0 +1,206 @@ +import os +import shutil +from pathlib import Path + + +def insert_at_file_start(filepath, import_lines): + import tempfile + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + tmp_file.write(import_lines + '\n\n') + with open(filepath, 'r') as original_file: + tmp_file.write(original_file.read()) + backup_path = filepath + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def append_at_file_end(filepath, import_lines): + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + if import_lines in content: + return False + with open(filepath, 'a', encoding='utf-8') as f: + f.write('\n' + import_lines) + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False + + +def post_install(): + import site + install_dir = site.getsitepackages()[0] + install_dir = os.path.join(install_dir, "triton") + init_path = os.path.join(install_dir, "__init__.py") + patched_content = """ +import sys +from .triton_patch.language import _utils as ascend_utils +sys.modules['triton.language._utils'] = ascend_utils +from .triton_patch.compiler import compiler as ascend_compiler +sys.modules['triton.compiler.compiler'] = ascend_compiler +from .triton_patch.compiler import code_generator as ascend_code_generator +sys.modules['triton.compiler.code_generator'] = ascend_code_generator +from .triton_patch.compiler import errors as ascend_errors +sys.modules['triton.compiler.errors'] = ascend_errors +from .triton_patch.runtime import autotuner as ascend_autotuner +sys.modules['triton.runtime.autotuner'] = ascend_autotuner +from .triton_patch import testing as ascend_testing +sys.modules['triton.testing'] = ascend_testing +""" + insert_at_file_start(init_path, patched_content) + + content_to_append = """ +from .triton_patch.language.core import dot, gather, insert, subview +from .triton_patch.language.standard import flip +from .triton_patch.language.math import umulhi, exp, exp2, log, log2, cos, sin, sqrt, sqrt_rn, rsqrt, div_rn, erf, tanh, floor, ceil +from . import language + +language.dot = dot +language.flip = flip +language.gather = gather +language.insert = insert +language.subview = subview + +# from .triton_patch.language.core import dtype, pointer_type, block_type, function_type +# language.core.dtype = dtype +# language.core.pointer_type = pointer_type +# language.core.block_type = block_type +# language.core.function_type = function_type + +from .triton_patch.language.semantic import arange, floordiv +language.semantic.arange = arange +language.semantic.floordiv = floordiv + +language.umulhi = umulhi +language.exp = exp +language.exp2 = exp2 +language.log = log +language.log2 = log2 +language.cos = cos +language.sin = sin +language.sqrt = sqrt +language.sqrt_rn = sqrt_rn +language.rsqrt = rsqrt +language.div_rn = div_rn +language.erf = erf +language.tanh = tanh +language.floor = floor +language.ceil = ceil +language.math.umulhi = umulhi +language.math.exp = exp +language.math.exp2 = exp2 +language.math.log = log +language.math.log2 = log2 +language.math.cos = cos +language.math.sin = sin +language.math.sqrt = sqrt +language.math.sqrt_rn = sqrt_rn +language.math.rsqrt = rsqrt +language.math.div_rn = div_rn +language.math.erf = erf +language.math.tanh = tanh +language.math.floor = floor +language.math.ceil = ceil +""" + append_at_file_end(init_path, content_to_append) + + +def get_ascend_patch_packages(backends): + packages = [] + # packages += get_language_extra_packages() + packages += [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + return packages + + +def get_ascend_patch_package_dir(backends): + package_dir = {} + # language_extra_list = get_language_extra_packages() + # for extra_full in language_extra_list: + # extra_name = extra_full.replace("triton/language/extra/", "") + # package_dir[extra_full] = f"{triton_root_rel_dir}/language/extra/{extra_name}" + # + triton_patch_root_rel_dir = "triton_patch/python/triton_patch" + package_dir["triton/triton_patch"] = f"{triton_patch_root_rel_dir}" + package_dir["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language" + package_dir["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler" + package_dir["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime" + return package_dir + + +def get_extra_install_packages(): + return [ + "triton/triton_patch", + "triton/triton_patch/language", + "triton/triton_patch/compiler", + "triton/triton_patch/runtime", + ] + + +def precompile_hock(*args, **kargs): + third_party_base_dir = Path(kargs['third_party_base_dir']) + ascend_path = Path(third_party_base_dir) / "ascend" + patch_path = Path(ascend_path) / "triton_patch" + project_path = Path(third_party_base_dir) / "triton_ascend" + if os.path.exists(ascend_path): + shutil.rmtree(ascend_path) + if not os.path.exists(project_path): + raise RuntimeError(f"{project_path} can't be found. It might be due to a network issue") + ascend_src_path = Path(project_path) / "ascend" + patch_src_path = Path(project_path) / "triton_patch" + shutil.copytree(ascend_src_path, ascend_path, dirs_exist_ok=True) + shutil.copytree(patch_src_path, patch_path, dirs_exist_ok=True) + shutil.rmtree(project_path) + patched_code = """ set(triton_abs_dir "${TRITON_ROOT_DIR}/include/triton/Dialect/Triton/IR") """ + src_code = """set(triton_abs_dir""" + + filepath = Path(patch_path) / "include" / "triton" / "Dialect" / "Triton" / "IR" / "CMakeLists.txt" + try: + import tempfile + with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tmp_file: + with open(filepath, 'r') as file: + lines = file.readlines() + for line in lines: + if src_code in line: + tmp_file.writelines(patched_code) + else: + tmp_file.writelines(line) + backup_path = str(filepath) + '.bak' + if os.path.exists(backup_path): + os.remove(backup_path) + shutil.move(filepath, backup_path) + shutil.move(tmp_file.name, filepath) + print(f"[INFO]: {filepath} is patched") + return True + except PermissionError: + print(f"[ERROR]: No permission to write to {filepath}!") + except FileNotFoundError: + print(f"[ERROR]: {filepath} does not exist!") + except Exception as e: + print(f"[ERROR]: Unknown error: {str(e)}") + return False diff --git a/python/setup_tools/utils/xpu.py b/python/setup_tools/utils/xpu.py new file mode 100644 index 000000000..92424b1b2 --- /dev/null +++ b/python/setup_tools/utils/xpu.py @@ -0,0 +1,2 @@ +def get_package_data_tools(): + return ["compile_xpu.h", "compile_xpu.c"] diff --git a/python/src/ir.cc b/python/src/ir.cc index 0befdc491..9945c6188 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -6,10 +7,12 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" @@ -20,13 +23,13 @@ #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Transforms/LocationSnapshot.h" -#include "mlir/Transforms/Passes.h" -#include "triton/Analysis/Allocation.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/SourceMgr.h" namespace { @@ -154,6 +157,7 @@ void init_triton_ir(py::module &&m) { .value("WB", CacheModifier::WB) .value("CS", CacheModifier::CS) .value("WT", CacheModifier::WT) + .value("CV", CacheModifier::CV) .export_values(); py::enum_(m, "MEM_SEMANTIC", py::module_local()) @@ -201,16 +205,40 @@ void init_triton_ir(py::module &&m) { .value("IEEE", InputPrecision::IEEE) .export_values(); - py::class_(m, "context", py::module_local()).def(py::init<>()); + py::enum_(m, "F8F6F4TY", py::module_local()) + .value("E4M3", F8F6F4Type::E4M3) + .value("E5M2", F8F6F4Type::E5M2) + .value("E2M3", F8F6F4Type::E2M3) + .value("E3M2", F8F6F4Type::E3M2) + .value("E2M1", F8F6F4Type::E2M1) + .export_values(); + + py::class_(m, "context", py::module_local()) + .def(py::init<>()) + .def("printOpOnDiagnostic", + [](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); }) + .def("printStackTraceOnDiagnostic", + [](MLIRContext &self, bool v) { + self.printStackTraceOnDiagnostic(v); + }) + .def("disable_multithreading", + [](MLIRContext &self) { self.disableMultithreading(); }); + + py::class_(m, "source_mgr_diag", + py::module_local()) + .def(py::init()); m.def("load_dialects", [](MLIRContext &context) { DialectRegistry registry; registry.insert(); + cf::ControlFlowDialect, LLVM::LLVMDialect, + mlir::ub::UBDialect>(); + mlir::LLVM::registerInlinerInterface(registry); registerBuiltinDialectTranslation(registry); registerLLVMDialectTranslation(registry); + mlir::LLVM::registerInlinerInterface(registry); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); @@ -429,6 +457,13 @@ void init_triton_ir(py::module &&m) { return py::none(); return py::str(ret.getValue().str()); }) + .def("get_bool_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::bool_(ret.getValue()); + }) .def("get_flat_symbol_ref_attr", [](Operation &self, const std::string &name) -> py::object { auto ret = self.getAttrOfType(name); @@ -519,6 +554,9 @@ void init_triton_ir(py::module &&m) { .def( "set_arg_attr", [](FuncOp &self, int arg_no, const std::string &name, int val) { + if (arg_no >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); // set arg attributes "name" to value "val" auto attrTy = IntegerType::get(self.getContext(), 32); self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); @@ -527,29 +565,9 @@ void init_triton_ir(py::module &&m) { // .def("has_attr", &::FuncOp::hasAttr) .def("finalize", [](FuncOp &self) -> void { - // Remove dead code - // 1. Unreachable code after return - self.walk([&](Block *block) { - Operation *retOp = nullptr; - // It's better to not use walk here because we only want to - // check operations in the current block - for (auto &op : block->getOperations()) { - if (isa(op)) - if (retOp == nullptr) { - retOp = &op; - break; - } - } - if (retOp && retOp != &block->back()) { - auto pos = retOp->getIterator(); - pos++; - auto *newBlock = block->splitBlock(pos); - newBlock->erase(); - } - }); - // 2. Check if the result of tl.advance is used - self.walk([&](Operation *op) { - if (isa(op) && op->getResult(0).use_empty()) + // Check if the result of tl.advance is used + self.walk([&](AdvanceOp op) { + if (op->getResult(0).use_empty()) outputWarning(op->getLoc(), "The result of tl.advance is not " "being used. Note that tl.advance " "does not have any side effects. " @@ -720,10 +738,8 @@ void init_triton_ir(py::module &&m) { return self.getBuilder().getI64Type(); }) .def("get_fp8e4nv_ty", - // TODO: fp8e4nv is using Float8E4M3FNUZType, which - // does not seem right. It should use FloatE4M3FNType [](TritonOpBuilder &self) -> Type { - return self.getBuilder().getType(); + return self.getBuilder().getType(); }) .def("get_fp8e4b8_ty", [](TritonOpBuilder &self) -> Type { @@ -1244,7 +1260,7 @@ void init_triton_ir(py::module &&m) { evictionPolicy); }) .def("create_descriptor_load", - [](TritonOpBuilder &self, Value &desc_ptr, + [](TritonOpBuilder &self, Value desc_ptr, std::vector &indices, Type type, CacheModifier cacheModifier, EvictionPolicy evictionPolicy) -> Value { @@ -1252,11 +1268,27 @@ void init_triton_ir(py::module &&m) { type, desc_ptr, indices, cacheModifier, evictionPolicy); }) .def("create_descriptor_store", - [](TritonOpBuilder &self, Value &desc_ptr, Value value, + [](TritonOpBuilder &self, Value desc_ptr, Value value, std::vector &indices) -> void { self.create(desc_ptr, value, indices); }) + .def("create_tensormap_create", + [](TritonOpBuilder &self, Value desc_ptr, Value global_address, + std::vector box_dim, std::vector global_dim, + std::vector global_stride, + std::vector element_stride, int32_t elem_type, + int32_t interleave_layout, int32_t swizzle_mode, + int32_t fill_mode) { + self.create( + desc_ptr, global_address, box_dim, global_dim, global_stride, + element_stride, elem_type, interleave_layout, swizzle_mode, + fill_mode); + }) + .def("create_tensormap_fenceproxy_acquire", + [](TritonOpBuilder &self, Value desc_ptr) { + self.create(desc_ptr); + }) .def("create_reshape", [](TritonOpBuilder &self, Value &arg, std::vector &shape, bool allowReorder) -> Value { @@ -1374,19 +1406,13 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, int axis) -> Value { if (axis < 0 || axis > 3) throw pybind11::index_error("program_id must be in [0,3]"); - return self.create( - self.getBuilder().getI32Type(), - ProgramIDDimAttr::get(self.getBuilder().getContext(), - ProgramIDDim(axis))); + return self.create(axis); }) .def("create_get_num_programs", [](TritonOpBuilder &self, int axis) -> Value { if (axis < 0 || axis > 3) throw pybind11::index_error("program_id must be in [0,3]"); - return self.create( - self.getBuilder().getI32Type(), - ProgramIDDimAttr::get(self.getBuilder().getContext(), - ProgramIDDim(axis))); + return self.create(axis); }) .def("create_dot", [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, @@ -1395,6 +1421,15 @@ void init_triton_ir(py::module &&m) { return self.create(c.getType(), a, b, c, inputPrecision, maxNumImpreciseAcc); }) + .def("create_dot_scaled", + [](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale, + F8F6F4Type lhs_format, mlir::Value &rhs, + std::optional &rhs_scale, F8F6F4Type rhs_format, + mlir::Value &c) -> mlir::Value { + return self.create( + c.getType(), lhs, rhs, c, lhs_scale, + rhs_scale.value_or(Value()), lhs_format, rhs_format); + }) .def("create_floor", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); @@ -1495,30 +1530,26 @@ void init_triton_ir(py::module &&m) { }) .def("create_print", [](TritonOpBuilder &self, const std::string &prefix, bool hex, - const std::vector &values) -> void { - self.create( - StringAttr::get(self.getBuilder().getContext(), - llvm::StringRef(prefix)), - hex, values); + const std::vector &values, + const std::vector &isSigned) -> void { + auto prefixAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)); + self.create(prefixAttr, hex, values, isSigned); }) .def("create_assert", [](TritonOpBuilder &self, Value &condition, - const std::string &message, const std::string &fileName, - const std::string &funcName, unsigned lineNo) -> void { + const std::string &message) -> void { auto messageAttr = StringAttr::get(self.getBuilder().getContext(), llvm::StringRef(message)); - auto fileNameAttr = StringAttr::get(self.getBuilder().getContext(), - llvm::StringRef(fileName)); - auto funcNameAttr = StringAttr::get(self.getBuilder().getContext(), - llvm::StringRef(funcName)); - auto lineNoAttr = self.getBuilder().getI32IntegerAttr(lineNo); - self.create(condition, messageAttr, fileNameAttr, - funcNameAttr, lineNoAttr); - }) - // Undef - .def("create_undef", + self.create(condition, messageAttr); + }) + .def("create_assume", + [](TritonOpBuilder &self, Value &condition) { + self.create(condition); + }) + .def("create_poison", [](TritonOpBuilder &self, Type &type) -> Value { - return self.create(type); + return self.create(type); }) .def("create_histogram", [](TritonOpBuilder &self, Value operand, int numBins) -> Value { @@ -1555,6 +1586,12 @@ void init_triton_ir(py::module &&m) { bool haveDiagnostics = ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + std::string funcToDump; + if (!haveDump) { + funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP"); + if (!funcToDump.empty()) + haveDump = true; + } if (haveDiagnostics || haveDump) { context->disableMultithreading(); } @@ -1570,7 +1607,19 @@ void init_triton_ir(py::module &&m) { auto printingFlags = OpPrintingFlags(); printingFlags.elideLargeElementsAttrs(16); printingFlags.enableDebugInfo(); - auto printAlways = [](Pass *, Operation *) { return true; }; + auto printAlways = [funcToDump](Pass *, Operation *op) -> bool { + if (funcToDump.empty()) + return true; + if (auto mod = dyn_cast(op)) { + return mod.lookupSymbol(funcToDump); + } + if (auto func = dyn_cast(op)) { + return SymbolTable::getSymbolName(func).getValue() == + funcToDump; + } + + return false; + }; self.enableIRPrinting( /*shouldPrintBeforePass=*/printAlways, /*shouldPrintAfterPass=*/printAlways, @@ -1614,7 +1663,8 @@ void init_triton_ir(py::module &&m) { }); ::llvm::DebugFlag = true; - ::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + using namespace llvm; + setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); } bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 0039d1a2f..f9b98a254 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -14,12 +14,17 @@ #include "llvm/Pass.h" #include "llvm/Passes/OptimizationLevel.h" #include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" #include "llvm/Passes/StandardInstrumentations.h" #include "llvm/Support/CodeGen.h" +#include "llvm/Support/Signals.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/InstCombine/InstCombine.h" +#include +#include #include #include #include @@ -35,6 +40,30 @@ struct BreakStructPhiNodesPass : PassInfoMixin { using namespace llvm; +std::unique_ptr +createTargetMachine(llvm::Module *module, std::string proc, + bool enable_fp_fusion, const std::string &features) { + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + llvm::TargetOptions opt; + bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (enable_fp_fusion) + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + opt.MCOptions.AsmVerbose = true; + opt.MCOptions.PreserveAsmComments = true; + std::unique_ptr machine{target->createTargetMachine( + module->getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + std::nullopt, + disableLLVMOpt ? llvm::CodeGenOptLevel::None + : llvm::CodeGenOptLevel::Aggressive)}; + return machine; +} + std::string translateLLVMIRToASM(llvm::Module &module, const std::string &triple, const std::string &proc, @@ -102,21 +131,7 @@ std::string translateLLVMIRToASM(llvm::Module &module, // create machine module.setTargetTriple(triple); - std::string error; - auto target = - llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error); - llvm::TargetOptions opt; - if (enable_fp_fusion) - opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; - opt.UnsafeFPMath = false; - opt.NoInfsFPMath = false; - opt.NoNaNsFPMath = true; - opt.TrapUnreachable = true; - std::unique_ptr machine{target->createTargetMachine( - module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, - std::nullopt, - disableLLVMOpt ? llvm::CodeGenOptLevel::None - : llvm::CodeGenOptLevel::Aggressive)}; + auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); // set data layout module.setDataLayout(machine->createDataLayout()); // emit machine code @@ -148,6 +163,8 @@ void init_triton_llvm(py::module &&m) { py::class_(m, "context", py::module_local()) .def(py::init<>()); + py::class_(m, "source_mgr", py::module_local()) + .def(py::init<>()); py::class_(m, "function_list") .def( @@ -242,10 +259,28 @@ void init_triton_llvm(py::module &&m) { }, py::keep_alive<0, 2>()); + m.def("attach_datalayout", [](llvm::Module *mod, const std::string triple, + const std::string proc, + const std::string features) { + std::string error; + auto target = llvm::TargetRegistry::lookupTarget(triple, error); + if (!target) { + throw std::runtime_error("target lookup error: " + error); + } + llvm::TargetOptions opt; + // Target machine is only used to create the data layout. + std::unique_ptr machine{target->createTargetMachine( + triple, proc, features, opt, llvm::Reloc::PIC_, std::nullopt, + llvm::CodeGenOptLevel::None)}; + // set data layout + mod->setDataLayout(machine->createDataLayout()); + }); + m.def( "optimize_module", [](llvm::Module *mod, const llvm::OptimizationLevel &opt, - const std::string triple) { + std::string arch, std::string features, std::vector flags, + bool enable_fp_fusion) { if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT")) return; // Check to see if we are passing a list of flags to disable @@ -296,11 +331,38 @@ void init_triton_llvm(py::module &&m) { // regressions with some scheduling solution. tuningOptions.SLPVectorization = true; - if (!triple.empty()) - mod->setTargetTriple(triple.c_str()); - - PassBuilder pb(nullptr /*targetMachine*/, tuningOptions, std::nullopt, - instrCbPtr); + std::string pluginFile = + mlir::triton::tools::getStrEnv("LLVM_PASS_PLUGIN_PATH"); + + // We don't pass the targetMachine to the LLVM-IR pass builder, unless + // `arch` is specified. + // + // Don't set target machine in LLVM pass builder when using LLVM IR + // level plugins. LLVM IR level plugin passes typically want to insert + // calls to externally generated code (i.e. precompile a Cuda/Hip kernel + // with Clang and then insert a call to it within an instrumentation + // pass) setting the targetMachine value here can can cause a mis-match + // in the target machine between the MLIR and Clang generated kernels + // and break the lowering of some target specific intrinsics. + std::unique_ptr targetMachine = nullptr; + if (!arch.empty() && pluginFile.empty()) + targetMachine = + createTargetMachine(mod, arch, enable_fp_fusion, features); + PassBuilder pb(/*targetMachine=*/targetMachine.get(), tuningOptions, + std::nullopt, instrCbPtr); + + if (!pluginFile.empty()) { + // TODO: Add some logging here that we inserted a pass into the LLVM + // pass pipeline + auto passPlugin = llvm::PassPlugin::Load(pluginFile); + if (!passPlugin) { + llvm::Error Err = passPlugin.takeError(); + std::string ErrMsg = + "Pass Plugin Error: " + llvm::toString(std::move(Err)); + throw std::runtime_error(ErrMsg); + } + passPlugin->registerPassBuilderCallbacks(pb); + } pb.registerModuleAnalyses(mam); pb.registerCGSCCAnalyses(cgam); @@ -320,7 +382,13 @@ void init_triton_llvm(py::module &&m) { mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); mpm.run(*mod, mam); }, - py::arg("mod"), py::arg("opt"), py::arg("triple") = ""); + // Mandatory parameters + py::arg("mod"), py::arg("opt"), + // If we want to specify the target machine, we require additional + // (optional) parameters + py::arg("arch") = "", py::arg("features") = "", + py::arg("flags") = std::vector{}, + py::arg("enable_fp_fusion") = false); m.def( "translate_to_asm", @@ -403,3 +471,14 @@ void init_triton_llvm(py::module &&m) { } }); } + +void triton_stacktrace_signal_handler(void *) { + llvm::sys::PrintStackTrace(llvm::errs()); + raise(SIGABRT); +} + +void init_triton_stacktrace_hook(pybind11::module &m) { + if (mlir::triton::tools::getBoolEnv("TRITON_ENABLE_PYTHON_STACKTRACE")) { + llvm::sys::AddSignalHandler(triton_stacktrace_signal_handler, nullptr); + } +} diff --git a/python/src/main.cc b/python/src/main.cc index 5ad4be7d5..82289edc0 100644 --- a/python/src/main.cc +++ b/python/src/main.cc @@ -1,4 +1,7 @@ +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Signals.h" #include + namespace py = pybind11; #define FOR_EACH_1(MACRO, X) MACRO(X) @@ -37,10 +40,12 @@ void init_triton_ir(pybind11::module &&m); void init_triton_llvm(pybind11::module &&m); void init_triton_interpreter(pybind11::module &&m); void init_triton_passes(pybind11::module &&m); +void init_triton_stacktrace_hook(pybind11::module &m); FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; + init_triton_stacktrace_hook(m); init_triton_env_vars(m); init_triton_ir(m.def_submodule("ir")); init_triton_passes(m.def_submodule("passes")); diff --git a/python/src/passes.cc b/python/src/passes.cc index 513e811d2..37bb392da 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -39,6 +39,7 @@ void init_triton_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass); ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", createRewriteTensorPointerPass); + ADD_PASS_WRAPPER_0("add_loop_unroll", createLoopUnrollPass); ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", createConvertTritonToTritonGPUPass, const std::string &, int, int, int); @@ -65,6 +66,17 @@ void init_triton_passes_ttgpuir(py::module &&m) { createAllocateSharedMemoryPass); ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", createTritonGPUCombineTensorSelectAndIf); + ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", + createTritonGPUOptimizeAccumulatorInit); + ADD_PASS_OPTION_WRAPPER_1("add_ws_task_partition", + createTritonGPUWSTaskPartition, int); + ADD_PASS_OPTION_WRAPPER_1("add_ws_data_partition", + createTritonGPUWSDataPartition, int); + ADD_PASS_OPTION_WRAPPER_1("add_ws_lowering", createTritonGPUWSLowering, int); + ADD_PASS_OPTION_WRAPPER_1("add_taskid_propagate", + createTritonGPUTaskIdPropagate, int); + ADD_PASS_OPTION_WRAPPER_4("add_ws_code_partition", + createTritonGPUWSCodePartition, int, int, int, int); } void init_triton_passes_convert(py::module &&m) { diff --git a/python/src/plugin.h b/python/src/plugin.h new file mode 100644 index 000000000..e01b15893 --- /dev/null +++ b/python/src/plugin.h @@ -0,0 +1,57 @@ +#ifndef FLAGTREE_PLUGIN_H +#define FLAGTREE_PLUGIN_H + +#include +#include +#include +#include + +#define DEFINE_LOAD_FUNC(symbol_name) \ + static symbol_name##Func load_##symbol_name##_func(const char *backend_name, \ + const char *func_name) { \ + void *symbol = load_backend_symbol(backend_name, func_name); \ + return reinterpret_cast(symbol); \ + } + +#define DEFINE_CALL_LOAD_FUNC(backend_name, symbol_name) \ + static auto func = load_##symbol_name##_func(#backend_name, #symbol_name); + +#ifdef _WIN32 +#define PLUGIN_EXPORT __declspec(dllexport) +#else +#define PLUGIN_EXPORT __attribute__((visibility("default"))) +#endif + +static void *load_backend_plugin(const char *backend_name) { + const std::string lib_name = std::string(backend_name) + "TritonPlugin.so"; + void *handle = dlopen(lib_name.c_str(), RTLD_LAZY); + if (!handle) { + std::cerr << "Failed to load plugin: " << std::string(dlerror()); + assert(handle); + } + return handle; +} + +static void *load_backend_symbol(const char *backend_name, + const char *func_name) { + void *handle = load_backend_plugin(backend_name); + void *symbol = dlsym(handle, func_name); + if (!symbol) { + std::cerr << "Failed to load symbol: " << std::string(dlerror()); + assert(symbol); + } + return symbol; +} + +static int load_backend_const_int(const char *backend_name, + const char *const_name) { + void *handle = load_backend_plugin(backend_name); + void *symbol = dlsym(handle, const_name); + if (!symbol) { + std::cerr << "Failed to load symbol: " << std::string(dlerror()); + assert(symbol); + } + return *(const int *)symbol; +} + +#endif diff --git a/python/test/backend/third_party_backends/conftest.py b/python/test/backend/third_party_backends/conftest.py deleted file mode 100644 index d939bc001..000000000 --- a/python/test/backend/third_party_backends/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# content of conftest.py - -import pytest - - -def pytest_addoption(parser): - parser.addoption("--backend", action="store", default="", help="Codegen backend") - - -@pytest.fixture -def cmdopt(request): - return request.config.getoption("--backend") diff --git a/python/test/backend/third_party_backends/test_xpu_backend.py b/python/test/backend/third_party_backends/test_xpu_backend.py deleted file mode 100644 index ededb0a07..000000000 --- a/python/test/backend/third_party_backends/test_xpu_backend.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch - - -def test_xpu_backend(cmdopt): - if cmdopt == "xpu": - has_ipex = False - try: - # Import IPEX to provide Intel GPU runtime - import intel_extension_for_pytorch # type: ignore # noqa: F401 - has_ipex = True if hasattr(torch, "xpu") else False - except Exception: - has_ipex = False - - import triton - import triton.language as tl - - @triton.jit() - def kernel(x_ptr, y_ptr, out_ptr): - pid = tl.program_id(axis=0) - x = tl.load(x_ptr + pid) - y = tl.load(y_ptr + pid) - out = x + y - tl.store(out_ptr + pid, out) - - if has_ipex: - for _ in range(1000): - x = torch.randn((65536, ), device="xpu", dtype=torch.float32) - y = torch.randn((65536, ), device="xpu", dtype=torch.float32) - z = torch.zeros((65536, ), device="xpu", dtype=torch.float32) - kernel[(65536, )](x, y, z, num_warps=32) - assert torch.all(x + y == z) - else: - return diff --git a/python/test/regression/conftest.py b/python/test/regression/conftest.py new file mode 100644 index 000000000..d88687b45 --- /dev/null +++ b/python/test/regression/conftest.py @@ -0,0 +1,22 @@ +import os +import pytest +import tempfile + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default="cuda") + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") + + +@pytest.fixture +def fresh_triton_cache(): + with tempfile.TemporaryDirectory() as tmpdir: + try: + os.environ["TRITON_CACHE_DIR"] = tmpdir + yield tmpdir + finally: + os.environ.pop("TRITON_CACHE_DIR", None) diff --git a/python/test/regression/test_cast_matmul.py b/python/test/regression/test_cast_matmul.py index 253bfbe89..67c216b4b 100644 --- a/python/test/regression/test_cast_matmul.py +++ b/python/test/regression/test_cast_matmul.py @@ -1,20 +1,68 @@ """ +Mixed precision tests for matmul (tl.dot) with cast (tl.to) + issue: https://github.com/triton-lang/triton/issues/2523 -fused type convert and matmul, base on triton matmul, the different with matmul: -1. force C's dtype=dot_out_dtype to ["float16", "float32"] -2. accept A and B with dtype=["float32", "float64"] +TODO: float8 types """ + import pytest import torch +import triton import triton.language as tl -from triton import cdiv, jit -input_dtypes = ["float32", "float64"] +input_dtypes = ["float16", "float32", "float64"] out_dtypes = ["float16", "float32"] +@triton.jit +def matmul_kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + dot_out_dtype: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # + BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr): + # matrix multiplication + pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + acc += tl.dot(a, b, out_dtype=dot_out_dtype) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc, mask=mask) + + @pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype", [(M, K, N, w, x, o) # for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] # @@ -23,7 +71,7 @@ for o in out_dtypes]) def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): if x_dtype == w_dtype: - pytest.skip("skip same dtype") + pytest.skip("skip the same input dtype") device = torch.cuda.current_device() x_dtype = getattr(torch, x_dtype) w_dtype = getattr(torch, w_dtype) @@ -36,53 +84,7 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): # launch kernel BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32 - grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1) - - @jit - def matmul_kernel(A, B, C, M, N, K, # - stride_am, stride_ak, # - stride_bk, stride_bn, # - stride_cm, stride_cn, # - dot_out_dtype: tl.constexpr, # - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # - BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr): - # matrix multiplication - pid = tl.program_id(0) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) - for k in range(0, tl.cdiv(K, BLOCK_K)): - k_remaining = K - k * BLOCK_K - _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) - a = a.to(C.dtype.element_ty) - b = b.to(C.dtype.element_ty) - acc += tl.dot(a, b, out_dtype=dot_out_dtype) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - acc = acc.to(C.dtype.element_ty) - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - tl.store(C, acc, mask=mask) + grid = ((triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), 1) matmul_kernel[grid]( a, b, out_triton, M, N, K, # diff --git a/python/test/regression/test_functional_regressions.py b/python/test/regression/test_functional_regressions.py index 43f58715b..82298c41c 100644 --- a/python/test/regression/test_functional_regressions.py +++ b/python/test/regression/test_functional_regressions.py @@ -224,3 +224,18 @@ def grid(META): BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, # num_stages=num_stages) torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2) + + +def test_reverse_range(device): + + @triton.jit + def kernel(in_ptr, out_ptr): + x0 = tl.arange(0, 512) + tmp0 = tl.load(in_ptr + (512 - x0)) + tl.store(out_ptr + x0, tmp0) + + data = torch.randn((516, ), dtype=torch.float32, device=device) + res = torch.empty((512, ), dtype=torch.float32, device=device) + kernel[(1, )](data, res) + ref = torch.flip(data[1:513], [0]) + assert (res == ref).all() diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py deleted file mode 100644 index 8c50e5ad5..000000000 --- a/python/test/regression/test_performance.py +++ /dev/null @@ -1,267 +0,0 @@ -import pytest -import torch - -import triton -import triton.language as tl -import triton.ops -from triton.testing import get_dram_gbps, get_max_tensorcore_tflops, nvsmi - -DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]] - -####################### -# Utilities -####################### - - -def print_perf(cur_ms, cur_util, ref_util): - # print on the same line cur_ms, cur_util and ref_util with 3 decimal places - print(f'{cur_ms:.3f} ms \t cur: {cur_util:.3f} \t ref: {ref_util:.3f} \t dif={cur_util - ref_util:.3f}', end='\t') - - -####################### -# Matrix Multiplication -####################### - -sm_clocks = {'v100': 1350, 'a100': 1350} -mem_clocks = {'v100': 877, 'a100': 1215} - -matmul_data = { - 'a100': { - # square - (512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05}, - (1024, 1024, 1024): {'float16': 0.355, 'float32': 0.313, 'int8': 0.169}, - (2048, 2048, 2048): {'float16': 0.653, 'float32': 0.532, 'int8': 0.34}, - (8192, 8192, 8192): {'float16': 0.839, 'float32': 0.754, 'int8': 0.51}, - # tall-skinny - (16, 1024, 1024): {'float16': 0.015, 'float32': 0.009, 'int8': 0.005}, - (16, 4096, 4096): {'float16': 0.080, 'float32': 0.051, 'int8': 0.026}, - (16, 8192, 8192): {'float16': 0.083, 'float32': 0.077, 'int8': 0.043}, - (64, 1024, 1024): {'float16': 0.045, 'float32': 0.023, 'int8': 0.017}, - (64, 4096, 4096): {'float16': 0.170, 'float32': 0.000, 'int8': 0.097}, - (64, 8192, 8192): {'float16': 0.227, 'float32': 0.000, 'int8': 0.174}, - (1024, 64, 1024): {'float16': 0.040, 'float32': 0.046, 'int8': 0.017}, - (4096, 64, 4096): {'float16': 0.160, 'float32': 0.214, 'int8': 0.102}, - (8192, 64, 8192): {'float16': 0.272, 'float32': 0.000, 'int8': 0.177}, - # test EVEN_K==False - (8192, 8192, 8176): {'float16': 0.828, 'float32': 0.743, 'int8': 0.51}, - } -} - - -@pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str) - for M, N, K in matmul_data[DEVICE_NAME].keys() - for dtype_str in ['float16']]) -def test_matmul(M, N, K, dtype_str): - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) - if dtype_str in ['float32', 'int8'] and DEVICE_NAME != 'a100': - pytest.skip('Only test float32 & int8 on a100') - if (M, N, K) in [(64, 4096, 4096), (64, 8192, 8192), (8192, 64, 8192)] and dtype_str == 'float32': - pytest.skip('Out of shared memory in float32') - dtype = {'float16': torch.float16, 'float32': torch.float32, 'int8': torch.int8}[dtype_str] - torch.manual_seed(0) - ref_gpu_util = matmul_data[DEVICE_NAME][(M, N, K)][dtype_str] - cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) - if dtype == torch.int8: - a = torch.randint(-128, 127, (M, K), dtype=dtype, device='cuda') - b = torch.randint(-128, 127, (N, K), dtype=dtype, device='cuda') - b = b.t() # only test row-col layout - else: - a = torch.randn((M, K), dtype=dtype, device='cuda') - b = torch.randn((K, N), dtype=dtype, device='cuda') - fn = lambda: triton.ops.matmul(a, b) - ms = triton.testing.do_bench_cudagraph(fn) - cur_gpu_perf = 2. * M * N * K / ms * 1e-9 - cur_gpu_util = cur_gpu_perf / max_gpu_perf - print_perf(ms, cur_gpu_util, ref_gpu_util) - triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) - - -####################### -# Element-Wise -####################### - - -@triton.jit -def _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - -elementwise_data = { - 'a100': { - 1024 * 16: {'float16': 0.031, 'float32': 0.060}, - 1024 * 64: {'float16': 0.120, 'float32': 0.224}, - 1024 * 256: {'float16': 0.394, 'float32': 0.691}, - 1024 * 1024: {'float16': 1.06, 'float32': 1.453}, - 1024 * 16384: {'float16': 0.832, 'float32': 0.862}, - 1024 * 65536: {'float16': 0.873, 'float32': 0.882}, - # Non pow 2 - 1020 * 100: {'float16': 0.173, 'float32': 0.327}, - 10003 * 7007: {'float16': 0.522, 'float32': 0.873}, - } -} - - -@pytest.mark.parametrize('N', elementwise_data[DEVICE_NAME].keys()) -@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32']) -def test_elementwise(N, dtype_str): - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) - torch.manual_seed(0) - if dtype_str in ['bfloat16'] and DEVICE_NAME != 'a100': - pytest.skip('Only test bfloat16 on a100') - dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str] - ref_dtype_str = 'float16' if dtype_str == 'bfloat16' else dtype_str - ref_gpu_util = elementwise_data[DEVICE_NAME][N][ref_dtype_str] - max_gpu_perf = get_dram_gbps() - z = torch.empty((N, ), dtype=dtype, device='cuda') - x = torch.randn_like(z) - y = torch.randn_like(z) - grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) - fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024) - ms = triton.testing.do_bench_cudagraph(fn) - cur_gpu_perf = 3. * N * z.element_size() / ms * 1e-6 - cur_gpu_util = cur_gpu_perf / max_gpu_perf - print_perf(ms, cur_gpu_util, ref_gpu_util) - triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) - - -####################### -# Flash-Attention -####################### - -flash_attention_data = { - "a100": { - (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542, - (4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471, - (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155, - (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.232, - (4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.231, - (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.138, - (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306, - (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266, - (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098, - (4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.157, - (4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.157, - (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.092, - (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541, - (4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471, - (4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150, - (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.291, - (4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.255, - (4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.144, - (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.306, - (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266, - (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098, - (4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159, - (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.159, - (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088, - } -} - - -@pytest.mark.parametrize("dtype_str", ['float16', 'bfloat16', 'float32']) -@pytest.mark.parametrize("mode", ['forward', 'backward']) -@pytest.mark.parametrize("causal", [True, False]) -@pytest.mark.parametrize("seq_par", [True, False]) -@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [[4, 48, 4096, 64]]) -def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str): - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) - is_backward = mode == 'backward' - capability = torch.cuda.get_device_capability() - if capability[0] < 8: - pytest.skip("Flash attention only supported for compute capability < 80") - torch.manual_seed(20) - dtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype_str] - # init data - if dtype_str == 'float32': - N_CTX = 1024 - D_HEAD = 16 - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() - sm_scale = 0.2 - # benchmark - fn = lambda: triton.ops.attention(q, k, v, causal, sm_scale, seq_par) - if is_backward: - o = fn() - do = torch.randn_like(o) - fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench_cudagraph(fn) - # compute flops - flops_per_matmul = 2. * Z * H * N_CTX * N_CTX * D_HEAD * 0.5 - total_flops = 2 * flops_per_matmul - if is_backward: - total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) - cur_gpu_perf = total_flops / ms * 1e-9 - # maximum flops - cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) - cur_gpu_util = cur_gpu_perf / max_gpu_perf - ref_gpu_util = flash_attention_data[DEVICE_NAME][(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str)] - print_perf(ms, cur_gpu_util, ref_gpu_util) - triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) - - -####################### -# Reduction -####################### - - -@triton.jit -def _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - # run in a loop to only to make it compute bound. - for i in range(100): - x = tl.sum(x, axis=0) + y - - tl.store(output_ptr + offsets, x, mask=mask) - - -reduction_data = { - 'a100': { - 1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.022, 'int32': 0.048}, - 1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.022, 'int32': 0.049}, - } -} - - -@pytest.mark.parametrize('N', reduction_data[DEVICE_NAME].keys()) -@pytest.mark.parametrize("dtype_str", ['float16', 'float32', 'int16', 'int32']) -def test_reductions(N, dtype_str): - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) - torch.manual_seed(0) - dtype = {'float16': torch.float16, 'float32': torch.float32, 'int16': torch.int16, 'int32': torch.int32}[dtype_str] - ref_gpu_util = reduction_data[DEVICE_NAME][N][dtype_str] - cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - max_gpu_perf = get_max_tensorcore_tflops(dtype, clock_rate=cur_sm_clock * 1e3) - z = torch.empty((N, ), dtype=dtype, device='cuda') - if dtype == torch.float16 or dtype == torch.float32: - x = torch.randn_like(z) - y = torch.randn_like(z) - else: - info = torch.iinfo(dtype) - x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda') - y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda') - grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) - fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024) - ms = triton.testing.do_bench_cudagraph(fn) - cur_gpu_perf = 100. * 2. * N / ms * 1e-9 - cur_gpu_util = cur_gpu_perf / max_gpu_perf - print_perf(ms, cur_gpu_util, ref_gpu_util) - triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) diff --git a/python/test/unit/conftest.py b/python/test/unit/conftest.py index 7a02d322b..d88687b45 100644 --- a/python/test/unit/conftest.py +++ b/python/test/unit/conftest.py @@ -1,12 +1,22 @@ -# content of conftest.py - +import os import pytest +import tempfile def pytest_addoption(parser): - parser.addoption("--device", action="store", default='cuda') + parser.addoption("--device", action="store", default="cuda") @pytest.fixture def device(request): return request.config.getoption("--device") + + +@pytest.fixture +def fresh_triton_cache(): + with tempfile.TemporaryDirectory() as tmpdir: + try: + os.environ["TRITON_CACHE_DIR"] = tmpdir + yield tmpdir + finally: + os.environ.pop("TRITON_CACHE_DIR", None) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index b20f75bc5..9695a5e47 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -1,83 +1,68 @@ -import numpy as np import pytest import torch -import tempfile import triton import triton.language as tl +from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor) +from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma -def test_descriptor_load_ttgir(): - if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: - pytest.skip("Test requires Hopper target.") - return - device = "cuda" - SIZE = 128 +def create_tma_desc_gmem_ptr(ptr, dims, block_dims, element_size): + cpu_desc = torch.empty(128, device="cpu") + if len(dims) == 1: + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, + cpu_desc.data_ptr()) + else: + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], block_dims[1], + element_size, cpu_desc.data_ptr()) + return cpu_desc.cuda() - x = torch.randn(SIZE, dtype=torch.float32, device=device) - desc = np.empty(SIZE, dtype=np.int8) - triton.runtime.driver.active.utils.fill_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size(), desc) - size_in_bytes = SIZE * x.element_size() - - ir = f""" - #blocked = #triton_gpu.blocked<{{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}}> - #shared = #triton_gpu.shared<{{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}}> - module attributes {{"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ - tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ - %c0_i32 = arith.constant 0 : i32 - %0 = tt.make_range {{end = {SIZE} : i32, start = 0 : i32}} : tensor<{SIZE}xi32, #blocked> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, mutable> - %2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, mutable> - triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, mutable> - %true = arith.constant 1 : i1 - triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : , <1xi64, #shared, mutable> -> <{SIZE}xf32, #shared, mutable> - triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, mutable> - %3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, mutable> -> tensor<{SIZE}xf32, #blocked> - %4 = tt.splat %arg0 : !tt.ptr -> tensor<{SIZE}x!tt.ptr, #blocked> - %5 = tt.addptr %4, %0 : tensor<{SIZE}x!tt.ptr, #blocked>, tensor<{SIZE}xi32, #blocked> - tt.store %5, %3 : tensor<{SIZE}x!tt.ptr, #blocked> - tt.return - }} - }} - """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) - - desc = torch.tensor(desc, device=device) - z_tri = torch.empty_like(x) - kernel[(1, 1, 1)](z_tri, desc) - assert torch.equal(x, z_tri) +def unwrap_tensor(t: torch.Tensor | triton.runtime.jit.TensorWrapper): + if isinstance(t, triton.runtime.jit.TensorWrapper): + return t.base + return t -def test_experimetal_descriptor_load(): - if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: - pytest.skip("Test requires Hopper target.") - return + +tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"}) + + +@requires_tma +@pytest.mark.parametrize("byval_tma", [True, False]) +def test_experimetal_descriptor_load(byval_tma): device = "cuda" SIZE = 128 @triton.jit - def kernel(Z, desc, SIZE: tl.constexpr): + def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr): + if not BYVAL_TMA: + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc) off_desc = 0 off = tl.arange(0, SIZE) x = tl._experimental_descriptor_load(desc, [off_desc], [SIZE], Z.dtype.element_ty) tl.store(Z + off, x) x = torch.randn(SIZE, dtype=torch.float32, device=device) - desc = np.empty(SIZE, dtype=np.int8) - triton.runtime.driver.active.utils.fill_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size(), desc) - desc = torch.tensor(desc, device=device) + if byval_tma: + desc = create_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size()) + else: + desc = create_tma_desc_gmem_ptr(x.data_ptr(), [SIZE], [SIZE], x.element_size()) z_tri = torch.empty_like(x) - kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4) + compiled_kernel = kernel[(1, )](z_tri, desc, SIZE=SIZE, BYVAL_TMA=byval_tma, num_warps=4) assert torch.equal(x, z_tri) + if byval_tma: + assert ".param .align 64 .b8" in compiled_kernel.asm["ptx"] @triton.jit def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # - M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BYVAL_TMA: tl.constexpr, dtype: tl.constexpr): + if not BYVAL_TMA: + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) pid_m = pid % num_pid_m @@ -87,44 +72,161 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # offs_k = 0 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], tl.float16) - b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16) + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], dtype) accumulator = tl.dot(a, b, acc=accumulator) offs_k += BLOCK_SIZE_K - accumulator = accumulator.to(tl.float16) + accumulator = accumulator.to(dtype) tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) +@requires_tma @pytest.mark.parametrize("num_stages", [1, 4]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) -def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K): - if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: - pytest.skip("Test requires Hopper target.") - return +@pytest.mark.parametrize("byval_tma", [True, False]) +def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): device = "cuda" M, N, K = 8192, 8192, 1024 torch.manual_seed(42) A = torch.randn((M, K), dtype=torch.float16, device=device) B = torch.randn((K, N), dtype=torch.float16, device=device) C = torch.empty((M, N), dtype=torch.float16, device=device) - TMA_SIZE = 128 - desc_a = np.empty(TMA_SIZE, dtype=np.int8) - desc_b = np.empty(TMA_SIZE, dtype=np.int8) - desc_c = np.empty(TMA_SIZE, dtype=np.int8) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size(), - desc_a) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size(), - desc_b) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size(), - desc_c) - - desc_a = torch.tensor(desc_a, device=device) - desc_b = torch.tensor(desc_b, device=device) - desc_c = torch.tensor(desc_c, device=device) + if byval_tma: + desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size()) + desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size()) + desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size()) + else: + desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size()) + desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size()) + desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size()) kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, - 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps=8, - num_stages=num_stages) + 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma, + num_warps=8, num_stages=num_stages, dtype=tl.float16) ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) if BLOCK_M >= 64 and BLOCK_N >= 64: assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + if byval_tma: + assert ".param .align 64 .b8" in kernel.asm["ptx"] + + +@triton.jit +def device_tensormap_kernel2d(in_ptr, out_ptr, in_desc, out_desc, ready_flag, M, N, M_BLOCK: tl.constexpr, + N_BLOCK: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + if pid_m == 0 and pid_n == 0: + # Write out descriptor + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=in_desc, + global_address=in_ptr, + load_size=[M_BLOCK, N_BLOCK], + global_size=[M, N], + element_ty=in_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=out_desc, + global_address=out_ptr, + load_size=[M_BLOCK, N_BLOCK], + global_size=[M, N], + element_ty=out_ptr.dtype.element_ty, + ) + tl.atomic_xchg(ready_flag, 1, sem="release") + else: + # Spin until descriptor is ready + flag = tl.full([], 0, tl.int32) + while flag == 0: + flag = tl.atomic_add(ready_flag, 0, sem="acquire") + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(in_desc) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(out_desc) + + moffset = pid_m * M_BLOCK + noffset = pid_n * N_BLOCK + + x = tl._experimental_descriptor_load(in_desc, [moffset, noffset], [M_BLOCK, N_BLOCK], in_ptr.dtype.element_ty) + tl._experimental_descriptor_store(out_desc, x, [moffset, noffset]) + + +@requires_tma +@pytest.mark.parametrize("dtype_str", tma_dtypes) +def test_device_tensormap2d(dtype_str): + M_BLOCK, N_BLOCK = 32, 64 + M_GRID, N_GRID = 2, 4 + + shape = (M_BLOCK * M_GRID, M_BLOCK * N_GRID) + device = "cuda" + inp = to_triton(numpy_random(shape, dtype_str=dtype_str), device=device, dst_type=dtype_str) + inp_copy = inp.clone() + out = to_triton(numpy_random(shape, dtype_str=dtype_str), device=device, dst_type=dtype_str) + + in_desc = torch.randint(0, 256, size=(128, ), dtype=torch.uint8, device="cuda") + out_desc = torch.randint(0, 256, size=(128, ), dtype=torch.uint8, device="cuda") + ready_flag = torch.zeros((), dtype=torch.int32, device="cuda") + + device_tensormap_kernel2d[M_GRID, N_GRID](inp, out, in_desc, out_desc, ready_flag, *shape, M_BLOCK=M_BLOCK, + N_BLOCK=N_BLOCK) + + # Check results are correct + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(out)) + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(inp_copy)) + + +@triton.jit +def device_tensormap_kernel1d(in_ptr, out_ptr, in_desc, out_desc, ready_flag, numel, BLOCK: tl.constexpr): + pid = tl.program_id(axis=0) + + if pid == 0: + # Write out descriptor + tl.extra.cuda.experimental_device_tensormap_create1d( + desc_ptr=in_desc, + global_address=in_ptr, + load_size=BLOCK, + global_size=numel, + element_ty=in_ptr.dtype.element_ty, + ) + tl.extra.cuda.experimental_device_tensormap_create1d( + desc_ptr=out_desc, + global_address=out_ptr, + load_size=BLOCK, + global_size=numel, + element_ty=out_ptr.dtype.element_ty, + ) + tl.atomic_xchg(ready_flag, 1, sem="release") + else: + # Spin until descriptor is ready + flag = tl.full([], 0, tl.int32) + while flag == 0: + flag = tl.atomic_add(ready_flag, 0, sem="acquire") + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(in_desc) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(out_desc) + + offset = pid * BLOCK + + x = tl._experimental_descriptor_load(in_desc, [offset], [BLOCK], in_ptr.dtype.element_ty) + tl._experimental_descriptor_store(out_desc, x, [offset]) + + +@requires_tma +@pytest.mark.parametrize("dtype_str", tma_dtypes) +def test_device_tensormap1d(dtype_str): + BLOCK = 256 + GRID = 8 + + shape = (BLOCK * GRID, ) + device = "cuda" + inp = to_triton(numpy_random(shape, dtype_str=dtype_str), device=device, dst_type=dtype_str) + inp_copy = inp.clone() + out = to_triton(numpy_random(shape, dtype_str=dtype_str), device=device, dst_type=dtype_str) + + in_desc = torch.randint(0, 256, size=(128, ), dtype=torch.uint8, device="cuda") + out_desc = torch.randint(0, 256, size=(128, ), dtype=torch.uint8, device="cuda") + ready_flag = torch.zeros((), dtype=torch.int32, device="cuda") + + device_tensormap_kernel1d[ + 1, + ](inp, out, in_desc, out_desc, ready_flag, *shape, BLOCK=BLOCK) + + # Check results are correct + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(out)) + torch.testing.assert_close(unwrap_tensor(inp), unwrap_tensor(inp_copy)) diff --git a/python/test/unit/hopper/test_flashattention.py b/python/test/unit/hopper/test_flashattention.py index fc8db664c..5053cfc4b 100644 --- a/python/test/unit/hopper/test_flashattention.py +++ b/python/test/unit/hopper/test_flashattention.py @@ -435,8 +435,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): @triton.testing.perf_report(configs) def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): assert mode in ['fwd', 'bwd'] - warmup = 25 - rep = 100 if provider == "triton": q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) @@ -447,7 +445,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + ms = triton.testing.do_bench(fn) return ms if provider == "flash": lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) @@ -459,7 +457,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + ms = triton.testing.do_bench(fn) return ms diff --git a/python/test/unit/hopper/test_tma_descriptor.py b/python/test/unit/hopper/test_tma_descriptor.py new file mode 100644 index 000000000..497248b6b --- /dev/null +++ b/python/test/unit/hopper/test_tma_descriptor.py @@ -0,0 +1,49 @@ +import pytest +import torch +from triton.tools.experimental_descriptor import create_1d_tma_descriptor, create_2d_tma_descriptor + + +@pytest.mark.parametrize("M, BLOCK_M, expect_error", [(128, 32, False), (127, 32, False), (128, 31, True)]) +def test_1d_tma_descriptor_exception(M, BLOCK_M, expect_error): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: + pytest.skip("Test requires Hopper target.") + return + + device = "cuda" + x = torch.randn(M, dtype=torch.float32, device=device) + # globalAddress in the tma descriptor must be aligned to 16 bytes for CU_TENSOR_MAP_INTERLEAVE_NONE. + # https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY + assert x.data_ptr() % 16 == 0 + is_error = False + + try: + create_1d_tma_descriptor(x.data_ptr(), M, BLOCK_M, x.element_size()) + except RuntimeError as e: + is_error = True + assert e.args[0] == "Triton Error [CUDA]: invalid argument" + + assert is_error == expect_error + + +@pytest.mark.parametrize("M, BLOCK_M", [(128, 32), (125, 33)]) +@pytest.mark.parametrize("N, BLOCK_N, expect_error", [(128, 32, False), (128, 30, True), (127, 32, True)]) +def test_2d_tma_descriptor_exception(M, N, BLOCK_M, BLOCK_N, expect_error): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: + pytest.skip("Test requires Hopper target.") + return + + device = "cuda" + torch.manual_seed(42) + A = torch.randn((M, N), dtype=torch.float16, device=device) + # globalAddress in the tma descriptor must be aligned to 16 bytes for CU_TENSOR_MAP_INTERLEAVE_NONE. + # https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY + assert A.data_ptr() % 16 == 0 + is_error = False + + try: + create_2d_tma_descriptor(A.data_ptr(), M, N, BLOCK_M, BLOCK_N, A.element_size()) + except RuntimeError as e: + is_error = True + assert e.args[0] == "Triton Error [CUDA]: invalid argument" + + assert is_error == expect_error diff --git a/python/test/unit/instrumentation/test_gpuhello.py b/python/test/unit/instrumentation/test_gpuhello.py new file mode 100644 index 000000000..413c3f642 --- /dev/null +++ b/python/test/unit/instrumentation/test_gpuhello.py @@ -0,0 +1,49 @@ +import torch + +import pytest +import os + +import triton +import triton.language as tl + +test_stdout = 'Hello From First Instruction of GPU Kernel: kernel1\ttest_gpuhello.py:17:4\n\ +Hello From First Instruction of GPU Kernel: kernel2\ttest_gpuhello.py:23:4\n\ +Hello From First Instruction of GPU Kernel: kernel3\ttest_gpuhello.py:29:4\n' + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel1(BLOCK_SIZE: tl.constexpr): + return + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel2(BLOCK_SIZE: tl.constexpr): + return + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel3(BLOCK_SIZE: tl.constexpr): + return + + +def func(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + assert x.is_cuda and y.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + kernel1[grid](BLOCK_SIZE=1024) + kernel2[grid](BLOCK_SIZE=1024) + kernel3[grid](BLOCK_SIZE=1024) + + +def test_op(capfd): + size = 98432 + x = torch.rand(size, device='cuda') + y = torch.rand(size, device='cuda') + func(x, y) + stdout, stderr = capfd.readouterr() + if 'LLVM_PASS_PLUGIN_PATH' in os.environ: + assert repr(stderr) == repr(test_stdout) diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index e032792f3..30a4745db 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -34,6 +34,13 @@ def kernel_print(X, Y, BLOCK: tl.constexpr): tl.store(Y + tl.arange(0, BLOCK), x) +@triton.jit +def kernel_device_print_scalar(SCALAR): + x = tl.load(SCALAR) + # Triton should add a space after this prefix. + print("x:", x) + + @triton.jit def kernel_device_print_large( BLOCK_M: tl.constexpr, @@ -83,17 +90,26 @@ def kernel_print_pointer(X, Y, BLOCK: tl.constexpr): tl.device_print("ptr ", X + tl.arange(0, BLOCK)) -def test_print(func: str, data_type: str): +def test_print(func: str, data_type: str, device: str): N = 128 # This value should match with test_print in test_subprocess.py. # TODO(antiagainst): Currently the warp count is chosen to make sure wedon't have multiple # threads printing duplicated messages due to broadcasting. Improve print op lowering logic # to filter out duplicated data range. num_warps = N // get_current_target_warp_size() - x = torch.arange(0, N, dtype=torch.int32, device='cuda').to(getattr(torch, data_type)) - y = torch.zeros((N, ), dtype=x.dtype, device="cuda") + x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type)) + y = torch.zeros((N, ), dtype=x.dtype, device=device) if func == "device_print": kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_scalar": + scalar = torch.tensor(42, dtype=x.dtype, device=device) + kernel_device_print_scalar[(1, )](scalar, num_warps=num_warps) + elif func == "device_print_negative": + x = -x + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_uint": + x = torch.arange((1 << 31), (1 << 31) + N, device=device).to(getattr(torch, data_type)) + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) elif func == "print": kernel_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) elif func == "device_print_large": @@ -117,9 +133,14 @@ def test_print(func: str, data_type: str): if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ func != "print_multiple_args" and func != "device_print_multiple_args" and \ - func != "device_print_pointer": + func != "device_print_pointer" and func != "device_print_scalar": assert_close(y, x) + # Wait until driver complete all the jobs for the device_print, especially test_subprocess + # require this which captures stdout when child exits. + getattr(torch, device).synchronize() + if __name__ == "__main__": - test_print(sys.argv[1], sys.argv[2]) + fn = globals()[sys.argv[1]] + fn(*sys.argv[2:]) diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index c932131c9..8e84a9f82 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -3,6 +3,7 @@ import triton import triton.language as tl +from test_core import check_type_supported @triton.jit @@ -13,23 +14,29 @@ def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: block_shape=(BLOCK_SIZE, ), order=(0, )) b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), block_shape=(BLOCK_SIZE, ), order=(0, )) - a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) + if padding_option is None: + a = tl.load(a_block_ptr, boundary_check=(0, )) + else: + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) tl.store(b_block_ptr, a, boundary_check=(0, )) @pytest.mark.interpreter @pytest.mark.parametrize("dtypes_str, n, padding_option", [ # (dtypes_str, n, padding) - for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("float16", "float16"), ("int16", "float16")) + for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"), + ("float32", "float32"), ("bfloat16", "bfloat16")) for n in (64, 128, 256, 512, 1024) - for padding in ("zero", "nan") # + for padding in (None, "zero", "nan") # ]) def test_block_copy(dtypes_str, n, padding_option, device): src_dtype_str = dtypes_str[0] - dst_dtype_str = dtypes_str[0] + dst_dtype_str = dtypes_str[1] src_dtype = getattr(torch, src_dtype_str) dst_dtype = getattr(torch, dst_dtype_str) - if src_dtype_str in ("bool", "int16"): + check_type_supported(src_dtype, device) + check_type_supported(dst_dtype, device) + if src_dtype_str in ("bool", "int16", "int32"): if padding_option == "nan": pytest.skip("Padding with NaN is not supported for integer types") a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype) @@ -43,7 +50,7 @@ def test_block_copy(dtypes_str, n, padding_option, device): assert torch.all(a[0:n // 2] == b[0:n // 2]) if padding_option == "zero": assert torch.all(b[n // 2:n] == 0) - else: + elif padding_option == "nan": assert torch.all(torch.isnan(b[n // 2:n])) diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 0531f8ebc..12c3997ec 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -1,11 +1,30 @@ +import contextlib import pytest +import os +import torch import triton import triton.language as tl from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure import traceback +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cuda(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_on_mi300(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') + + def test_err_undefined_variable(): @triton.jit @@ -131,7 +150,8 @@ def kernel(): try: inner = e.value.__cause__ outer = e.value - assert "/core.py" in '\n'.join(traceback.format_tb(inner.__traceback__)), "error should point inside core.py" + assert f"{os.sep}core.py" in '\n'.join(traceback.format_tb( + inner.__traceback__)), "error should point inside core.py" assert "at 2:4:" in str(outer), "error should point to expand_dims call" assert "" not in str(outer) @@ -155,6 +175,15 @@ def kernel(): triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) +def test_not_const_annotate_no_err(): + + @triton.jit + def kernel(N: int = 1): + pass + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + + @triton.jit def returns_branched_on_constexpr(N: tl.constexpr): if N == 0: @@ -301,4 +330,81 @@ def kernel(a=GLOBAL): pass # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={0: "i32"}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={})) + + +def test_defaults_assign_no_err(): + + @triton.jit + def kernel(a=1, B: tl.constexpr = ""): + pass + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""})) + + +def test_where_warning(fresh_triton_cache): + + @triton.jit + def kernel(): + a = tl.full((64, ), 0, tl.uint32) + b = tl.full((64, ), 1, tl.float32) + c = tl.full((64, ), 2, tl.float32) + tl.where(a, b, c) + + with pytest.warns(UserWarning): + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]) +def test_fp8_support(dtype): + warning_dtypes = [] + supported_dtypes = [tl.float8e5] + if is_cuda(): + cc = torch.cuda.get_device_capability(0) + supported_dtypes.append(tl.float8e4b15) + if cc >= (9, 0): + warning_dtypes.append(tl.float8e4b15) + if cc >= (8, 9): + supported_dtypes.append(tl.float8e4nv) + elif is_hip(): + if is_on_mi300(): + supported_dtypes += [tl.float8e4b8, tl.float8e5b16] + elif is_interpreter(): + supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15] + + @triton.jit + def dtype_kernel(dtype: tl.constexpr): + _ = tl.full((256, ), 0.0, dtype) + + if dtype in warning_dtypes: + ctx = pytest.warns(UserWarning, match=r"fp8e4b15 is deprecated in this architecture") + elif dtype in supported_dtypes: + ctx = contextlib.nullcontext() + else: + ctx = pytest.raises(CompilationError, match="") + + with ctx as e: + triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype})) + + if dtype not in supported_dtypes: + try: + assert ("not supported in this architecture" in str(e.value.__cause__)) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_max_num_imprecise_acc_limit(): + + @triton.jit + def dot_kernel(): + SIZE: tl.constexpr = 64 + a = tl.full((SIZE, SIZE), 0.0, tl.float8e5) + b = tl.full((SIZE, SIZE), 0.0, tl.float8e5) + tl.dot(a, b, max_num_imprecise_acc=128) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={})) + try: + assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)") + except AssertionError as assertion_err: + raise assertion_err from e.value diff --git a/python/test/unit/language/test_conversions.py b/python/test/unit/language/test_conversions.py index 467b398b7..723a15fe8 100644 --- a/python/test/unit/language/test_conversions.py +++ b/python/test/unit/language/test_conversions.py @@ -251,11 +251,11 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr, device=device) dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device) - dst = launch_type_convert_triton(dst, dst_dtype, tl.float32, device=device) + dst_to_float32 = launch_type_convert_triton(dst, dst_dtype, tl.float32, device=device) - dst2 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device=device) + src_emulated_to_float32 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device=device) - assert(torch.equal(dst, dst2)) + assert(torch.equal(src_emulated_to_float32, dst_to_float32)) @pytest.mark.parametrize("src_dtype, dst_dtype", [ @@ -281,15 +281,13 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia ('float8e5b16', 'float16'), ]) def test_typeconvert_upcast(src_dtype, dst_dtype, device): - - if src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): - pytest.skip("float8e4nv upcast tests only supported on NVGPU with compute capability 9.0+") - - if src_dtype in ('float8e4nv', 'float8e4b15') and is_hip(): - pytest.skip(f"{src_dtype} upcast tests not supported on ROCm") - - if src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()): - pytest.skip("{src_dtype} upcast tests only supported on AMDGPU MI300") + if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9)) + or (src_dtype in ('float8e4nv', 'float8e4b15') and is_hip()) + or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()))): + # If the dtype should error out in the given device, we assert that and return + with pytest.raises(triton.CompilationError, match="not supported in this architecture"): + launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) + return # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) stuff = { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5972c93d7..569ed16bb 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1,7 +1,8 @@ # flake8: noqa: F821,F841 +import contextlib import itertools import re -from typing import Optional, Union +from typing import Optional import math import textwrap import tempfile @@ -15,30 +16,37 @@ import triton import triton.language as tl -from triton.runtime.jit import TensorWrapper, reinterpret - - -def is_interpreter(): - return os.environ.get('TRITON_INTERPRET', '0') == '1' - - -def is_cuda(): - return not is_interpreter() and \ - triton.runtime.driver.active.get_current_target().backend == "cuda" - - -def is_hip(): - return not is_interpreter() and \ - triton.runtime.driver.active.get_current_target().backend == "hip" - +from triton.language.extra import libdevice + +from triton._internal_testing import ( + integral_dtypes, + int_dtypes, + uint_dtypes, + float_dtypes, + dtypes, + dtypes_with_bfloat16, + is_cuda, + is_interpreter, + is_hip, + get_arch, + torch_float8_dtypes, + torch_dtypes, + numpy_random, + to_triton, + torch_dtype_name, + to_numpy, +) + + +@contextlib.contextmanager +def promotion_numpy_2_0(): + state = np._get_promotion_state() + np._set_promotion_state("weak") + try: + yield + finally: + np._set_promotion_state(state) -int_dtypes = ['int8', 'int16', 'int32', 'int64'] -uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] -float_dtypes = ['float16', 'float32', 'float64'] -dtypes = int_dtypes + uint_dtypes + float_dtypes -dtypes_with_bfloat16 = dtypes + ['bfloat16'] -torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] -torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] # TODO: enable multiple cta cluster testing. # num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] @@ -58,77 +66,6 @@ def _bitwidth(dtype: str) -> int: return int(re.search(r'(\d+)$', dtype).group(1)) -def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): - """ - Override `rs` if you're calling this function twice and don't want the same - result for both calls. - """ - if isinstance(shape, int): - shape = (shape, ) - if rs is None: - rs = RandomState(seed=17) - if dtype_str in int_dtypes + uint_dtypes: - iinfo = np.iinfo(getattr(np, dtype_str)) - low = iinfo.min if low is None else max(low, iinfo.min) - high = iinfo.max if high is None else min(high, iinfo.max) - dtype = getattr(np, dtype_str) - x = rs.randint(low, high, shape, dtype=dtype) - x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out. - return x - elif dtype_str and 'float8' in dtype_str: - x = rs.randint(20, 40, shape, dtype=np.int8) - return x - elif dtype_str in float_dtypes: - return rs.normal(0, 1, shape).astype(dtype_str) - elif dtype_str == 'bfloat16': - return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') - elif dtype_str in ['bool', 'int1', 'bool_']: - return rs.normal(0, 1, shape) > 0.0 - else: - raise RuntimeError(f'Unknown dtype {dtype_str}') - - -def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]: - ''' - Note: We need dst_type because the type of x can be different from dst_type. - For example: x is of type `float32`, dst_type is `bfloat16`. - If dst_type is None, we infer dst_type from x. - ''' - t = x.dtype.name - if t in uint_dtypes: - signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" - x_signed = x.astype(getattr(np, signed_type_name)) - return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) - else: - if dst_type and 'float8' in dst_type: - return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type)) - if t == 'float32' and dst_type == 'bfloat16': - return torch.tensor(x, device=device).bfloat16() - return torch.tensor(x, device=device) - - -def torch_dtype_name(dtype) -> str: - if isinstance(dtype, triton.language.dtype): - return dtype.name - elif isinstance(dtype, torch.dtype): - # 'torch.int64' -> 'int64' - m = re.match(r'^torch\.(\w+)$', str(dtype)) - return m.group(1) - else: - raise TypeError(f'not a triton or torch dtype: {type(dtype)}') - - -def to_numpy(x): - if isinstance(x, TensorWrapper): - return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) - elif isinstance(x, torch.Tensor): - if x.dtype is torch.bfloat16: - return x.cpu().float().numpy() - return x.cpu().numpy() - else: - raise ValueError(f"Not a triton-compatible tensor: {x}") - - def patch_kernel(template, to_replace): if is_interpreter(): local_namespace = {} @@ -148,7 +85,7 @@ def check_cuda_or_hip(device): # CUDA and HIP both use pytorch device 'cuda'. Other backends like Intel # GPU do not. if device not in ['cuda']: - pytest.skip("Only for cuda") + pytest.skip("Only for cuda or HIP") def check_type_supported(dtype, device): @@ -180,11 +117,12 @@ def __str__(self): class WmmaLayout: - def __init__(self, warps_per_cta): + def __init__(self, version, warps_per_cta): + self.version = version self.warps_per_cta = warps_per_cta def __str__(self): - return f"#{GPU_DIALECT}.amd_wmma<{{warpsPerCTA = {self.warps_per_cta}}}>" + return f"#{GPU_DIALECT}.amd_wmma<{{version = {self.version}, warpsPerCTA = {self.warps_per_cta}}}>" class MmaLayout: @@ -329,7 +267,7 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, - y_low=None, y_high=None, test_broadcast=True): + y_low=None, y_high=None, filter_y=None, test_broadcast=True, test_scalar=True): check_type_supported(dtype_x, device) # early return if dtype_x is not supported check_type_supported(dtype_y, device) SIZE = 128 @@ -359,45 +297,92 @@ def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr): z = GENERATE_TEST_HERE tl.store(Z + off, z) + @triton.jit + def kernel_scalar_rhs(Z, X, y: tl.constexpr, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + replacements = {'GENERATE_TEST_HERE': expr} kernel = patch_kernel(kernel, replacements) kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements) kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements) + kernel_scalar_rhs = patch_kernel(kernel_scalar_rhs, replacements) # inputs rs = RandomState(17) x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) + if filter_y: + y[filter_y(y)] = 1 if mode_x == 'nan': x[:] = float('nan') if mode_y == 'nan': y[:] = float('nan') def do_test(x, y, kernel_fn): - # reference result - z_ref = eval(expr if numpy_expr is None else numpy_expr) + x_is_scalar = isinstance(x, (bool, int, float)) + y_is_scalar = isinstance(y, (bool, int, float)) + scalar_test = x_is_scalar or y_is_scalar + + # For scalars, we follow the NumPy 2.0 (and JAX/PyTorch pretty much) casting rules. + if scalar_test: + # We remove any explicit casting + pattern = r'\.astype\(np\.\w+\)' + scalar_expr = expr if numpy_expr is None else re.sub(pattern, '', numpy_expr) + with promotion_numpy_2_0(): + z_ref = eval(scalar_expr) + else: + z_ref = eval(expr if numpy_expr is None else numpy_expr) + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) - if dtype_z is not None: + if not scalar_test and dtype_z is not None: z_ref = z_ref.astype(dtype_z) + # triton result - x_tri = to_triton(x, device=device, dst_type=dtype_x) - y_tri = to_triton(y, device=device, dst_type=dtype_y) + x_tri = x if x_is_scalar else to_triton(x, device=device, dst_type=dtype_x) + y_tri = y if y_is_scalar else to_triton(y, device=device, dst_type=dtype_y) z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) err_msg = f"{expr}, {kernel_fn.__name__}" - np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=1e-3, rtol=0.01) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=3e-3, rtol=0.01) + + def get_scalar(x, dtype, low, high, filter): + # If dtype is int, don't choose a huge number for the scalar + # as it'll overflow easily when converted to the other dtype + if dtype in integral_dtypes: + # Choose in range [-7, 7] ([0, 7] for uints) + low_x = 0 if dtype in uint_dtypes else -7 + if low is not None: + low_x = max(low_x, low) + high_x = 7 + if high is not None: + high_x = min(high_x, high) + scalar = numpy_random((), dtype_str=dtype, rs=rs, low=low_x, high=high_x).item() + if filter and filter(scalar): + # https://xkcd.com/221/ + scalar = 4 + else: + scalar = x.flat[0].item() + return scalar do_test(x, y, kernel) + if mode_y != 'nan' and test_scalar: + if dtype_x in uint_dtypes: + low = 0 if y_low is None else max(y_low, 0) + else: + low = y_low + y_scalar = get_scalar(y, dtype_y, low, y_high, filter_y) + do_test(x, y_scalar, kernel_scalar_rhs) if test_broadcast: do_test(x[:1].reshape(()), y, kernel_broadcast_lhs) do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: - # The result of x % y is ill-conditioned if x % y is much smaller than x. - # pytorch/CUDA has slightly different (probably better) rounding on - # remainders than stock LLVM. We currently don't expect to match it - # bit-for-bit. + # FIXME For large x, we are casting x to a floating point where it does not fit + # For small y, we are computing floor(div(float(x), y)) which may not fit return (dtype_x, dtype_y) in [ ('int32', 'bfloat16'), ('int32', 'float16'), @@ -439,7 +424,7 @@ def test_dtype_codegen(): ]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): - expr = f' x {op} y' + expr = f'x {op} y' if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. numpy_expr = 'np.fmod(x, y)' @@ -463,11 +448,25 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) else: + # skip when bfloat16, as NumPy's ref performs the computation in float32 + # while Triton performs it in bfloat16 + # We also skip mod when it is ill-conditioned + skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y) + or (expr == "x % y" and dtype_x in int_dtypes + uint_dtypes and dtype_y in float_dtypes + and _mod_operation_ill_conditioned(dtype_x, "float32"))) + # can't divide by zero + not_zero = op in ('/', '%') and dtype_x in integral_dtypes and dtype_y in integral_dtypes + # can't represent -int(max) + not_minus_one = op in ('*', '/') and dtype_x in int_dtypes and dtype_y in int_dtypes + if not_zero or not_minus_one: + filter_y = lambda y: not_zero * (y == 0) | not_minus_one * (y == -1) + else: + filter_y = None _test_binary( dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, # fails with values where fmod(x, y) is roughly zero, but happens to # pass with the random values chosen for non-broadcast tests - test_broadcast=(op != "%")) + test_broadcast=(op != "%"), filter_y=filter_y, test_scalar=not skip_scalar_test) @pytest.mark.interpreter @@ -507,7 +506,13 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device): # reference result for //. expr = 'x // y' numpy_expr = '((x - np.fmod(x, y)) / y)' - _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + # can't represent -int(max) + not_minus_one = dtype_x in int_dtypes and dtype_y in int_dtypes + if not_minus_one: + filter_y = lambda y: y == -1 + else: + filter_y = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas) def test_unsigned_name_mangling(device): @@ -572,10 +577,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # - (dtype_x, dtype_y, op) - for op in ['<<', '>>'] - for dtype_x in int_dtypes + uint_dtypes - for dtype_y in int_dtypes + uint_dtypes + (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes ]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): @@ -896,7 +898,7 @@ def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] - mask = 0 + mask = False vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) res = tl.where(mask, vals, 0.) tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) @@ -1055,6 +1057,12 @@ def test_abs(dtype_x, device): def test_abs_fp8(in_dtype, device): if is_hip(): pytest.skip('test_abs_fp8 not supported on HIP.') + elif is_cuda(): + cc = torch.cuda.get_device_capability() + if in_dtype == tl.float8e4b15 and cc >= (9, 0): + pytest.skip("float8e4b15 not supported on CUDA >= 9.0") + if in_dtype == tl.float8e4nv and cc < (8, 9): + pytest.skip("float8e4nv not supported on CUDA < 8.9") @triton.jit def abs_kernel(X, Z, SIZE: tl.constexpr): @@ -1094,6 +1102,9 @@ def kernel(): a = tl.arange(0, 32).reshape(4, 8).permute(1, 0) tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + a = tl.arange(0, 32).reshape(4, 8).trans() + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + a = tl.arange(0, 32).reshape(4, 8).reshape(32) tl.static_assert(a.shape == [tl.constexpr(32)]) @@ -1437,35 +1448,44 @@ def kernel(X): @pytest.mark.interpreter -@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas) - for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] - for axis in [0, 1] - for num_ctas in num_ctas_list]) -def test_tensor_atomic_rmw(shape, axis, num_ctas, device): +@pytest.mark.parametrize("shape, axis, num_ctas, dtype_x_str", + [(shape, axis, num_ctas, dtype_x_str) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list + for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device): shape0, shape1 = shape # triton kernel @triton.jit - def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): off0 = tl.arange(0, SHAPE0) off1 = tl.arange(0, SHAPE1) x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) z = tl.sum(x, axis=AXIS) if AXIS == 1: - tl.atomic_add(Z + off0, z) + old = tl.atomic_add(Z + off0, z) + tl.store(OLD + off0, old) else: - tl.atomic_add(Z + off1, z) + old = tl.atomic_add(Z + off1, z) + tl.store(OLD + off1, old) rs = RandomState(17) - x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) - # reference result - z_ref = np.sum(x, axis=axis, keepdims=False) + x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs) + old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str)) + # reference results + z_ref = z + np.sum(x, axis=axis, keepdims=False) + old_ref = np.copy(z) # triton result x_tri = to_triton(x, device=device) - z_shape = (shape0, ) if axis == 1 else (shape1, ) - z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) - kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) + z_tri = to_triton(z, device=device) + old_tri = to_triton(old, device=device) + kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + np.testing.assert_equal(old_ref, to_numpy(old_tri)) @pytest.mark.interpreter @@ -1510,6 +1530,10 @@ def serialized_add(data, Lock, SEM: tl.constexpr): tl.store(ptrs, tl.load(ptrs) + 1.0) + # insert barrier to set a fence between tl.store and + # tl.atomic_xchg in a block. + tl.debug_barrier() + # release lock tl.atomic_xchg(Lock, 0) @@ -1697,6 +1721,27 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): assert torch.all(output == ref) +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant_default_dtype(num_ctas, device): + """Tests that boolean True is stored as 1""" + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + value = 1 + output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype) + tl.store(output_ptr + offsets, output, mask=mask) + + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + + assert torch.all(output == ref) + + def test_load_store_same_ptr(device): @triton.jit() @@ -2029,7 +2074,7 @@ def kernel(X, Z, BLOCK: tl.constexpr): 'argmax-tie-break-left': np.argmax, }[op] if 'tie-break-left' in op: - x[3:10] = numpy_op(x) + x[3:10] = x[numpy_op(x)] x_tri = to_triton(x, device=device) # numpy result z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str @@ -2091,19 +2136,20 @@ def kernel(X, Z, BLOCK: tl.constexpr): for op in ['min', 'max', 'sum', 'argmin', 'argmax'] for axis in [0, 1, 2]] + [(op, 'float32', (32, 2, 16), None, True) for op in ['min', 'max', 'sum']] +reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + - negative_config + keep_dims_2d_configs + keep_dims_3d_configs) + negative_config + keep_dims_2d_configs + keep_dims_3d_configs + reduce_bool) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, - AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr): + AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr, USE_I1: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) range_k = tl.arange(0, BLOCK_K) @@ -2112,8 +2158,9 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const range_k[None, None, :]) else: x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + if USE_I1: + x = tl.cast(x, tl.int1) z = GENERATE_TEST_HERE - z_ptr = Z if KEEP_DIMS and AXIS is None: if IS_3D: @@ -2142,9 +2189,14 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const # limit the range of integers so that the sum does not overflow x = numpy_random(shape, dtype_str=dtype_str, rs=rs) x_tri = to_triton(x, device=device) - numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax}[op] + numpy_op = { + 'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax, 'xor_sum': + np.bitwise_xor.reduce + }[op] z_dtype_str = get_reduced_dtype(dtype_str, op) z_tri_dtype_str = z_dtype_str + if z_dtype_str == 'bool': + z_dtype_str = 'int8' # numpy result # Silence numpy error on axis out of bounds, to give triton a chance to fail @@ -2163,14 +2215,15 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) BLOCK_K = 1 if len(shape) == 2 else shape[2] IS_3D = bool(len(shape) == 3) + USE_I1 = dtype_str == 'bool' if axis is not None and axis >= len(shape): with pytest.raises(triton.TritonError): kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, - KEEP_DIMS=keep_dims, num_ctas=num_ctas) + KEEP_DIMS=keep_dims, USE_I1=USE_I1, num_ctas=num_ctas) return else: kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, - KEEP_DIMS=keep_dims, num_ctas=num_ctas) + KEEP_DIMS=keep_dims, USE_I1=USE_I1, num_ctas=num_ctas) z_tri = to_numpy(z_tri) @@ -2464,7 +2517,20 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl. @pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) -def test_scan_layouts(M, N, src_layout, axis, device): +@pytest.mark.parametrize("add_overflow_check", [False, True]) +def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device): + + overflow_check = """ + %17 = arith.extsi %arg2 : i32 to i64 + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.addi %17, %18 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %20 = arith.cmpi slt, %19, %i32.max : i64 + %21 = arith.cmpi sge, %19, %i32.min : i64 + %22 = arith.andi %20, %21 : i1 + tt.assert %22, "overflow detected" : i1 + """ ir = f""" #blocked = {src_layout} @@ -2484,7 +2550,7 @@ def test_scan_layouts(M, N, src_layout, axis, device): %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #blocked> %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ ^bb0(%arg2: i32, %arg3: i32): - %16 = arith.addi %arg2, %arg3 : i32 + %16 = arith.addi %arg2, %arg3 : i32{overflow_check if add_overflow_check else ""} tt.scan.return %16 : i32 }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> @@ -2536,9 +2602,9 @@ def test_scan_layouts(M, N, src_layout, axis, device): MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=True), MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=True), MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=True), - WmmaLayout(warps_per_cta=[2, 2]), - WmmaLayout(warps_per_cta=[4, 1]), - WmmaLayout(warps_per_cta=[1, 4]), + WmmaLayout(version=1, warps_per_cta=[2, 2]), + WmmaLayout(version=1, warps_per_cta=[4, 1]), + WmmaLayout(version=1, warps_per_cta=[1, 4]), ] @@ -2546,9 +2612,10 @@ def test_scan_layouts(M, N, src_layout, axis, device): @pytest.mark.parametrize("src_layout", filter_layouts(layouts)) @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) -@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) +@pytest.mark.parametrize("dtype_str,add_overflow_check", [("int32", False), ("int32", True), ("float32", False), + ("float16", False)]) @pytest.mark.parametrize("reduce_op", ["sum", "max"]) -def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_overflow_check, reduce_op, device): if isinstance(src_layout, (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") @@ -2556,13 +2623,22 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce pytest.skip("Skipping test because it runs out of shared memory") if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: pytest.skip("Skipping sum reduction on float16 due to accuracy issues") - if epilogue_kind == 'expand_reduce2d' and isinstance(src_layout, MmaLayout): - pytest.skip( - "Currently MmaLayout combined with slice encoding and reduce op trigger device illegal memory access") if isinstance(src_layout, MmaLayout) and src_layout.version == 3: src_layout[2] = 16 if dtype_str == "float16" else 8 + overflow_check = """ + %18 = arith.extsi %arg3 : i32 to i64 + %19 = arith.extsi %arg4 : i32 to i64 + %20 = arith.addi %18, %19 : i64 + %i32.min = arith.constant -2147483648: i64 + %i32.max = arith.constant 2147483647: i64 + %21 = arith.cmpi slt, %20, %i32.max : i64 + %22 = arith.cmpi sge, %20, %i32.min : i64 + %23 = arith.andi %21, %22 : i1 + tt.assert %23, "overflow detected" : i1 + """ + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] arith_op = { "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # @@ -2576,7 +2652,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce num_warps = src_layout.warps_per_cta[0] * src_layout.warps_per_cta[1] if num_warps == 8: blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, 2], [0, 1], [1, 1], [1, 1], [0, 1]) - one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [4], [0], [1], [1], [0]) + one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [num_warps], [0], [1], [1], [0]) expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1" other_axis = 1 - axis @@ -2595,7 +2671,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce f""" %14 = "tt.reduce"(%13) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty} + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} tt.store %arg2, %14 : !tt.ptr<{ty}> @@ -2607,7 +2683,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> %15 = "tt.reduce"(%14) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty} + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) %16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> @@ -2640,7 +2716,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> %13 = "tt.reduce"(%12) ({{ ^bb0(%arg3: {ty}, %arg4: {ty}): - %17 = {arith_op} %arg3, %arg4 : {ty} + %17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""} tt.reduce.return %17 : {ty} }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> """ + epilogue @@ -2893,8 +2969,11 @@ def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_permute(dtype_str, shape, perm, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested - if is_hip() and shape == (128, 128) and dtype_str == 'float32': - pytest.skip("TODO Out of LDS for float32 with shape 128x128") + if dtype_str == "float8e4b15" and (is_hip() or (is_cuda() and torch.cuda.get_device_capability() >= (9, 0))): + pytest.skip("float8e4b15 not supported on ROCm or CUDA >= 9.0") + if is_hip(): + if shape == (128, 128) and dtype_str == 'float32': + pytest.skip("TODO Out of LDS for float32 with shape 128x128") # triton kernel @triton.jit @@ -3027,7 +3106,7 @@ def convert_fp8_to_fp32(x, device, dtype_str): [(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack) for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] - for input_precision in ["ieee" if is_hip() else "tf32"] + for input_precision in ["tf32" if is_cuda() else "ieee"] for col_a in [True, False] for col_b in [True, False] for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', @@ -3036,6 +3115,8 @@ def convert_fp8_to_fp32(x, device, dtype_str): for col_a in [True, False] for col_b in [True, False]] + [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1)] + + ([(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1), + (32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1)] if "gfx9" in get_arch() else []) + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1) for float8_type in ["float8e5", "float8e4nv"]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -3167,21 +3248,6 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) - if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"): - if not is_cuda(): - pass - else: - ptx = pgm.asm["ptx"] - start = ptx.find("shfl.sync.bfly") - end = ptx.find("cvt.rn.f16.f32") - red_code = ptx[start:end] - assert len(red_code) > 0 - - # skip this check on hopper because there are some functions whose name contain "shared" in ptx. - # TODO: we should eliminate these unused functions in ptx code. - if not (capability[0] >= 9): - assert "shared" not in red_code - assert "bar.sync" not in red_code # torch result if in_dtype == 'int8': z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) @@ -3204,6 +3270,14 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid z_ref = num / denom if epilogue == 'chain-dot': if 'float8' in in_dtype: + # Reduce z_ref's precision to fp8 to match the kernel behavior + if in_dtype == 'float8e4nv': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fn) + elif in_dtype == 'float8e5': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2) + else: + assert "Unsupported float8 dtype" + z_ref = to_numpy(z_fp8.to(torch.float32)) w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) z_ref = np.matmul(z_ref, w) # compare @@ -3252,23 +3326,225 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx +@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", + [(M, N, K, col_a, col_b, type_a, type_b, 4) + for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) + for col_a, col_b in itertools.product([True, False], repeat=2) + for type_a in ["e2m1", "e4m3", "e5m2"] + for type_b in ["e4m3", "e5m2"]]) +def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): + if not is_cuda(): + pytest.skip("scaled_dot only supported on CUDA") + else: + cc = torch.cuda.get_device_capability() + if cc < (8, 9): + pytest.skip("float8e4nv not supported on CUDA < 8.9") + + @triton.jit + def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, + type_b: tl.constexpr): + tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8") + IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2" + DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2 + PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR + PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K + a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, + PACKED_BLOCK_K_A)[None, :] * stride_a1 + b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, + BLOCK_N)[None, :] * stride_b1 + + SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] + + a = tl.load(a_ptr) + b = tl.load(b_ptr) + a_scale = tl.load(scale_a_ptr) + c = tl.dot_scaled(a, a_scale, type_a, b, None, type_b) + out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] + tl.store(out_ptr, c.to(tl.bfloat16)) + + @triton.jit + def mxfp_to_bf16_kernel( + x_ptr, + scale_ptr, + mxfp_ptr, + N, + e_bits: tl.constexpr, + m_bits: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + # x.shape == (N, 32) for fp8 or (N, 16) for fp4 + # scale.shape == (N,) + # out.shape == (N, 32) + is_fp8: tl.constexpr = e_bits + m_bits == 7 + # fp8: BLOCK_SIZE -> BLOCK_SIZE // 32, 32 + # fp4: BLOCK_SIZE // 2 -> BLOCK_SIZE // 32 , 16 + PARALLEL_DIM: tl.constexpr = BLOCK_SIZE // 32 + LAST_DIM: tl.constexpr = 32 if is_fp8 else 16 + LOAD_SIZE: tl.constexpr = LAST_DIM * PARALLEL_DIM + + offsets = (tl.program_id(0) * LOAD_SIZE + tl.arange(0, PARALLEL_DIM)[:, None] * LAST_DIM + + tl.arange(0, LAST_DIM)[None, :]) + x = tl.load(x_ptr + offsets, mask=offsets < N * LAST_DIM) + + offsets = tl.program_id(0) * PARALLEL_DIM + tl.arange(0, PARALLEL_DIM)[:, None] + scale = tl.load(scale_ptr + offsets, mask=offsets < N) + tl.static_assert(scale.dtype == tl.uint8) + tl.static_assert(x.dtype == tl.uint8) + + scale_bf16 = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) + if is_fp8: + if e_bits == 5 and m_bits == 2: + x_f8 = x.to(tl.float8e5, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask: tl.constexpr = ((1 << e_bits) - 1) << m_bits + non_finite_mask_bf16: tl.constexpr = ((1 << 8) - 1) << 7 + x_bf16 = tl.where( + x & non_finite_mask == non_finite_mask, + (x_bf16.to(tl.uint16, bitcast=True) | non_finite_mask_bf16).to(tl.bfloat16, bitcast=True), + x_bf16, + ) + else: + tl.static_assert(e_bits == 4 and m_bits == 3) + x_f8 = x.to(tl.float8e4nv, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + else: + # e2m1 + em0 = x & 0x70 + em1 = x & 0x7 + x0 = (em0.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << 8) + x1 = (em1.to(tl.uint16) << (2 + 4)) | ((x & 0x8).to(tl.uint16) << (8 + 4)) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x60) != 0, x0 + ((127 - 1) << 7), x0) + x1 = tl.where((em1 & 0x6) != 0, x1 + ((127 - 1) << 7), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 + x0 = tl.where(em0 == 0x10, 16128 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x1, 16128 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True) + # Multiplication preserves infs and NaNs in x_bf16 + mxfp = x_bf16 * scale_bf16 + # If scale is NaN, we encode it as an bf16 inf, so we need to correct for that + mxfp = tl.where(scale == 0xFF, float("nan"), mxfp) + + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) + + def dot_scale_ref(x, scale, y, type_x, type_y): + e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] + type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] + + comp_dtype = torch.bfloat16 + + x = x.contiguous() + x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) + + N = x_upcast.numel() + BLOCK_SIZE = 512 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps) + assert x_upcast.isfinite().all() + + y_upcast = y.view(type_fp8_y).to(comp_dtype) + + class AccumulateInFp32: + + def __enter__(self): + self.prev_value = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value + + with AccumulateInFp32(): + return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype)) + + torch.manual_seed(0) + + def create_uint8(shape, col_major=False, max_val=255): + if col_major: + shape = shape[:-2] + (shape[-1], shape[-2]) + ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device) + if col_major: + ret = ret.mT + return ret + + DIV_FACTOR = 2 if type_a == "e2m1" else 1 + x = create_uint8((M, K // DIV_FACTOR), col_major=col_a) + y = create_uint8((K, N), col_major=col_b) + + # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright) + # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow + m_bytes = int(type_a[1]) + bias_type_a = 1 << (m_bytes - 1) - 1 + max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a + scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64) + + def make_finite(x, dtype): + # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and + # Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme) + if dtype not in ("e5m2", "e4m3"): + return x + mask = 0x7C if dtype == "e5m2" else 0x7F + finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask + x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x) + x.copy_(x_finite) + return x + + x = make_finite(x, type_a) + y = make_finite(y, type_b) + + z = x.new_empty((M, N), dtype=torch.bfloat16) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, + num_warps=num_warps) + + z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) + + # generous rtol as we are sampling the whole range of floats + torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) + + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + if (max(M, N) * K) // (num_warps * 32) >= 4: + assert 'ld.global.v4' in ptx + if M * N // (num_warps * 32) >= 4: + assert 'st.global.v4' in ptx + assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) + + @pytest.mark.interpreter -@pytest.mark.parametrize("B", [1, 2, 4, 8]) -@pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16]) -@pytest.mark.parametrize("M, N, K", [(64, 64, 64), (32, 32, 32)]) -@pytest.mark.parametrize("in_dtype_str, out_dtype_str", [('int8', 'int8'), ('float16', 'float16'), - ('float16', 'float32'), ('float32', 'float32')]) -def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device): +@pytest.mark.parametrize("B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str", + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 4, 8] + for num_warps in [1, 2, 4, 8, 16] + for BLOCK_M, BLOCK_N in [(32, 32)] + for M, N, K in [(64, 64, 64), (32, 32, 32)] + for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), + ('float16', 'float32'), ('float32', 'float32')]] + + # Large block sizes + [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')]) +def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device): if is_hip(): # hip does not support tf32 precision, so use ieee for all tests input_precision = "ieee" - if "gfx11" in triton.runtime.driver.active.get_current_target().arch: + arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in arch or "gfx12" in arch: if in_dtype_str == "float32": pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") else: - input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" + input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" + + if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": + if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties( + torch.cuda.current_device())["max_shared_mem"] < 131072: + pytest.skip( + "Skipping tests with B = 8, M = 64, in_type = float32, out_type = float32 due to insufficient shared memory (less than 128 KB per SM) on this GPU." + ) @triton.jit def kernel( @@ -3328,7 +3604,6 @@ def kernel( out_tri = to_triton(out, device=device) BLOCK_B = B - BLOCK_M, BLOCK_N = 32, 32 BLOCK_K = K grid = ( @@ -3364,39 +3639,6 @@ def kernel( np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) -@pytest.mark.interpreter -def test_max_num_imprecise_acc(device): - - if not hasattr(torch, 'float8_e5m2'): - pytest.skip(f"torch {torch.__version__} does not support float8_e5m2") - - if is_cuda(): - capability = torch.cuda.get_device_capability() - if capability != (9, 0): - return - - @triton.jit - def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - MAX_NUM_IMPRECISE_ACC: tl.constexpr): - off_m = tl.arange(0, BLOCK_M) - off_n = tl.arange(0, BLOCK_N) - off_k = tl.arange(0, BLOCK_K) - x = tl.load(X + off_m[:, None] * BLOCK_K + off_k[None, :]) - y = tl.load(Y + off_k[:, None] * BLOCK_N + off_n[None, :]) - z = tl.load(Z + off_m[:, None] * BLOCK_N + off_n[None, :]) - z = tl.dot(x, y, acc=z, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC) - tl.store(Z + off_m[:, None] * BLOCK_N + off_n[None, :], z) - - M, N, K, num_warps, MAX_NUM_IMPRECISE_ACC = 128, 128, 128, 4, 64 - x = torch.zeros((M, K), dtype=torch.float8_e5m2, device=device) - y = torch.zeros((K, N), dtype=torch.float8_e5m2, device=device) - z = torch.zeros((M, N), dtype=torch.float32, device=device) - h = kernel[(1, 1)](x, y, z, M, N, K, MAX_NUM_IMPRECISE_ACC, num_warps=num_warps) - if not is_cuda(): - return - assert h.asm["ptx"].count("add.f32") == (M * N) // (32 * num_warps) * (K / MAX_NUM_IMPRECISE_ACC) - - @pytest.mark.parametrize('in_dtype', ['float32']) def test_dot_mulbroadcasted(in_dtype, device): if is_cuda(): @@ -3562,6 +3804,20 @@ def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.co if expect_fail: with pytest.raises(triton.CompilationError) as exc_info: patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + if constexpr: + error = "Cannot store to a constant pointer" + else: + if mode == "call": + error = "Inconsistent return types" + elif mode == "if": + error = "Mismatched type for final_out" + elif mode == "ternary": + error = "Ternary expression with dynamic condition has inconsistent type" + else: + assert mode == "direct" and choose_const + error = "Cannot store to a constant pointer" + error_msg = exc_info.value.error_message or str(exc_info.value.__cause__) + assert error in error_msg, "Wrong error message!" else: patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) assert torch.all(input == output) @@ -3681,7 +3937,7 @@ def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.co torch.testing.assert_close(output, reference_out) -# Testing masked loads with an intermate copy to shared memory run. +# Testing masked loads with a copy to shared memory. # FIXME: Shape too small for ldmatrix when num_ctas=4 @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @@ -3727,7 +3983,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_ @pytest.mark.interpreter -@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) +@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"]) def test_load_cache_modifier(cache, device): src = torch.empty(128, device=device) dst = torch.empty(128, device=device) @@ -3739,19 +3995,36 @@ def _kernel(dst, src, CACHE: tl.constexpr): tl.store(dst + offsets, x) pgm = _kernel[(1, )](dst, src, CACHE=cache) - if not is_cuda(): - return - ptx = pgm.asm['ptx'] - if cache == '': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cg' not in ptx - if cache == '.cg': - assert 'ld.global.cg' in ptx - assert 'ld.global.ca' not in ptx - if cache == '.ca': - assert 'ld.global.ca' in ptx - assert 'ld.global.cg' not in ptx + if is_hip(): + target_arch = get_arch() + # TODO: support testing for remaining architectures + if 'gfx94' not in target_arch: + return + amdgcn = pgm.asm['amdgcn'] + cg_cache_modifier_str = 'nt' + cv_cache_modifier_str = 'sc0 sc1' + buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line] + global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] + flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line] + if cache == '' or cache == '.ca': + assert cg_cache_modifier_str not in (global_load_line[0] if global_load_line else buffer_load_line[0]) + if cache == '.cg': + assert cg_cache_modifier_str in global_load_line[0] + if cache == '.cv': + assert cv_cache_modifier_str in flat_load_line[0] + + if is_cuda(): + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx @pytest.mark.interpreter @@ -3759,7 +4032,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_vectorization(N, num_ctas, device): block_size = 1024 * num_ctas - src = torch.empty(block_size, device=device) + src = torch.randn(block_size, device=device) dst = torch.empty(block_size, device=device) @triton.jit @@ -3778,7 +4051,7 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): assert "ld.global.v4.b32" in ptx else: assert "ld.global.b32" in ptx - # np.testing.assert_allclose(dst, src[:N]) + torch.testing.assert_close(dst[:N], src[:N], atol=1e-6, rtol=0) @pytest.mark.interpreter @@ -3808,6 +4081,27 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): assert "ld.global.v4.b32" not in ptx +@pytest.mark.interpreter +def test_assume(device): + + @triton.jit + def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): + current_size = N - tl.program_id(0) * BLOCK_N + tl.assume(current_size >= BLOCK_N) + if current_size >= 128: + tl.store(out_ptr + tl.program_id(0), current_size) + else: + tl.store(out_ptr + tl.program_id(0), current_size + 101024) + + output = torch.zeros(1024 // 128, device=device) + pgm = _kernel[(1024 // 128, )](output, N=1024, BLOCK_N=128) + + if is_interpreter(): + return + + assert 'llvm.assume' in pgm.asm['llir'] + + # --------------- # test store # --------------- @@ -3825,35 +4119,83 @@ def _kernel(dst, src, CACHE: tl.constexpr): x = tl.load(src + offsets) tl.store(dst + offsets, x, cache_modifier=CACHE) + pgm = _kernel[(1, )](dst, src, CACHE=cache) + + if is_hip(): + target_arch = get_arch() + # TODO: support testing for remaining architectures + if 'gfx94' not in target_arch: + return + amdgcn = pgm.asm['amdgcn'] + cs_cache_modifier_str = 'nt' + wt_cache_modifier_str = 'sc0 sc1' + global_store_line = [line for line in amdgcn.splitlines() if "global_store" in line] + if not global_store_line: + return + if cache == '' or cache == '.cg': + assert cs_cache_modifier_str not in global_store_line[0] + assert wt_cache_modifier_str not in global_store_line[0] + if cache == '.cs': + assert cs_cache_modifier_str in global_store_line[0] + assert wt_cache_modifier_str not in global_store_line[0] + if cache == '.wt': + assert cs_cache_modifier_str not in global_store_line[0] + assert wt_cache_modifier_str in global_store_line[0] + + if is_cuda(): + ptx = pgm.asm['ptx'] + if cache == '': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.wb': + assert 'st.global.wb' in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cg': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cs': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' in ptx + assert 'st.global.wt' not in ptx + if cache == '.wt': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("eviction_policy", ["", "evict_last", "evict_first"]) +def test_store_eviction_policy(eviction_policy, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, POLICY: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, eviction_policy=POLICY) + if not is_cuda(): return - pgm = _kernel[(1, )](dst, src, CACHE=cache) + pgm = _kernel[(1, )](dst, src, POLICY=eviction_policy) ptx = pgm.asm['ptx'] - if cache == '': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' not in ptx - if cache == '.wb': - assert 'st.global.wb' in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' not in ptx - if cache == '.cg': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' not in ptx - if cache == '.cs': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' in ptx - assert 'st.global.wt' not in ptx - if cache == '.wt': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' in ptx + if eviction_policy == '': + assert 'evict_last' not in ptx + assert 'evict_first' not in ptx + if eviction_policy == 'evict_last': + assert 'evict_last' in ptx + assert 'evict_first' not in ptx + if eviction_policy == 'evict_first': + assert 'evict_last' not in ptx + assert 'evict_first' in ptx # --------------- @@ -4714,6 +5056,52 @@ def nested_while(data, countPtr): assert data[0] == 40 +def test_constexpr_if_return(device): + # Reproducer for #4883, return statement in an if with a constexpr causes + # errors when combined with non-trivial control flow graphs + + @triton.jit + def kernel(Semaphore, Out, total: tl.constexpr): + if total == 1: + tl.store(Out, tl.program_id(0)) + return + + prev = tl.atomic_add(Semaphore, 1) + if prev + 1 != total: + return + + tl.store(Out, tl.program_id(0) + prev) + + sem = torch.zeros((), device=device, dtype=torch.int32) + out = torch.empty((), device=device, dtype=torch.int32) + kernel[(1, )](sem, out, 1) + assert out.item() == 0 + + sem = torch.zeros((), device=device, dtype=torch.int32) + out = torch.full((), fill_value=-1, device=device, dtype=torch.int32) + kernel[(4, )](sem, out, 4) + assert out.item() >= 0 + + +@triton.jit +def return_poison(x): + a = False + if a: + return x + + +def test_poison_return(device): + + @triton.jit + def kernel(Out): + tl.store(Out, return_poison(0)) + + a = torch.empty((), device=device, dtype=torch.int32) + h = kernel[(1, )](a) + assert "ub.poison" in h.asm["ttir"], h.asm["ttir"] + assert "poison" in h.asm["llir"], h.asm["llir"] + + # ----------------------- # test extra # ----------------------- @@ -4790,6 +5178,7 @@ def kernel(Out): intermediate_layouts = [ None, + SharedLayout(1, 1, 1, [0, 1], [1, 1], [1, 1], [0, 1]), SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), @@ -4819,11 +5208,6 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape): @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): - if (M == 1 or N == 1) and interm_layout: - # TODO(jlebar): These OOB accesses don't even hit an assert in the - # compiler, and some of them return the wrong result instead of - # crashing! - pytest.skip("Out of bound access when maxPhase > 1") if str(src_layout) == str(dst_layout): pytest.skip() if is_hip(): @@ -4852,10 +5236,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> """ if interm_layout is None else f""" - %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm> - %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm> -> tensor<{M}x{N}xi32, #src> - %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm> - %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm> -> tensor<{M}x{N}xf16, #src> + %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> + %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> + %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xf16, #src> %12 = triton_gpu.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> %13 = triton_gpu.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> @@ -4914,23 +5298,30 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): MmaLayout((2, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), MmaLayout((2, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), ], - # Mma -> mma support is TODO on Hopper (and Volta) - # [ - # MmaLayout((3, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), - # MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), - # ], - # [ - # MmaLayout((3, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), - # MmaLayout((3, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), - # ], - # [ - # MmaLayout((3, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), - # MmaLayout((3, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), - # ], - # [ - # MmaLayout((3, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), - # MmaLayout((3, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), - # ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + ], + [ + MmaLayout((3, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + ], + [ + MmaLayout((3, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 64, 32]), + MmaLayout((3, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 32, 32]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]), + ], + [ + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]), + MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), + ], ] @@ -4942,12 +5333,22 @@ def test_convertmma2mma(M, N, mma_pair, dtype, device): pytest.skip("test_mma2mma is not supported in HIP") src_layout, _ = mma_pair + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc[0] < 9 and src_layout.version[0] >= 3: + pytest.skip("Skip testing MMAv3 on devices with CC < 9") + num_warps = np.cumprod(src_layout.warps_per_cta)[-1] + # TODO(Keren): Remove the intermediate layout once we have resolved the redundantDataMask issue for WGMMA + warps_per_cta = src_layout.warps_per_cta + interm = BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [warps_per_cta[0], warps_per_cta[1]], [0, 1], [1, 1], + [1, 1], [0, 1]) def do_test(src_layout, dst_layout): layouts = f""" #src = {src_layout} #dst = {dst_layout} + #interm = {interm} """ conversion = f""" @@ -4970,10 +5371,12 @@ def do_test(src_layout, dst_layout): %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> - %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #interm> """ + conversion + f""" - %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> - tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + %15 = triton_gpu.convert_layout %12 : tensor<{M}x{N}xi32, #dst> -> tensor<{M}x{N}xi32, #interm> + %16 = triton_gpu.convert_layout %13 : tensor<{M}x{N}xf16, #dst> -> tensor<{M}x{N}xf16, #interm> + %17 = tt.addptr %3, %15 : tensor<{M}x{N}x!tt.ptr, #interm>, tensor<{M}x{N}xi32, #interm> + tt.store %17, %16 : tensor<{M}x{N}x!tt.ptr, #interm> tt.return }} }} @@ -5106,27 +5509,33 @@ def matmul_kernel( # @pytest.mark.interpreter +@pytest.mark.parametrize("M, N, K", [(128, 256, 256)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)]) @pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15']) @pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) -def test_fp8_dot_acc(in_type_str, low_precision_acc, device): - if is_hip(): - pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') +def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device): + num_stages = 3 if is_cuda(): cc = torch.cuda.get_device_capability() if cc[0] >= 9 and in_type_str == "float8e4b15": pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") + elif is_hip(): + num_stages = 2 + if in_type_str != 'float8e5': + pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') + check_type_supported(in_type_str, device) - M, N, K = 128, 256, 256 - BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128 A = numpy_random((M, K), dtype_str=in_type_str) B = numpy_random((K, N), dtype_str=in_type_str) C = torch.empty((M, N), dtype=torch.float32, device=device) num_warps = 8 a = to_triton(A, device=device, dst_type=in_type_str) b = to_triton(B, device=device, dst_type=in_type_str) - grid = (triton.cdiv(M, BLOCK_M), 1) - matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None + h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), + C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, + num_pipeline_stages=num_stages) torch_a = torch.from_numpy(A).to(device=device) th_a = f8_to_f16(torch_a, in_type_str) torch_b = torch.from_numpy(B).to(device=device) @@ -5134,10 +5543,10 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device): ref_out = torch.matmul(th_a, th_b).to(torch.float32) if in_type_str == 'float8e4nv': torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) - elif low_precision_acc > 32: - torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) else: - torch.testing.assert_close(ref_out, C) + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + if is_cuda() and low_precision_acc > 0 and torch.cuda.get_device_capability()[0] >= 9: + assert h.asm["ptx"].count("add.f32") == (BLOCK_M * BLOCK_N) // (32 * num_warps) * (BLOCK_K // low_precision_acc) # ----------------------- @@ -5146,7 +5555,8 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device): @pytest.mark.parametrize("enable_fp_fusion", [False, True]) -def test_enable_fp_fusion(enable_fp_fusion, device): +@pytest.mark.parametrize("default_override", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion, default_override, device): if is_hip(): pytest.skip( 'test_enable_fp_fusion for HIP currently broken in https://github.com/triton-lang/triton. Use https://github.com/ROCmSoftwarePlatform/triton' @@ -5159,7 +5569,11 @@ def mul_add(data): tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) data = torch.randn((128, ), device=device, dtype=torch.float32) - h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) + if default_override: + os.environ["TRITON_DEFAULT_FP_FUSION"] = "1" if enable_fp_fusion else "0" + h = mul_add[(1, )](data) + else: + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) if not is_cuda(): return @@ -5321,12 +5735,12 @@ def test_tl_range(device): torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) else: torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) - if device in ['cuda']: - capability = torch.cuda.get_device_capability() - if capability[0] >= 8: - ptx = pgm.asm['ptx'] - # check that the loop got pipelined with the right number of stages. - assert 'cp.async.wait_group 0x6' in ptx + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group 0x6' in ptx @triton.jit(noinline=True) @@ -5341,7 +5755,7 @@ def maxnreg_noinline2(X): def test_maxnreg(device): assert not is_interpreter(), "this test won't work with the interpreter" - if is_hip(): + if not is_cuda(): pytest.skip('maxnreg only works on CUDA') # triton kernel @@ -5396,3 +5810,158 @@ def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) acc += temp assert (acc == out).all() + + +@pytest.mark.interpreter +def test_num_programs(device): + # Assuming that the kernel is launched with a grid of (11, 21, 31) + grid = (11, 21, 31) + input = torch.empty((3, ), dtype=torch.int32, device=device) + + @triton.jit + def kernel(input): + num_programs_0 = tl.num_programs(0) + num_programs_1 = tl.num_programs(1) + num_programs_2 = tl.num_programs(2) + tl.store(input, num_programs_0) + tl.store(input + 1, num_programs_1) + tl.store(input + 2, num_programs_2) + + kernel[grid](input) + assert torch.all(input == torch.tensor(grid, device=device)) + + +# ----------------------- +# test extern functions +# ----------------------- + + +@pytest.mark.parametrize("dtype_str", ['float32', 'float64']) +def test_math_extern(dtype_str, device): + if is_interpreter(): + pytest.skip('math_extern does not work in the interpreter mode') + + @triton.jit + def kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = libdevice.tanh(x) + tl.store(y_ptr + offsets, y, mask=mask) + + shape = (128, ) + rs = RandomState(17) + + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + y_ref = np.tanh(x) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str=dtype_str, rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, shape[0], BLOCK_SIZE=shape[0]) + # compare + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) + + +# ----------------------- +# test loop unrolling +# ----------------------- + + +def test_unroll_attr(device): + + @triton.jit + def _kernel(dst, unroll_factor: tl.constexpr): + pid = tl.program_id(axis=0) + for i in tl.range(0, 10, loop_unroll_factor=unroll_factor): + tl.atomic_add(dst + pid, i + pid) + + def check_loop_unroll_count(ir, opStr, loop_unroll_factor): + for line in ir.splitlines(): + if opStr in line: + loop_unroll_factor = loop_unroll_factor - 1 + # Sometimes we get a remainder loop + assert loop_unroll_factor <= 0 + + # Try for all different loop unroll factors: + for unroll_factor in [1, 2, 4, 5, 8]: + h = _kernel[(1, )](torch.empty(1, device=device), unroll_factor) + check_loop_unroll_count(h.asm["ttir"], 'tt.atomic_rmw', unroll_factor) + + +@triton.jit +def sanitize_add(a, b): + a64 = a.to(tl.int64) + b64 = b.to(tl.int64) + r64 = a64 + b64 + tl.device_assert((r64 >= -2**31) & (r64 <= 2**31 - 1)) + return a + b + + +def test_side_effectful_reduction(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.reduce(vals, 0, sanitize_add) + tl.store(Z, z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros((), device="cuda", dtype=torch.int32) + sanitize_sum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.sum().to(torch.int32)) + + +@pytest.mark.parametrize("reduce_dim", [0, 1]) +def test_side_effectful_reduction_2d(device, reduce_dim): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr, + NON_REDUCE_DIM: tl.constexpr): + offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :] + vals = tl.load(X + offsets) + z = tl.reduce(vals, reduce_dim, sanitize_add) + tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z) + + BLOCK_0 = 16 + BLOCK_1 = 32 + NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32) + Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32) + sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim, + NON_REDUCE_DIM=NON_REDUCE_DIM) + torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) + + +def test_side_effectful_scan(device): + if device != "cuda": + pytest.skip() + + @triton.jit(debug=True) + def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.associative_scan(vals, 0, sanitize_add) + tl.store(Z + tl.arange(0, BLOCK), z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros_like(X) + sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) diff --git a/python/test/unit/language/test_decorator.py b/python/test/unit/language/test_decorator.py index 66371ba60..fbbfb7144 100644 --- a/python/test/unit/language/test_decorator.py +++ b/python/test/unit/language/test_decorator.py @@ -33,7 +33,9 @@ def test_triton_heuristic(device): src = torch.empty(N, device=device) dst = torch.zeros(N, device=device) - @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1) + do_bench = lambda kernel, quantiles: triton.testing.do_bench(kernel, quantiles=quantiles, warmup=1, rep=1) + + @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], do_bench=do_bench) @triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs @triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args @triton.jit diff --git a/python/test/unit/language/test_libdevice.py b/python/test/unit/language/test_libdevice.py new file mode 100644 index 000000000..da0d7d49c --- /dev/null +++ b/python/test/unit/language/test_libdevice.py @@ -0,0 +1,23 @@ +import torch + +import triton +import triton.language as tl + +from triton.language.extra.libdevice import fast_dividef as my_fast_dividef + + +def test_libdevice_rename(device): + # mark the import as used by this test + _ = my_fast_dividef + + @triton.jit + def triton_copy(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + data = tl.load(in_ptr + offsets) + tl.store(out_ptr + offsets, data) + + BLOCK_SIZE = 256 + inp = torch.randn(BLOCK_SIZE, device=device) + out = torch.empty_like(inp) + + triton_copy[(1, )](inp, out, BLOCK_SIZE) diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 6421c7309..a53976deb 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -67,6 +67,14 @@ def kernel_dot_combine(x): tl.device_print("", d) +# Call another jit function (cdiv) not in this file +@triton.jit +def kernel_cdiv(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + d = tl.cdiv(c, 4) + tl.device_print("", d) + + def get_disassembler_command_and_debug_line_format(): """Gets backend specific disassembler information. @@ -118,19 +126,26 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True): should_contain: whether the file name and line number should be in the file_lines """ for file, line in file_lines: - if lineno == -1: - if file_name in file: - return True + if lineno == -1 and file_name in file: + return True if file_name in file and str(lineno) in line: return should_contain return not should_contain -func_types = ["single", "call", "call_noinline", "autotune", "dot_combine"] +func_types = ["single", "call", "call_noinline", "autotune", "dot_combine", "cdiv"] + + +def is_interpreter(): + import os + return os.environ.get('TRITON_INTERPRET', '0') == '1' @pytest.mark.parametrize("func", func_types) def test_line_info(func: str): + if is_interpreter(): + pytest.skip("interpreter does not support warmup compilation") + try: obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() except BaseException: @@ -148,6 +163,8 @@ def test_line_info(func: str): kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1,))[0] elif func == "dot_combine": kernel_info = kernel_dot_combine.warmup(20, grid=(1,)) + elif func == "cdiv": + kernel_info = kernel_cdiv.warmup(20, grid=(1,)) file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) if func == "single": @@ -155,12 +172,10 @@ def test_line_info(func: str): assert (check_file_lines(file_lines, "test_line_info.py", 16)) elif func == "call": assert (check_file_lines(file_lines, "test_line_info.py", 28)) - assert (check_file_lines(file_lines, "test_line_info.py", 21)) assert (check_file_lines(file_lines, "test_line_info.py", 30)) elif func == "call_noinline": assert (check_file_lines(file_lines, "test_line_info.py", 42)) assert (check_file_lines(file_lines, "test_line_info.py", 35)) - assert (check_file_lines(file_lines, "test_line_info.py", 36)) assert (check_file_lines(file_lines, "test_line_info.py", 37)) elif func == "autotune": assert (check_file_lines(file_lines, "test_line_info.py", 53)) @@ -169,3 +184,35 @@ def test_line_info(func: str): elif func == "dot_combine": assert (check_file_lines(file_lines, "test_line_info.py", 65)) assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False)) + elif func == "cdiv": + assert (check_file_lines(file_lines, "test_line_info.py", 75)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func", func_types) +def test_line_info_interpreter(func: str): + if not is_interpreter(): + pytest.skip("interpreter is not enabled") + + kernel = None + expected_def_lineno = 0 + if func == "single": + kernel = kernel_single + expected_def_lineno = 12 + elif func == "call": + kernel = kernel_call + expected_def_lineno = 25 + elif func == "call_noinline": + kernel = kernel_call_noinline + expected_def_lineno = 41 + elif func == "autotune": + kernel = kernel_autotune.fn + expected_def_lineno = 52 + elif func == "dot_combine": + kernel = kernel_dot_combine + expected_def_lineno = 62 + elif func == "cdiv": + kernel = kernel_cdiv + expected_def_lineno = 72 + kernel.rewrite() + assert kernel.rewriter.def_file_lineno == expected_def_lineno diff --git a/python/test/unit/language/test_pipeliner.py b/python/test/unit/language/test_pipeliner.py new file mode 100644 index 000000000..fa5f34290 --- /dev/null +++ b/python/test/unit/language/test_pipeliner.py @@ -0,0 +1,366 @@ +# End-to-end tests to check the correctness of the pipeliner + +import pytest +import torch +import triton +import triton.language as tl +import triton.tools.experimental_descriptor + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hopper(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_hip_mi200(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == 'hip' and target.arch == 'gfx90a' + + +def check_capabilities(): + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc[0] < 8: + pytest.skip("CUDA 8.0+ required") + + +@triton.jit +def matmul_kernel( # + a_ptr, scale_ptr, b_ptr, output_ptr, # + M, N, K_MXFP, # K_MXFP is the number of mxfp vectors in a row of a. Otherwise it's just K + stride_am, stride_ak, # + stride_sm, stride_sk, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr, a_type: tl.constexpr, b_type: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + IS_SCALED: tl.constexpr = a_type is not None and b_type is not None + DIV_FACTOR: tl.constexpr = 2 if IS_SCALED and a_type == "e2m1" else 1 + # We pass K_MXFP to make explicit that KB is multiple of 32 and KA is multiple of 16 or 32 + # for the pipeliner divisibility condition + KA = K_MXFP if not IS_SCALED else K_MXFP * (32 // DIV_FACTOR) + KB = K_MXFP if not IS_SCALED else K_MXFP * 32 + BLOCK_AK: tl.constexpr = BLOCK_K // DIV_FACTOR + offs_k = tl.arange(0, BLOCK_K) + offs_ak = tl.arange(0, BLOCK_AK) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if IS_SCALED: + BLOCK_SK: tl.constexpr = BLOCK_K // 32 + offs_sk = tl.arange(0, BLOCK_SK) + scale_ptrs = scale_ptr + (offs_am[:, None] * stride_sm + offs_sk[None, :] * stride_sk) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(KB, BLOCK_K), num_stages=NUM_STAGES): + mask_a = (offs_am[:, None] < M) & (offs_ak[None, :] + k * BLOCK_AK < KA) + mask_b = ((offs_k[:, None] + k * BLOCK_K) < KB) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=mask_a, other=0) + b = tl.load(b_ptrs, mask=mask_b, other=0) + if IS_SCALED: + # Adapted scale indexing and dot_scaled operation + mask_scale = (offs_am[:, None] < M) & (offs_sk[None, :] + k * BLOCK_SK < K_MXFP) + a_scale = tl.load(scale_ptrs, mask=mask_scale, other=0) + accumulator = tl.dot_scaled(a, a_scale, a_type, b, None, b_type, acc=accumulator) + else: + accumulator = tl.dot(a, b, acc=accumulator) + a_ptrs += BLOCK_AK * stride_ak + b_ptrs += BLOCK_K * stride_bk + if IS_SCALED: + scale_ptrs += BLOCK_SK * stride_sk + OUT_DTYPE = tl.bfloat16 if IS_SCALED else tl.float16 + accumulator = accumulator.to(OUT_DTYPE) + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_c = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(output_ptrs, accumulator, mask=mask_c) + + +@triton.jit +def matmul_kernel_tma( # + a_ptr, b_ptr, output_ptr, # + M, N, K, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_STAGES: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_M) % M + offs_bn = (pid_n * BLOCK_N) % N + offs_am = tl.multiple_of(offs_am, BLOCK_M) + offs_bn = tl.multiple_of(offs_bn, BLOCK_N) + offs_k = 0 + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for _ in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): + a = tl._experimental_descriptor_load(a_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], tl.float16) + b = tl._experimental_descriptor_load(b_ptr, [offs_k, offs_bn], [BLOCK_K, BLOCK_N], tl.float16) + accumulator = tl.dot(a, b, acc=accumulator) + offs_k += BLOCK_K + accumulator = accumulator.to(tl.float16) + tl._experimental_descriptor_store(output_ptr, accumulator, [offs_am, offs_bn]) + + +@triton.jit +def vecadd_kernel(a_ptr, b_ptr, output_ptr, n_elements, num_blocks, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE * num_blocks + offsets = block_start + tl.arange(0, BLOCK_SIZE) + for _ in tl.range(0, num_blocks, num_stages=NUM_STAGES): + mask = offsets < n_elements + x = tl.load(a_ptr + offsets, mask=mask) + y = tl.load(b_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + offsets += BLOCK_SIZE + + +@triton.jit +def mxfp_to_bf16_kernel( + x_ptr, + scale_ptr, + mxfp_ptr, + N, + e_bits: tl.constexpr, + m_bits: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + # x.shape == (N, 32) for fp8 or (N, 16) for fp4 + # scale.shape == (N,) + # out.shape == (N, 32) + is_fp8: tl.constexpr = e_bits + m_bits == 7 + # fp8: BLOCK_SIZE -> BLOCK_SIZE // 32, 32 + # fp4: BLOCK_SIZE // 2 -> BLOCK_SIZE // 32 , 16 + PARALLEL_DIM: tl.constexpr = BLOCK_SIZE // 32 + LAST_DIM: tl.constexpr = 32 if is_fp8 else 16 + LOAD_SIZE: tl.constexpr = LAST_DIM * PARALLEL_DIM + + offsets = (tl.program_id(0) * LOAD_SIZE + tl.arange(0, PARALLEL_DIM)[:, None] * LAST_DIM + + tl.arange(0, LAST_DIM)[None, :]) + x = tl.load(x_ptr + offsets, mask=offsets < N * LAST_DIM) + + offsets = tl.program_id(0) * PARALLEL_DIM + tl.arange(0, PARALLEL_DIM)[:, None] + scale = tl.load(scale_ptr + offsets, mask=offsets < N) + tl.static_assert(scale.dtype == tl.uint8) + tl.static_assert(x.dtype == tl.uint8) + + scale_bf16 = (scale.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True) + if is_fp8: + if e_bits == 5 and m_bits == 2: + x_f8 = x.to(tl.float8e5, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask: tl.constexpr = ((1 << e_bits) - 1) << m_bits + non_finite_mask_bf16: tl.constexpr = ((1 << 8) - 1) << 7 + x_bf16 = tl.where( + x & non_finite_mask == non_finite_mask, + (x_bf16.to(tl.uint16, bitcast=True) | non_finite_mask_bf16).to(tl.bfloat16, bitcast=True), + x_bf16, + ) + else: + tl.static_assert(e_bits == 4 and m_bits == 3) + x_f8 = x.to(tl.float8e4nv, bitcast=True) + x_bf16 = x_f8.to(tl.bfloat16) + else: + # e2m1 + em0 = x & 0x70 + em1 = x & 0x7 + x0 = (em0.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << 8) + x1 = (em1.to(tl.uint16) << (2 + 4)) | ((x & 0x8).to(tl.uint16) << (8 + 4)) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x60) != 0, x0 + ((127 - 1) << 7), x0) + x1 = tl.where((em1 & 0x6) != 0, x1 + ((127 - 1) << 7), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16 + x0 = tl.where(em0 == 0x10, 16128 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x1, 16128 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True) + # Multiplication preserves infs and NaNs in x_bf16 + mxfp = x_bf16 * scale_bf16 + # If scale is NaN, we encode it as an bf16 inf, so we need to correct for that + mxfp = tl.where(scale == 0xFF, float("nan"), mxfp) + + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) + + +def dot_scale_ref(x, scale, y, type_x, type_y): + e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] + type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] + + comp_dtype = torch.float32 + out_dtype = torch.bfloat16 + + x = x.contiguous() + x_upcast = x.new_empty(scale.shape[:-1] + (32 * scale.shape[-1], ), dtype=comp_dtype) + + N = x_upcast.numel() + BLOCK_SIZE = 512 + grid = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, ) + mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=4) + y_upcast = y.view(type_fp8_y) + + class AccumulateInFp32: + + def __enter__(self): + self.prev_value = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value + + with AccumulateInFp32(): + return torch.matmul(x_upcast.to(out_dtype), y_upcast.to(out_dtype)) + + +@pytest.mark.parametrize("scale", [True, False]) +def test_pipeline_matmul(scale, device): + check_capabilities() + if scale and not is_cuda(): + pytest.skip("NYI: scale_dot just implemented in CUDA") + M, N, K = 512, 512, 128 + BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 + NUM_STAGES = 4 + + if scale: + # TODO Use e5m2 for Ampere, as it does not support fp_to_fp conversions for fp8e4m3 + BLOCK_K = 64 # 32 NYI + K = BLOCK_K * NUM_STAGES + a_type = "e2m1" + DIV_FACTOR = 2 if a_type == "e2m1" else 1 + a = torch.randint(256, (M, K // DIV_FACTOR), device=device, dtype=torch.uint8) + # Sample small-ish scales to avoid overflow + scale_a = torch.randint(74, (M, K // 32), device=device, dtype=torch.uint8) + # Ampere does not support fp8e4m3 + b_type = "e4m3" if is_hopper() else "e5m2" + b = torch.randint(256, (K, N), device=device, dtype=torch.uint8) + # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and + # Fp8E5M2_to_Bf16 doesn't preserve NaNs (fixme) + if b_type == "e5m2": + finite = torch.arange(K * N, device=device, dtype=torch.uint8).reshape(K, N) % 0x7C + b = torch.where(b & 0x7C == 0x7C, finite | (0x80 & b), b) + output = torch.empty((M, N), dtype=torch.bfloat16, device=device) + else: + a = torch.randn(M, K, device=device, dtype=torch.float16) + b = torch.randn(K, N, device=device, dtype=torch.float16) + scale_a = None + a_type, b_type = None, None + output = torch.empty((M, N), dtype=torch.float16, device=device) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + use_tma = not scale and is_hopper() + + if use_tma: + a_tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(a.data_ptr(), M, K, BLOCK_M, BLOCK_K, + a.element_size()) + b_tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(b.data_ptr(), K, N, BLOCK_K, BLOCK_N, + b.element_size()) + output_tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor(output.data_ptr(), M, N, BLOCK_M, + BLOCK_N, output.element_size()) + handler = matmul_kernel_tma[grid](a_tma, b_tma, output_tma, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, + NUM_STAGES=NUM_STAGES) + else: + # Pass K_MXFP to make explicit that KB is multiple of 32 and KA is multiple of 16 or 32º + if scale: + K = scale_a.shape[-1] + stride_sm, stride_sk = scale_a.stride() if scale else (0, 0) + handler = matmul_kernel[grid](a, scale_a, b, output, M, N, K, a.stride(0), a.stride(1), stride_sm, stride_sk, + b.stride(0), b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, NUM_STAGES=NUM_STAGES, a_type=a_type, b_type=b_type) + if scale: + ref_out = dot_scale_ref(a, scale_a, b, a_type, b_type) + else: + ref_out = torch.matmul(a, b) + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and + # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + atol = 1e-2 if is_hip_mi200() or scale else None + rtol = 1e-2 if is_hip_mi200() or scale else None + torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol, equal_nan=scale) + if is_cuda(): + ttgir = handler.asm["ttgir"] + if use_tma: + assert ttgir.count("triton_nvidia_gpu.async_tma_copy_global_to_local") != 0, "async tma copy not found" + assert ttgir.count(f"num = {NUM_STAGES} : i32") == 0, "num_stages not match" + # a_tma, b_tma, output_tma, barriar + assert ttgir.count("triton_gpu.local_alloc") == 4, "alloc number not match" + assert ttgir.count("triton_nvidia_gpu.barrier_expect") != 0, "barrier_expect not found" + assert ttgir.count("triton_nvidia_gpu.wait_barrier") != 0, "wait_barrier not found" + assert ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found" + else: + # 1. check async + assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found" + # 2. check number of stages + assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match" + # 3. check alloc + assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match" + # 4. check dot + cc = torch.cuda.get_device_capability() + if cc[0] >= 9: + ttgir.count("triton_nvidia_gpu.warp_group_dot") != 0, "warp_group_dot not found" + else: + ttgir.count("triton_gpu.dot") != 0, "dot not found" + + +def test_pipeline_vecadd(device): + check_capabilities() + SIZE = 4096 + NUM_BLOCKS = 4 + BLOCK_SIZE = 256 + NUM_STAGES = 3 + a = torch.randn(SIZE, dtype=torch.float16, device=device) + b = torch.randn(SIZE, dtype=torch.float16, device=device) + output = torch.empty(SIZE, dtype=torch.float16, device=device) + grid = (triton.cdiv(SIZE, NUM_BLOCKS * BLOCK_SIZE), 1) + handler = vecadd_kernel[grid](a, b, output, SIZE, NUM_BLOCKS, BLOCK_SIZE, NUM_STAGES) + ref_out = a + b + torch.testing.assert_close(ref_out, output) + if is_cuda(): + ttgir = handler.asm["ttgir"] + # 1. check async + assert ttgir.count("triton_gpu.async_copy_global_to_local") != 0, "async copy not found" + # 2. check number of stages + assert ttgir.count(f"num = {NUM_STAGES} : i32") != 0, "num_stages not match" + # 3. check alloc + assert ttgir.count("triton_gpu.local_alloc") == 2, "alloc number not match" + + +@pytest.mark.parametrize("ROW_COUNT", [0, 1, 2, 3]) +@pytest.mark.parametrize("NUM_STAGES", [1, 2, 3, 4, 5]) +def test_pipeline_epilogue(ROW_COUNT, NUM_STAGES, device): + + @triton.jit + def kernel_up(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr): + row_step = tl.num_programs(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + for row_idx in tl.range(0, n_rows, row_step, num_stages=NUM_STAGES): + row_start_ptr = input_ptr + row_idx * input_row_stride + input_ptrs = row_start_ptr + col_offsets + val = tl.load(input_ptrs, mask=mask, other=-float('inf')) + val += 1.0 + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, val, mask=mask) + + width = ROW_COUNT + depth = 78 + x = torch.zeros(width, depth, device=device) + y0 = torch.rand_like(x) + n_rows, n_cols = x.shape + BLOCK_SIZE = triton.next_power_of_2(n_cols) + kernel_up[(1, )](y0, x, x.stride(0), y0.stride(0), n_rows, n_cols, BLOCK_SIZE, NUM_STAGES) + assert (y0 == torch.ones_like(x)).all() diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index e0e59b069..614b05deb 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -54,9 +54,10 @@ def _dtype(self): def _into_pieces(self, n, pad=4): res = [] + bits = np.dtype(self._dtype).itemsize * 8 while len(res) < pad: - res.append(np.array(n, dtype=self._dtype)) - n >>= (np.dtype(self._dtype).itemsize * 8) + res.append(np.array((n & ((1 << bits) - 1)), dtype=self._dtype)) + n >>= bits assert n == 0 return tuple(res) diff --git a/python/test/unit/language/test_standard.py b/python/test/unit/language/test_standard.py index 017ff36f8..b3392d475 100644 --- a/python/test/unit/language/test_standard.py +++ b/python/test/unit/language/test_standard.py @@ -28,7 +28,7 @@ def test_maximum_minium(dtype, op, device): @pytest.mark.interpreter @pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) @pytest.mark.parametrize("descending", [False, True]) -@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_sort(M, N, descending, dtype_str, device): @triton.jit @@ -55,7 +55,7 @@ def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr @pytest.mark.interpreter @pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) -@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_flip(M, N, dtype_str, device): @triton.jit @@ -73,3 +73,21 @@ def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr): z = torch.empty_like(x, device=device) flip_kernel[(1, )](x, z, N, M, num_warps=8) assert (y == z).all(), (y, z) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("size_i, size_j, size_g", [[5, 7, 3]]) +def test_swizzle2d(size_i, size_j, size_g, device): + + @triton.jit + def swizzle2d_kernel(output, size_i, size_j, size_g): + for i in tl.range(0, size_i, 1): + for j in tl.range(0, size_j, 1): + new_i, new_j = tl.swizzle2d(i, j, size_i, size_j, size_g) + tl.store(output + new_i * size_j + new_j, i * size_j + j) + + output = torch.zeros(size_i, size_j).to(device) + swizzle2d_kernel[(1, )](output, size_i, size_j, size_g) + expected_order = torch.tensor([[0, 3, 6, 9, 12, 15, 18], [1, 4, 7, 10, 13, 16, 19], [2, 5, 8, 11, 14, 17, 20], + [21, 23, 25, 27, 29, 31, 33], [22, 24, 26, 28, 30, 32, 34]]).to(device) + assert (output == expected_order).all(), (output, expected_order) diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 683a02a56..76a7d9508 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -8,11 +8,6 @@ dir_path = os.path.dirname(os.path.realpath(__file__)) print_path = os.path.join(dir_path, "print_helper.py") -assert_path = os.path.join(dir_path, "assert_helper.py") - -# TODO: bfloat16 after LLVM-15 -assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"] -nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]] torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] @@ -24,44 +19,62 @@ def is_interpreter(): @pytest.mark.interpreter -@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [ - ("print", "int32"), - ("static_print", "int32"), - ("no_arg_print", "int32"), - ("print_no_arg", "int32"), - ("device_print_large", "int32"), - ("print_multiple_args", "int32"), - ("device_print_multiple_args", "int32"), - ("device_print_hex", "int16"), - ("device_print_hex", "int32"), - ("device_print_hex", "int64"), - ("device_print_pointer", "int32"), -]) -def test_print(func_type: str, data_type: str): - proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, - stderr=subprocess.PIPE, shell=False) - outs, err = proc.communicate() +@pytest.mark.parametrize("func_type, data_type", [(fn, data_type) + for fn in ["device_print", "device_print_scalar"] + for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + ("no_arg_print", "int32"), + ("print_no_arg", "int32"), + ("device_print_large", "int32"), + ("print_multiple_args", "int32"), + ("device_print_multiple_args", "int32"), + ("device_print_hex", "int16"), + ("device_print_hex", "int32"), + ("device_print_hex", "int64"), + ("device_print_pointer", "int32"), + ("device_print_negative", "int32"), + # ("device_print_uint", "uint32"), # TODO: flagtree + ]) +def test_print(func_type: str, data_type: str, device: str): + proc = subprocess.run( + [sys.executable, print_path, "test_print", func_type, data_type, device], + capture_output=True, + ) assert proc.returncode == 0 if is_interpreter() and func_type != "static_assert": # Interpreter uses a different format for device_print # Only check if there's no error - assert err == b'' + assert proc.stderr == b'' return - outs = [line for line in outs.decode("UTF-8").split("\n") if line] + outs = [line for line in proc.stdout.decode("UTF-8").splitlines() if line] # The total number of elements in the 1-D tensor to print. N = 128 + # Constant for testing the printing of scalar values + SCALAR_VAL = 42 + # Format is # pid (, , ) idx (, , ...) (operand ) expected_lines = Counter() - if func_type == "print" or func_type == "device_print": + if func_type in ("print", "device_print", "device_print_uint"): for i in range(N): - line = f"pid (0, 0, 0) idx ({i:3}) x: {i}" + offset = (1 << 31) if data_type == "uint32" else 0 + line = f"pid (0, 0, 0) idx ({i:3}) x: {i + offset}" if data_type.startswith("float"): line += ".000000" expected_lines[line] = 1 + elif func_type == "device_print_scalar": + line = f"pid (0, 0, 0) idx () x: {SCALAR_VAL}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = N + elif func_type == "device_print_negative": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: {-i}" + expected_lines[line] = 1 elif func_type == "device_print_hex": for i in range(N): line = f"pid (0, 0, 0) idx ({i:3}) x: 0x" @@ -102,58 +115,3 @@ def test_print(func_type: str, data_type: str): continue print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') assert all(delta == 0 for delta in diff.values()) - - -@pytest.mark.parametrize("func_type", assert_types) -def test_assert(func_type: str): - # The total number of elements in the 1-D tensor to assert on. - N = 128 - - os.environ["TRITON_DEBUG"] = "1" - proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=False) - _, errs = proc.communicate() - errs = errs.splitlines() - num_errs = 0 - for err in errs: - if "x != 0" in err.decode("utf-8", errors="ignore"): - num_errs += 1 - - # Check for segfaults. - assert all("segmentation fault" not in line.decode("utf-8", errors="ignore").lower() for line in errs) - - os.environ["TRITON_DEBUG"] = "0" - if func_type == "static_assert" or func_type == "device_assert_passes": - assert num_errs == 0 - else: - assert num_errs == N - 1 - - -@pytest.mark.parametrize("caller_type, callee_type", nested_types) -def test_assert_nested(caller_type, callee_type): - # The total number of elements in the 1-D tensor to assert on. - N = 128 - - proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, - stderr=subprocess.PIPE, shell=False) - _, errs = proc.communicate() - errs = errs.splitlines() - num_errs = 0 - for err in errs: - if "x != 0" in err.decode("utf-8", errors="ignore"): - num_errs += 1 - if caller_type == "none": - if callee_type == "true": - assert num_errs == N - 1 - else: - assert num_errs == 0 - elif caller_type == "true": - if callee_type == "false": - assert num_errs == 0 - else: - assert num_errs == N - 1 - elif caller_type == "false": - if callee_type == "true": - assert num_errs == N - 1 - else: - assert num_errs == 0 diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index 6bbff1227..4462abcfc 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -5,33 +5,41 @@ import pytest +def do_bench(kernel_call, quantiles): + return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1) + + @pytest.mark.parametrize('use_cuda_graph', [False, True]) -def test_kwargs(use_cuda_graph: bool): - N = 1024 - src = torch.empty(N, device='cuda') - dst = torch.empty(N, device='cuda') +def test_kwargs(use_cuda_graph: bool, device: str): + M, N = 1024, 16 + src = torch.randn(M * N, device=device) + dst = torch.empty(M * N, device=device) - configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})] - @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) + @triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph, do_bench=do_bench) @triton.jit - def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x = tl.load(src + offsets, mask=offsets < N) - tl.store(dst + offsets, x, mask=offsets < N) + def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M) + offsets_n = tl.arange(0, BLOCK_SIZE_N) + x = tl.load(src + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :]) + tl.store(dst + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :], x) - grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) - _kernel[grid](dst, src, N) - _kernel[grid](dst=dst, src=src, N=N) + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE_M']), ) + _kernel[grid](dst, src, N, M, N) + # the key word args could be in arbitrary order. + _kernel[grid](dst=dst, src=src, M=M // 2, stride_m=N, BLOCK_SIZE_N=N) + assert len(_kernel.cache) == 2 -def test_restore(): +@pytest.mark.parametrize('pass_kwargs_to_kernel', [False, True]) +def test_restore(pass_kwargs_to_kernel, device): N = 1024 - src = torch.zeros(N, device='cuda') + src = torch.zeros(N, device=device) configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] - @triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1) + @triton.autotune(configs=configs, key=['N'], restore_value=['src'], do_bench=do_bench) @triton.jit def _kernel(src, N, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -39,14 +47,17 @@ def _kernel(src, N, BLOCK_SIZE: tl.constexpr): tl.store(src + offsets, x, mask=offsets < N) grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) - _kernel[grid](src, N) + if pass_kwargs_to_kernel: + _kernel[grid](src=src, N=N) + else: + _kernel[grid](src, N) triton.testing.assert_close(src, torch.ones_like(src)) -def test_hooks(): +def test_hooks(device): # Autotuner's pre- and post- hooks should be called the same number of times N = 4096 - src = torch.zeros(N, device='cuda') + src = torch.zeros(N, device=device) configs = [triton.Config(kwargs={'BLOCK_SIZE': 4096}), triton.Config(kwargs={'BLOCK_SIZE': 32})] @@ -61,7 +72,7 @@ def _post_hook(*args, exception): values["has_exception"] = True assert values["counter"] == 0 - @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook) + @triton.autotune(configs=configs, key=['N'], do_bench=do_bench, pre_hook=_pre_hook, post_hook=_post_hook) @triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4}) @triton.jit def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr): @@ -87,10 +98,10 @@ def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr): @pytest.mark.parametrize('with_perf_model', [False, True]) -def test_prune_configs(with_perf_model: bool): +def test_prune_configs(with_perf_model: bool, device: str): N = 1024 - src = torch.empty(N, device='cuda') - dst = torch.empty(N, device='cuda') + src = torch.randn(N, device=device) + dst = torch.empty(N, device=device) records = {} def early_config_prune(configs, named_args, **kwargs): @@ -112,7 +123,7 @@ def perf_model(*args, **kwargs): else: prune_configs_by = {'early_config_prune': early_config_prune} - @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1) + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, do_bench=do_bench) @triton.jit def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index c48ba9b4a..206d13230 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -2,6 +2,7 @@ import triton.language as tl import torch +import math @triton.jit @@ -27,9 +28,9 @@ def add_kernel( tl.store(out_ptr + offsets, output, mask=mask) -def test_module_walk(): +def test_module_walk(device): """ - Test the MLIR bindings exposed for the out-ot-tree walk. + Test the MLIR bindings exposed for the out-of-tree walk. """ def walk_fn(op): @@ -52,30 +53,56 @@ def walk_fn(op): kernel = add_kernel args = [ - torch.empty((32, 32), device="cuda"), # in_ptr0 - torch.empty((32, 32), device="cuda"), # in_ptr1 + torch.empty((32, 32), device=device), # in_ptr0 + torch.empty((32, 32), device=device), # in_ptr1 1024, # n_elements - torch.empty((32, 32), device="cuda"), # out_ptr + torch.empty((32, 32), device=device), # out_ptr 16, # BLOCK_SIZE ] + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) src = triton.compiler.compiler.ASTSource( fn=kernel, - signature={i: kernel._type_of(kernel._key_of(arg)) - for i, arg in enumerate(args) - if i not in kernel.constexprs}, - constants={i: arg + signature={ + kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args) + if i not in kernel.constexprs + }, + constants={kernel.arg_names[i]: arg for i, arg in enumerate(args) if not isinstance(arg, torch.Tensor)}, - attrs=kernel._get_config(*args, ), + attrs=backend.get_attrs_descriptor(args, kernel.params), ) context = triton._C.libtriton.ir.context() - target = triton.runtime.driver.active.get_current_target() - backend = triton.compiler.compiler.make_backend(target) options = backend.parse_options(dict()) codegen_fns = dict() + module_map = backend.get_module_map() triton._C.libtriton.ir.load_dialects(context) backend.load_dialects(context) - ttir_module = src.make_ir(options, codegen_fns, context) + ttir_module = src.make_ir(options, codegen_fns, module_map, context) ttir_module.walk(walk_fn) + + +def test_python_func_in_visit_call(device): + + @triton.jit + def test_py_call_const_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + log2e: tl.constexpr = math.log2(math.e) + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = x * log2e + tl.store(out_ptr + offsets, output, mask=mask) + + x = torch.randn(4, device=device) + out = torch.zeros_like(x) + test_py_call_const_kernel[(4, )](x, out, 4, 4) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index eddfe06e0..a45cb3f88 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,6 +1,5 @@ import importlib.util import itertools -import os import shutil import tempfile @@ -10,15 +9,23 @@ import triton import triton.language as tl from triton.runtime.jit import JITFunction +from triton._internal_testing import is_hip -tmpdir = ".tmp" + +@triton.jit +def function_0(i): + return i + 1 @triton.jit def function_1(i): i = i + 1 - i = function_2(i) - return i + cond: tl.constexpr = True + if cond: + FN: tl.constexpr = function_2 + else: + FN: tl.constexpr = function_0 + return FN(i) @triton.jit @@ -46,6 +53,13 @@ def kernel_nospec(X, i, BLOCK: tl.constexpr): tl.store(X, i) +@triton.jit(do_not_specialize_on_alignment=["i"]) +def kernel_nospec_on_alignment(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + @triton.jit def kernel_with_combine_fn(X, BLOCK: tl.constexpr): i = tl.arange(0, BLOCK) @@ -53,32 +67,38 @@ def kernel_with_combine_fn(X, BLOCK: tl.constexpr): tl.store(X, i) -def apply_src_change(target, old, new): +def apply_src_change(target, old, new, to_modify): kernel.hash = None + function_0.hash = None function_1.hash = None function_2.hash = None - function_1.src = function_1.src.replace(old, new) - target.src = target.src.replace(old, new) + to_modify.src = to_modify.src.replace(old, new) ret = target.cache_key - target.src = target.src.replace(new, old) + to_modify.src = to_modify.src.replace(new, old) return ret def test_nochange(): baseline = kernel.cache_key - updated = apply_src_change(kernel, 'i + 1', 'i + 1') + updated = apply_src_change(kernel, 'i + 1', 'i + 1', function_1) assert baseline == updated def test_toplevel_change(): baseline = kernel.cache_key - updated = apply_src_change(kernel, 'i + 1', 'i + 2') + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_1) assert baseline != updated def test_nested1_change(): baseline = kernel.cache_key - updated = apply_src_change(function_1, 'i + 1', 'i + 2') + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_2) + assert baseline != updated + + +def test_nested2_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2', function_0) assert baseline != updated @@ -135,14 +155,7 @@ def test_kernel(i): assert orig_cache_key != updated_cache_key -def reset_tmp_dir(): - os.environ["TRITON_CACHE_DIR"] = tmpdir - if os.path.exists(tmpdir): - # https://stackoverflow.com/questions/303200/how-do-i-remove-delete-a-folder-that-is-not-empty - shutil.rmtree(tmpdir, ignore_errors=True) - - -def test_reuse(): +def test_reuse(device, fresh_triton_cache): counter = 0 def inc_counter(*args, **kwargs): @@ -150,15 +163,14 @@ def inc_counter(*args, **kwargs): counter += 1 JITFunction.cache_hook = inc_counter - reset_tmp_dir() - x = torch.empty(1, dtype=torch.int32, device='cuda') + x = torch.empty(1, dtype=torch.int32, device=device) for i in range(10): kernel[(1, )](x, 1, BLOCK=1024) assert counter == 1 -@pytest.mark.parametrize('mode', ['enable', 'disable']) -def test_specialize(mode): +@pytest.mark.parametrize('mode', ['enable', 'disable', 'disable_on_alignment']) +def test_specialize(mode, device, fresh_triton_cache): counter = 0 def inc_counter(*args, **kwargs): @@ -166,24 +178,23 @@ def inc_counter(*args, **kwargs): counter += 1 JITFunction.cache_hook = inc_counter - reset_tmp_dir() - x = torch.empty(1, dtype=torch.int32, device='cuda') - function = {'enable': kernel, 'disable': kernel_nospec}[mode] - target = {'enable': 3, 'disable': 1}[mode] + x = torch.empty(1, dtype=torch.int32, device=device) + function = {'enable': kernel, 'disable': kernel_nospec, 'disable_on_alignment': kernel_nospec_on_alignment}[mode] + target = {'enable': 3, 'disable': 1, 'disable_on_alignment': 2}[mode] for i in [1, 2, 4, 8, 16, 32]: function[(1, )](x, i, BLOCK=512) assert counter == target -def test_annotation(): +def test_annotation(device): @triton.jit def kernel(X, i: tl.int32): tl.store(X, i) - x = torch.empty(1, dtype=torch.int32, device='cuda') + x = torch.empty(1, dtype=torch.int32, device=device) - device = torch.cuda.current_device() + device = getattr(torch, device).current_device() kernel[(1, )](x, 1) kernel[(1, )](x, 8) kernel[(1, )](x, 16) @@ -194,14 +205,14 @@ def kernel(X, i: tl.int32): GLOBAL_DEFAULT_ARG = 1 -def test_kernel_default_arg(): +def test_kernel_default_arg(device): global GLOBAL_DEFAULT_ARG @triton.jit def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): tl.store(X, i) - x = torch.empty(1, dtype=torch.int32, device='cuda') + x = torch.empty(1, dtype=torch.int32, device=device) kernel[(1, )](x) assert x == torch.ones_like(x) @@ -211,21 +222,21 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): kernel[(1, )](x) assert x == torch.ones_like(x) - device = torch.cuda.current_device() + device = getattr(torch, device).current_device() assert len(kernel.cache[device]) == 1 GLOBAL_VAR: tl.constexpr = 1 -def test_kernel_global_var_change(): +def test_kernel_global_var_change(device): global GLOBAL_VAR @triton.jit def kernel(X): tl.store(X, GLOBAL_VAR) - x = torch.empty(1, dtype=torch.int32, device='cuda') + x = torch.empty(1, dtype=torch.int32, device=device) kernel[(1, )](x) assert x == torch.ones_like(x) @@ -370,13 +381,13 @@ def kernel(): assert not kernel.used_global_vals -def test_constexpr_not_callable() -> None: +def test_constexpr_not_callable(device) -> None: @triton.jit def kernel(X, c: tl.constexpr): tl.store(X, 2) - x = torch.empty(1, dtype=torch.int32, device='cuda') + x = torch.empty(1, dtype=torch.int32, device=device) error = False try: kernel[(1, )](x, c="str") @@ -391,7 +402,7 @@ def kernel(X, c: tl.constexpr): assert error is True -def test_jit_warmup_cache() -> None: +def test_jit_warmup_cache(device) -> None: @triton.jit def kernel_add(a, b, o, N: tl.constexpr): @@ -399,12 +410,12 @@ def kernel_add(a, b, o, N: tl.constexpr): tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) args = [ - torch.randn(32, dtype=torch.float32, device="cuda"), - torch.randn(32, dtype=torch.float32, device="cuda"), - torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), 32, ] - device = torch.cuda.current_device() + device = getattr(torch, device).current_device() assert len(kernel_add.cache[device]) == 0 kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 1 @@ -414,26 +425,21 @@ def kernel_add(a, b, o, N: tl.constexpr): assert len(kernel_add.cache[device]) == 1 -def test_jit_debug() -> None: +def test_jit_debug(device) -> None: @triton.jit - def kernel_add(a, b, o, N: tl.constexpr): - idx = tl.arange(0, N) - tl.device_assert(idx < 32, "idx < 32") - tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + def kernel(tmp): + tl.device_assert(tl.load(tmp) == 1, "tmp == 1") - device = torch.cuda.current_device() - assert len(kernel_add.cache[device]) == 0 - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 1 - kernel_add.debug = False - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 2 - kernel_add.debug = True - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) - assert len(kernel_add.cache[device]) == 3 - bins = list(kernel_add.cache[device].values()) - assert bins[2].asm['ttir'] != bins[1].asm['ttir'] + device = getattr(torch, device).current_device() + tmp = torch.tensor([1], dtype=torch.int32, device=device) + assert len(kernel.cache[device]) == 0 + kernel[(1, )](tmp, debug=False) + assert len(kernel.cache[device]) == 1 + kernel[(1, )](tmp, debug=True) + assert len(kernel.cache[device]) == 2 + bins = list(kernel.cache[device].values()) + assert bins[0].asm['ttir'] != bins[1].asm['ttir'] @triton.jit @@ -442,13 +448,13 @@ def add_fn(a, b, o, N: tl.constexpr): tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) -def test_jit_noinline() -> None: +def test_jit_noinline(device) -> None: @triton.jit def kernel_add_device(a, b, o, N: tl.constexpr): add_fn(a, b, o, N) - device = torch.cuda.current_device() + device = getattr(torch, device).current_device() assert len(kernel_add_device.cache[device]) == 0 kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add_device.cache[device]) == 1 @@ -478,7 +484,7 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) -def test_preload() -> None: +def test_preload(device, fresh_triton_cache) -> None: @triton.jit def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): @@ -492,7 +498,7 @@ def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): tl.device_assert(idx < 32, "idx < 32") tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) - device = torch.cuda.current_device() + device = getattr(torch, device).current_device() # get the serialized specialization data specialization_data = None @@ -507,7 +513,7 @@ def cache_hook(*args, **kwargs): assert specialization_data is not None # clear the cache - reset_tmp_dir() + shutil.rmtree(fresh_triton_cache) kernel_add.cache[device].clear() # preload the kernel @@ -532,3 +538,64 @@ def inc_counter(*args, **kwargs): # test that we can't preload a mismatched kernel with pytest.raises(RuntimeError, match="Specialization data is for"): kernel_sub.preload(specialization_data) + + +def test_hooks(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + # get the serialized specialization data + specialization_data = None + is_warmup = False + key = 0 + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + nonlocal is_warmup + is_warmup = kwargs["compile"]["is_warmup"] + nonlocal key + key = kwargs["compile"]["key"] + + specialization_data_compiled = None + + def compiled_hook(*args, **kwargs): + nonlocal specialization_data_compiled + specialization_data_compiled = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + JITFunction.compiled_hook = compiled_hook + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + assert specialization_data is not None and specialization_data_compiled == specialization_data + assert is_warmup is True + assert key in kernel_add.cache[getattr(torch, device).current_device()] + + +@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip()) +def test_within_2gb(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a): + tl.load(a) + + # This is the attribute we want to test + pointer_range_32 = None + + def cache_hook(*args, **kwargs): + nonlocal pointer_range_32 + pointer_range_32 = kwargs["compile"]["configs"][0].pointer_range_32 + + JITFunction.cache_hook = cache_hook + # In warmup we assume that the pointer range is 32 bits + kernel_add.warmup(torch.float32, grid=(1, )) + assert pointer_range_32 == [0] + # Torch tensor > 2GB + kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) + assert len(pointer_range_32) == 0 + # Torch tensor <= 2GB + kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) + assert pointer_range_32 == [0] diff --git a/python/test/unit/runtime/test_cublas.py b/python/test/unit/runtime/test_cublas.py new file mode 100644 index 000000000..a4315fc3c --- /dev/null +++ b/python/test/unit/runtime/test_cublas.py @@ -0,0 +1,49 @@ +import pytest +import torch +import triton +import os + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cuda(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cuda" + + +@pytest.mark.parametrize("m, n, k", [(16, 16, 16), (32, 16, 16), (16, 32, 16), (16, 16, 32)]) +@pytest.mark.parametrize("dtype_str", ["float8_e4m3fn", "float16"]) +def test_cublas(m, n, k, dtype_str, device): + dtype = getattr(torch, dtype_str) + if not is_cuda(): + pytest.skip("test_cublas is only supported on CUDA") + if dtype == torch.float8_e4m3fn and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("fp8 is only supported on CUDA with cc >= 90") + + from triton._C.libtriton import nvidia + + torch.manual_seed(123) + workspace_size = 32 * 1024 * 1024 + + def limited_rand(elements, shape): + total_elems = torch.prod(torch.tensor(shape)).item() + indices = torch.randint(0, len(elements), (total_elems, ), device=device) + return elements[indices].view(shape) + + elements = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32, device=device) + a = limited_rand(elements, (m, k)).to(dtype) + b = limited_rand(elements, (k, n)).to(dtype) + c = torch.zeros((m, n), dtype=dtype, device=device) + + b = b.T.contiguous() + + workspace = torch.empty(workspace_size, dtype=torch.int8, device=device) + + cublas = nvidia.cublas.CublasLt(workspace) + cublas.matmul(a, b, c) + + ref = torch.matmul(a.to(torch.float16), b.to(torch.float16).T) + + assert torch.allclose(c.to(torch.float16), ref, atol=2.0) diff --git a/python/test/unit/runtime/test_driver.py b/python/test/unit/runtime/test_driver.py index de00082f5..9bd51cc2b 100644 --- a/python/test/unit/runtime/test_driver.py +++ b/python/test/unit/runtime/test_driver.py @@ -1,6 +1,9 @@ import sys +from concurrent.futures import ThreadPoolExecutor +import torch import triton +import triton.language as tl def test_is_lazy(): @@ -12,3 +15,27 @@ def test_is_lazy(): assert triton.runtime.driver.active._obj is None utils = triton.runtime.driver.active.utils # noqa: F841 assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase")) + + +def test_kernel_in_thread(device): + # Test calling in a new thread sets a valid device context + buf = torch.zeros((38016 * 1024, ), dtype=torch.float32, device=device) + + @triton.jit + def _kernel(P, BLOCK: tl.constexpr): + pid = tl.program_id(0).to(tl.int64) + offset = pid * BLOCK + tl.arange(0, BLOCK) + + p = tl.load(P + offset) + tl.store(P + offset, p) + + def call_triton(): + N = buf.numel() + grid = lambda meta: (triton.cdiv(N, meta["BLOCK"]), ) + _kernel[grid](buf, BLOCK=1024) + getattr(torch, device).synchronize() + + call_triton() + with ThreadPoolExecutor(1) as pool: + future = pool.submit(call_triton) + future.result() diff --git a/python/test/unit/runtime/test_launch.py b/python/test/unit/runtime/test_launch.py index f17c05674..91fc6e19b 100644 --- a/python/test/unit/runtime/test_launch.py +++ b/python/test/unit/runtime/test_launch.py @@ -43,7 +43,7 @@ def kernel(x): assert used_hook -def test_memory_leak() -> None: +def test_memory_leak(device) -> None: @triton.jit def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): @@ -57,8 +57,8 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): tracemalloc.start() try: - inp = torch.randn(10, device='cuda') - out = torch.randn(10, device='cuda') + inp = torch.randn(10, device=device) + out = torch.randn(10, device=device) kernel[(10, )](inp, out, 10, XBLOCK=16) gc.collect() begin, _ = tracemalloc.get_traced_memory() diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 333d1f929..027779233 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -1,25 +1,15 @@ import multiprocessing -import os import shutil -import torch - import triton import triton.language as tl +from triton.backends.compiler import AttrsDescriptor from triton.compiler import ASTSource -tmpdir = ".tmp" - target = triton.runtime.driver.active.get_current_target() -def reset_tmp_dir(): - os.environ["TRITON_CACHE_DIR"] = tmpdir - if os.path.exists(tmpdir): - shutil.rmtree(tmpdir, ignore_errors=True) - - -def compile_fn(attrs, capability): +def compile_fn(attrs): @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): @@ -28,26 +18,23 @@ def kernel_sub(a, b, o, N: tl.constexpr): src = ASTSource( fn=kernel_sub, - constants={3: 32}, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + constants={'N': 32}, + signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32"}, attrs=attrs, ) triton.compile(src=src, target=target) def test_compile_in_subproc() -> None: - major, minor = torch.cuda.get_device_capability(0) - cc = major * 10 + minor - config = triton.compiler.AttrsDescriptor(tuple(range(4)), ()) - + config = AttrsDescriptor.from_hints({i: 16 for i in range(4)}) multiprocessing.set_start_method('fork') - proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) + proc = multiprocessing.Process(target=compile_fn, args=(config, )) proc.start() proc.join() assert proc.exitcode == 0 -def compile_fn_dot(attrs, capability): +def compile_fn_dot(attrs): @triton.jit def kernel_dot(Z): @@ -56,18 +43,64 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) + src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constants={}) + triton.compile(src=src, target=target) + + +def test_compile_in_forked_subproc(fresh_triton_cache) -> None: + config = AttrsDescriptor.from_hints({0: 16}) + assert multiprocessing.get_start_method() == 'fork' + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, )) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_empty_kernel_with_gc(attrs): + + @triton.jit + def empty_kernel(): + pass + + import gc + gc.collect() + src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants={}) triton.compile(src=src, target=target) -def test_compile_in_forked_subproc() -> None: - reset_tmp_dir() - major, minor = torch.cuda.get_device_capability(0) - capability = major * 10 + minor - config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) +def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None: + ''' + Tests that compilation artifacts can safely live in forked process. + Scenario being tested here ("p" stands for parent process, "c" is child process): + 1. p compiles a kernel 1, and produces compilation artifacts. + 2. p forks the process to create c. + 3. c deletes compilation artifacts inherited from p, compiles kernel 2, and terminates. + 3. p wait for c and join it. + + This is a regression test that ensures thread pool in MLIRContext is released + safely after compilation. + ''' + import gc + old_gc_state = gc.isenabled() + # disable GC to manage resources manually in the manner described in comment above + gc.disable() + + # stage 1.p + config = AttrsDescriptor.from_hints({0: 16}) + compile_empty_kernel_with_gc(config) + + # stage 2.p + shutil.rmtree(fresh_triton_cache) assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) + proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, )) + + # stage 3.c proc.start() + # stage 3.p proc.join() + + # restore gc state + if old_gc_state: + gc.enable() assert proc.exitcode == 0 diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py new file mode 100644 index 000000000..e1c74b677 --- /dev/null +++ b/python/test/unit/test_debug.py @@ -0,0 +1,129 @@ +import os +import pytest +import torch +import triton.language as tl +import triton + + +@pytest.mark.skip(reason="flagtree") +@pytest.mark.parametrize('cond, opt_flag, env_var', [ + (cond, opt_flag, env_var) for cond in [True, False] \ + for opt_flag in [True, False] \ + for env_var in [True, False]\ +]) +@pytest.mark.forked +def test_device_assert(cond, opt_flag, env_var, device): + os.environ['TRITON_DEBUG'] = str(int(env_var)) + torch.zeros([1], dtype=torch.int32, device=device) + + @triton.jit + def _kernel(COND: tl.constexpr): + tl.device_assert(COND, 'test') + + if not cond and (opt_flag or env_var): + with pytest.raises(RuntimeError): + _kernel[(1, )](cond, debug=opt_flag) + getattr(torch, device).synchronize() + return + + _kernel[(1, )](cond, debug=opt_flag) + getattr(torch, device).synchronize() + + +@pytest.mark.skip(reason="flagtree") +@pytest.mark.parametrize("cond", [False, True]) +def test_static_assert(cond): + + @triton.jit + def _kernel(COND: tl.constexpr): + tl.static_assert(COND) + + if not cond: + with pytest.raises(triton.compiler.errors.CompileTimeAssertionFailure): + _kernel[(1, )](cond) + return + + _kernel[(1, )](cond) + + +def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func, device): + x = torch.tensor([x], dtype=getattr(torch, x_dtype), device=device) + y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device) + z = torch.empty_like(x) + if should_overflow and debug: + with pytest.raises(RuntimeError) as exc_info: + tri_func[(1, )](x, y, z, debug=debug) + getattr(torch, device).synchronize() + assert "device-side assert" in str(exc_info.value) + else: + tri_func[(1, )](x, y, z, debug=debug) + getattr(torch, device).synchronize() + assert int(z) == int(ref_func(x, y)) + + +# integer overflow sanitization + + +@pytest.mark.skip(reason="flagtree") +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (-2**31, -1, 'int32', 'int32', False, False), + (-2**31, -1, 'int32', 'int32', True, True), + (2**31 - 1, 1, 'int32', 'int32', True, True), + (2**31 - 1, 100, 'int32', 'int32', True, True), + (-2**31, 0, 'int32', 'int32', True, False), + (-2**31, 2, 'int32', 'int32', True, False), + (0, -1, 'int32', 'int32', True, False), + (-2**15, -1, 'int16', 'int16', True, True), + (2**15 - 1, 1, 'int16', 'int16', True, True), +]) +@pytest.mark.forked +def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device): + + @triton.jit + def _kernel_add(X, Y, Z): + tl.store(Z, tl.load(X) + tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y, device) + + +# mul overflow + + +@pytest.mark.skip(reason="flagtree") +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (2**30, 4, 'int32', 'int32', False, False), + (2**30, 4, 'int32', 'int32', True, True), + (2**30, 2, 'int32', 'int32', True, True), + (-2**30, -4, 'int32', 'int32', True, True), + (-2**31, 1, 'int32', 'int32', True, False), + (-2**30, 2, 'int32', 'int32', True, False), +]) +@pytest.mark.forked +def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device): + + @triton.jit + def _kernel_mul(X, Y, Z): + tl.store(Z, tl.load(X) * tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y, device) + + +# sub overflow + + +@pytest.mark.skip(reason="flagtree") +@pytest.mark.parametrize("x, y, x_dtype, y_dtype, debug, should_overflow", [ + (-2**31, 1, 'int32', 'int32', False, False), + (-2**31, 1, 'int32', 'int32', True, True), + (2**31 - 1, -1, 'int32', 'int32', True, True), + (2**31 - 1, 1, 'int32', 'int32', True, False), + (-2**31, -1, 'int32', 'int32', True, False), +]) +@pytest.mark.forked +def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device): + + @triton.jit + def _kernel_sub(X, Y, Z): + tl.store(Z, tl.load(X) - tl.load(Y)) + + _test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y, device) diff --git a/python/test/unit/test_debug_dump.py b/python/test/unit/test_debug_dump.py new file mode 100644 index 000000000..a387df42d --- /dev/null +++ b/python/test/unit/test_debug_dump.py @@ -0,0 +1,51 @@ +import os +from contextlib import contextmanager + +import torch +import triton +import triton.language as tl + + +@contextmanager +def enable_dump_context(pass_name="1"): + try: + os.environ["MLIR_ENABLE_DUMP"] = pass_name + yield + finally: + os.environ["MLIR_ENABLE_DUMP"] = "0" + + +def test_fn_dump(capfd, device, fresh_triton_cache): + return # TODO: flagtree + + N = 1024 + src = torch.zeros(N, device=device) + + grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]), ) + + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + with enable_dump_context(): + BLOCK_SIZE = 16 + _kernel[grid](src, N, BLOCK_SIZE) + captured = capfd.readouterr() + print(captured.err) + assert "IR Dump Before" in captured.err + assert "tt.func public @_kernel" in captured.err + + with enable_dump_context("_kernel"): + BLOCK_SIZE = 32 + _kernel[grid](src, N, BLOCK_SIZE) + captured = capfd.readouterr() + assert "IR Dump Before" in captured.err + assert "tt.func public @_kernel" in captured.err + + with enable_dump_context("_kernel2"): + BLOCK_SIZE = 64 + _kernel[grid](src, N, BLOCK_SIZE) + captured = capfd.readouterr() + assert "IR Dump Before" not in captured.err diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py new file mode 100644 index 000000000..54072a829 --- /dev/null +++ b/python/test/unit/test_perf_warning.py @@ -0,0 +1,177 @@ +import os +from contextlib import contextmanager + +import pytest +import torch +import triton +import triton.language as tl + + +@contextmanager +def enable_remark_context(): + try: + os.environ["MLIR_ENABLE_REMARK"] = "1" + yield + finally: + os.environ["MLIR_ENABLE_REMARK"] = "0" + + +def is_perf_warning_enabled(): + return os.environ.get("MLIR_ENABLE_REMARK", "0") == "1" + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def test_mma_remark(capfd, fresh_triton_cache): + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 9: + pytest.skip("Requires sm >= 90 to run") + + @triton.jit + def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + ): + a_block_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(0, 0), + block_shape=(32, 128), + order=(1, 0), + ) + b_block_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, 0), + block_shape=(128, 32), + order=(0, 1), + ) + c_block_ptr = tl.make_block_ptr( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + offsets=(0, 0), + block_shape=(32, 32), + order=(1, 0), + ) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + c = tl.dot(a, b) + tl.store(c_block_ptr, c) + + with enable_remark_context(): + triton.compile( + triton.compiler.ASTSource( + fn=matmul_kernel, + signature={ + "a_ptr": "*fp32", + "b_ptr": "*fp32", + "c_ptr": "*fp32", + "M": "i32", + "N": "i32", + "K": "i32", + "stride_am": "i32", + "stride_ak": "i32", + "stride_bk": "i32", + "stride_bn": "i32", + "stride_cm": "i32", + "stride_cn": "i32", + }, + constants={}, + )) + captured = capfd.readouterr() + + assert ("remark: Warning: can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark" + assert "note: see current operation:" in captured.err + + +def test_remark_vectorization(capfd, fresh_triton_cache): + + @triton.jit + def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + x0 = xindex % 9 + x2 = (xindex // 3456) % 512 + x1 = (xindex // 9) % 384 + x4 = xindex + tmp0 = tl.load(in_ptr0 + (x2 + (512 * x0)), None, eviction_policy="evict_last") + tmp1 = tmp0 + 520 + tmp2 = tmp0 < 0 + tmp3 = tl.where(tmp2, tmp1, tmp0) + tmp9 = (-4) + tmp3 + tmp12 = tl.full([1], 512, tl.int64) + tmp14 = tmp9 < tmp12 + tmp16 = tl.load(in_ptr3 + (x1), tmp14, eviction_policy="evict_last", other=0.0) + tmp18 = tmp16.to(tl.float32) + tmp19 = tmp18.to(tl.float32) + tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype) + tmp21 = tl.where(tmp14, tmp19, tmp20) + tmp22 = tmp21.to(tl.float32) + tl.store(out_ptr0 + (x4), tmp22, None) + + XBLOCK = 1024 + with enable_remark_context(): + triton.compile( + triton.compiler.ASTSource( + fn=ldst_vec, + signature={ + "in_ptr0": "*i64", + "in_ptr1": "*i64", + "in_ptr2": "*fp16", + "in_ptr3": "*fp32", + "out_ptr0": "*fp16", + }, + constants={"XBLOCK": XBLOCK}, + ), + options={"num_warps": 1}, + ) + + _, err = capfd.readouterr() + assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark" + + +def test_remark_swp_op_before_operands(capfd, fresh_triton_cache): + + @triton.jit + def kernel_pipe_error(in_ptr, out_ptr): + SIZE: tl.constexpr = 64 + in_ptrs = in_ptr + tl.arange(0, SIZE) + val = tl.zeros((SIZE, ), dtype=tl.float32) + k = 0 + for i in tl.range(0, 64, num_stages=3): + in_ptrs = in_ptr + tl.arange(0, SIZE) + SIZE * k + val = tl.load(in_ptrs) + out_ptrs = out_ptr + (tl.arange(0, SIZE) + i * SIZE) + tl.store(out_ptrs, val) + if tl.max(val) > 0: + k += 1 + + with enable_remark_context(): + triton.compile( + triton.compiler.ASTSource( + fn=kernel_pipe_error, + signature={"in_ptr": "*fp32", "out_ptr": "*fp32"}, + constants={}, + ), + options={"cluster_dims": (1, 1, 1)}, + ) + + _, err = capfd.readouterr() + + assert "operation scheduled before its operands" in err, "expect swp op remark" diff --git a/python/test/unit/tools/test_disasm.py b/python/test/unit/tools/test_disasm.py new file mode 100644 index 000000000..f2c9bcc0d --- /dev/null +++ b/python/test/unit/tools/test_disasm.py @@ -0,0 +1,22 @@ +import torch + +import triton +import pytest +import triton.language as tl + + +@pytest.mark.skip(reason="flagtree") +def test_disam_cubin(): + if not triton.runtime.driver.active.get_current_target().backend == "cuda": + pytest.skip("Test requires CUDA.") + + @triton.jit + def kernel(X, i: tl.constexpr): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + h = kernel[(1, )](x, i=12) + assert x[0] == 12 + sass = h.asm["sass"] + # check that the sass has a store instruction. + assert "STG.E" in sass diff --git a/python/triton/__init__.py b/python/triton/__init__.py index a5f77f91e..08872dae0 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,5 +1,5 @@ """isort:skip_file""" -__version__ = '3.1.0' +__version__ = '3.2.0' # --------------------------------------- # Note: import order is significant here. diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py new file mode 100644 index 000000000..f8909f7c0 --- /dev/null +++ b/python/triton/_internal_testing.py @@ -0,0 +1,123 @@ +import os +import re +import numpy as np +import torch +import triton +import triton.language as tl +import pytest + +from numpy.random import RandomState +from typing import Optional, Union +from triton.runtime.jit import TensorWrapper, reinterpret + +int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] +integral_dtypes = int_dtypes + uint_dtypes +float_dtypes = ['float16', 'float32', 'float64'] +dtypes = integral_dtypes + float_dtypes +dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] +torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def get_current_target(): + if is_interpreter(): + return None + return triton.runtime.driver.active.get_current_target() + + +def is_cuda(): + target = get_current_target() + return False if target is None else target.backend == "cuda" + + +def is_hip(): + target = get_current_target() + return False if target is None else target.backend == "hip" + + +def get_arch(): + target = get_current_target() + return "" if target is None else str(target.arch) + + +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): + """ + Override `rs` if you're calling this function twice and don't want the same + result for both calls. + """ + if isinstance(shape, int): + shape = (shape, ) + if rs is None: + rs = RandomState(seed=17) + if dtype_str in int_dtypes + uint_dtypes: + iinfo = np.iinfo(getattr(np, dtype_str)) + low = iinfo.min if low is None else max(low, iinfo.min) + high = iinfo.max if high is None else min(high, iinfo.max) + dtype = getattr(np, dtype_str) + x = rs.randint(low, high, shape, dtype=dtype) + x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out. + return x + elif dtype_str and 'float8' in dtype_str: + x = rs.randint(20, 40, shape, dtype=np.int8) + return x + elif dtype_str in float_dtypes: + return rs.normal(0, 1, shape).astype(dtype_str) + elif dtype_str == 'bfloat16': + return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') + elif dtype_str in ['bool', 'int1', 'bool_']: + return rs.normal(0, 1, shape) > 0.0 + else: + raise RuntimeError(f'Unknown dtype {dtype_str}') + + +def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]: + ''' + Note: We need dst_type because the type of x can be different from dst_type. + For example: x is of type `float32`, dst_type is `bfloat16`. + If dst_type is None, we infer dst_type from x. + ''' + t = x.dtype.name + if t in uint_dtypes: + signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" + x_signed = x.astype(getattr(np, signed_type_name)) + return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) + else: + if dst_type and 'float8' in dst_type: + return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type)) + if t == 'float32' and dst_type == 'bfloat16': + return torch.tensor(x, device=device).bfloat16() + return torch.tensor(x, device=device) + + +def torch_dtype_name(dtype) -> str: + if isinstance(dtype, triton.language.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + # 'torch.int64' -> 'int64' + m = re.match(r'^torch\.(\w+)$', str(dtype)) + return m.group(1) + else: + raise TypeError(f'not a triton or torch dtype: {type(dtype)}') + + +def to_numpy(x): + if isinstance(x, TensorWrapper): + return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) + elif isinstance(x, torch.Tensor): + if x.dtype is torch.bfloat16: + return x.cpu().float().numpy() + return x.cpu().numpy() + else: + raise ValueError(f"Not a triton-compatible tensor: {x}") + + +def supports_tma(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +requires_tma = pytest.mark.skipif(not supports_tma(), reason="Requires TMA support (NVIDIA Hopper or higher)") diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index fbf65d9e9..92ba144ba 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -7,7 +7,7 @@ def _load_module(name, path): - spec = importlib.util.spec_from_file_location(name[:-3], path) + spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 990690045..cac42a663 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -1,10 +1,217 @@ import os import re +import hashlib import subprocess from abc import ABCMeta, abstractmethod, abstractclassmethod from dataclasses import dataclass -from typing import Union +from typing import Dict, List, Tuple, Union +from types import ModuleType + +# Table that associates strings to AttrsDescriptor (sub)classes. +# In this way we can dynamically select the correct class +# constructor +_descriptor_table = {} + + +def register_descriptor(cls): + """ + Register a descriptor into the descriptor table + """ + _descriptor_table[cls.__name__] = cls + return cls + + +@register_descriptor +class AttrsDescriptor: + """ + This class handles compile-time properties for specific function parameters. + + Different backends can add more properties to the common ones. The class + contains two fields: + + `arg_properties`: a dictionary containing the different compile-time properties for different + parameters. I.e., the dictionary is a map from property names to parameter indices + { + "prop0": (0, 2, 3) + "prop1": (0, 4, 5) + } + Different backends might need different properties on those paraemters to enable + specific optimizations. The common compile time properties contained in this class + are : + - "tt.divisibility", i.e., is the given parameter divisible by 16 + - "tt.equal_to_1", i.e., is the given parameter an integer constant 1 + + `property_values`: a dictionary containing the value of the different compile-time properties, like: + { + "prop0": val0 + "prop1": val1 + } + + `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant + + """ + __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') + + def __init__(self, params=None, values=None): + """ + Initialize the compile-time properties + + We can initialize the AttrsDescriptor class by passing the list of params + of the function and their `values`. The function will try to apply the properties + to the values and save the parameters in the `arg_properties` list. If we don't pass + either the `params` or the `values` we should initialize the class via an alternative method + (see `from_dict` or `from_hints`) + """ + # Default initialization + self.arg_properties = {} + self.property_values = {} + self.constant_properties = set() + + self._add_common_properties(params, values) + self._add_backend_properties(params, values) + self._init_slots() + + def _add_common_properties(self, params, values): + """ Add common compile-time properties """ + self.property_values["tt.divisibility"] = 16 + self.property_values["tt.equal_to"] = 1 + self.constant_properties.add("tt.equal_to") + + if (params is None) or (values is None): + return + + # Compile properties deduction + assert (len(params) == len(values)) + + # Divisibility property + self.arg_properties["tt.divisibility"] = [ + param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) + and not param.do_not_specialize and not param.do_not_specialize_on_alignment + ] + + # Equal to 1 property + self.arg_properties["tt.equal_to"] = [ + param.num + for param, arg in zip(params, values) + if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize + ] + + def _add_backend_properties(self, params=None, values=None): + """ This method is for different subclasses to implement their own compile-time properties """ + pass + + def _init_slots(self): + """ Initialize the slots of this class """ + for name, val in self.arg_properties.items(): + setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val) + + def get_fn_attrs(self) -> Dict: + """ + Get the function attributes as a dictionary. + + The returned dictionary will look like : + { + "arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]} + "arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]} + } + """ + attrs = {} + for prop_name, arg_set in self.arg_properties.items(): + prop_val = self.property_values[prop_name] + for arg in arg_set: + attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)] + return attrs + + def get_constants(self) -> Dict: + """ Return a mapping of constant parameters to their values """ + constants = {} + for prop_name in self.constant_properties: + for p in self.arg_properties.get(prop_name, []): + constants[p] = self.property_values[prop_name] + return constants + + def filter_out_constants(self): + """ Return the same object, without properties marked as constants""" + import copy + c = copy.deepcopy(self) + for prop_name in c.constant_properties: + c.arg_properties.pop(prop_name, None) + c.property_values.pop(prop_name, None) + c.constant_properties = {} + return c + + def hash(self): + values = [sorted(self.arg_properties.values())] + values += [sorted(self.property_values.values())] + values += [sorted(self.constant_properties)] + key = str(values) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def to_dict(self): + """ + Store the fields of this class in a serializable dictionary + """ + # We need to only store the `arg_properties` field. To initialize the + # other fields we relay on the class type. We store it as a string in + # the dictionary so that we can use it to invoke the appropriate + # (sub)class constructor in the `from_dict` method. + return {"arg_properties": self.arg_properties, "cls": type(self).__name__} + + @staticmethod + def from_dict(data): + """ + Create the object from a serializable dictionary + """ + attrs_descriptor = _descriptor_table[data["cls"]]() + for prop_name, param_ids in data["arg_properties"].items(): + attrs_descriptor.arg_properties[prop_name] = param_ids + attrs_descriptor._init_slots() + return attrs_descriptor + + @classmethod + def from_hints(cls, hints: List[Tuple[int, int]]): + """ + Create the class from a set of hints that are passed in. + + Instead of deducing the properties from a list of paramaters and values, + the user can pass in a list of `hints=[(param_index, val)]` and if `val` + matches one of the values of the properties (e.g., `prop_val[prop0]`), + then we insert `param_index` into the correct list (e.g., in + `arg_properties[prop0]`) + """ + attrs_descriptor = cls() + for prop_name, prop_val in attrs_descriptor.property_values.items(): + attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] + attrs_descriptor._init_slots() + return attrs_descriptor + + @staticmethod + def is_divisible_by_16(x): + """ Return if the argument is a multiple of 16""" + if hasattr(x, "data_ptr"): + return x.data_ptr() % 16 == 0 + elif isinstance(x, int): + return x % 16 == 0 + if x is None: + return True + return False + + @staticmethod + def is_equal_to_1(x): + """ Return if the argument is a constant 1""" + return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False + + @staticmethod + def get_property_key(val, align): + if align and AttrsDescriptor.is_divisible_by_16(val): + return "D" + if AttrsDescriptor.is_equal_to_1(val): + return "1" + return "N" + + def __repr__(self): + return f"AttrsDescriptor.from_dict({self.to_dict()!r})" @dataclass(frozen=True) @@ -74,3 +281,24 @@ def load_dialects(self, context): Load additional MLIR dialects into the provided `context` """ raise NotImplementedError + + @abstractmethod + def get_module_map(self) -> Dict[str, ModuleType]: + """ + Return a map of interface modules to their device-specific implementations + """ + raise NotImplementedError + + def get_attrs_descriptor(self, params, args): + """ + Return an attribute descriptor: given a set of parameters and arguments + the descriptor stores a set of compile time properties that can improve code + generation. Different backends might benefit from different properties + """ + return AttrsDescriptor(params, args) + + def compute_spec_key(self, arg, align): + """ + Return the ascii key for a given argument with a given set of properties + """ + return AttrsDescriptor.get_property_key(arg, align) diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index e66442943..202ae1568 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -1,4 +1,11 @@ from abc import ABCMeta, abstractmethod, abstractclassmethod +from typing import Callable, List, Protocol, Sequence + + +class Benchmarker(Protocol): + + def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]: + pass class DriverBase(metaclass=ABCMeta): @@ -11,6 +18,13 @@ def is_active(self): def get_current_target(self): pass + @abstractmethod + def get_benchmarker(self) -> Benchmarker: + """ + Return the benchmarking function that this backend should use by default. + """ + raise NotImplementedError + def __init__(self) -> None: pass diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index ce0cfedfc..bbe8c047c 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,4 +1,4 @@ -from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict +from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict from .errors import CompilationError __all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 6903052ca..d8ca58d8d 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -9,7 +9,8 @@ from .. import language from .._C.libtriton import ir from ..language import constexpr, tensor, str_to_ty -from ..runtime.jit import _normalize_ty +from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value +from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) @@ -31,7 +32,7 @@ def mangle_ty(ty): return f'{elt}S{shape}S' if ty.is_void(): return 'V' - assert False, "Unsupported type" + raise TypeError(f'Unsupported type {ty}') def mangle_fn(name, arg_tys, constants): @@ -46,6 +47,10 @@ def mangle_fn(name, arg_tys, constants): return ret +def _is_triton_value(o: Any) -> bool: + return isinstance(o, _value) + + def _is_triton_tensor(o: Any) -> bool: return isinstance(o, tensor) @@ -62,10 +67,6 @@ def _is_list_like(o: Any) -> bool: return isinstance(o, (list, tuple)) -def _unwrap_if_constexpr(o: Any): - return o.value if isinstance(o, constexpr) else o - - def _check_fn_args(node, fn, args): if fn.noinline: for idx, arg in enumerate(args): @@ -76,24 +77,6 @@ def _check_fn_args(node, fn, args): ) -def _get_fn_file_line(fn): - base_fn = fn - while not isinstance(base_fn, JITFunction): - base_fn = base_fn.fn - file_name = base_fn.fn.__code__.co_filename - lines, begin_line = inspect.getsourcelines(base_fn.fn) - # Match the following pattern: - # @triton.autotune(...) <- foo.__code__.co_firstlineno - # @triton.heuristics(...) - # @triton.jit - # def foo(...): <- this line is the first line - for idx, line in enumerate(lines): - if line.strip().startswith("def "): - begin_line += idx - break - return file_name, begin_line - - _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels @@ -124,10 +107,7 @@ def __init__(self, gscope): self.gscope = gscope def _visit_stmts(self, body) -> bool: - for s in body: - if self.visit(s): - return True - return False + return any(self.visit(s) for s in body) def _visit_function(self, fn) -> bool: # Currently we only support JITFunctions defined in the global scope @@ -163,7 +143,7 @@ def visit_Attribute(self, node: ast.Attribute) -> bool: return self.visit(node.value) def visit_Name(self, node: ast.Name) -> bool: - if type(node.ctx) == ast.Store: + if type(node.ctx) is ast.Store: return False if node.id in self.gscope: fn = self.gscope[node.id] @@ -212,7 +192,7 @@ def visit_Call(self, node: ast.Call) -> bool: class CodeGenerator(ast.NodeVisitor): def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, - codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None, + codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) @@ -225,18 +205,30 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n # Convert custom types not natively supported on HW. # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) self.builder.codegen_fns = codegen_fns + self.builder.module_map = {} if module_map is None else module_map self.module = self.builder.create_module() if module is None else module self.function_ret_types = {} if function_types is None else function_types self.prototype = prototype - self.gscope = gscope - self.lscope = dict() + + self.gscope = {} + for k, v in gscope.items(): + if isinstance(v, ModuleType): + self.gscope[k] = module_map.get(v.__name__, v) + continue + + module_name = getattr(v, "__module__", "") + if module_name in module_map: + self.gscope[k] = getattr(module_map[module_name], v.__name__) + else: + self.gscope[k] = v + + self.lscope = {} self.attributes = attributes self.constants = constants self.jit_fn = jit_fn self.function_name = function_name self.is_kernel = is_kernel self.cur_node = None - self.debug = options.debug if debug is None else debug self.noinline = noinline self.scf_stack = [] self.ret_type = None @@ -284,19 +276,20 @@ def global_lookup(name: str, absent): # The high-level rule is that only constexpr globals are allowed. # But actually a bunch of other things, such as module imports, are # technically Python globals. We have to allow these too! - if (val is absent # - or name in self.builtin_namespace # - or type(val) == ModuleType # - or isinstance(val, JITFunction) # - or getattr(val, "__triton_builtin__", False) # - or getattr(val, "__module__", "").startswith("triton.language") # - or isinstance(val, language.dtype) # - or self._is_constexpr_global(name) # + if any([ + val is absent, name in self.builtin_namespace, # + type(val) is ModuleType, # + isinstance(val, JITFunction), # + getattr(val, "__triton_builtin__", False), # + getattr(val, "__module__", "").startswith("triton.language"), # + isinstance(val, language.dtype), # + self._is_constexpr_global(name), # # Allow accesses to globals while visiting an ast.arg # because you should be able to do # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... - or self.visiting_arg_default_value # - or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"): + self.visiting_arg_default_value, # + os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1" + ]): return val raise NameError( textwrap.dedent(f"""\ @@ -367,36 +360,35 @@ def visit_List(self, node): # By design, only non-kernel functions can return def visit_Return(self, node): ret_value = self.visit(node.value) - # ret_block = self.builder.create_block() - # post_ret_block = self.builder.create_block() - # self.builder.create_branch(ret_block) - # self.builder.set_insertion_point_to_end(ret_block) if ret_value is None: self.builder.ret([]) ret_ty = language.void elif isinstance(ret_value, tuple): - ret_values = [language.core._to_tensor(v, self.builder) for v in ret_value] + ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] ret_types = [v.type for v in ret_values] self.builder.ret([v.handle for v in ret_values]) ret_ty = tuple(ret_types) else: - ret = language.core._to_tensor(ret_value, self.builder) + ret = language.semantic.to_tensor(ret_value, self.builder) self.builder.ret([ret.handle]) ret_ty = ret.type - # self.builder.create_branch(post_ret_block) - # self.builder.set_insertion_point_to_end(post_ret_block) if self.ret_type is None: self.ret_type = ret_ty elif self.ret_type != ret_ty: raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + # A return op must always terminate the basic block, so we create a dead + # basic block in case there are any ops after the return. + post_ret_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(post_ret_block) + def visit_FunctionDef(self, node): arg_names, kwarg_names = self.visit(node.args) if self.fn: raise self._unsupported(node, "nested function definition is not supported.") # initialize defaults - for i, default_value in enumerate(node.args.defaults): + for i, default_value in enumerate(node.args.defaults[::-1]): arg_node = node.args.args[-i - 1] annotation = arg_node.annotation name = arg_node.arg @@ -421,7 +413,7 @@ def visit_FunctionDef(self, node): entry = self.fn.add_entry_block() arg_values = [] idx = 0 - for i, arg_name in enumerate(arg_names): + for i in range(len(arg_names)): if i in self.constants: cst = self.constants[i] if not _is_constexpr(cst): @@ -432,6 +424,11 @@ def visit_FunctionDef(self, node): if i in self.attributes: for name, value in self.attributes[i]: self.fn.set_arg_attr(idx, name, value) + + # Mark this argument as a pass-by-value TMA descriptor (nvidia) + if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): + self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) idx += 1 @@ -441,22 +438,24 @@ def visit_FunctionDef(self, node): self.builder.set_insertion_point_to_start(entry) # visit function body self.visit_compound_statement(node.body) + # finalize function + assert not self.builder.get_insertion_block().has_terminator() if self.ret_type is None or self.ret_type == language.void: self.ret_type = language.void self.builder.ret([]) else: - # update return type - if isinstance(self.ret_type, tuple): - self.prototype.ret_types = list(self.ret_type) - self.fn.reset_type(self.prototype.to_ir(self.builder)) - else: - self.prototype.ret_types = [self.ret_type] - self.fn.reset_type(self.prototype.to_ir(self.builder)) + self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) + self.builder.ret([ + self.builder.create_poison(ty.to_ir(self.builder)) + for ty in self.prototype.ret_types + if self.ret_type is not None + ]) + self.fn.finalize() + if insert_pt: self.builder.set_insertion_point_to_end(insert_pt) - # Remove dead code - self.fn.finalize() def visit_arguments(self, node): arg_names = [] @@ -488,8 +487,11 @@ def visit_AnnAssign(self, node): def visit_Assign(self, node): _names = [] - for target in node.targets: - _names += [self.visit(target)] + if isinstance(node, ast.AnnAssign): + _names += [self.visit(node.target)] + else: + for target in node.targets: + _names += [self.visit(target)] if len(_names) > 1: raise self._unsupported(node, "simultaneous multiple assignment is not supported.") names = _names[0] @@ -503,9 +505,9 @@ def visit_Assign(self, node): # by default, constexpr are assigned into python variable value = _unwrap_if_constexpr(value) if value is not None and \ - not _is_triton_tensor(value) and \ + not _is_triton_value(value) and \ not isinstance(value, native_nontensor_types): - value = language.core._to_tensor(value, self.builder) + value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) def visit_AugAssign(self, node): @@ -517,7 +519,7 @@ def visit_AugAssign(self, node): return self.dereference_name(name) def visit_Name(self, node): - if type(node.ctx) == ast.Store: + if type(node.ctx) is ast.Store: return node.id return self.dereference_name(node.id) @@ -605,13 +607,13 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): then_defs[name] = liveins[name] # variables that are both in then and else but not in liveins # TODO: could probably be cleaned up - for name in then_defs.keys() & else_defs.keys(): + for name in sorted(then_defs.keys() & else_defs.keys()): if name in names: continue then_ty = then_defs[name].type else_ty = else_defs[name].type assert then_ty == else_ty, \ - f'mismatched type for {name} between then block ({then_ty}) '\ + f'Mismatched type for {name} between then block ({then_ty}) '\ f'and else block ({else_ty})' names.append(name) ret_types.append(then_ty) @@ -620,40 +622,35 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types def visit_if_top_level(self, cond, node): - has_endif_block = True with enter_sub_region(self) as sr: liveins, ip_block = sr then_block = self.builder.create_block() else_block = self.builder.create_block() - # create basic-block after conditional - endif_block = self.builder.create_block() # create branch self.builder.set_insertion_point_to_end(ip_block) self.builder.create_cond_branch(cond.handle, then_block, else_block) # visit then and else blocks then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create basic-block after conditional + endif_block = self.builder.create_block() # then terminator self.builder.set_insertion_point_to_end(then_block) - if then_block.has_return() and else_block.has_return(): - has_endif_block = False - endif_block.erase() - if not then_block.has_terminator() and has_endif_block: - self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + assert not then_block.has_terminator(), f"{then_block}" + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) # else terminator self.builder.set_insertion_point_to_end(else_block) - if not else_block.has_terminator() and has_endif_block: - self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) - if has_endif_block: - for ty in ir_ret_types: - endif_block.add_argument(ty) - if has_endif_block: - # change block - self.builder.set_insertion_point_to_start(endif_block) - # update value - for i, name in enumerate(names): - new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) - self.set_value(name, new_tensor) + assert not else_block.has_terminator(), f"{else_block}" + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + for ty in ir_ret_types: + endif_block.add_argument(ty) + + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) # TODO: refactor def visit_if_scf(self, cond, node): @@ -685,18 +682,19 @@ def visit_if_scf(self, cond, node): def visit_If(self, node): cond = self.visit(node.test) + if _is_triton_tensor(cond): cond = cond.to(language.int1, _builder=self.builder) contains_return = ContainsReturnChecker(self.gscope).visit(node) - if self.scf_stack and contains_return: - raise self._unsupported( - node, "Cannot have `return` statements inside `while` or `for` statements in triton " - "(note that this also applies to `return` statements that are inside functions " - "transitively called from within `while`/`for` statements)") - elif self.scf_stack or not contains_return: - self.visit_if_scf(cond, node) - else: + if contains_return: + if self.scf_stack: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") self.visit_if_top_level(cond, node) + else: + self.visit_if_scf(cond, node) else: cond = _unwrap_if_constexpr(cond) # not isinstance - we insist the real thing, no subclasses and no ducks @@ -705,10 +703,9 @@ def visit_If(self, node): node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( ', '.join(_.__name__ for _ in _condition_types), type(cond).__name__)) - if cond: - self.visit_compound_statement(node.body) - else: - self.visit_compound_statement(node.orelse) + + active_block = node.body if cond else node.orelse + self.visit_compound_statement(active_block) def visit_IfExp(self, node): cond = self.visit(node.test) @@ -720,20 +717,20 @@ def visit_IfExp(self, node): then_block = self.builder.create_block() self.builder.set_insertion_point_to_start(then_block) - then_val = language.core._to_tensor(self.visit(node.body), self.builder) + then_val = language.semantic.to_tensor(self.visit(node.body), self.builder) then_block = self.builder.get_insertion_block() else_block = self.builder.create_block() self.builder.set_insertion_point_to_start(else_block) # do not need to reset lscope since # ternary expressions cannot define new variables - else_val = language.core._to_tensor(self.visit(node.orelse), self.builder) + else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder) else_block = self.builder.get_insertion_block() self._set_insertion_point_and_loc(ip, last_loc) assert then_val.type == else_val.type, \ - f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' ret_type = then_val.type ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] @@ -773,9 +770,9 @@ def visit_Compare(self, node): rhs = self.visit(node.comparators[0]) lhs_value = _unwrap_if_constexpr(lhs) rhs_value = _unwrap_if_constexpr(rhs) - if type(node.ops[0]) == ast.Is: + if type(node.ops[0]) is ast.Is: return constexpr(lhs_value is rhs_value) - if type(node.ops[0]) == ast.IsNot: + if type(node.ops[0]) is ast.IsNot: return constexpr(lhs_value is not rhs_value) method_name = self._method_name_for_comp_op.get(type(node.ops[0])) if method_name is None: @@ -804,6 +801,15 @@ def visit_UnaryOp(self, node): ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' } + def _verify_loop_carried_variable(self, name, loop_val, live_val): + assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop' + assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type' + assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ + f'Loop-carried variable {name} has initial type {live_val.type} '\ + f'but is re-assigned to {loop_val.type} in loop! '\ + f'Please make sure that the type stays consistent.' + def visit_While(self, node): with enter_sub_region(self) as sr: liveins, insert_block = sr @@ -826,17 +832,14 @@ def visit_While(self, node): for name in loop_defs: if name in liveins: # We should not def new constexpr - assert _is_triton_tensor(loop_defs[name]), f'cannot reassign constxpr {name} in the loop' - assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop' - assert loop_defs[name].type == liveins[name].type, \ - f'Loop-carried variable {name} has initial type {liveins[name].type} '\ - f'but is re-assigned to {loop_defs[name].type} in loop! '\ - f'Please make sure that the type stays consistent.' + loop_val = loop_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) # these are loop-carried values names.append(name) - ret_types.append(loop_defs[name].type) - init_args.append(liveins[name]) + ret_types.append(loop_val.type) + init_args.append(live_val) self._set_insertion_point_and_loc(ip, last_loc) while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], @@ -906,6 +909,7 @@ def visit_For(self, node): ast.NodeVisitor.generic_visit(self, stmt) return num_stages = None + loop_unroll_factor = None if IteratorClass is language.range: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments @@ -915,6 +919,7 @@ def visit_For(self, node): ub = iterator.end step = iterator.step num_stages = iterator.num_stages + loop_unroll_factor = iterator.loop_unroll_factor elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -930,9 +935,9 @@ def visit_For(self, node): step = constexpr(-step.value) negative_step = True lb, ub = ub, lb - lb = language.core._to_tensor(lb, self.builder) - ub = language.core._to_tensor(ub, self.builder) - step = language.core._to_tensor(step, self.builder) + lb = language.semantic.to_tensor(lb, self.builder) + ub = language.semantic.to_tensor(ub, self.builder) + step = language.semantic.to_tensor(step, self.builder) # induction variable type if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") @@ -949,7 +954,7 @@ def visit_For(self, node): ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) # Create placeholder for the loop induction variable - iv = self.builder.create_undef(iv_ir_type) + iv = self.builder.create_poison(iv_ir_type) self.set_value(node.target.id, language.core.tensor(iv, iv_type)) with enter_sub_region(self) as sr: @@ -972,22 +977,21 @@ def visit_For(self, node): names = [] for name in self.local_defs: if name in liveins: - assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' - assert _is_triton_tensor(liveins[name]) - assert self.local_defs[name].type == liveins[name].type, \ - f'Loop-carried variable {name} has initial type {liveins[name].type} '\ - f'but is re-assigned to {self.local_defs[name].type} in loop! '\ - f'Please make sure that the type stays consistent.' + loop_val = self.local_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) names.append(name) - init_args.append(language.core._to_tensor(liveins[name], self.builder)) - yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + init_args.append(live_val) + yields.append(loop_val) # create ForOp self._set_insertion_point_and_loc(ip, last_loc) for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) if num_stages is not None: for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + if loop_unroll_factor is not None: + for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) @@ -1001,7 +1005,7 @@ def visit_For(self, node): yields = [] for name in self.local_defs: if name in liveins: - yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder)) # create YieldOp if len(yields) > 0: @@ -1039,19 +1043,16 @@ def visit_keyword(self, node) -> Tuple[str, Any]: return node.arg, self.visit(node.value) def visit_Assert(self, node) -> Any: - if not self.debug: - return test = self.visit(node.test) msg = self.visit(node.msg) if node.msg is not None else "" - # Convert assert to triton's device_assert which happens on the device return language.core.device_assert(test, msg, _builder=self.builder) def call_JitFunction(self, fn: JITFunction, args, kwargs): args = inspect.getcallargs(fn.fn, *args, **kwargs) args = [args[name] for name in fn.arg_names] - args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] + args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] # generate function def - attributes = dict() + attributes = {} constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] constants = {i: args[i] for i in constexprs} # generate call @@ -1064,12 +1065,12 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): prototype = language.function_type([], arg_types) gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller - file_name, begin_line = _get_fn_file_line(fn) - debug = self.debug if fn.debug is None else fn.debug + file_name, begin_line = get_jit_fn_file_line(fn) generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, - options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug) + options=self.builder.options, codegen_fns=self.builder.codegen_fns, + module_map=self.builder.module_map) try: generator.visit(fn.parse()) except Exception as e: @@ -1101,14 +1102,11 @@ def visit_Call(self, node): kws = dict(self.visit(keyword) for keyword in node.keywords) args = [self.visit(arg) for arg in node.args] - if fn is language.core.device_assert: # TODO: this should not be so hardcoded - if not self.debug: - return if isinstance(fn, JITFunction): _check_fn_args(node, fn, args) return self.call_JitFunction(fn, args, kws) - if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): - extra_kwargs = dict(_builder=self.builder) + if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = {"_builder": self.builder} sig = inspect.signature(fn) if '_generator' in sig.parameters: extra_kwargs['_generator'] = self @@ -1157,9 +1155,8 @@ def visit_Str(self, node): def visit_Attribute(self, node): lhs = self.visit(node.value) - if _is_triton_tensor(lhs): - if node.attr == "T": - return language.semantic.permute(lhs, (1, 0), builder=self.builder) + if _is_triton_tensor(lhs) and node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) return getattr(lhs, node.attr) def visit_Expr(self, node): @@ -1268,12 +1265,12 @@ def kernel_suffix(signature, specialization): suffix += str(i) if i in specialization.equal_to_1: suffix += 'c' - if i in specialization.divisible_by_16: + if i in specialization.divisibility_16: suffix += 'd' return suffix -def ast_to_ttir(fn, specialization, context, options, codegen_fns): +def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): attrs = specialization.attrs # create kernel prototype cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i @@ -1282,18 +1279,22 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns): gscope = fn.__globals__.copy() function_name = fn.repr(specialization) tys = list(specialization.signature.values()) - new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1} - new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16} + new_constants = attrs.get_constants() + for k in new_constants: + if k in tys and tys[k] == "i1" and new_constants[k] == 1: + new_constants[k] = True + new_attrs = attrs.filter_out_constants() + fn_attrs = new_attrs.get_fn_attrs() all_constants = constants.copy() all_constants.update(new_constants) arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] - file_name, begin_line = _get_fn_file_line(fn) + file_name, begin_line = get_jit_fn_file_line(fn) prototype = language.function_type([], arg_types) generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, - jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name, - begin_line=begin_line, options=options, codegen_fns=codegen_fns) + jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) generator.visit(fn.parse()) ret = generator.module diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 367aa1b1a..8ca1f8b32 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -3,44 +3,19 @@ import json from .._C.libtriton import get_cache_invalidating_env_vars, ir from ..backends import backends -from ..backends.compiler import GPUTarget +from ..backends.compiler import GPUTarget, AttrsDescriptor from .. import __version__ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager from ..runtime.driver import driver +from ..tools.disasm import get_sass # TODO: this shouldn't be here -from dataclasses import dataclass from .code_generator import ast_to_ttir from pathlib import Path import re import functools import os - -@dataclass -class AttrsDescriptor: - divisible_by_16: set = None - equal_to_1: set = None - - def __post_init__(self): - if self.divisible_by_16 is None: - self.divisible_by_16 = set() - if self.equal_to_1 is None: - self.equal_to_1 = set() - - def to_dict(self): - return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)} - - @staticmethod - def from_dict(data): - return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])), - equal_to_1=set(data.get('equal_to_1', []))) - - def hash(self): - key = str([sorted(x) for x in self.__dict__.values()]) - return hashlib.sha256(key.encode("utf-8")).hexdigest() - - # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, # and any following whitespace # - (public\s+)? : optionally match the keyword public and any following whitespace @@ -57,7 +32,7 @@ def hash(self): "ptx": ptx_prototype_pattern, } -mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+),?' +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?' ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { "ttir": mlir_arg_type_pattern, @@ -70,6 +45,10 @@ def convert_type_repr(x): # Currently we only capture the pointer type and assume the pointer is on global memory. # TODO: Capture and support shared memory space match = re.search(r'!tt\.ptr<([^,]+)', x) + tma = re.search(r'tt.nv_tma_desc = 1', x) + if tma is not None: + return 'nvTmaDesc' + x = re.sub(r' {[^}]+}', '', x) if match is not None: return '*' + convert_type_repr(match.group(1)) return x @@ -96,8 +75,16 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: self.attrs = attrs if isinstance(self.signature, str): self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + else: + for k in self.signature.keys(): + if not isinstance(k, str): + raise TypeError("Signature keys must be string") if self.constants is None: - self.constants = dict() + self.constants = {} + else: + for k in self.constants.keys(): + if not isinstance(k, str): + raise TypeError("Constants keys must be string") if self.attrs is None: self.attrs = AttrsDescriptor() @@ -109,8 +96,9 @@ def hash(self): key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" return hashlib.sha256(key.encode("utf-8")).hexdigest() - def make_ir(self, options, codegen_fns, context): - return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns) + def make_ir(self, options, codegen_fns, module_map, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map) def parse_options(self): return dict() @@ -132,7 +120,7 @@ def __init__(self, path): def hash(self): return hashlib.sha256(self.src.encode("utf-8")).hexdigest() - def make_ir(self, options, codegen_fns, context): + def make_ir(self, options, codegen_fns, module_map, context): module = ir.parse_mlir_module(self.path, context) module.context = context return module @@ -172,7 +160,7 @@ def triton_key(): contents.append(libtriton_hash.hexdigest()) # language language_path = os.path.join(TRITON_PATH, 'language') - for lib in pkgutil.iter_modules([language_path]): + for lib in pkgutil.walk_packages([language_path], prefix="triton.language."): with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: contents += [hashlib.sha256(f.read()).hexdigest()] return f'{__version__}' + '-'.join(contents) @@ -195,6 +183,9 @@ def filter_traceback(e: BaseException): These are uninteresting to the user -- "just show me *my* code!" """ + if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1": + return + if e.__cause__ is not None: filter_traceback(e.__cause__) if e.__context__ is not None: @@ -246,7 +237,12 @@ def compile(src, target=None, options=None): enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" fn_override_manager = get_override_manager(src.hash()) if enable_override else None fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None - metadata_filename = f"{src.name}.json" + # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. + # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". + # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate + # the file name to 150 characters to be safe. + file_name = src.name[:150] + metadata_filename = f"{file_name}.json" metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" @@ -272,32 +268,37 @@ def compile(src, target=None, options=None): ir.load_dialects(context) backend.load_dialects(context) codegen_fns = backend.get_codegen_implementation() + module_map = backend.get_module_map() try: - module = src.make_ir(options, codegen_fns, context) + module = src.make_ir(options, codegen_fns, module_map, context) except Exception as e: filter_traceback(e) raise - use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1" + use_ir_loc = os.environ.get("USE_IR_LOC", None) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) - ir_filename = f"{src.name}.{ext}" + ir_filename = f"{file_name}.{ext}" + if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): + print(f"\nOverriding kernel with file {full_name}") + next_module = parse(full_name, ext, context) metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) if fn_dump_manager is not None: fn_dump_manager.put(next_module, ir_filename) - if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)): - print(f"\nOverriding kernel with file {ir_filename}") - full_name = fn_override_manager.get_file(ir_filename) - next_module = parse(full_name, ext, context) - # use an env variable to parse ttgir from file - if use_ttgir_loc and ext == "ttgir": - ttgir_full_name = fn_cache_manager.get_file(ir_filename) - next_module.create_location_snapshot(ttgir_full_name) - print(f"Create new locations for {ttgir_full_name}") + # use an env variable to parse ir from file + if use_ir_loc == ext: + ir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ir_full_name) + print(f"Creating new locations for {ir_full_name}") module = next_module # write-back metadata metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) + # Compilation completed, disabling multithreading in context. + # This is needed to safely finalize threads pool inside context: if current process forks before + # python GC deletes context object, thread pool in child process will be invalid, which could + # lead to child crash or hang. + context.disable_multithreading() # return handle to compiled kernel return CompiledKernel(src, metadata_group, hash) @@ -326,6 +327,19 @@ def add(self, func, args): self.extras.append((func, args)) +class AsmDict(dict): + + def __missing__(self, key): + + if key == "sass": + value = get_sass(self["cubin"]) + else: + raise KeyError("Unknown key: '%s'" % key) + + self[key] = value + return value + + class CompiledKernel: # Hooks for external tools to monitor the execution of triton kernels @@ -351,10 +365,10 @@ def __init__(self, src, metadata_group, hash): # stores the text of each level of IR that was generated during compilation asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] binary_ext = backend.binary_ext - self.asm = { + self.asm = AsmDict({ file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() for file in asm_files - } + }) self.kernel = self.asm[binary_ext] # binaries are lazily initialized # because it involves doing runtime things diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 168dccfea..6502a5348 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -28,9 +28,11 @@ TRITON_MAX_TENSOR_NUMEL, _experimental_descriptor_load, _experimental_descriptor_store, + add, advance, arange, associative_scan, + assume, atomic_add, atomic_and, atomic_cas, @@ -47,12 +49,12 @@ cast, clamp, const, - const_pointer_type, constexpr, debug_barrier, device_assert, device_print, dot, + dot_scaled, dtype, expand_dims, float16, @@ -84,6 +86,7 @@ permute, pi32_t, pointer_type, + nv_tma_desc_type, program_id, range, reduce, @@ -124,11 +127,13 @@ "_experimental_descriptor_load", "_experimental_descriptor_store", "abs", + "add", "advance", "arange", "argmax", "argmin", "associative_scan", + "assume", "atomic_add", "atomic_and", "atomic_cas", @@ -148,7 +153,6 @@ "ceil", "clamp", "const", - "const_pointer_type", "constexpr", "cos", "cumprod", @@ -158,6 +162,7 @@ "device_print", "div_rn", "dot", + "dot_scaled", "dtype", "erf", "exp", @@ -207,6 +212,7 @@ "philox_impl", "pi32_t", "pointer_type", + "nv_tma_desc_type", "program_id", "rand", "rand4x", @@ -253,12 +259,16 @@ def str_to_ty(name): if name[0] == "*": name = name[1:] + const = False if name[0] == "k": name = name[1:] - ty = str_to_ty(name) - return const_pointer_type(ty) + const = True ty = str_to_ty(name) - return pointer_type(ty) + return pointer_type(element_ty=ty, const=const) + + if name == "nvTmaDesc": + return nv_tma_desc_type() + tys = { "fp8e4nv": float8e4nv, "fp8e4b8": float8e4b8, diff --git a/python/triton/language/_utils.py b/python/triton/language/_utils.py new file mode 100644 index 000000000..b9aa69071 --- /dev/null +++ b/python/triton/language/_utils.py @@ -0,0 +1,21 @@ +from typing import List + +TRITON_MAX_TENSOR_NUMEL = 1048576 + + +def is_power_of_two(x): + return (x & (x - 1)) == 0 + + +def validate_block_shape(shape: List[int]): + numel = 1 + for i, d in enumerate(shape): + if not isinstance(d, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + if not is_power_of_two(d): + raise ValueError(f"Shape element {i} must be a power of 2") + numel *= d + + if numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + return numel diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f2d3266e9..e2c57b388 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -13,11 +13,10 @@ from .._C.libtriton import ir from . import semantic +from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape T = TypeVar('T') -TRITON_MAX_TENSOR_NUMEL = 1048576 - TRITON_BUILTIN = "__triton_builtin__" PropagateNan = ir.PROPAGATE_NAN @@ -30,6 +29,7 @@ def builtin(fn: T) -> T: @wraps(fn) def wrapper(*args, **kwargs): if "_builder" not in kwargs or kwargs["_builder"] is None: + print(kwargs) raise ValueError("Did you forget to add @triton.jit ? " "(`_builder` argument must be provided outside of JIT functions.)") return fn(*args, **kwargs) @@ -112,41 +112,177 @@ def is_builtin(fn) -> bool: @builtin def to_tensor(x, _builder=None): - return _to_tensor(x, _builder) - - -def _to_tensor(x, builder): - if isinstance(x, bool): - return tensor(builder.get_int1(x), int1) - # Note: compile-time const integers are represented by unsigned values - elif isinstance(x, int): - if -2**31 <= x < 2**31: - return tensor(builder.get_int32(x), int32) - elif 2**31 <= x < 2**32: - return tensor(builder.get_uint32(x), uint32) - elif -2**63 <= x < 2**63: - return tensor(builder.get_int64(x), int64) - elif 2**63 <= x < 2**64: - return tensor(builder.get_uint64(x), uint64) - else: - raise RuntimeError(f'Nonrepresentable integer {x}.') - elif isinstance(x, float): - min_float32 = 2**-126 - max_float32 = (2 - 2**-23) * 2**127 - abs_x = __builtins__['abs'](x) - if abs_x == float("inf") or\ - abs_x == 0.0 or \ - x != x or \ - min_float32 <= abs_x <= max_float32: - return tensor(builder.get_fp32(x), float32) + return semantic.to_tensor(x, _builder) + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value else: - return tensor(builder.get_fp64(x), float64) + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _constexpr_to_value + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _constexpr_to_value(other)) + + def __radd__(self, other): + return constexpr(_constexpr_to_value(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _constexpr_to_value(other)) + + def __rsub__(self, other): + return constexpr(_constexpr_to_value(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _constexpr_to_value(other)) + + def __mod__(self, other): + return constexpr(self.value % _constexpr_to_value(other)) + + def __rmul__(self, other): + return constexpr(_constexpr_to_value(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _constexpr_to_value(other)) + + def __rtruediv__(self, other): + return constexpr(_constexpr_to_value(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _constexpr_to_value(other)) + + def __rfloordiv__(self, other): + return constexpr(_constexpr_to_value(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _constexpr_to_value(other)) + + def __rgt__(self, other): + return constexpr(_constexpr_to_value(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _constexpr_to_value(other)) + + def __rge__(self, other): + return constexpr(_constexpr_to_value(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _constexpr_to_value(other)) + + def __rlt__(self, other): + return constexpr(_constexpr_to_value(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _constexpr_to_value(other)) + + def __rle__(self, other): + return constexpr(_constexpr_to_value(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _constexpr_to_value(other)) - elif isinstance(x, constexpr): - return _to_tensor(x.value, builder) - elif isinstance(x, tensor): - return x - assert False, f"cannot convert {x} of type {type(x)} to tensor" + def __ne__(self, other): + return constexpr(self.value != _constexpr_to_value(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _constexpr_to_value(other)) + + def logical_and(self, other): + return constexpr(self.value and _constexpr_to_value(other)) + + def __or__(self, other): + return constexpr(self.value | _constexpr_to_value(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _constexpr_to_value(other)) + + def logical_or(self, other): + return constexpr(self.value or _constexpr_to_value(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_constexpr_to_value(other)) + + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _constexpr_to_value(other)) + + def __lshift__(self, other): + return constexpr(self.value << _constexpr_to_value(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +CONSTEXPR_0 = constexpr(0) + + +def _unwrap_if_constexpr(o): + return o.value if isinstance(o, constexpr) else o + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +# ----------------------- +# dtype +# ----------------------- class dtype: @@ -160,9 +296,13 @@ class SIGNEDNESS(Enum): SIGNED = 0 UNSIGNED = 1 + class KIND(Enum): + BOOLEAN = 0 + INTEGRAL = 1 + FLOATING = 2 + def __init__(self, name): - if hasattr(name, 'value'): - name = name.value + name = _unwrap_if_constexpr(name) self.name = name assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name if name in dtype.SINT_TYPES: @@ -207,7 +347,7 @@ def __init__(self, name): self.primitive_bitwidth = 32 self.exponent_bias = 127 elif name == 'fp64': - self.fp_mantissa_width = 53 + self.fp_mantissa_width = 52 self.primitive_bitwidth = 64 self.exponent_bias = 1023 else: @@ -290,6 +430,30 @@ def is_int(self): def is_bool(self): return self.is_int1() + def kind(self): + # Return int value following the type ordering bool < integer < fp + if self.is_bool(): + return dtype.KIND.BOOLEAN + elif self.is_int(): + return dtype.KIND.INTEGRAL + else: + assert self.is_floating() + return dtype.KIND.FLOATING + + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + @staticmethod def is_dtype(type_str): return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES @@ -326,6 +490,13 @@ def scalar(self): return self def to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + if self.name not in builder.options.supported_fp8_dtypes: + raise ValueError(f'type {self} not supported in this architecture. ' + f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') + if self.name in builder.options.deprecated_fp8_dtypes: + warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release") + if self.name == 'void': return builder.get_void_ty() elif self.name == 'int1': @@ -387,16 +558,17 @@ def __repr__(self): class pointer_type(dtype): - def __init__(self, element_ty: dtype, address_space: int = 1): + def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False): + element_ty = _unwrap_if_constexpr(element_ty) if not isinstance(element_ty, dtype): - raise TypeError(f'element_ty is a {type(element_ty).__name__}.') + raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.') self.element_ty = element_ty self.address_space = address_space - - self.name = f'pointer<{element_ty}>' + self.const = const + self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' def to_ir(self, builder: ir.builder) -> ir.pointer_type: - return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) + return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space) def __str__(self): return self.name @@ -407,10 +579,13 @@ def __repr__(self): def is_ptr(self): return True + def is_const(self): + return self.const + def __eq__(self, other: pointer_type) -> bool: if not isinstance(other, pointer_type): return False - return self.element_ty == other.element_ty and self.address_space == other.address_space + return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const def __ne__(self, other: pointer_type) -> bool: return not self.__eq__(other) @@ -420,21 +595,11 @@ def scalar(self): return self -class const_pointer_type(pointer_type): - - def __init__(self, element_ty: dtype, address_space: int = 1): - super().__init__(element_ty, address_space) - - def __str__(self): - return f'const_pointer<{self.element_ty}>' - - def is_const(self): - return True +class nv_tma_desc_type(pointer_type): - def __eq__(self, other) -> bool: - if not isinstance(other, const_pointer_type): - return False - return self.element_ty == other.element_ty and self.address_space == other.address_space + def __init__(self, const=True, address_space=0): + super().__init__(uint8, const=const, address_space=address_space) + self.name = 'nv_tma_desc_type' class block_type(dtype): @@ -446,18 +611,11 @@ def __init__(self, element_ty: dtype, shape: List): # while tensor's shape is a list of constexpr. # shape can be empty ([]) when an input is a 0D tensor. - if not shape: + self.shape = _unwrap_shape(shape) + if not self.shape: raise TypeError('0d block_type is forbidden') - if isinstance(shape[0], constexpr): - shape = [s.value for s in shape] - - self.shape = shape - self.numel = 1 - for s in self.shape: - self.numel *= s - if self.numel > TRITON_MAX_TENSOR_NUMEL: - raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + self.numel = validate_block_shape(self.shape) self.name = f'<{self.shape}, {self.element_ty}>' def to_ir(self, builder: ir.builder) -> ir.block_type: @@ -550,168 +708,20 @@ def get_int_dtype(bitwidth: int, signed: bool) -> dtype: raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') -# ----------------------- -# constexpr -# ----------------------- - - -class const: - """ - This class is used as a type annotation to mark pointers to constant data. - The `store` function cannot be called with a pointer to const. Constness - is part of the pointer type and the usual Triton type consistency rules - apply. For example you cannot have a function that returns constant pointer - in one return statement and non-constant pointer in another. - """ - pass - - -class constexpr: - """ - This class is used to store a value that is known at compile-time. +class _value: + """Base class of values that exist in the triton IR (i.e. not constexprs). """ - def __init__(self, value): - if isinstance(value, constexpr): - self.value = value.value - else: - self.value = value - - def __repr__(self) -> str: - return f"constexpr[{self.value}]" - - def __index__(self): - return self.value - - # In interpreter mode, constant values are not wrapped in constexpr, - # and therefore do not have a .value attribute. - # As a result, from here and below, we need to call the _constexpr_to_value - # function to obtain either constexpr.value or the value itself. - def __add__(self, other): - return constexpr(self.value + _constexpr_to_value(other)) - - def __radd__(self, other): - return constexpr(_constexpr_to_value(other) + self.value) - - def __sub__(self, other): - return constexpr(self.value - _constexpr_to_value(other)) - - def __rsub__(self, other): - return constexpr(_constexpr_to_value(other) - self.value) - - def __mul__(self, other): - return constexpr(self.value * _constexpr_to_value(other)) - - def __mod__(self, other): - return constexpr(self.value % _constexpr_to_value(other)) - - def __rmul__(self, other): - return constexpr(_constexpr_to_value(other) * self.value) - - def __truediv__(self, other): - return constexpr(self.value / _constexpr_to_value(other)) - - def __rtruediv__(self, other): - return constexpr(_constexpr_to_value(other) / self.value) - - def __floordiv__(self, other): - return constexpr(self.value // _constexpr_to_value(other)) - - def __rfloordiv__(self, other): - return constexpr(_constexpr_to_value(other) // self.value) - - def __gt__(self, other): - return constexpr(self.value > _constexpr_to_value(other)) - - def __rgt__(self, other): - return constexpr(_constexpr_to_value(other) > self.value) - - def __ge__(self, other): - return constexpr(self.value >= _constexpr_to_value(other)) - - def __rge__(self, other): - return constexpr(_constexpr_to_value(other) >= self.value) - - def __lt__(self, other): - return constexpr(self.value < _constexpr_to_value(other)) - - def __rlt__(self, other): - return constexpr(_constexpr_to_value(other) < self.value) - - def __le__(self, other): - return constexpr(self.value <= _constexpr_to_value(other)) - - def __rle__(self, other): - return constexpr(_constexpr_to_value(other) <= self.value) - - def __eq__(self, other): - return constexpr(self.value == _constexpr_to_value(other)) - - def __ne__(self, other): - return constexpr(self.value != _constexpr_to_value(other)) - - def __bool__(self): - return bool(self.value) - - def __neg__(self): - return constexpr(-self.value) - - def __and__(self, other): - return constexpr(self.value & _constexpr_to_value(other)) - - def logical_and(self, other): - return constexpr(self.value and _constexpr_to_value(other)) - - def __or__(self, other): - return constexpr(self.value | _constexpr_to_value(other)) - - def __xor__(self, other): - return constexpr(self.value ^ _constexpr_to_value(other)) - - def logical_or(self, other): - return constexpr(self.value or _constexpr_to_value(other)) - - def __pos__(self): - return constexpr(+self.value) - - def __invert__(self): - return constexpr(~self.value) - - def __pow__(self, other): - return constexpr(self.value**_constexpr_to_value(other)) - - def __rpow__(self, other): - return constexpr(_constexpr_to_value(other)**self.value) - - def __rshift__(self, other): - return constexpr(self.value >> _constexpr_to_value(other)) - - def __lshift__(self, other): - return constexpr(self.value << _constexpr_to_value(other)) - - def __not__(self): - return constexpr(not self.value) - - def __iter__(self): - return iter(self.value) - - def __call__(self, *args, **kwds): - return self.value(*args, **kwds) - - -CONSTEXPR_0 = constexpr(0) + def __init__(self, handle): + self.handle = handle -def check_bit_width(value, shift_value): - if isinstance(value, tensor) and isinstance(shift_value, constexpr): - bitwidth = value.type.scalar.primitive_bitwidth - if shift_value.value >= bitwidth: - warn( - f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." - ) +# ----------------------- +# tensor +# ----------------------- -class tensor: +class tensor(_value): """Represents an N-dimensional array of values or pointers. :code:`tensor` is the fundamental data structure in Triton programs. Most @@ -734,7 +744,7 @@ class tensor: def __init__(self, handle, type: dtype): """Not called by user code.""" # IR handle - self.handle = handle + super().__init__(handle) # Block shape self.shape = type.shape if type.is_block() else () self.numel = 1 @@ -752,60 +762,56 @@ def __str__(self) -> str: @builtin def __add__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.add(self, other, _builder) + return add(self, other, sanitize_overflow=True, _builder=_builder) @builtin def __radd__(self, other, _builder=None): - return self.__add__(other, _builder=_builder) + return add(other, self, sanitize_overflow=True, _builder=_builder) @builtin def __sub__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.sub(self, other, _builder) + return sub(self, other, sanitize_overflow=True, _builder=_builder) @builtin def __rsub__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.sub(other, self, _builder) + return sub(other, self, sanitize_overflow=True, _builder=_builder) @builtin def __mul__(self, other, _builder=None): - other = _to_tensor(other, _builder) - return semantic.mul(self, other, _builder) + return mul(self, other, sanitize_overflow=True, _builder=_builder) @builtin def __rmul__(self, other, _builder=None): - return self.__mul__(other, _builder=_builder) + return mul(other, self, sanitize_overflow=True, _builder=_builder) @builtin def __truediv__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.truediv(self, other, _builder) @builtin def __rtruediv__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.truediv(other, self, _builder) @builtin def __floordiv__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.floordiv(self, other, _builder) @builtin def __rfloordiv__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.floordiv(other, self, _builder) @builtin def __mod__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.mod(self, other, _builder) @builtin def __rmod__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.mod(other, self, _builder) # unary operators @@ -821,50 +827,50 @@ def __invert__(self, _builder=None): @builtin def __and__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.and_(self, other, _builder) @builtin def __rand__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.and_(other, self, _builder) @builtin def __or__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.or_(self, other, _builder) @builtin def __ror__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.or_(other, self, _builder) @builtin def __xor__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.xor_(self, other, _builder) @builtin def __rxor__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.xor_(other, self, _builder) @builtin def __lshift__(self, other, _builder=None): check_bit_width(self, other) - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.shl(self, other, _builder) @builtin def __rlshift__(self, other, _builder=None): check_bit_width(other, self) - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) return semantic.shl(other, self, _builder) @builtin def __rshift__(self, other, _builder=None): check_bit_width(self, other) - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) if self.dtype.is_int_signed(): return semantic.ashr(self, other, _builder) else: @@ -873,7 +879,7 @@ def __rshift__(self, other, _builder=None): @builtin def __rrshift__(self, other, _builder=None): check_bit_width(other, self) - other = _to_tensor(other, _builder) + other = _unwrap_if_constexpr(other) if self.dtype.is_int_signed(): return semantic.ashr(other, self, _builder) else: @@ -882,76 +888,76 @@ def __rrshift__(self, other, _builder=None): # > @builtin def __gt__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.greater_than(self, other, _builder) @builtin def __rgt__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.greater_than(other, self, _builder) # >= @builtin def __ge__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.greater_equal(self, other, _builder) @builtin def __rge__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.greater_equal(other, self, _builder) # < @builtin def __lt__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.less_than(self, other, _builder) @builtin def __rlt__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.less_than(other, self, _builder) # <= @builtin def __le__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.less_equal(self, other, _builder) @builtin def __rle__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.less_equal(other, self, _builder) # == @builtin def __eq__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.equal(self, other, _builder) @builtin def __req__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.equal(other, self, _builder) @builtin def __ne__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.not_equal(self, other, _builder) @builtin def __rne__(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.not_equal(other, self, _builder) @builtin def logical_and(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.logical_and(self, other, _builder) @builtin def logical_or(self, other, _builder=None): - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) return semantic.logical_or(self, other, _builder) # note: __not__ isn't actually a magic method in python @@ -986,8 +992,8 @@ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: """ # Triton doesn't like core functions calling other core functions, so we # just copy-paste the implementation of cast here. It's not too bad. - if isinstance(bitcast, constexpr): - bitcast = bitcast.value + dtype = _unwrap_if_constexpr(dtype) + bitcast = _unwrap_if_constexpr(bitcast) if bitcast: return semantic.bitcast(self, dtype, _builder) return semantic.cast(self, dtype, _builder, fp_downcast_rounding) @@ -1151,7 +1157,7 @@ def program_id(axis, _builder=None): # pid1 = program_id(1, _builder) # pid2 = program_id(2, _builder) # npg0 = num_programs(0, _builder) - # npg1 = num_programs(0, _builder) + # npg1 = num_programs(1, _builder) # return pid0 + pid1*npg0 + pid2*npg0*npg1 axis = _constexpr_to_value(axis) return semantic.program_id(axis, _builder) @@ -1176,34 +1182,33 @@ def num_programs(axis, _builder=None): @builtin def arange(start, end, _builder=None): - """ + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) + + +arange.__doc__ = f""" Returns contiguous values within the half-open interval :code:`[start, end)`. :code:`end - start` must be less than or equal to - :code:`TRITON_MAX_TENSOR_NUMEL = 131072` + :code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}` :param start: Start of the interval. Must be a power of two. :type start: int32 :param end: End of the interval. Must be a power of two greater than :code:`start`. :type end: int32 - """ - start = _constexpr_to_value(start) - end = _constexpr_to_value(end) - return semantic.arange(start, end, _builder) +""" -def _shape_check_impl(shape): +def _unwrap_shape(shape): shape = _constexpr_to_value(shape) - for i, d in enumerate(shape): - if isinstance(d, int): - d = constexpr(d) - if not isinstance(d, constexpr): - raise TypeError(f"Shape element {i} must have type `constexpr`") - if not isinstance(d.value, int): - raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") - if d.value & (d.value - 1) != 0: - raise ValueError(f"Shape element {i} must be a power of 2") - return [_constexpr_to_value(x) for x in shape] + return [_constexpr_to_value(s) for s in shape] + + +def _shape_check_impl(shape): + shape = _unwrap_shape(shape) + validate_block_shape(shape) + return shape @builtin @@ -1212,10 +1217,11 @@ def full(shape, value, dtype, _builder=None): Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. :param shape: Shape of the new array, e.g., (8, 16) or (8, ) - :value value: A scalar value to fill the array with :type shape: tuple of ints - :param dtype: Data-type of the new array, e.g., :code:`tl.float16` - :type dtype: DType + :param value: A scalar value to fill the array with + :type value: scalar + :param dtype: Data type of the new array, e.g., :code:`tl.float16` + :type dtype: tl.dtype """ shape = _shape_check_impl(shape) value = _constexpr_to_value(value) @@ -1268,8 +1274,8 @@ def trans(input: tensor, *dims, _builder=None): """ Permutes the dimensions of a tensor. - If no permutation is specified, tries to do a (1,0) permutation, i.e. tries - to transpose a 2D tensor. + If the parameter :code:`dims` is not specified, the function defaults to a (1,0) permutation, + effectively transposing a 2D tensor. :param input: The input tensor. :param dims: The desired ordering of dimensions. For example, @@ -1319,12 +1325,13 @@ def cat(input, other, can_reorder=False, _builder=None): Concatenate the given blocks :param input: The first input tensor. - :type input: + :type input: Tensor :param other: The second input tensor. - :type other: + :type other: Tensor :param reorder: Compiler hint. If true, the compiler is allowed to reorder elements while concatenating inputs. Only use if the - order does not matter (e.g., result is only used in reduction ops) + order does not matter (e.g., result is only used in reduction ops). + Current implementation of `cat` supports only can_reorder=True. """ return semantic.cat(input, other, can_reorder, _builder) @@ -1426,7 +1433,7 @@ def reshape(input, *shape, can_reorder=False, _builder=None): :type input: Block :param shape: The new shape. - :code:`shape ` can be passed as a tuple or as individual parameters: :: + :code:`shape` can be passed as a tuple or as individual parameters: :: # These are equivalent reshape(x, (32, 32)) @@ -1458,7 +1465,7 @@ def expand_dims(input, axis, _builder=None): :type axis: int | Sequence[int] """ - input = _to_tensor(input, _builder) + input = semantic.to_tensor(input, _builder) axis = _constexpr_to_value(axis) axes = list(axis) if isinstance(axis, Sequence) else [axis] new_ndim = len(input.shape) + len(axes) @@ -1480,15 +1487,18 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas Casts a tensor to the given :code:`dtype`. :param dtype: The target data type. + :type dtype: tl.dtype :param fp_downcast_rounding: The rounding mode for downcasting - floating-point values. This parameter is only used when self is a + floating-point values. This parameter is only used when self is a floating-point tensor and dtype is a floating-point type with a smaller bitwidth. Supported values are :code:`"rtne"` (round to nearest, ties to even) and :code:`"rtz"` (round towards zero). + :type fp_downcast_rounding: str, optional :param bitcast: If true, the tensor is bitcasted to the given :code:`dtype`, instead of being numerically casted. + :type bitcast: bool, optional """ - input = _to_tensor(input, _builder) + input = semantic.to_tensor(input, _builder) if isinstance(bitcast, constexpr): bitcast = bitcast.value if bitcast: @@ -1507,18 +1517,22 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i """ Returns the matrix product of two blocks. - The two blocks must be two-dimensional and have compatible inner dimensions. + The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions. + For three-dimensional blocks, `tl.dot` performs the batched matrix product, + where the first dimension of each block represents the batch dimension. :param input: The first tensor to be multiplied. - :type input: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} :param other: The second tensor to be multiplied. - :type other: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`} :param input_precision: How to exercise the Tensor Cores for f32 x f32. If the device does not have Tensor Cores or the inputs are not of dtype f32, - this option is ignored. For devices that do have tensor cores, the + this option is ignored. For devices that do have tensor cores, the default precision is tf32. :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. - :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". Only one of :code:`input_precision` and :code:`allow_tf32` can be specified (i.e. at least one must be :code:`None`). """ @@ -1534,6 +1548,29 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) +@builtin +def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, _builder=None): + """ + Returns the matrix product of two blocks in microscaling format. + lhs and rhs use microscaling formats described here: + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + :param lhs: The first tensor to be multiplied. + :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :param lhs_scale: Scale factor for lhs tensor. + :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). + :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :param rhs: The second tensor to be multiplied. + :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :param rhs_scale: Scale factor for rhs tensor. + :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). + :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + """ + out_dtype = _constexpr_to_value(out_dtype) + assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment" + return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, out_dtype, _builder) + + # ----------------------- # Non-Atomic Memory Operations # ----------------------- @@ -1562,9 +1599,8 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c (3) If `pointer` is a block pointer defined by `make_block_ptr`, a tensor is loaded. In this case: - - `mask` and `other` must be None, and - - `boundary_check` and `padding_option` can be specified to control - the behavior of out-of-bound access. + - `mask` and `other` must be `None`, and + - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access. :param pointer: Pointer to the data to be loaded :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` @@ -1575,9 +1611,11 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c :type other: Block, optional :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check :type boundary_check: tuple of ints, optional - :param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound + :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value. :param cache_modifier: changes cache option in NVIDIA PTX - :type cache_modifier: str, optional + :type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for + cache at all levels and "cg" stands for cache at global level (cache in L2 and below, not L1), see + `cache operator `_ for more details. :param eviction_policy: changes eviction policy in NVIDIA PTX :type eviction_policy: str, optional :param volatile: changes volatile option in NVIDIA PTX @@ -1587,9 +1625,9 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c mask = _constexpr_to_value(mask) other = _constexpr_to_value(other) if mask is not None: - mask = _to_tensor(mask, _builder) + mask = semantic.to_tensor(mask, _builder) if other is not None: - other = _to_tensor(other, _builder) + other = semantic.to_tensor(other, _builder) padding_option = _constexpr_to_value(padding_option) cache_modifier = _constexpr_to_value(cache_modifier) eviction_policy = _constexpr_to_value(eviction_policy) @@ -1606,7 +1644,7 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder= This loads a tensor of data based on the descriptor and offsets. """ - type = block_type(dtype, shape) + type = block_type(_constexpr_to_value(dtype), shape) return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) @@ -1656,15 +1694,17 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check :type boundary_check: tuple of ints, optional :param cache_modifier: changes cache option in NVIDIA PTX - :type cache_modifier: str, optional + :type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for + cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt" + stands for cache write-through, see `cache operator `_ for more details. :param eviction_policy: changes eviction policy in NVIDIA PTX - :type eviction_policy: str, optional + :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"} """ # `value` can be constexpr - value = _to_tensor(value, _builder) + value = semantic.to_tensor(value, _builder) mask = _constexpr_to_value(mask) if mask is not None: - mask = _to_tensor(mask, _builder) + mask = semantic.to_tensor(mask, _builder) cache_modifier = _constexpr_to_value(cache_modifier) eviction_policy = _constexpr_to_value(eviction_policy) return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder) @@ -1719,12 +1759,13 @@ def _decorator(func: T) -> T: docstr += """ :param val: The values with which to perform the atomic operation :type val: Block of dtype=pointer.dtype.element_ty - :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default), - "ACQUIRE", "RELEASE", or "RELAXED") - :type sem: str - :param scope: Scope of threads that observe synchronizing effect of the - atomic operation ("GPU" (default), "CTA", or "SYSTEM") - :type scope: str + :param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire", + "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, + the function defaults to using "acq_rel" semantics. + :type sem: str, optional + :param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation. + Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + :type scope: str, optional """ func.__doc__ = docstr return func @@ -1736,8 +1777,8 @@ def _decorator(func: T) -> T: @builtin @_add_atomic_docstr("compare-and-swap", has_cmp=True) def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): - cmp = _to_tensor(cmp, _builder) - val = _to_tensor(val, _builder) + cmp = semantic.to_tensor(cmp, _builder) + val = semantic.to_tensor(val, _builder) sem = _constexpr_to_value(sem) scope = _constexpr_to_value(scope) return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder) @@ -1747,7 +1788,7 @@ def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): @builtin @_add_atomic_docstr("exchange") def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): - val = _to_tensor(val, _builder) + val = semantic.to_tensor(val, _builder) sem = _constexpr_to_value(sem) scope = _constexpr_to_value(scope) mask = _constexpr_to_value(mask) @@ -1758,7 +1799,7 @@ def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): @builtin @_add_atomic_docstr("add") def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): - val = _to_tensor(val, _builder) + val = semantic.to_tensor(val, _builder) sem = _constexpr_to_value(sem) scope = _constexpr_to_value(scope) mask = _constexpr_to_value(mask) @@ -1769,7 +1810,7 @@ def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): @builtin @_add_atomic_docstr("max") def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): - val = _to_tensor(val, _builder) + val = semantic.to_tensor(val, _builder) sem = _constexpr_to_value(sem) scope = _constexpr_to_value(scope) mask = _constexpr_to_value(mask) @@ -1780,7 +1821,7 @@ def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): @builtin @_add_atomic_docstr("min") def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): - val = _to_tensor(val, _builder) + val = semantic.to_tensor(val, _builder) sem = _constexpr_to_value(sem) scope = _constexpr_to_value(scope) mask = _constexpr_to_value(mask) @@ -1791,7 +1832,7 @@ def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): @builtin @_add_atomic_docstr("logical and") def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): - val = _to_tensor(val, _builder) + val = semantic.to_tensor(val, _builder) sem = _constexpr_to_value(sem) scope = _constexpr_to_value(scope) mask = _constexpr_to_value(mask) @@ -1802,7 +1843,7 @@ def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): @builtin @_add_atomic_docstr("logical or") def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): - val = _to_tensor(val, _builder) + val = semantic.to_tensor(val, _builder) sem = _constexpr_to_value(sem) scope = _constexpr_to_value(scope) mask = _constexpr_to_value(mask) @@ -1813,7 +1854,7 @@ def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): @builtin @_add_atomic_docstr("logical xor") def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): - val = _to_tensor(val, _builder) + val = semantic.to_tensor(val, _builder) sem = _constexpr_to_value(sem) scope = _constexpr_to_value(scope) mask = _constexpr_to_value(mask) @@ -1842,9 +1883,9 @@ def where(condition, x, y, _builder=None): :param x: values selected at indices where condition is True. :param y: values selected at indices where condition is False. """ - condition = _to_tensor(condition, _builder) - x = _to_tensor(x, _builder) - y = _to_tensor(y, _builder) + condition = semantic.to_tensor(condition, _builder) + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) return semantic.where(condition, x, y, _builder) @@ -1853,6 +1894,27 @@ def where(condition, x, y, _builder=None): # ----------------------- +@builtin +def add(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.add(x, y, sanitize_overflow, _builder) + + +@builtin +def sub(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.sub(x, y, sanitize_overflow, _builder) + + +@builtin +def mul(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.mul(x, y, sanitize_overflow, _builder) + + @builtin def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): """ @@ -1867,8 +1929,8 @@ def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): .. seealso:: :class:`tl.PropagateNan` """ - x = _to_tensor(x, _builder) - y = _to_tensor(y, _builder) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) x = _promote_bfloat16_to_float32(x, _builder=_builder) y = _promote_bfloat16_to_float32(y, _builder=_builder) propagate_nan = _constexpr_to_value(propagate_nan) @@ -1889,8 +1951,8 @@ def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): .. seealso:: :class:`tl.PropagateNan` """ - x = _to_tensor(x, _builder) - y = _to_tensor(y, _builder) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) x = _promote_bfloat16_to_float32(x, _builder=_builder) y = _promote_bfloat16_to_float32(y, _builder=_builder) propagate_nan = _constexpr_to_value(propagate_nan) @@ -1915,9 +1977,9 @@ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=No .. seealso:: :class:`tl.PropagateNan` """ - x = _to_tensor(x, _builder) - min = _to_tensor(min, _builder) - max = _to_tensor(max, _builder) + x = semantic.to_tensor(x, _builder) + min = semantic.to_tensor(min, _builder) + max = semantic.to_tensor(max, _builder) x = _promote_bfloat16_to_float32(x, _builder=_builder) min = _promote_bfloat16_to_float32(min, _builder=_builder) max = _promote_bfloat16_to_float32(max, _builder=_builder) @@ -1939,14 +2001,19 @@ def _decorator(func: T) -> T: Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` :param input: the input values - :param axis: the dimension along which the reduction should be done - :param keep_dims: if true, keep the reduced dimensions with length 1""" + :type input: Tensor + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int + :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool""" if return_indices_arg is not None: docstr += f""" - :param {return_indices_arg}: if true, return index corresponding to the {name} value""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value + :type {return_indices_arg}: bool""" if tie_break_arg is not None: docstr += f""" - :param {tie_break_arg}: if true, return the left-most indices in case of ties for values that aren't NaN""" + :param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN + :type {tie_break_arg}: bool""" func.__doc__ = docstr.format(name=name) return func @@ -1967,9 +2034,13 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` :param input: the input tensor, or tuple of tensors + :type input: Tensor :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int | None :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :type combine_fn: Callable :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool """ if isinstance(input, tensor): @@ -2049,7 +2120,9 @@ def _decorator(func: T) -> T: Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` :param input: the input values - :param axis: the dimension along which the scan should be done""" + :type input: Tensor + :param axis: the dimension along which the scan should be done + :type axis: int""" func.__doc__ = docstr.format(name=name) return func @@ -2062,9 +2135,13 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry :param input: the input tensor, or tuple of tensors + :type input: Tensor :param axis: the dimension along which the reduction should be done + :type axis: int :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) - :param reverse: apply the associative scan in the reverse direction along axis. + :type combine_fn: Callable + :param reverse: whether to apply the associative scan in the reverse direction along axis + :type reverse: bool """ if isinstance(input, tensor): @@ -2098,7 +2175,9 @@ def histogram(input, num_bins, _builder=None, _generator=None): """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. :param input: the input tensor + :type input: Tensor :param num_bins: number of histogram bins + :type num_bins: int """ num_bins = _constexpr_to_value(num_bins) @@ -2169,6 +2248,14 @@ def max_constancy(input, values, _builder=None): return semantic.max_constancy(input, values) +@builtin +def assume(cond, _builder=None): + ''' + Allow compiler to assume the :code:`cond` is True. + ''' + return semantic.assume(semantic.to_tensor(cond, _builder), _builder) + + # ----------------------- # Debugging functions # ----------------------- @@ -2185,7 +2272,7 @@ def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=Fals .. highlight:: python .. code-block:: python - tl.static_print(f"{BLOCK_SIZE=}") + tl.static_print(f"BLOCK_SIZE={BLOCK_SIZE}") ''' pass @@ -2249,7 +2336,7 @@ def device_print(prefix, *args, hex=False, _builder=None): assert b_ascii, f"{prefix} is not an ascii string" new_args = [] for arg in args: - new_args.append(_to_tensor(arg, _builder)) + new_args.append(semantic.to_tensor(arg, _builder)) return semantic.device_print(prefix, new_args, hex, _builder) @@ -2273,25 +2360,7 @@ def device_assert(cond, msg="", _builder=None): :param msg: the message to print if the assertion fails. This is required to be a string literal. ''' msg = _constexpr_to_value(msg) - import inspect - frame = inspect.currentframe() - module = inspect.getmodule(frame) - # The triton function module doesn't have the name attribute. - # We use this trick to find the caller. - while hasattr(module, "__name__"): - frame = frame.f_back - module = inspect.getmodule(frame) - lineno = 0 - func_name = 'unknown' - file_name = 'unknown' - if frame is not None and frame.f_back is not None: - func_name = frame.f_code.co_name - file_name = frame.f_back.f_code.co_filename - # TODO: The line number currently indicates the line - # where the triton function is called but not where the - # device_assert is called. Need to enhance this. - lineno = frame.f_back.f_lineno - return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder) + return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, _builder) @builtin @@ -2317,66 +2386,66 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un cost you anything if you don't use it. Example using - [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + `PTX `_ assembly: .. highlight:: python .. code-block:: python - @triton.jit - def kernel(A, B, C, D, BLOCK: tl.constexpr): - a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor - b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor - - # For each (a,b) in zip(a,b), perform the following: - # - Let ai be `a` converted to int32. - # - Let af be `a` converted to float. - # - Let m be the max of ai and b. - # - Return ai and mi. - # Do the above 4 elements at a time. - (c, d) = tl.inline_asm_elementwise( - asm=""" - { - // Unpack `a` into `ai`. - .reg .b8 tmp<4>; - mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; - cvt.u32.u8 $0, tmp0; - cvt.u32.u8 $1, tmp1; - cvt.u32.u8 $2, tmp2; - cvt.u32.u8 $3, tmp3; - } - // Convert `ai` to float. - cvt.rn.f32.s32 $4, $0; - cvt.rn.f32.s32 $5, $1; - cvt.rn.f32.s32 $6, $2; - cvt.rn.f32.s32 $7, $3; - // Take max of `ai` and `b`. - max.f32 $4, $4, $9; - max.f32 $5, $5, $10; - max.f32 $6, $6, $11; - max.f32 $7, $7, $12; - """, - constraints=( - # 8 output registers, namely - # $0=ai0, $1=ai1, $2=ai2, $3=ai3, - # $4=m0, $5=m1, $6=m2, $7=m3. - "=r,=r,=r,=r,=r,=r,=r,=r," - # 5 input registers, namely - # $8=ai, - # $9=b0, $10=b1, $11=b2, $12=b3. - # The four elements from `a` are all packed into one register. - "r,r,r,r,r"), - args=[a, b], - dtype=(tl.int32, tl.float32), - is_pure=True, - pack=4, - ) - tl.store(C + tl.arange(0, BLOCK), c) - tl.store(D + tl.arange(0, BLOCK), d) + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) :param asm: assembly to run. Must match target's assembly format. :param constraints: asm constraints in - [LLVM format](https://llvm.org/docs/LangRef.html#inline-asm-constraint-string) + `LLVM format `_ :param args: the input tensors, whose values are passed to the asm block :param dtype: the element type(s) of the returned tensor(s) :param is_pure: if true, the compiler assumes the asm block has no side-effects @@ -2400,7 +2469,7 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr): dtype = typing.cast(Sequence[_DtypeClass], dtype) res_tys = dtype - if dispatch_args := [_to_tensor(arg, _builder) for arg in args]: + if dispatch_args := [semantic.to_tensor(arg, _builder) for arg in args]: bin_op_type_checking = partial( semantic.binary_op_type_checking_impl, builder=_builder, @@ -2449,17 +2518,17 @@ def kernel(...): """ def __init__(self, arg1, arg2=None, step=None): - assert isinstance(arg1, constexpr) + assert isinstance(arg1, constexpr), f"{arg1} used as tl.static_range start value is not a constexpr" if step is None: self.step = constexpr(1) else: - assert isinstance(step, constexpr) + assert isinstance(step, constexpr), f"{step} used as tl.static_range step value is not a constexpr" self.step = step if arg2 is None: self.start = constexpr(0) self.end = arg1 else: - assert isinstance(arg2, constexpr) + assert isinstance(arg2, constexpr), f"{arg2} used as tl.static_range end value is not a constexpr" self.start = arg1 self.end = arg2 @@ -2493,9 +2562,12 @@ def kernel(...): kernel argument. The kernel argument only pipelines loads that feed into :code:`dot` operations, while this attribute tries to pipeline most (though not all) loads in this loop. + :param loop_unroll_factor: Tells the Triton IR level loop unroller how many + times to unroll a for loop that this range is used with. Less than 2 for + this value implies no unrolling. """ - def __init__(self, arg1, arg2=None, step=None, num_stages=None): + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): if step is None: self.step = constexpr(1) else: @@ -2507,6 +2579,7 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None): self.start = arg1 self.end = arg2 self.num_stages = num_stages + self.loop_unroll_factor = loop_unroll_factor def __iter__(self): raise RuntimeError("tl.range can only be used in @triton.jit'd functions") @@ -2581,7 +2654,7 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol ret_shape = None arg_types = [] for i in builtins.range(len(dispatch_args)): - dispatch_args[i] = _to_tensor(dispatch_args[i], _builder) + dispatch_args[i] = semantic.to_tensor(dispatch_args[i], _builder) arg_types.append(dispatch_args[i].dtype) if dispatch_args[i].type.is_block(): all_scalar = False diff --git a/python/triton/language/extra/__init__.py b/python/triton/language/extra/__init__.py index 14e1778d2..3f8c70a71 100644 --- a/python/triton/language/extra/__init__.py +++ b/python/triton/language/extra/__init__.py @@ -1,4 +1,26 @@ -from . import cuda -from . import hip +import pkgutil +from importlib.util import module_from_spec +from sys import modules -__all__ = ['cuda', 'hip'] +_backends = [] +for module_finder, module_name, is_pkg in pkgutil.iter_modules( + __path__, + prefix=__name__ + ".", +): + # skip .py files (like libdevice.py) + if not is_pkg: + continue + + # import backends (like cuda and hip) that are included during setup.py + spec = module_finder.find_spec(module_name) + if spec is None or spec.loader is None: + continue + module = module_from_spec(spec) + spec.loader.exec_module(module) + + _backends.append(module_name) + modules[module_name] = module + +__all__ = _backends + +del _backends diff --git a/python/triton/language/extra/libdevice.py b/python/triton/language/extra/libdevice.py index 625cf3957..76627035d 100644 --- a/python/triton/language/extra/libdevice.py +++ b/python/triton/language/extra/libdevice.py @@ -1,1213 +1,786 @@ -from .cuda import libdevice as cuda_libdevice -from .hip import libdevice as hip_libdevice -from triton.language import core -from functools import wraps -from typing import TypeVar - -T = TypeVar('T') - - -def dispatch(fn: T) -> T: - """Dispatch a function to a correct implementation.""" - assert callable(fn) - - @wraps(fn) - def wrapper(*args, **kwargs): - _backend = kwargs["_builder"].options.backend_name - if _backend == 'cuda': - _curr_libdevice_module = cuda_libdevice - elif _backend == 'hip': - _curr_libdevice_module = hip_libdevice - else: - raise RuntimeError('unknown backend') - - try: - _impl = getattr(_curr_libdevice_module, fn.__name__) - except AttributeError: - raise RuntimeError(f'`{_backend}` does not provide support for `{fn.__name__}` extra function') - - return _impl(*args, **kwargs) - - return wrapper - - -@core.extern -@dispatch -def clz(arg0, _builder=None): +def clz(arg0): ... -@core.extern -@dispatch -def popc(arg0, _builder=None): +def popc(arg0): ... -@core.extern -@dispatch -def byte_perm(arg0, arg1, arg2, _builder=None): +def byte_perm(arg0, arg1, arg2): ... -@core.extern -@dispatch -def mulhi(arg0, arg1, _builder=None): +def mulhi(arg0, arg1): ... -@core.extern -@dispatch -def mul24(arg0, arg1, _builder=None): +def mul24(arg0, arg1): ... -@core.extern -@dispatch -def brev(arg0, _builder=None): +def brev(arg0): ... -@core.extern -@dispatch -def sad(arg0, arg1, arg2, _builder=None): +def sad(arg0, arg1, arg2): ... -@core.extern -@dispatch -def abs(arg0, _builder=None): +def abs(arg0): ... -@core.extern -@dispatch -def floor(arg0, _builder=None): +def floor(arg0): ... -@core.extern -@dispatch -def rcp64h(arg0, _builder=None): +def rcp64h(arg0): ... -@core.extern -@dispatch -def rsqrt(arg0, _builder=None): +def rsqrt(arg0): ... -@core.extern -@dispatch -def ceil(arg0, _builder=None): +def ceil(arg0): ... -@core.extern -@dispatch -def trunc(arg0, _builder=None): +def trunc(arg0): ... -@core.extern -@dispatch -def exp2(arg0, _builder=None): +def exp2(arg0): ... -@core.extern -@dispatch -def saturatef(arg0, _builder=None): +def saturatef(arg0): ... -@core.extern -@dispatch -def fma_rn(arg0, arg1, arg2, _builder=None): +def fma_rn(arg0, arg1, arg2): ... -@core.extern -@dispatch -def fma_rz(arg0, arg1, arg2, _builder=None): +def fma_rz(arg0, arg1, arg2): ... -@core.extern -@dispatch -def fma_rd(arg0, arg1, arg2, _builder=None): +def fma_rd(arg0, arg1, arg2): ... -@core.extern -@dispatch -def fma_ru(arg0, arg1, arg2, _builder=None): +def fma_ru(arg0, arg1, arg2): ... -@core.extern -@dispatch -def fast_dividef(arg0, arg1, _builder=None): +def fast_dividef(arg0, arg1): ... -@core.extern -@dispatch -def div_rn(arg0, arg1, _builder=None): +def div_rn(arg0, arg1): ... -@core.extern -@dispatch -def div_rz(arg0, arg1, _builder=None): +def div_rz(arg0, arg1): ... -@core.extern -@dispatch -def div_rd(arg0, arg1, _builder=None): +def div_rd(arg0, arg1): ... -@core.extern -@dispatch -def div_ru(arg0, arg1, _builder=None): +def div_ru(arg0, arg1): ... -@core.extern -@dispatch -def rcp_rn(arg0, _builder=None): +def rcp_rn(arg0): ... -@core.extern -@dispatch -def rcp_rz(arg0, _builder=None): +def rcp_rz(arg0): ... -@core.extern -@dispatch -def rcp_rd(arg0, _builder=None): +def rcp_rd(arg0): ... -@core.extern -@dispatch -def rcp_ru(arg0, _builder=None): +def rcp_ru(arg0): ... -@core.extern -@dispatch -def sqrt_rn(arg0, _builder=None): +def sqrt_rn(arg0): ... -@core.extern -@dispatch -def sqrt_rz(arg0, _builder=None): +def sqrt_rz(arg0): ... -@core.extern -@dispatch -def sqrt_rd(arg0, _builder=None): +def sqrt_rd(arg0): ... -@core.extern -@dispatch -def sqrt_ru(arg0, _builder=None): +def sqrt_ru(arg0): ... -@core.extern -@dispatch -def sqrt(arg0, _builder=None): +def sqrt(arg0): ... -@core.extern -@dispatch -def add_rn(arg0, arg1, _builder=None): +def add_rn(arg0, arg1): ... -@core.extern -@dispatch -def add_rz(arg0, arg1, _builder=None): +def add_rz(arg0, arg1): ... -@core.extern -@dispatch -def add_rd(arg0, arg1, _builder=None): +def add_rd(arg0, arg1): ... -@core.extern -@dispatch -def add_ru(arg0, arg1, _builder=None): +def add_ru(arg0, arg1): ... -@core.extern -@dispatch -def mul_rn(arg0, arg1, _builder=None): +def mul_rn(arg0, arg1): ... -@core.extern -@dispatch -def mul_rz(arg0, arg1, _builder=None): +def mul_rz(arg0, arg1): ... -@core.extern -@dispatch -def mul_rd(arg0, arg1, _builder=None): +def mul_rd(arg0, arg1): ... -@core.extern -@dispatch -def mul_ru(arg0, arg1, _builder=None): +def mul_ru(arg0, arg1): ... -@core.extern -@dispatch -def double2float_rn(arg0, _builder=None): +def double2float_rn(arg0): ... -@core.extern -@dispatch -def double2float_rz(arg0, _builder=None): +def double2float_rz(arg0): ... -@core.extern -@dispatch -def double2float_rd(arg0, _builder=None): +def double2float_rd(arg0): ... -@core.extern -@dispatch -def double2float_ru(arg0, _builder=None): +def double2float_ru(arg0): ... -@core.extern -@dispatch -def double2int_rn(arg0, _builder=None): +def double2int_rn(arg0): ... -@core.extern -@dispatch -def double2int_rz(arg0, _builder=None): +def double2int_rz(arg0): ... -@core.extern -@dispatch -def double2int_rd(arg0, _builder=None): +def double2int_rd(arg0): ... -@core.extern -@dispatch -def double2int_ru(arg0, _builder=None): +def double2int_ru(arg0): ... -@core.extern -@dispatch -def double2uint_rn(arg0, _builder=None): +def double2uint_rn(arg0): ... -@core.extern -@dispatch -def double2uint_rz(arg0, _builder=None): +def double2uint_rz(arg0): ... -@core.extern -@dispatch -def double2uint_rd(arg0, _builder=None): +def double2uint_rd(arg0): ... -@core.extern -@dispatch -def double2uint_ru(arg0, _builder=None): +def double2uint_ru(arg0): ... -@core.extern -@dispatch -def int2double_rn(arg0, _builder=None): +def int2double_rn(arg0): ... -@core.extern -@dispatch -def uint2double_rn(arg0, _builder=None): +def uint2double_rn(arg0): ... -@core.extern -@dispatch -def float2int_rn(arg0, _builder=None): +def float2int_rn(arg0): ... -@core.extern -@dispatch -def float2int_rz(arg0, _builder=None): +def float2int_rz(arg0): ... -@core.extern -@dispatch -def float2int_rd(arg0, _builder=None): +def float2int_rd(arg0): ... -@core.extern -@dispatch -def float2int_ru(arg0, _builder=None): +def float2int_ru(arg0): ... -@core.extern -@dispatch -def float2uint_rn(arg0, _builder=None): +def float2uint_rn(arg0): ... -@core.extern -@dispatch -def float2uint_rz(arg0, _builder=None): +def float2uint_rz(arg0): ... -@core.extern -@dispatch -def float2uint_rd(arg0, _builder=None): +def float2uint_rd(arg0): ... -@core.extern -@dispatch -def float2uint_ru(arg0, _builder=None): +def float2uint_ru(arg0): ... -@core.extern -@dispatch -def int2float_rn(arg0, _builder=None): +def int2float_rn(arg0): ... -@core.extern -@dispatch -def int2float_rz(arg0, _builder=None): +def int2float_rz(arg0): ... -@core.extern -@dispatch -def int2float_rd(arg0, _builder=None): +def int2float_rd(arg0): ... -@core.extern -@dispatch -def int2float_ru(arg0, _builder=None): +def int2float_ru(arg0): ... -@core.extern -@dispatch -def uint2float_rn(arg0, _builder=None): +def uint2float_rn(arg0): ... -@core.extern -@dispatch -def uint2float_rz(arg0, _builder=None): +def uint2float_rz(arg0): ... -@core.extern -@dispatch -def uint2float_rd(arg0, _builder=None): +def uint2float_rd(arg0): ... -@core.extern -@dispatch -def uint2float_ru(arg0, _builder=None): +def uint2float_ru(arg0): ... -@core.extern -@dispatch -def hiloint2double(arg0, arg1, _builder=None): +def hiloint2double(arg0, arg1): ... -@core.extern -@dispatch -def double2loint(arg0, _builder=None): +def double2loint(arg0): ... -@core.extern -@dispatch -def double2hiint(arg0, _builder=None): +def double2hiint(arg0): ... -@core.extern -@dispatch -def float2ll_rn(arg0, _builder=None): +def float2ll_rn(arg0): ... -@core.extern -@dispatch -def float2ll_rz(arg0, _builder=None): +def float2ll_rz(arg0): ... -@core.extern -@dispatch -def float2ll_rd(arg0, _builder=None): +def float2ll_rd(arg0): ... -@core.extern -@dispatch -def float2ll_ru(arg0, _builder=None): +def float2ll_ru(arg0): ... -@core.extern -@dispatch -def float2ull_rn(arg0, _builder=None): +def float2ull_rn(arg0): ... -@core.extern -@dispatch -def float2ull_rz(arg0, _builder=None): +def float2ull_rz(arg0): ... -@core.extern -@dispatch -def float2ull_rd(arg0, _builder=None): +def float2ull_rd(arg0): ... -@core.extern -@dispatch -def float2ull_ru(arg0, _builder=None): +def float2ull_ru(arg0): ... -@core.extern -@dispatch -def double2ll_rn(arg0, _builder=None): +def double2ll_rn(arg0): ... -@core.extern -@dispatch -def double2ll_rz(arg0, _builder=None): +def double2ll_rz(arg0): ... -@core.extern -@dispatch -def double2ll_rd(arg0, _builder=None): +def double2ll_rd(arg0): ... -@core.extern -@dispatch -def double2ll_ru(arg0, _builder=None): +def double2ll_ru(arg0): ... -@core.extern -@dispatch -def double2ull_rn(arg0, _builder=None): +def double2ull_rn(arg0): ... -@core.extern -@dispatch -def double2ull_rz(arg0, _builder=None): +def double2ull_rz(arg0): ... -@core.extern -@dispatch -def double2ull_rd(arg0, _builder=None): +def double2ull_rd(arg0): ... -@core.extern -@dispatch -def double2ull_ru(arg0, _builder=None): +def double2ull_ru(arg0): ... -@core.extern -@dispatch -def ll2float_rn(arg0, _builder=None): +def ll2float_rn(arg0): ... -@core.extern -@dispatch -def ll2float_rz(arg0, _builder=None): +def ll2float_rz(arg0): ... -@core.extern -@dispatch -def ll2float_rd(arg0, _builder=None): +def ll2float_rd(arg0): ... -@core.extern -@dispatch -def ll2float_ru(arg0, _builder=None): +def ll2float_ru(arg0): ... -@core.extern -@dispatch -def ull2float_rn(arg0, _builder=None): +def ull2float_rn(arg0): ... -@core.extern -@dispatch -def ull2float_rz(arg0, _builder=None): +def ull2float_rz(arg0): ... -@core.extern -@dispatch -def ull2float_rd(arg0, _builder=None): +def ull2float_rd(arg0): ... -@core.extern -@dispatch -def ull2float_ru(arg0, _builder=None): +def ull2float_ru(arg0): ... -@core.extern -@dispatch -def ll2double_rn(arg0, _builder=None): +def ll2double_rn(arg0): ... -@core.extern -@dispatch -def ll2double_rz(arg0, _builder=None): +def ll2double_rz(arg0): ... -@core.extern -@dispatch -def ll2double_rd(arg0, _builder=None): +def ll2double_rd(arg0): ... -@core.extern -@dispatch -def ll2double_ru(arg0, _builder=None): +def ll2double_ru(arg0): ... -@core.extern -@dispatch -def ull2double_rn(arg0, _builder=None): +def ull2double_rn(arg0): ... -@core.extern -@dispatch -def ull2double_rz(arg0, _builder=None): +def ull2double_rz(arg0): ... -@core.extern -@dispatch -def ull2double_rd(arg0, _builder=None): +def ull2double_rd(arg0): ... -@core.extern -@dispatch -def ull2double_ru(arg0, _builder=None): +def ull2double_ru(arg0): ... -@core.extern -@dispatch -def int_as_float(arg0, _builder=None): +def int_as_float(arg0): ... -@core.extern -@dispatch -def float_as_int(arg0, _builder=None): +def float_as_int(arg0): ... -@core.extern -@dispatch -def uint_as_float(arg0, _builder=None): +def uint_as_float(arg0): ... -@core.extern -@dispatch -def float_as_uint(arg0, _builder=None): +def float_as_uint(arg0): ... -@core.extern -@dispatch -def longlong_as_double(arg0, _builder=None): +def longlong_as_double(arg0): ... -@core.extern -@dispatch -def double_as_longlong(arg0, _builder=None): +def double_as_longlong(arg0): ... -@core.extern -@dispatch -def fast_sinf(arg0, _builder=None): +def fast_sinf(arg0): ... -@core.extern -@dispatch -def fast_cosf(arg0, _builder=None): +def fast_cosf(arg0): ... -@core.extern -@dispatch -def fast_log2f(arg0, _builder=None): +def fast_log2f(arg0): ... -@core.extern -@dispatch -def fast_logf(arg0, _builder=None): +def fast_logf(arg0): ... -@core.extern -@dispatch -def fast_expf(arg0, _builder=None): +def fast_expf(arg0): ... -@core.extern -@dispatch -def fast_tanf(arg0, _builder=None): +def fast_tanf(arg0): ... -@core.extern -@dispatch -def fast_exp10f(arg0, _builder=None): +def fast_exp10f(arg0): ... -@core.extern -@dispatch -def fast_log10f(arg0, _builder=None): +def fast_log10f(arg0): ... -@core.extern -@dispatch -def fast_powf(arg0, arg1, _builder=None): +def fast_powf(arg0, arg1): ... -@core.extern -@dispatch -def hadd(arg0, arg1, _builder=None): +def hadd(arg0, arg1): ... -@core.extern -@dispatch -def rhadd(arg0, arg1, _builder=None): +def rhadd(arg0, arg1): ... -@core.extern -@dispatch -def sub_rn(arg0, arg1, _builder=None): +def sub_rn(arg0, arg1): ... -@core.extern -@dispatch -def sub_rz(arg0, arg1, _builder=None): +def sub_rz(arg0, arg1): ... -@core.extern -@dispatch -def sub_rd(arg0, arg1, _builder=None): +def sub_rd(arg0, arg1): ... -@core.extern -@dispatch -def sub_ru(arg0, arg1, _builder=None): +def sub_ru(arg0, arg1): ... -@core.extern -@dispatch -def rsqrt_rn(arg0, _builder=None): +def rsqrt_rn(arg0): ... -@core.extern -@dispatch -def ffs(arg0, _builder=None): +def ffs(arg0): ... -@core.extern -@dispatch -def rint(arg0, _builder=None): +def rint(arg0): ... -@core.extern -@dispatch -def llrint(arg0, _builder=None): +def llrint(arg0): ... -@core.extern -@dispatch -def nearbyint(arg0, _builder=None): +def nearbyint(arg0): ... -@core.extern -@dispatch -def isnan(arg0, _builder=None): +def isnan(arg0): ... -@core.extern -@dispatch -def signbit(arg0, _builder=None): +def signbit(arg0): ... -@core.extern -@dispatch -def copysign(arg0, arg1, _builder=None): +def copysign(arg0, arg1): ... -@core.extern -@dispatch -def finitef(arg0, _builder=None): +def finitef(arg0): ... -@core.extern -@dispatch -def isinf(arg0, _builder=None): +def isinf(arg0): ... -@core.extern -@dispatch -def nextafter(arg0, arg1, _builder=None): +def nextafter(arg0, arg1): ... -@core.extern -@dispatch -def sin(arg0, _builder=None): +def sin(arg0): ... -@core.extern -@dispatch -def cos(arg0, _builder=None): +def cos(arg0): ... -@core.extern -@dispatch -def sinpi(arg0, _builder=None): +def sinpi(arg0): ... -@core.extern -@dispatch -def cospi(arg0, _builder=None): +def cospi(arg0): ... -@core.extern -@dispatch -def tan(arg0, _builder=None): +def tan(arg0): ... -@core.extern -@dispatch -def log2(arg0, _builder=None): +def log2(arg0): ... -@core.extern -@dispatch -def exp(arg0, _builder=None): +def exp(arg0): ... -@core.extern -@dispatch -def exp10(arg0, _builder=None): +def exp10(arg0): ... -@core.extern -@dispatch -def cosh(arg0, _builder=None): +def cosh(arg0): ... -@core.extern -@dispatch -def sinh(arg0, _builder=None): +def sinh(arg0): ... -@core.extern -@dispatch -def tanh(arg0, _builder=None): +def tanh(arg0): ... -@core.extern -@dispatch -def atan2(arg0, arg1, _builder=None): +def atan2(arg0, arg1): ... -@core.extern -@dispatch -def atan(arg0, _builder=None): +def atan(arg0): ... -@core.extern -@dispatch -def asin(arg0, _builder=None): +def asin(arg0): ... -@core.extern -@dispatch -def acos(arg0, _builder=None): +def acos(arg0): ... -@core.extern -@dispatch -def log(arg0, _builder=None): +def log(arg0): ... -@core.extern -@dispatch -def log10(arg0, _builder=None): +def log10(arg0): ... -@core.extern -@dispatch -def log1p(arg0, _builder=None): +def log1p(arg0): ... -@core.extern -@dispatch -def acosh(arg0, _builder=None): +def acosh(arg0): ... -@core.extern -@dispatch -def asinh(arg0, _builder=None): +def asinh(arg0): ... -@core.extern -@dispatch -def atanh(arg0, _builder=None): +def atanh(arg0): ... -@core.extern -@dispatch -def expm1(arg0, _builder=None): +def expm1(arg0): ... -@core.extern -@dispatch -def hypot(arg0, arg1, _builder=None): +def hypot(arg0, arg1): ... -@core.extern -@dispatch -def rhypot(arg0, arg1, _builder=None): +def rhypot(arg0, arg1): ... -@core.extern -@dispatch -def norm3d(arg0, arg1, arg2, _builder=None): +def norm3d(arg0, arg1, arg2): ... -@core.extern -@dispatch -def rnorm3d(arg0, arg1, arg2, _builder=None): +def rnorm3d(arg0, arg1, arg2): ... -@core.extern -@dispatch -def norm4d(arg0, arg1, arg2, arg3, _builder=None): +def norm4d(arg0, arg1, arg2, arg3): ... -@core.extern -@dispatch -def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): +def rnorm4d(arg0, arg1, arg2, arg3): ... -@core.extern -@dispatch -def cbrt(arg0, _builder=None): +def cbrt(arg0): ... -@core.extern -@dispatch -def rcbrt(arg0, _builder=None): +def rcbrt(arg0): ... -@core.extern -@dispatch -def j0(arg0, _builder=None): +def j0(arg0): ... -@core.extern -@dispatch -def j1(arg0, _builder=None): +def j1(arg0): ... -@core.extern -@dispatch -def y0(arg0, _builder=None): +def y0(arg0): ... -@core.extern -@dispatch -def y1(arg0, _builder=None): +def y1(arg0): ... -@core.extern -@dispatch -def yn(arg0, arg1, _builder=None): +def yn(arg0, arg1): ... -@core.extern -@dispatch -def jn(arg0, arg1, _builder=None): +def jn(arg0, arg1): ... -@core.extern -@dispatch -def cyl_bessel_i0(arg0, _builder=None): +def cyl_bessel_i0(arg0): ... -@core.extern -@dispatch -def cyl_bessel_i1(arg0, _builder=None): +def cyl_bessel_i1(arg0): ... -@core.extern -@dispatch -def erf(arg0, _builder=None): +def erf(arg0): ... -@core.extern -@dispatch -def erfinv(arg0, _builder=None): +def erfinv(arg0): ... -@core.extern -@dispatch -def erfc(arg0, _builder=None): +def erfc(arg0): ... -@core.extern -@dispatch -def erfcx(arg0, _builder=None): +def erfcx(arg0): ... -@core.extern -@dispatch -def erfcinv(arg0, _builder=None): +def erfcinv(arg0): ... -@core.extern -@dispatch -def normcdfinv(arg0, _builder=None): +def normcdfinv(arg0): ... -@core.extern -@dispatch -def normcdf(arg0, _builder=None): +def normcdf(arg0): ... -@core.extern -@dispatch -def lgamma(arg0, _builder=None): +def lgamma(arg0): ... -@core.extern -@dispatch -def ldexp(arg0, arg1, _builder=None): +def ldexp(arg0, arg1): ... -@core.extern -@dispatch -def scalbn(arg0, arg1, _builder=None): +def scalbn(arg0, arg1): ... -@core.extern -@dispatch -def fmod(arg0, arg1, _builder=None): +def fmod(arg0, arg1): ... -@core.extern -@dispatch -def remainder(arg0, arg1, _builder=None): +def remainder(arg0, arg1): ... -@core.extern -@dispatch -def fma(arg0, arg1, arg2, _builder=None): +def fma(arg0, arg1, arg2): ... -@core.extern -@dispatch -def pow(arg0, arg1, _builder=None): +def pow(arg0, arg1): ... -@core.extern -@dispatch -def tgamma(arg0, _builder=None): +def tgamma(arg0): ... -@core.extern -@dispatch -def round(arg0, _builder=None): +def round(arg0): ... -@core.extern -@dispatch -def llround(arg0, _builder=None): +def llround(arg0): ... -@core.extern -@dispatch -def fdim(arg0, arg1, _builder=None): +def fdim(arg0, arg1): ... -@core.extern -@dispatch -def ilogb(arg0, _builder=None): +def ilogb(arg0): ... -@core.extern -@dispatch -def logb(arg0, _builder=None): +def logb(arg0): ... -@core.extern -@dispatch -def isfinited(arg0, _builder=None): +def isfinited(arg0): ... diff --git a/python/triton/language/extra/hip/__init__.py b/python/triton/language/extra/xpu/__init__.py similarity index 100% rename from python/triton/language/extra/hip/__init__.py rename to python/triton/language/extra/xpu/__init__.py diff --git a/python/triton/language/extra/xpu/libdevice.py b/python/triton/language/extra/xpu/libdevice.py new file mode 100644 index 000000000..66006339c --- /dev/null +++ b/python/triton/language/extra/xpu/libdevice.py @@ -0,0 +1,1650 @@ +from triton.language import core + + +@core.extern +def clz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def popc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("_ZN3xpu6umulhiEjj", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64")): ("Unsupported", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("Unsupported", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul24(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("Unsupported", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def brev(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sad(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("Unsupported", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int64")), + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu9xpu_floorEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("_ZN3xpu9xpu_floorEd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp64h(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu6hrsqrtEDF16_", core.dtype("fp16")), + (core.dtype("fp32"), ): ("_ZN3xpu6rsqrtfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("_ZN3xpu6rsqrtfEf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("_ZN3xpu8xpu_ceilEd", core.dtype("fp64")), + (core.dtype("fp32"), ): ("_ZN3xpu8xpu_ceilEf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), ): ("_ZN3xpu6truncfEf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5exp2fEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def saturatef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu9__fdiv_rnEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu9__fdiv_rzEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu10__fsqrt_rnEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("_ZN3xpu10__dsqrt_rnEd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu8xpu_sqrtEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("_ZN3xpu8xpu_sqrtEd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("Unsupported", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hiloint2double(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2loint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2hiint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_uint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def longlong_as_double(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double_as_longlong(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_sinf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_cosf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log2f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_logf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_tanf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_exp10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_powf(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("Unsupported", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("Unsupported", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("_ZN3xpu4rintEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("_ZN3xpu9nearbyintEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp16"), ): ("_ZN3xpu6hisnanEDF16_", core.dtype("int32")), + (core.dtype("fp32"), ): ("_ZN3xpu5isnanEf", core.dtype("int32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("_ZN3xpu10__signbitfEf", core.dtype("int32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu7hfiniteEDF16_", core.dtype("int16")), + (core.dtype("fp32"), ): ("_ZN3xpu7finitefEf", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu4hsinEDF16_", core.dtype("int32")), + (core.dtype("fp32"), ): ("_ZN3xpu5isinfEf", core.dtype("int32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu4sinfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu4cosfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinpi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu4tanfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu5hlog2EDF16_", core.dtype("fp16")), + (core.dtype("fp32"), ): ("_ZN3xpu5log2fEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5coshfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5sinhfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu5htanhEDF16_", core.dtype("fp16")), + (core.dtype("fp32"), ): ("_ZN3xpu5tanhfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu6atan2fEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5atanfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5asinfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5acosfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6log10fEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6log1pfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6acoshfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6asinhfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6atanhfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6expm1fEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def yn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def jn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu3erfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6erfinvEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu4erfcEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def scalbn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu5fmodfEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def remainder(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu3fmaEfff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp16"), core.dtype("fp16")): ("_ZN3xpu4hpowEDF16_DF16_", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu3powEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def logb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isfinited(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def xpu_trunc_div(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu9xpu_truncEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) diff --git a/python/triton/language/math.py b/python/triton/language/math.py index de5b5be6b..60ba170c7 100644 --- a/python/triton/language/math.py +++ b/python/triton/language/math.py @@ -86,8 +86,8 @@ def _decorator(func: T) -> T: @_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) @_add_math_2arg_docstr("most significant N bits of the 2N-bit product") def umulhi(x, y, _builder=None): - x = core._to_tensor(x, _builder) - y = core._to_tensor(y, _builder) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) x, y = core.binary_op_type_legalization(x, y, _builder) return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) @@ -97,7 +97,7 @@ def umulhi(x, y, _builder=None): @_add_math_1arg_docstr("exponential") @core._tensor_member_fn def exp(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_exp(x.handle), x.type) @@ -106,7 +106,7 @@ def exp(x, _builder=None): @_add_math_1arg_docstr("exponential (base 2)") @core._tensor_member_fn def exp2(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_exp2(x.handle), x.type) @@ -115,7 +115,7 @@ def exp2(x, _builder=None): @_add_math_1arg_docstr("natural logarithm") @core._tensor_member_fn def log(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_log(x.handle), x.type) @@ -124,7 +124,7 @@ def log(x, _builder=None): @_add_math_1arg_docstr("logarithm (base 2)") @core._tensor_member_fn def log2(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_log2(x.handle), x.type) @@ -133,7 +133,7 @@ def log2(x, _builder=None): @_add_math_1arg_docstr("cosine") @core._tensor_member_fn def cos(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_cos(x.handle), x.type) @@ -142,7 +142,7 @@ def cos(x, _builder=None): @_add_math_1arg_docstr("sine") @core._tensor_member_fn def sin(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_sin(x.handle), x.type) @@ -151,16 +151,16 @@ def sin(x, _builder=None): @_add_math_1arg_docstr("fast square root") @core._tensor_member_fn def sqrt(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_sqrt(x.handle), x.type) @core.builtin @_check_dtype(dtypes=["fp32"]) -@_add_math_1arg_docstr("precise square root (rounding to nearest)") +@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") @core._tensor_member_fn def sqrt_rn(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) @@ -169,7 +169,7 @@ def sqrt_rn(x, _builder=None): @_add_math_1arg_docstr("inverse square root") @core._tensor_member_fn def rsqrt(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_rsqrt(x.handle), x.type) @@ -177,7 +177,7 @@ def rsqrt(x, _builder=None): @_add_math_1arg_docstr("absolute value") @core._tensor_member_fn def abs(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) dtype = x.dtype if dtype.is_fp8e4b15(): mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder) @@ -196,17 +196,17 @@ def abs(x, _builder=None): @_add_math_2arg_docstr("fast division") def fdiv(x, y, ieee_rounding=False, _builder=None): ieee_rounding = core._constexpr_to_value(ieee_rounding) - x = core._to_tensor(x, _builder) - y = core._to_tensor(y, _builder) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) return semantic.fdiv(x, y, ieee_rounding, _builder) @core.builtin @_check_dtype(dtypes=["fp32"]) -@_add_math_2arg_docstr("precise division (rounding to nearest)") +@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") def div_rn(x, y, _builder=None): - x = core._to_tensor(x, _builder) - y = core._to_tensor(y, _builder) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) x, y = core.binary_op_type_legalization(x, y, _builder) return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) @@ -216,7 +216,7 @@ def div_rn(x, y, _builder=None): @_add_math_1arg_docstr("error function") @core._tensor_member_fn def erf(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_erf(x.handle), x.type) @@ -225,7 +225,7 @@ def erf(x, _builder=None): @_add_math_1arg_docstr("floor") @core._tensor_member_fn def floor(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_floor(x.handle), x.type) @@ -234,16 +234,16 @@ def floor(x, _builder=None): @_add_math_1arg_docstr("ceil") @core._tensor_member_fn def ceil(x, _builder=None): - x = core._to_tensor(x, _builder) + x = semantic.to_tensor(x, _builder) return core.tensor(_builder.create_ceil(x.handle), x.type) @core.builtin @_add_math_3arg_docstr("fused multiply-add") def fma(x, y, z, _builder=None): - x = core._to_tensor(x, _builder) - y = core._to_tensor(y, _builder) - z = core._to_tensor(z, _builder) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + z = semantic.to_tensor(z, _builder) x, y = core.binary_op_type_legalization(x, y, _builder) z, x = core.binary_op_type_legalization(z, x, _builder) z, y = core.binary_op_type_legalization(z, y, _builder) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 430aeb09e..1c001695e 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -34,11 +34,11 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL _c0, _c2 = c0, c2 c0 = math.umulhi(B, _c2) ^ c1 ^ k0 c2 = math.umulhi(A, _c0) ^ c3 ^ k1 - c1 = B * _c2 - c3 = A * _c0 + c1 = tl.mul(B, _c2, sanitize_overflow=False) + c3 = tl.mul(A, _c0, sanitize_overflow=False) # raise key - k0 = k0 + PHILOX_KEY_A - k1 = k1 + PHILOX_KEY_B + k0 = tl.add(k0, PHILOX_KEY_A, sanitize_overflow=False) + k1 = tl.add(k1, PHILOX_KEY_B, sanitize_overflow=False) return c0, c1, c2, c3 diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 83d4dfc8c..8e9f87b5e 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1,6 +1,8 @@ from __future__ import annotations # remove after python 3.11 +import warnings from typing import List, Optional, Sequence, Tuple, TypeVar +import numbers from .._C.libtriton import ir from . import core as tl @@ -56,7 +58,19 @@ def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") -def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: +def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool, + div_or_mod: bool) -> tl.dtype: + # 0) For scalars we follow semantics similar to PyTorch, namely: + # - If the scalar is of a lower or equal kind (bool < uint < int < fp), + # it doesn't participate in the pomotion + if a_is_scalar != b_is_scalar: + scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty) + if scalar_ty.kind().value <= tensor_ty.kind().value: + # Upcast because of 3) and 4) below! + if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)): + return tl.float32 + return tensor_ty + # 1) if one operand is double, the other is implicitly # converted to double if a_ty.is_fp64() or b_ty.is_fp64(): @@ -80,9 +94,12 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t if a_ty.is_bf16() and b_ty.is_bf16(): return tl.bfloat16 return tl.float32 + # 5) return fp16 if operands are different fp8 + if a_ty.is_fp8() and b_ty.is_fp8(): + return a_ty if a_ty == b_ty else tl.float16 if not a_ty.is_int() or not b_ty.is_int(): raise TypeError(f"unexpected type {a_ty} and {b_ty}") - # 5 ) both operands are integer and undergo + # 6 ) both operands are integer and undergo # integer promotion if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + @@ -91,6 +108,44 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t return integer_promote_impl(a_ty, b_ty) +def to_tensor(x, builder, check_type: bool = True): + if isinstance(x, bool): + return tl.tensor(builder.get_int1(x), tl.int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + dtype = tl.int32 + elif 2**31 <= x < 2**32: + dtype = tl.uint32 + elif -2**63 <= x < 2**63: + dtype = tl.int64 + elif 2**63 <= x < 2**64: + dtype = tl.uint64 + else: + raise ValueError(f'Nonrepresentable integer {x}.') + return full((), x, dtype=dtype, builder=builder) + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + dtype = tl.float32 + else: + dtype = tl.float64 + return full((), x, dtype=dtype, builder=builder) + + elif isinstance(x, tl.constexpr): + return to_tensor(x.value, builder) + elif isinstance(x, tl.tensor): + return x + if check_type: + raise TypeError(f"cannot convert {x} of type {type(x)} to tensor") + return x + + # ===----------------------------------------------------------------------===// # Binary Operators # ===----------------------------------------------------------------------===// @@ -108,24 +163,60 @@ def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) - raise IncompatibleTypeErrorImpl(type_a, type_b) -def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False, - allow_rhs_ptr=False, arithmetic_check=True, +def binary_op_type_checking_impl(lhs: tl.tensor | numbers.Number, rhs: tl.tensor | numbers.Number, builder: ir.builder, + allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True, div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: - # implicit broadcasting - lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + lhs_is_scalar = isinstance(lhs, numbers.Number) + rhs_is_scalar = isinstance(rhs, numbers.Number) + if lhs_is_scalar: + lhs_scalar = lhs + lhs = to_tensor(lhs, builder) + if rhs_is_scalar: + rhs_scalar = rhs + rhs = to_tensor(rhs, builder) + # implicit typecasting lhs_sca_ty = lhs.type.scalar rhs_sca_ty = rhs.type.scalar check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): - ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) - lhs = cast(lhs, ret_sca_ty, builder) - rhs = cast(rhs, ret_sca_ty, builder) + ret_sca_ty = computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod) + if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned() + or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()): + raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. " + "Perform a explicit cast on one of them.") + lhs = full( + (), lhs_scalar, dtype=ret_sca_ty, builder=builder) if lhs_is_scalar else cast(lhs, ret_sca_ty, builder) + rhs = full( + (), rhs_scalar, dtype=ret_sca_ty, builder=builder) if rhs_is_scalar else cast(rhs, ret_sca_ty, builder) + + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) return lhs, rhs -def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: +def binary_op_sanitize_overflow_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, binary_op: callable): + if lhs.type.scalar.int_bitwidth >= 64 or not builder.options.sanitize_overflow: + return + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + assert lhs_sca_ty == rhs_sca_ty + assert lhs_sca_ty.is_int() + lhs = cast(lhs, tl.int64, builder) + rhs = cast(rhs, tl.int64, builder) + ret = binary_op(lhs, rhs, False, builder) + max_value = lhs_sca_ty.get_int_max_value() + max_value = tl.tensor(builder.get_int64(max_value), tl.int64) + min_value = lhs_sca_ty.get_int_min_value() + min_value = tl.tensor(builder.get_int64(min_value), tl.int64) + cond = and_(less_equal(ret, max_value, builder), greater_equal(ret, min_value, builder), builder) + msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}" + device_assert(cond, msg, builder) + + +def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -145,11 +236,14 @@ def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) # int + int elif input_scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, add) return tl.tensor(builder.create_add(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {input_scalar_ty}") -def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: +def sub(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, True, False) scalar_ty = input.type.scalar # ptr - offset @@ -160,23 +254,28 @@ def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) # int - int elif scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, sub) return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {scalar_ty}") -def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: +def mul(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float * float if scalar_ty.is_floating(): return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) - # * int + # int * int elif scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, mul) return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) raise TypeError(f"unexpected type {scalar_ty}") -def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: +def truediv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -202,7 +301,7 @@ def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tenso return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) -def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: +def floordiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -217,7 +316,8 @@ def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tens raise TypeError(f"unexpected type {input_scalar_ty}") -def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor: +def fdiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, ieee_rounding: bool, + builder: ir.builder) -> tl.tensor: input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): @@ -227,14 +327,15 @@ def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.bu return tl.tensor(ret, input.type) -def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: +def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar # float % float if scalar_ty.is_floating(): # input - input.div(other, rounding_mode="floor") * other - ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder) + floor = math.floor(fdiv(input, other, False, builder), _builder=builder) + ret = sub(input, mul(floor, other, True, builder), True, builder) return ret # % int elif scalar_ty.is_int(): @@ -309,7 +410,7 @@ def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.Propag def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: - input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input, other = binary_op_type_checking_impl(input, other, builder) input_sca_ty = input.type.scalar other_sca_ty = other.type.scalar if not input_sca_ty.is_int() or not other_sca_ty.is_int(): @@ -388,7 +489,7 @@ def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: if input_sca_ty.is_ptr(): raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) - return sub(_0, input, builder) + return sub(_0, input, True, builder) def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: @@ -755,9 +856,6 @@ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) - if (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): - assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" - if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): assert builder.codegen_fns.get( "convert_custom_types") is not None, "target doesn't provide conversion for this type." @@ -853,6 +951,8 @@ def _str_to_load_cache_modifier(cache_modifier): cache = ir.CACHE_MODIFIER.CA elif cache_modifier == ".cg": cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cv": + cache = ir.CACHE_MODIFIER.CV else: raise ValueError(f"Cache modifier {cache_modifier} not supported") return cache @@ -995,12 +1095,13 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ elt_ty = ptr_ty.element_ty # Treat `pointer_type` as `pointer_type` - if elt_ty == tl.int1: + is_bool = elt_ty == tl.int1 + if is_bool: elt_ty = tl.int8 ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) ptr = cast(ptr, ptr_ty, builder) - # Cast `other` into `ele_ty` type + # Cast `other` into `elt_ty` type if other is not None: other = cast(other, elt_ty, builder) @@ -1014,11 +1115,14 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ # Build IR if mask is None: - return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) else: - return tl.tensor( + ret = tl.tensor( builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, is_volatile), dst_ty) + if is_bool: + ret = cast(ret, tl.int1, builder) + return ret def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, @@ -1051,6 +1155,41 @@ def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void) +def tensormap_create( + desc_ptr: tl.tensor, + global_address: tl.tensor, + box_dim: List[tl.tensor], + global_dim: List[tl.tensor], + global_stride: List[tl.tensor], + element_stride: List[tl.tensor], + elem_type: int, + interleave_layout: int, + swizzle_mode: int, + fill_mode: int, + builder: ir.builder, +) -> tl.tensor: + assert not global_stride or global_stride[0].dtype == tl.int64 + return tl.tensor( + builder.create_tensormap_create( + desc_ptr.handle, + global_address.handle, + [x.handle for x in box_dim], + [x.handle for x in global_dim], + [x.handle for x in global_stride], + [x.handle for x in element_stride], + elem_type, + interleave_layout, + swizzle_mode, + fill_mode, + ), + tl.void, + ) + + +def tensormap_fenceproxy_acquire(desc_ptr: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_tensormap_fenceproxy_acquire(desc_ptr.handle), tl.void) + + def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder): # Store by a block pointer: `pointer_type>` # Block pointers can not have the `mask` argument @@ -1318,41 +1457,18 @@ def _str_to_dot_input_precision(input_precision, builder): def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() - def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): - if not options.allow_fp8e4nv: - assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( - ), "Dot op does not support fp8e4nv on CUDA arch < 90" - if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): - return - assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" - else: - if lhs_dtype.is_int() or rhs_dtype.is_int(): - assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" - assert lhs_dtype.is_int8() or lhs_dtype.is_uint8( - ), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" - elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): - if options.allow_fp8e4b15: - allowed_types = ['fp8e4nv', 'fp8e5', 'fp8e4b15'] - else: - allowed_types = ['fp8e4nv', 'fp8e5'] - - def _validate_dtype(dtype, allowed_types, operand_name): - if not any(getattr(dtype, f'is_{dtype_name}')() for dtype_name in allowed_types): - supported_types = ', '.join(allowed_types) - raise AssertionError(f"Only supports {supported_types}. {operand_name} ({dtype})") - - _validate_dtype(lhs_dtype, allowed_types, "First operand") - _validate_dtype(rhs_dtype, allowed_types, "Second operand") - else: - assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1( - ), f"Unsupported dtype {lhs_dtype}" - assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1( - ), f"Unsupported dtype {rhs_dtype}" - assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + # All combinations of supported fp8 x fp8 are permitted + pass + else: + assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported lhs dtype {lhs.dtype}" + assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported rhs dtype {rhs.dtype}" + assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" - assert lhs.type.is_block() and rhs.type.is_block() - assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): lhs = cast(lhs, tl.float16, builder) rhs = cast(rhs, tl.float16, builder) @@ -1367,13 +1483,13 @@ def _validate_dtype(dtype, allowed_types, operand_name): assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" assert lhs.shape[-1].value == rhs.shape[ -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" - assert lhs.shape[-2].value >= 16 and lhs.shape[-1].value >= 16 \ - and rhs.shape[-1].value >= 16, \ - f"All non-batch values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" + assert builder.codegen_fns.get("min_dot_size") is not None, "target doesn't provide lower shape bounds for dot." + min_dot_size = builder.codegen_fns["min_dot_size"](lhs.type, rhs.type) + assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \ + and rhs.shape[-1].value >= min_dot_size[1], \ + f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}" if lhs.type.scalar.is_int(): assert lhs.type.scalar == tl.int8, "only int8 supported!" - # TODO: This is CUDA specific, check if ROCm has the same limitation - assert lhs.shape[1].value >= 32, "small blocks not supported!" _0 = builder.get_int32(0) ret_scalar_ty = tl.int32 elif out_dtype.is_bf16(): @@ -1388,6 +1504,7 @@ def _validate_dtype(dtype, allowed_types, operand_name): M = lhs.type.shape[-2] N = rhs.type.shape[-1] + K = lhs.type.shape[-1] B = lhs.type.shape[0] if lhs_rank == 3 else None ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) if acc is None: @@ -1402,25 +1519,80 @@ def _validate_dtype(dtype, allowed_types, operand_name): max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default else: max_num_imprecise_acc = 0 + else: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K: + raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})") return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty) +def _str_to_fp_type(float_format: Optional[str]): + if float_format == 'e4m3': + return ir.F8F6F4TY.E4M3 + if float_format == 'e5m2': + return ir.F8F6F4TY.E5M2 + if float_format == 'e2m3': + return ir.F8F6F4TY.E2M3 + if float_format == 'e3m2': + return ir.F8F6F4TY.E3M2 + if float_format == 'e2m1': + return ir.F8F6F4TY.E2M1 + raise ValueError(f"Invalid float format: {float_format}.") + + +def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], + rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + #TODO: validate types. + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format_enum = _str_to_fp_type(lhs_format) + rhs_format_enum = _str_to_fp_type(rhs_format) + assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}" + assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}" + rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None + assert rhs_scale_is_none, "NYI: rhs_scale not supported" + + M = lhs.type.shape[-2] + K, N = rhs.type.shape[-2:] + PACKED = 2 if lhs_format == "e2m1" else 1 + assert K == PACKED * lhs.type.shape[ + -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert K >= 64, f"scaled_dot NYI for K < 64. Got {K=}" + B = lhs.type.shape[0] if lhs_rank == 3 else None + + ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) + _0 = builder.get_fp32(0) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle + return tl.tensor( + builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle, + rhs_format_enum, acc_handle), ret_ty) + + # ===----------------------------------------------------------------------===// # Indexing # ===----------------------------------------------------------------------===// def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + if condition.dtype != tl.int1: + warnings.warn( + f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}" + ) condition = cast(condition, tl.int1, builder) + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + # x, y are broadcasted if condition.type.is_block(): condition, x = broadcast_impl_value(condition, x, builder) x, y = broadcast_impl_value(x, y, builder) - condition, x = broadcast_impl_value(condition, x, builder) - - x, y = binary_op_type_checking_impl(x, y, builder, True, True) - if not condition.type.is_block(): + else: condition, _ = broadcast_impl_value(condition, x, builder) ret_ty = x.type return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) @@ -1533,15 +1705,18 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil prefix = " " + prefix new_args = [arg.handle for arg in args] - return tl.tensor(builder.create_print(prefix, hex, new_args), tl.void) + is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args] + return tl.tensor(builder.create_print(prefix, hex, new_args, is_signed), tl.void) + + +def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor: + if not builder.options.debug: + return + return tl.tensor(builder.create_assert(cond.handle, msg), tl.void) -def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: - cond_ty = cond.type - if not cond_ty.is_block(): - cond_ty = tl.block_type(cond_ty.scalar, (1, )) - cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) - return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) +def assume(cond, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_assume(cond.handle), tl.void) def _convert_elem_to_ir_value(builder, elem, require_i64): diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index de30cf260..eaaeffb68 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -4,11 +4,7 @@ from . import core from . import math -# constexpr utilities (triton metaprogramming sucks) - - -def _unwrap_if_constexpr(o): - return o.value if isinstance(o, core.constexpr) else o +# constexpr utilities def _log2(i: core.constexpr): @@ -39,7 +35,7 @@ def cdiv(x, div): :param x: the input number :type x: Block :param div: the divisor - :param div: Block + :type div: Block """ return (x + div - 1) // div @@ -76,9 +72,8 @@ def ravel(x): @jit def swizzle2d(i, j, size_i, size_j, size_g): """ - Transforms indices of a row-major :code:`size_i * size_j` matrix into those - of one where the indices are col-major for each group of :code:`size_g` - rows. + Transforms the indices of a row-major `size_i * size_j` matrix into + the indices of a column-major matrix for each group of `size_g` rows. For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will transform :: @@ -106,9 +101,11 @@ def swizzle2d(i, j, size_i, size_j, size_g): off_i = group_id * size_g # last group may have fewer rows size_g = core.minimum(size_i - off_i, size_g) + # linear index with respect to the first element in this group + ij = ij % size_gj # new row and column indices - new_i = off_i + (ij % size_g) - new_j = (ij % size_gj) // size_g + new_i = off_i + ij % size_g + new_j = ij // size_g return new_i, new_j @@ -128,7 +125,10 @@ def zeros(shape, dtype): @jit def zeros_like(input): """ - Creates a tensor of zeros with the same shape and type as a given tensor. + Returns a tensor of zeros with the same shape and type as a given tensor. + + :param input: input tensor + :type input: Tensor """ return zeros(input.shape, input.dtype) @@ -326,8 +326,8 @@ def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): y = core.reshape(x, shape) # slice left/right with 'stride' 2**(n_dims - i - 1) mask = core.arange(0, 2)[None, :, None] - left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape) - right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) + left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype) + right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype) left = core.reshape(left, x.shape) right = core.reshape(right, x.shape) # actual compare-and-swap @@ -335,7 +335,7 @@ def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): ileft = left.to(idtype, bitcast=True) iright = right.to(idtype, bitcast=True) ix = x.to(idtype, bitcast=True) - ret = ix ^ core.where((left > right) ^ flip, ileft ^ iright, zeros_like(ix)) + ret = ix ^ core.where((left > right) != flip, ileft ^ iright, zeros_like(ix)) return ret.to(x.dtype, bitcast=True) @@ -367,6 +367,16 @@ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core @core._tensor_member_fn @jit def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + """ + Sorts a tensor along a specified dimension. + + :param x: The input tensor to be sorted. + :type x: Tensor + :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported. + :type dim: int, optional + :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order. + :type descending: bool, optional + """ # handle default dimension or check that it is the most minor dim _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") @@ -381,8 +391,8 @@ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTE def _get_flip_dim(dim, shape): - dim = _unwrap_if_constexpr(dim) - shape = _unwrap_if_constexpr(shape) + dim = core._unwrap_if_constexpr(dim) + shape = core._unwrap_if_constexpr(shape) if dim is None: dim = len(shape) - 1 assert dim == len(shape) - 1, "Currently only support flipping the last dimension" @@ -422,15 +432,16 @@ def flip(x, dim=None): @jit def interleave(a, b): """ - Interleaves the values of two tensors along their last dimension. - - The two tensors must have the same shape. + Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape. + Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])` - Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])` + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor """ c = core.join(a, b) - assert isinstance(c.shape, list) if len(c.shape) == 1: # We must have interleaved two scalars. return c diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 73e618662..9f494a062 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -6,9 +6,9 @@ import inspect from typing import Dict -from ..testing import do_bench, do_bench_cudagraph from .jit import KernelInterface from .errors import OutOfResources +from .driver import driver class Autotuner(KernelInterface): @@ -24,9 +24,10 @@ def __init__( pre_hook=None, post_hook=None, prune_configs_by: Dict = None, - warmup=25, - rep=100, + warmup=None, + rep=None, use_cuda_graph=False, + do_bench=None, ): """ :param prune_configs_by: a dict of functions that are used to prune configs, fields: @@ -35,44 +36,51 @@ def __init__( 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. """ if not configs: - self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + self.configs = [ + Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0) + ] else: self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] + self.keys = key self.cache = {} self.arg_names = arg_names # Reset to zero or restore values - self.reset_idx = [] + self.reset_to_zero = [] if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - self.restore_idx = [] + self.reset_to_zero = list(reset_to_zero) + self.restore_value = [] if restore_value is not None: - self.restore_idx = [arg_names.index(k) for k in restore_value] + self.restore_value = list(restore_value) # Hook to reset or restore for required tensors - self.pre_hook = lambda args, reset_only=False: 0 - self.post_hook = lambda args, exception: 0 + self.pre_hook = lambda kwargs, reset_only=False: 0 + self.post_hook = lambda kwargs, exception: 0 + self.user_defined_pre_hook = False + self.user_defined_post_hook = False if pre_hook: self.pre_hook = pre_hook - elif (len(self.reset_idx) > 0 or len(self.restore_idx) > 0): + self.user_defined_pre_hook = True + elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0): - def _pre_hook(args, reset_only=False): - for i in self.reset_idx: - args[i].zero_() + def _pre_hook(kwargs, reset_only=False): + for name in self.reset_to_zero: + kwargs[name].zero_() if not reset_only: - self.restore_copies = [args[i].clone() for i in self.restore_idx] + self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} self.pre_hook = _pre_hook if post_hook: self.post_hook = post_hook - elif len(self.restore_idx) > 0: + self.user_defined_post_hook = True + elif len(self.restore_value) > 0: - def _post_hook(args, exception): - for i, j in enumerate(self.restore_idx): - args[j].copy_(self.restore_copies[i]) - self.restore_copies = [] + def _post_hook(kwargs, exception): + for name in self.restore_value: + kwargs[name].copy_(self.restore_copies[name]) + self.restore_copies = {} self.post_hook = _post_hook @@ -88,10 +96,40 @@ def _post_hook(args, exception): self.base_fn = fn while not inspect.isfunction(self.base_fn): self.base_fn = self.base_fn.fn + self.num_warmups = warmup self.num_reps = rep - import torch - self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available() + self.use_cuda_graph = use_cuda_graph + + # If we got explicitly called via the old interface, raise a warning + # and proceed with the old behavior. + if warmup is not None or rep is not None or use_cuda_graph: + import warnings + warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " + "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning, + stacklevel=1) + if use_cuda_graph: + from ..testing import do_bench_cudagraph + self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( + kernel_call, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + import triton.testing + self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( + kernel_call, + warmup=warmup if warmup is not None else 25, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + if do_bench is None: + self.do_bench = driver.active.get_benchmarker() + else: + self.do_bench = do_bench def _bench(self, *args, config, **meta): from ..compiler.errors import CompileTimeAssertionFailure @@ -109,7 +147,7 @@ def _bench(self, *args, config, **meta): def kernel_call(): if config.pre_hook: config.pre_hook(full_nargs) - self.pre_hook(args) + self.pre_hook(full_nargs) try: self.fn.run( *args, @@ -117,34 +155,26 @@ def kernel_call(): ) except Exception as e: try: - self.post_hook(args, exception=e) + self.post_hook(full_nargs, exception=e) finally: # Throw exception raised by `self.fn.run` raise - self.post_hook(args, exception=None) + self.post_hook(full_nargs, exception=None) try: - if self.use_cuda_graph: - import torch - with torch.cuda.stream(torch.cuda.Stream()): - bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median") - return bench_res - return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) except (OutOfResources, CompileTimeAssertionFailure): - return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")] + return [float("inf"), float("inf"), float("inf")] def run(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) used_cached_result = True if len(self.configs) > 1: all_args = {**self.nargs, **kwargs} - _args = [] - for name in self.arg_names: - if name in all_args: - _args.append(all_args[name]) - key = [_args[i] for i in self.key_idx] - for arg in _args: + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key = [_args[key] for key in self.keys if key in _args] + for _, arg in _args.items(): if hasattr(arg, "dtype"): key.append(str(arg.dtype)) key = tuple(key) @@ -157,7 +187,8 @@ def run(self, *args, **kwargs): bench_end = time.time() self.bench_time = bench_end - bench_start self.cache[key] = builtins.min(timings, key=timings.get) - self.pre_hook(args, reset_only=True) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) self.configs_timings = timings config = self.cache[key] else: @@ -167,7 +198,8 @@ def run(self, *args, **kwargs): print(f"Triton autotuning for function {self.base_fn.__name__} finished after " f"{self.bench_time:.2f}s; best config selected: {self.best_config};") if config.pre_hook is not None: - config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) ret = self.fn.run( *args, **kwargs, @@ -230,11 +262,16 @@ class Config: function are args. """ - def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps self.num_ctas = num_ctas self.num_stages = num_stages + self.num_buffers_warp_spec = num_buffers_warp_spec + self.num_consumer_groups = num_consumer_groups + self.reg_dec_producer = reg_dec_producer + self.reg_inc_consumer = reg_inc_consumer self.maxnreg = maxnreg self.pre_hook = pre_hook @@ -246,6 +283,10 @@ def all_kwargs(self): ("num_warps", self.num_warps), ("num_ctas", self.num_ctas), ("num_stages", self.num_stages), + ("num_buffers_warp_spec", self.num_buffers_warp_spec), + ("num_consumer_groups", self.num_consumer_groups), + ("reg_dec_producer", self.reg_dec_producer), + ("reg_inc_consumer", self.reg_inc_consumer), ("maxnreg", self.maxnreg), ) if v is not None } @@ -258,12 +299,16 @@ def __str__(self): res.append(f"num_warps: {self.num_warps}") res.append(f"num_ctas: {self.num_ctas}") res.append(f"num_stages: {self.num_stages}") + res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}") + res.append(f"num_consumer_groups: {self.num_consumer_groups}") + res.append(f"reg_dec_producer: {self.reg_dec_producer}") + res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") res.append(f"maxnreg: {self.maxnreg}") return ", ".join(res) def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, - warmup=25, rep=100, use_cuda_graph=False): + warmup=None, rep=None, use_cuda_graph=False, do_bench=None): """ Decorator for auto-tuning a :code:`triton.jit`'d function. @@ -303,18 +348,20 @@ def kernel(x_ptr, x_size, **META): :type restore_value: list[str] :param pre_hook: a function that will be called before the kernel is called. This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. - 'args': a list of arguments passed to the kernel. + 'kwargs': a dict of all arguments passed to the kernel. 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. :type pre_hook: lambda args, reset_only :param post_hook: a function that will be called after the kernel is called. This overrides the default post_hook used for 'restore_value'. - 'args': a list of arguments passed to the kernel. + 'kwargs': a dict of all arguments passed to the kernel. 'exception': the exception raised by the kernel in case of a compilation or runtime error. :type post_hook: lambda args, exception - :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). :type warmup: int - :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :param rep: repetition time (in ms) to pass to benchmarking (deprecated). :type rep: int + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles """ def decorator(fn): diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index d7baeb286..20da2bc25 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -40,11 +40,13 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): if scheme == 'posix_local': scheme = 'posix_prefix' py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] - include_dirs = include_dirs + [srcdir, py_include_dir] - cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] + custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH')) + include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] + # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] - cc_cmd += [f"-I{dir}" for dir in include_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] ret = subprocess.check_call(cc_cmd) if ret == 0: return so diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index bd3c29b99..82b2fea37 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -5,19 +5,24 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, List, Optional +import base64 import hashlib +def get_home_dir(): + return os.getenv("TRITON_HOME", Path.home()) + + def default_cache_dir(): - return os.path.join(Path.home(), ".triton", "cache") + return os.path.join(get_home_dir(), ".triton", "cache") def default_override_dir(): - return os.path.join(Path.home(), ".triton", "override") + return os.path.join(get_home_dir(), ".triton", "override") def default_dump_dir(): - return os.path.join(Path.home(), ".triton", "dump") + return os.path.join(get_home_dir(), ".triton", "dump") class CacheManager(ABC): @@ -48,12 +53,12 @@ def __init__(self, key, override=False, dump=False): self.key = key self.lock_path = None if dump: - self.cache_dir = default_dump_dir() + self.cache_dir = os.getenv("TRITON_DUMP_DIR", "").strip() or default_dump_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) elif override: - self.cache_dir = default_override_dir() + self.cache_dir = os.getenv("TRITON_OVERRIDE_DIR", "").strip() or default_override_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) else: # create cache directory if it doesn't exist @@ -116,14 +121,18 @@ def put(self, data, filename, binary=True) -> str: rnd_id = str(uuid.uuid4()) # we use the PID in case a bunch of these around so we can see what PID made it pid = os.getpid() - # use tempfile to be robust against program interruptions - temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") + os.makedirs(temp_dir, exist_ok=True) + temp_path = os.path.join(temp_dir, filename) + mode = "wb" if binary else "w" with open(temp_path, mode) as f: f.write(data) # Replace is guaranteed to be atomic on POSIX systems if it succeeds # so filepath cannot see a partial write os.replace(temp_path, filepath) + os.removedirs(temp_dir) return filepath @@ -247,6 +256,11 @@ def put_group(self, filename: str, group: Dict[str, str]): __cache_cls_nme = "DEFAULT" +def _base64(key): + # Assume key is a hex string. + return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") + + def get_cache_manager(key) -> CacheManager: import os @@ -260,15 +274,15 @@ def get_cache_manager(key) -> CacheManager: __cache_cls = getattr(module, clz_nme) __cache_cls_nme = user_cache_manager - return __cache_cls(key) + return __cache_cls(_base64(key)) def get_override_manager(key) -> CacheManager: - return __cache_cls(key, override=True) + return __cache_cls(_base64(key), override=True) def get_dump_manager(key) -> CacheManager: - return __cache_cls(key, dump=True) + return __cache_cls(_base64(key), dump=True) def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): @@ -278,4 +292,4 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): for kw in kwargs: key = f"{key}-{kwargs.get(kw)}" key = hashlib.sha256(key.encode("utf-8")).hexdigest() - return key + return _base64(key) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index a82832ecf..0aeaff73a 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1,3 +1,5 @@ +import ast +import textwrap import inspect from typing import Tuple @@ -73,12 +75,14 @@ def materialize_pointers(self, boundary_check): class InterpreterOptions: extern_libs: dict = None debug: bool = False + sanitize_overflow: bool = True arch: str = None - allow_fp8e4nv: bool = True - allow_fp8e4b15: bool = True + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15") + deprecated_fp8_dtypes: Tuple[str] = () default_dot_input_precision: str = "tf32" allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") max_num_imprecise_acc_default: int = 0 + backend_name: str = "interpreter" def _get_signed_np_dtype(dtype): @@ -238,6 +242,7 @@ def __init__(self) -> None: self.options = InterpreterOptions() self.codegen_fns = {} self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (16, 16, 16) def set_grid_idx(self, x, y, z): if not x < self.grid_dim[0]: @@ -460,6 +465,8 @@ def binary_op(self, lhs, rhs, op): create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + create_int_to_ptr = create_bitcast + create_ptr_to_int = create_bitcast def create_idiv(self, lhs, rhs): # Triton has IEEE, not numpy/torch, semantics for %, and those carry @@ -584,12 +591,6 @@ def create_expand_dims(self, arg, axis): def create_broadcast(self, arg, shape): return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) - def create_int_to_ptr(self, val, dst_ty): - return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) - - def create_ptr_to_int(self, val, dst_ty): - return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) - def create_cat(self, lhs, rhs): return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) @@ -628,7 +629,10 @@ def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): raise NotImplementedError("inline_asm not supported in interpreter mode") - def create_print(self, prefix, hex, values): + def create_print(self, prefix, hex, values, isSigned): + # NOTE: the `isSigned` variable is not really used here; because Signness is already known + # by `values` themselves in python interpreter, thus not really needed here; + # it is only used for triton PrintOpToLLVM to correctly construct the format specifier. # Interpreter's device_print function has a different format than Triton's device_print msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" if prefix: @@ -640,9 +644,12 @@ def create_print(self, prefix, hex, values): if hex: np.set_printoptions(formatter=None) - def create_assert(self, condition, message, fileName, funcName, lineNo): + def create_assert(self, condition, message): # Interpreter's device_assert function has a different format than Triton's device_assert - assert condition, f"{message} in {fileName}:{funcName}:{lineNo}" + assert condition, f"{message}" + + def create_assume(self, condition): + assert condition, "Assume failed" def create_barrier(self): # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter @@ -723,7 +730,7 @@ def to_tensor(self, ret, dtype): if hasattr(ret, "shape") and ret.shape: ret_type = tl.block_type(dtype, ret.shape) else: - ret = np.array([ret], dtype=_get_np_dtype(dtype)) + ret = np.array([ret]).astype(_get_np_dtype(dtype)) ret_type = dtype return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) @@ -990,14 +997,15 @@ def _set_attr(input, values, name): def _patch_lang(fn): - lang = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]] - assert len(lang) == 1, "triton.language must be visible from within jit'd function" - _patch_builtin(lang[0], interpreter_builder) - _patch_builtin(lang[0].tensor, interpreter_builder) - if lang[0] == tl: - _patch_builtin(lang[0].math, interpreter_builder) - _patch_lang_tensor(lang[0].tensor) - _patch_lang_core(lang[0]) + langs = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]] + assert len(langs) >= 1, "triton.language must be visible from within jit'd function" + for lang in langs: + _patch_builtin(lang, interpreter_builder) + _patch_builtin(lang.tensor, interpreter_builder) + if lang == tl: + _patch_builtin(lang.math, interpreter_builder) + _patch_lang_tensor(lang.tensor) + _patch_lang_core(lang) # TODO: wrap everything in triton tensors @@ -1098,30 +1106,130 @@ def __call__(self, *args_dev, **kwargs): self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) +class ASTTransformer(ast.NodeTransformer): + + def visit_Assign(self, node): + names = [] + for target in node.targets: + names += [self.visit(target)] + if len(names) > 1: + raise ValueError("Multiple assignments are not supported") + # Modify the assignment x = value to + # triton.language.semantic.to_tensor(value, interpreter_builder, False) + node.value = ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()), + attr='semantic', ctx=ast.Load()), attr='to_tensor', ctx=ast.Load()), + args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()), + ast.Constant(value=False)], keywords=[]) + return node + + +class FunctionRewriter: + ast_transformer = ASTTransformer() + + def __init__(self, fn, **kwargs): + self.fn = fn + self.kwargs = kwargs + self.filename: str = "" + # Absolute line number in the file + self.def_file_lineno: int = 0 + + def rewrite_ast(self): + # If exception is raise, it means the function does not have source code available, + # e.g., dynamically generated functions, we cannot rewrite it so just return the original function + try: + lines, _ = inspect.getsourcelines(self.fn) + except Exception: + return self.fn + + # truncate lines before def + # @triton.autotune(...) + # ... + # @triton.jit + # ... + # def foo(...): <- this line is the function definition + self.filename, self.def_file_lineno = self._get_jit_fn_file_line() + self.def_lineno = self._find_def(lines) + src = self._prepare_source(lines) + transformed_ast = self._transform_ast(src) + return self._compile_and_exec(transformed_ast) + + def _get_jit_fn_file_line(self): + from .jit import get_jit_fn_file_line, JITFunction + return get_jit_fn_file_line(JITFunction(self.fn)) + + def _find_def(self, lines): + def_lineno = 0 + # Line numbers start from 1 + for i, line in enumerate(lines): + if line.strip().startswith("def "): + def_lineno = i + 1 + return def_lineno + + def _prepare_source(self, lines): + lines = lines[self.def_lineno - 1:] + src = ''.join(lines) + return textwrap.dedent(src) + + def _transform_ast(self, src): + # src is like: + # 1: def foo(...): + # 2: ... + parsed_ast = ast.parse(src) + transformed_ast = self.ast_transformer.visit(parsed_ast) + ast.fix_missing_locations(transformed_ast) + inc_lineno = self.def_file_lineno - 1 + ast.increment_lineno(transformed_ast, inc_lineno) + return transformed_ast + + def _compile_and_exec(self, transformed_ast): + compiled_code = compile(transformed_ast, filename=self.filename, mode='exec') + local_namespace = {**self.kwargs} + fn_globals = self.fn.__globals__ + for key, value in globals().items(): + if key not in fn_globals: + fn_globals[key] = value + exec(compiled_code, fn_globals, local_namespace) + return local_namespace[self.fn.__name__] + + class InterpretedFunction: + # Cache all rewritten functions + rewritten_fn = {} - def __init__(self, fn) -> None: + def __init__(self, fn, **kwargs) -> None: self.fn = fn + self.rewriter = FunctionRewriter(fn, **kwargs) def run(*args, **kwargs): grid = kwargs["grid"] - return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs) self.run = run signature = inspect.signature(fn) self.arg_names = [v.name for v in signature.parameters.values()] + def rewrite(self): + if self.fn not in self.rewritten_fn: + self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast() + return self.rewritten_fn[self.fn] + @property def __name__(self): return self.fn.__name__ def __getitem__(self, grid): - return GridExecutor(self.fn, self.arg_names, grid) + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid) def __call__(self, *args, **kwargs): # This is a device function call _patch_lang(self.fn) + fn = self.rewrite() try: - return self.fn(*args, **kwargs) + return fn(*args, **kwargs) except Exception as e: raise InterpreterError(repr(e)) from e diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index a12b1d235..45178a40b 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -73,8 +73,32 @@ def __init__(self, name, globals, src) -> None: def ret(self): return self.hasher.hexdigest() + def _is_triton_builtin(self, node, func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + def _update_hash(self, func): + if isinstance(func, JITFunction): + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & func.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = func.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + self.used_global_vals.update(func.used_global_vals) + # update hash + func_key = func.cache_key + func_key += str(getattr(func, "noinline", False)) + self.hasher.update(func_key.encode("utf-8")) + def visit_Name(self, node): - if type(node.ctx) == ast.Store: + if type(node.ctx) is ast.Store: return node.id if node.id in self.local_names: @@ -93,14 +117,14 @@ def visit_Name(self, node): and not self.visiting_arg_default_value # It would be pretty evil if someone did `import x` and then # `x = blah`. - and type(val) != ModuleType + and type(val) is not ModuleType # It would be pretty evil if we used function `foo` inside of # `bar` and then someone did `foo = baz`. and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # - and node.id not in self.supported_python_builtins # - ): + and node.id not in self.supported_python_builtins): self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + self._update_hash(val) return val def visit_Tuple(self, node): @@ -114,52 +138,9 @@ def visit_Attribute(self, node): lhs = self.visit(lhs.value) if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): return None - return getattr(lhs, node.attr) - - def visit_Call(self, node): - - def is_triton_builtin(func): - if inspect.isbuiltin(node.func): - return True - module = getattr(func, "__module__", "") - return module.startswith(TRITON_MODULE) - - func = self.visit(node.func) - assert func is None or is_triton_builtin(func) or isinstance( - func, JITFunction - ), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this' - - # Traverse arguments as well as node.func so we can find JITFunctions - # passed to tl.reduce or tl.associative_scan as the combine_fn - for obj in itertools.chain( - (func, ), - map(self.visit, node.args), - (self.visit(kw.value) for kw in node.keywords), - ): - if not isinstance(obj, JITFunction): - continue - if is_triton_builtin(obj): - continue - - func_cache_key = obj.cache_key - - # Merge our used_global_vals with those of the called function, - # after checking that all overlapping values are consistent. - for k in self.used_global_vals.keys() & obj.used_global_vals.keys(): - var_name, _ = k - v1, _ = self.used_global_vals[k] - v2, _ = obj.used_global_vals[k] - if v1 != v2: - raise RuntimeError( - f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." - ) - - self.used_global_vals.update(obj.used_global_vals) - - noinline = str(getattr(obj, "noinline", False)) - - key = func_cache_key + noinline - self.hasher.update(key.encode("utf-8")) + ret = getattr(lhs, node.attr) + self._update_hash(ret) + return ret def visit_FunctionDef(self, node): # Save the local name, which may hide the global name. @@ -249,10 +230,12 @@ def _normalize_ty(ty) -> str: class KernelParam: """Represents a parameter (name plus metadata) to a @jit'ed function.""" - def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool): + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool, + do_not_specialize_on_alignment: bool): self.num = num self._param = param self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment @cached_property def name(self): @@ -292,13 +275,13 @@ def has_default(self): return self._param.default != inspect.Parameter.empty -def compute_spec_key(v): +def compute_spec_key(v, align): - if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): + if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): return "D" elif isinstance(v, int): # bool is a subclass of int, so we don't check explicitly above. - if (v % 16 == 0): + if align and (v % 16 == 0): return "D" elif v == 1: return "1" @@ -323,6 +306,8 @@ def mangle_type(arg, is_const=False): return "i64" elif isinstance(arg, float): return "fp32" + elif hasattr(arg, "tma_desc_cpu_ptr"): + return "nvTmaDesc" else: # dtypes are hashable so we can memoize this mapping: dsk = (arg.dtype, is_const) @@ -357,7 +342,7 @@ def serialize_specialization_data(name, signature, constants, attrs, options, ke return serialized_obj -def create_function_from_signature(sig, kparams): +def create_function_from_signature(sig, kparams, backend): """ Equivalent to sig.bind followed by apply_defaults. This generates a native Python function (using exec) which can be memoized on a per-kernel @@ -387,7 +372,10 @@ def create_function_from_signature(sig, kparams): else: non_constexpr_vals.append(name) if not kp.do_not_specialize: - specialisations.append('compute_spec_key(%s)' % name) + if not kp.do_not_specialize_on_alignment: + specialisations.append('compute_spec_key(%s, align=True)' % name) + else: + specialisations.append('compute_spec_key(%s, align=False)' % name) if kp.annotation_type: signature_types.append('"%s"' % kp.annotation_type) else: @@ -413,7 +401,7 @@ def create_function_from_signature(sig, kparams): } func_namespace['mangle_type'] = mangle_type - func_namespace['compute_spec_key'] = compute_spec_key + func_namespace['compute_spec_key'] = backend.compute_spec_key # Execute the function string in func_namespace to create the function exec(func_body, func_namespace) @@ -454,7 +442,9 @@ def create_function_from_signature(sig, kparams): class JITFunction(KernelInterface[T]): # Hook for inspecting compiled functions and modules cache_hook = None - divisibility = 16 + # Hook to signal that a kernel is done compiling and inspect compiled function. + # cache_hook will always be called before compilation and compiled_hook after. + compiled_hook = None @staticmethod def _key_of(arg): @@ -476,42 +466,6 @@ def _key_of(arg): else: raise TypeError(f"Unsupported type {type(arg)} for {arg}") - @staticmethod - def _spec_of(arg): - if hasattr(arg, "data_ptr"): - return arg.data_ptr() % JITFunction.divisibility == 0 - elif isinstance(arg, int): - return (arg % 16 == 0, arg == 1) - return (arg is None, ) - - def _get_config(self, *args): - from ..compiler import AttrsDescriptor - - def is_divisible_by_16(x): - if hasattr(x, "data_ptr"): - return x.data_ptr() % JITFunction.divisibility == 0 - elif isinstance(x, int): - return x % JITFunction.divisibility == 0 - if x is None: - return True - return False - - divisible_by_16 = { - param.num - for param, arg in zip(self.params, args) - if is_divisible_by_16(arg) and not param.do_not_specialize - } - equal_to_1 = { - param.num - for param, arg in zip(self.params, args) - if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize - } - # folded equal_to_1 and None - # TODO: method to collect all folded args - return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1)) - # return _triton.code_gen.instance_descriptor(divisible_by_16, - # equal_to_1) - @staticmethod def _type_of(key, is_const=False): # `None` is nullptr. Implicitly convert to *i8. @@ -537,8 +491,11 @@ def _call_hook( constants, options, configs, + is_warmup, + before, ): - if JITFunction.cache_hook is None: + hook = JITFunction.cache_hook if before else JITFunction.compiled_hook + if hook is None: return False name = self.fn.__name__ @@ -567,14 +524,15 @@ def __init__(self, module, name, jit_function): 'extern_libs': options.extern_libs, 'configs': configs, 'specialization_data': specialization_data, + 'is_warmup': is_warmup, } - return JITFunction.cache_hook( + return hook( key=key, repr=repr, fn=JitFunctionInfo(module, name, self), compile={"key": key, **kwargs}, - is_manual_warmup=False, + is_manual_warmup=is_warmup, already_compiled=False, ) @@ -586,7 +544,7 @@ def add_pre_run_hook(self, hook): assert callable(hook) self.pre_run_hooks.append(hook) - def create_binder(self): + def create_binder(self, backend): """ Precompute as much as possible. """ @@ -595,7 +553,7 @@ def create_binder(self): self.compile = compile self.ASTSource = ASTSource self.make_backend = make_backend - self.binder = create_function_from_signature(self.signature, self.params) + self.binder = create_function_from_signature(self.signature, self.params, backend) self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] self.specialised_indices = [ @@ -603,17 +561,21 @@ def create_binder(self): ] def run(self, *args, grid, warmup, **kwargs): + kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" + # parse options + from ..compiler import make_backend device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) - kwargs["debug"] = self.debug + target = driver.active.get_current_target() + backend = make_backend(target) # Execute pre run hooks with args and kwargs for hook in self.pre_run_hooks: hook(*args, **kwargs) if self.binder is None: - self.create_binder() + self.create_binder(backend) bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) @@ -623,8 +585,6 @@ def run(self, *args, grid, warmup, **kwargs): if kernel is None: # Kernel is not cached; we have to compile. - target = driver.active.get_current_target() - backend = self.make_backend(target) options = backend.parse_options(kwargs) # deprecated arguments @@ -645,17 +605,18 @@ def run(self, *args, grid, warmup, **kwargs): sigvals = sig_and_spec[:len(sigkeys)] signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} - configs = (self._get_config(*bound_vals), ) + configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) + constant_params = configs[0].get_constants() constants = { p.name: v for (v, p) in zip(bound_vals, self.params) - if p.is_constexpr or p.num in configs[0].equal_to_1 or v is None + if p.is_constexpr or (p.num in constant_params) or v is None } for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if self._call_hook(key, signature, device, constants, options, configs): + if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): return None # compile the kernel src = self.ASTSource(self, signature, constants, configs[0]) @@ -665,10 +626,11 @@ def run(self, *args, grid, warmup, **kwargs): options=options.__dict__, ) self.cache[device][key] = kernel + self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) # Check that used global values have not changed. not_present = object() - for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items(): + for (name, _), (val, globals_dict) in self.used_global_vals.items(): if (newVal := globals_dict.get(name, not_present)) != val: raise RuntimeError( f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") @@ -692,15 +654,17 @@ def run(self, *args, grid, warmup, **kwargs): self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) return kernel - def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None, repr=None, - launch_metadata=None): + def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None, + noinline=None, repr=None, launch_metadata=None): do_not_specialize = do_not_specialize if do_not_specialize else [] + do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else [] self.fn = fn self.module = fn.__module__ self.version = version self.signature = inspect.signature(fn) self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment self.starting_line_number = inspect.getsourcelines(fn)[1] self.repr = lambda _: fn.__name__ if repr is None else repr(_) self.launch_metadata = launch_metadata @@ -709,8 +673,9 @@ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinlin self.params = [] for i, param in enumerate(self.signature.parameters.values()): - dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize) - self.params.append(KernelParam(i, param, dns)) + dns = i in do_not_specialize or param.name in do_not_specialize + dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment + self.params.append(KernelParam(i, param, dns, dns_oa)) # function source code (without decorators) self.src = textwrap.dedent(inspect.getsource(fn)) @@ -733,7 +698,6 @@ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinlin # JITFunction can be instantiated as kernel # when called with a grid using __getitem__ self.kernel = None - self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug self.noinline = noinline # TODO(jlebar): Remove uses of these fields outside this file, then @@ -764,7 +728,8 @@ def warmup(self, *args, grid, **kwargs): return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) def preload(self, specialization_data): - from ..compiler import AttrsDescriptor, compile, ASTSource + from ..compiler import compile, ASTSource + from triton.backends.compiler import AttrsDescriptor import json import triton.language as tl device = driver.active.get_current_device() @@ -828,6 +793,7 @@ def jit( repr: Optional[Callable] = None, launch_metadata: Optional[Callable] = None, do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, debug: Optional[bool] = None, noinline: Optional[bool] = None, ) -> Callable[[T], JITFunction[T]]: @@ -841,6 +807,7 @@ def jit( repr: Optional[Callable] = None, launch_metadata: Optional[Callable] = None, do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, debug: Optional[bool] = None, noinline: Optional[bool] = None, ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: @@ -866,12 +833,15 @@ def decorator(fn: T) -> JITFunction[T]: assert callable(fn) if os.getenv("TRITON_INTERPRET", "0") == "1": from .interpreter import InterpretedFunction - return InterpretedFunction(fn) + return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, + noinline=noinline, repr=repr, launch_metadata=launch_metadata) else: return JITFunction( fn, version=version, do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, noinline=noinline, repr=repr, @@ -909,6 +879,10 @@ def __init__(self, dtype): def data_ptr(): return 0 # optimistically assumes multiple of 16 + @staticmethod + def ptr_range(): + return 0 # optimistically assumes 32 bit pointer range + class TensorWrapper: @@ -937,6 +911,9 @@ def cpu(self): def copy_(self, other): self.base.copy_(other.base) + def clone(self): + return TensorWrapper(self.base.clone(), self.dtype) + def to(self, device): return TensorWrapper(self.base.to(device), self.dtype) @@ -954,3 +931,21 @@ def reinterpret(tensor, dtype): return TensorWrapper(tensor, dtype) else: raise TypeError(f"Cannot reinterpret a {type(tensor)}.") + + +def get_jit_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line diff --git a/python/triton/testing.py b/python/triton/testing.py index 0c8d4bcea..71cb8ab1e 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from typing import Any, Dict, List from . import language as tl +from . import runtime def nvsmi(attrs): @@ -16,7 +17,19 @@ def nvsmi(attrs): return ret -def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"): +def _summarize_statistics(times, quantiles, return_mode): + import torch + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times.tolist() + return getattr(torch, return_mode)(times).item() + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): """ Benchmark the runtime of the provided function. @@ -26,61 +39,60 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"): :type rep: int :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". + :type return_mode: str """ import torch - assert return_mode in ["min", "max", "mean", "median"] + assert return_mode in ["min", "max", "mean", "median", "all"] - if torch.cuda.current_stream() == torch.cuda.default_stream(): - raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.") - # warmup - fn() - # step 1 - we estimate the amount of time the kernel call takes - # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point - # but it is probably good enough - if grad_to_none is not None: - for x in grad_to_none: - x.detach_() - x.requires_grad_(True) - x.grad = None - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): + with torch.cuda.stream(torch.cuda.Stream()): + # warmup fn() - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - g.replay() - end_event.record() - torch.cuda.synchronize() - estimate_ms = start_event.elapsed_time(end_event) - n_repeat = max(1, int(rep / estimate_ms)) - # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize - # host overhead - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - for i in range(n_repeat): - if grad_to_none is not None: - for x in grad_to_none: - x.grad = None - fn() - torch.cuda.synchronize() - # measure time and return - ret = [] - n_retries = 10 - for i in range(n_retries): + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive, + # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2 + # cache flush). start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() - g.replay() + for _ in range(5): + fn() end_event.record() torch.cuda.synchronize() - ret += [start_event.elapsed_time(end_event) / n_repeat] - times = torch.tensor(ret) - return getattr(torch, return_mode)(times).item() - - -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", - device_type="cuda"): + estimate_ms = start_event.elapsed_time(end_event) / 5 + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for _ in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -94,25 +106,18 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu :param grad_to_none: Reset the gradient of the provided tensor to None :type grad_to_none: torch.tensor, optional :param quantiles: Performance percentile to return in addition to the median. - :type quantiles: list[float] - :param fast_flush: Use faster kernel to flush L2 between measurements - :type fast_flush: bool + :type quantiles: list[float], optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str """ - assert return_mode in ["min", "max", "mean", "median"] + assert return_mode in ["min", "max", "mean", "median", "all"] import torch - di = torch._dynamo.device_interface.get_interface_for_device(device_type) + di = runtime.driver.active.get_device_interface() fn() di.synchronize() - # We maintain a buffer of 256 MB that we clear - # before each kernel call to make sure that the L2 - # doesn't contain any input data before the run - if fast_flush: - cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device_type) - else: - cache = torch.empty(int(256e6), dtype=torch.int8, device=device_type) + cache = runtime.driver.active.get_empty_cache_for_benchmark() # Estimate the runtime of the function start_event = di.Event(enable_timing=True) @@ -150,15 +155,24 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu # Record clocks di.synchronize() times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) - if quantiles is not None: - ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() - if len(ret) == 1: - ret = ret[0] - return ret - return getattr(torch, return_mode)(times).item() + return _summarize_statistics(times, quantiles, return_mode) def assert_close(x, y, atol=None, rtol=None, err_msg=''): + """ + Asserts that two inputs are close within a certain tolerance. + + :param x: The first input. + :type x: scala, list, numpy.ndarray, or torch.Tensor + :param y: The second input. + :type y: scala, list, numpy.ndarray, or torch.Tensor + :param atol: The absolute tolerance. Default value is 1e-2. + :type atol: float, optional + :param rtol: The relative tolerance. Default value is 0. + :type rtol: float, optional + :param err_msg: The error message to use if the assertion fails. + :type err_msg: str + """ import numpy as np import torch @@ -213,7 +227,6 @@ def __init__( ylabel: str = '', x_log: bool = False, y_log: bool = False, - color=None, styles=None, ): """ @@ -245,6 +258,8 @@ def __init__( :type x_log: bool, optional :param y_log: Whether the y axis should be log scale. :type y_log: bool, optional + :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle. + :type styles: list[tuple[str, str]] """ self.x_names = x_names self.x_vals = x_vals diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 872332b03..443341fa0 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -7,6 +7,7 @@ from typing import List import triton +import triton.backends from triton.compiler.code_generator import kernel_suffix from triton.backends.nvidia.driver import ty_to_cpp @@ -92,30 +93,39 @@ def constexpr(s): hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} - constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} constants = {k: v for k, v in constants.items() if v is not None} - signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + signature = { + kernel.arg_names[i]: s.split(":")[0] + for i, s in enumerate(signature) + if kernel.arg_names[i] not in constants + } const_sig = 'x'.join([str(v) for v in constants.values()]) - doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] + doc_string = [f"{k}={v}" for k, v in constants.items()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] # compile ast into cubin for h in hints.values(): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" - divisible_by_16 = [i for i, h in hints.items() if h == 16] - equal_to_1 = [i for i, h in hints.items() if h == 1] - attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) - for i in equal_to_1: - constants.update({i: 1}) + attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints) + for p, v in attrs.get_constants().items(): + constants.update({kernel.arg_names[p]: v}) src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts) arg_names = [] arg_types = [] - for i in signature.keys(): - if i not in equal_to_1: - arg_names += [kernel.arg_names[i]] - arg_types += [signature[i]] + arg_names_not_1 = [] + arg_types_not_1 = [] + for i, arg_name in enumerate(kernel.arg_names): + if arg_name not in constants: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) + arg_names_not_1.append(arg_name) + arg_types_not_1.append(signature[arg_name]) + elif i in attrs.equal_to_1: + arg_names.append(arg_name) + arg_types.append(signature[arg_name]) # dump C stub code suffix = kernel_suffix(signature.values(), attrs) @@ -126,10 +136,10 @@ def constexpr(s): "triton_kernel_name": args.kernel_name, "bin_size": len(hex_), "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), - "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), - "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), - "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), - "num_args": len(arg_names), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]), + "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1]), + "num_args": len(arg_names_not_1), "kernel_docstring": doc_string, "shared": ccinfo.metadata.shared, "num_warps": args.num_warps, diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index 1e309a2e4..002c4e9b5 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -26,8 +26,6 @@ import subprocess import tempfile -from ..common.backend import path_to_cuobjdump, path_to_nvdisasm - FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') @@ -77,10 +75,14 @@ def get_sass(cubin_asm, fun=None): return sass +@functools.lru_cache() +def path_to_cuobjdump(): + from triton.backends.nvidia.compiler import _path_to_binary + return _path_to_binary("cuobjdump") + + def extract(file_path, fun): cuobjdump, _ = path_to_cuobjdump() - nvdisasm, _ = path_to_nvdisasm() - os.environ["NVDISASM_PATH"] = nvdisasm if fun is None: sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) else: diff --git a/python/triton/tools/experimental_descriptor.py b/python/triton/tools/experimental_descriptor.py new file mode 100644 index 000000000..6077cab6f --- /dev/null +++ b/python/triton/tools/experimental_descriptor.py @@ -0,0 +1,32 @@ +import torch + +import triton + + +class TmaDescKernelParam: + TMA_DESC_SIZE = 128 + + def __init__(self, ptr, dims, block_dims, element_size): + self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device="cpu") + assert len(dims) == len(block_dims) + assert 1 <= len(dims) <= 2 + assert self.desc.data_ptr() % 64 == 0 + + if len(dims) == 1: + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, + self.desc.data_ptr()) + else: + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], + block_dims[1], element_size, self.desc.data_ptr()) + + # Return a CUtensorMap* pointer in host memory + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + +def create_1d_tma_descriptor(ptr, dim, block_dim, element_size): + return TmaDescKernelParam(ptr, [dim], [block_dim], element_size) + + +def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size): + return TmaDescKernelParam(ptr, [dim1, dim0], [block_dim1, block_dim0], element_size) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 5ada4d4d5..e0220a45c 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -123,7 +123,7 @@ def benchmark(size, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) - gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6 + gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 04873cd51..d08afb1e5 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -28,6 +28,15 @@ from triton.runtime import driver +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', + 'gfx90a', 'gfx908') + + def naive_softmax(x): """Compute row-wise softmax of X using native pytorch @@ -123,7 +132,7 @@ def softmax(x): # way so you don't have to come up with manual heuristics yourself. num_warps = 8 - # Number of software piepling stages. + # Number of software pipelining stages. num_stages = 4 if SIZE_SMEM > 200000 else 2 # Allocate output @@ -137,7 +146,25 @@ def softmax(x): kernel._init_handles() n_regs = kernel.n_regs size_smem = kernel.metadata.shared - occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) + if is_hip(): + # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. + # However, this is not always the case. In most cases all registers can be used as regular purpose registers. + # ISA SECTION (3.6.4 for CDNA3) + # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used + # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total + # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is + # not required to be equal numbers of both types. + if is_cdna(): + NUM_GPRS = NUM_REGS * 2 + + # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. + # When we divide this number with WARP_SIZE we get maximum number of waves that can + # execute on a CU (multi-processor) in parallel. + MAX_NUM_THREADS = properties["max_threads_per_sm"] + max_num_waves = MAX_NUM_THREADS // WARP_SIZE + occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps + else: + occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) occupancy = min(occupancy, SIZE_SMEM // size_smem) num_programs = NUM_SM * occupancy kernels[BLOCK_SIZE] = (kernel, num_programs) @@ -204,7 +231,7 @@ def benchmark(M, N, provider): ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': ms = triton.testing.do_bench(lambda: softmax(x)) - gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) + gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 69522f787..815350905 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -130,7 +130,7 @@ # group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # # *Within groups*, programs are ordered in a column-major order # # Row-id of the program in the *launch grid* -# pid_m = first_pid_m + (pid % group_size_m) +# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) # # Col-id of the program in the *launch grid* # pid_n = (pid % num_pid_in_group) // group_size_m # @@ -206,19 +206,19 @@ def get_hip_autotune_config(): return [ triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, - num_warps=4, num_stages=0), + num_warps=4, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, - num_warps=8, num_stages=0), + num_warps=8, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, - num_warps=8, num_stages=0), + num_warps=8, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, - num_warps=4, num_stages=0), + num_warps=4, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, - num_warps=4, num_stages=0), + num_warps=4, num_stages=2), ] @@ -269,7 +269,7 @@ def matmul_kernel( group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 457cef466..a234153a0 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -232,8 +232,8 @@ def forward(ctx, x, normalized_shape, weight, bias, eps): # reshape input data into 2D tensor x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape - mean = torch.empty((M, ), dtype=torch.float32, device='cuda') - rstd = torch.empty((M, ), dtype=torch.float32, device='cuda') + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) @@ -262,9 +262,9 @@ def backward(ctx, dy): if N <= 4096: GROUP_SIZE_M = 128 if N <= 1024: GROUP_SIZE_M = 256 # allocate output - locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda') - _dw = torch.empty((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) - _db = torch.empty((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) + locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device) + _dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) + _db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) dw = torch.empty((N, ), dtype=w.dtype, device=w.device) db = torch.empty((N, ), dtype=w.dtype, device=w.device) dx = torch.empty_like(dy) @@ -353,12 +353,12 @@ def y_fwd(): # forward pass if mode == 'forward': - gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6 + gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) # backward pass if mode == 'backward': y = y_fwd() - gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704 + gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) # noqa: F811, E704 ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 09b1f4ba0..09efc06de 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -3,11 +3,13 @@ =============== This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + Credits: OpenAI kernel team Extra Credits: -- Original flash attention paper (https://arxiv.org/abs/2205.14135) -- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) """ @@ -75,13 +77,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, # return acc, l_i, m_i -# We don't run auto-tuning every time to keep the tutorial fast. Uncommenting +# We don't run auto-tuning every time to keep the tutorial fast. Keeping # the code below and commenting out the equivalent parameters is convenient for # re-tuning. configs = [ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ for BM in [64, 128]\ - for BN in [64, 128]\ + for BN in [32, 64]\ for s in ([1] if is_hip() else [3, 4, 7])\ for w in [4, 8]\ ] @@ -95,7 +97,7 @@ def keep(conf): return True -@triton.autotune(list(filter(keep, configs)), key=["N_CTX"]) +@triton.autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"]) @triton.jit def _attn_fwd(Q, K, V, sm_scale, M, Out, # stride_qz, stride_qh, stride_qm, stride_qk, # @@ -103,9 +105,9 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # stride_vz, stride_vh, stride_vk, stride_vn, # stride_oz, stride_oh, stride_om, stride_on, # Z, H, N_CTX, # + HEAD_DIM: tl.constexpr, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # - HEAD_DIM: tl.constexpr, # STAGE: tl.constexpr # ): tl.static_assert(BLOCK_N <= HEAD_DIM) @@ -442,7 +444,7 @@ def forward(ctx, q, k, v, causal, sm_scale): # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. - HEAD_DIM_V = v.shape[-2] if v.dtype == torch.float8_e5m2 else v.shape[-1] + HEAD_DIM_V = v.shape[-1] assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V assert HEAD_DIM_K in {16, 32, 64, 128, 256} o = torch.empty_like(q) @@ -551,7 +553,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) rtol = 0.0 # Relative tolerance workaround for known hardware limitation of MI200 GPU. - # For detailss see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": rtol = 1e-2 assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol) @@ -583,8 +585,8 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): (["flash"] if HAS_FLASH else []), line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + (["Flash-2"] if HAS_FLASH else []), - styles=[("red", "-"), ("blue", "-")], - ylabel="ms", + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="TFLOPS", plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", args={ "H": N_HEADS, @@ -599,16 +601,15 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): @triton.testing.perf_report(configs) def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): assert mode in ["fwd", "bwd"] - warmup = 25 - rep = 100 dtype = torch.float16 if "triton" in provider: - q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda", requires_grad=True) + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) if mode == "fwd" and "fp8" in provider: q = q.to(torch.float8_e5m2) k = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2).contiguous() v = v.permute(0, 1, 3, 2) v = v.to(torch.float8_e5m2) sm_scale = 1.3 @@ -617,7 +618,7 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + ms = triton.testing.do_bench(fn) if provider == "flash": qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, causal=causal) @@ -625,14 +626,14 @@ def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, dev o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + ms = triton.testing.do_bench(fn) flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM total_flops = 2 * flops_per_matmul if causal: total_flops *= 0.5 if mode == "bwd": total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) - return total_flops / ms * 1e-9 + return total_flops * 1e-12 / (ms * 1e-3) if __name__ == "__main__": diff --git a/python/tutorials/07-extern-functions.py b/python/tutorials/07-extern-functions.py index 74e4d6c20..bf5f0acf9 100644 --- a/python/tutorials/07-extern-functions.py +++ b/python/tutorials/07-extern-functions.py @@ -3,10 +3,11 @@ ============================== Triton can invoke a custom function from an external library. In this example, we will use the `libdevice` library to apply `asin` on a tensor. -Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html (CUDA) and/or https://github.com/ROCm/llvm-project/tree/amd-staging/amd/device-libs/ocml/src (HIP) regarding the semantics of all available libdevice functions. + +Please refer to `CUDA libdevice-users-guide `_ and/or `HIP device-lib source code `_ regarding the semantics of all available libdevice functions. + In `libdevice.py`, we try to aggregate functions with the same computation but different data types together. For example, both `__nv_asin` and `__nv_asinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. -Using triton, you can simply call `tl.math.asin`. Triton automatically selects the correct underlying device function to invoke based on input and output types. """ @@ -18,8 +19,12 @@ import triton import triton.language as tl +import inspect +import os from triton.language.extra import libdevice +from pathlib import Path + @triton.jit def asin_kernel( @@ -56,13 +61,36 @@ def asin_kernel( print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(output_torch - output_triton))}') + # %% # Customize the libdevice library path # ------------------------------------- # We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel. +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +current_file = inspect.getfile(inspect.currentframe()) +current_dir = Path(os.path.dirname(os.path.abspath(current_file))) + +if is_cuda(): + libdir = current_dir.parent.parent / 'third_party/nvidia/backend/lib' + extern_libs = {'libdevice': str(libdir / 'libdevice.10.bc')} +elif is_hip(): + libdir = current_dir.parent.parent / 'third_party/amd/backend/lib' + extern_libs = {} + libs = ["ocml", "ockl"] + for lib in libs: + extern_libs[lib] = str(libdir / f'{lib}.bc') +else: + raise RuntimeError('unknown backend') output_triton = torch.empty_like(x) -asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, extern_libs=extern_libs) print(output_torch) print(output_triton) print(f'The maximum difference between torch and triton is ' diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py new file mode 100644 index 000000000..ba255dd76 --- /dev/null +++ b/python/tutorials/09-persistent-matmul.py @@ -0,0 +1,888 @@ +""" +Persistent Matmul +===================== +This script demonstrates persistent kernel implementations of matrix multiplication using Triton. +Various matmul methods are included, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches. +The kernels support both FP16 and FP8 data types but the FP8 implementation is only available on CUDA devices with compute capability >= 9.0. + +Triton and cuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler. +Users can pass command-line arguments to specify matrix dimensions and iteration steps flexibly. + +.. code-block:: bash + + # FP8 + python 09-persistent-matmul.py --prec fp8 --K_range 128 1024 --K_step 128 + + # FP16 + python 09-persistent-matmul.py --prec fp16 --K_range 128 1024 --K_step 128 + +Note that currently this tutorial will fail on devices with a small shared memory size, such as RTX-4090. +""" + +import argparse +import time + +import torch +import triton +import triton.language as tl +import triton.tools.experimental_descriptor +import triton.profiler as proton + +if torch.cuda.is_available(): + from triton._C.libtriton import nvidia + cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) + cublas = nvidia.cublas.CublasLt(cublas_workspace) +else: + cublas = None + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def supports_tma(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) + return ret + + +HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) + +if HAS_TMA_DESC: + print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", ) +else: + print("TMA benchmarks will be running without grid constant TMA descriptor.", ) + + +class TmaAutoTuneHelper: + + # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 + class KernelParamWrapper: + + def __init__(self, desc): + self.desc = desc + + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + TMA_SIZE = 128 + + def __init__(self): + self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor) + self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor) + if HAS_TMA_DESC: + self.descriptors = {} + else: + self.cuda_descriptors = {} + + # Call this method outside of the lambda function for grid size + def init_tma_descriptor(self, name): + if HAS_TMA_DESC: + self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8) + else: + self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8) + + # Call this method inside the lambda function for grid size + def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr()) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr()) + desc_x.copy_(buf_x, non_blocking=True) + + # Call this method inside the lambda function for grid size + def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()) + desc_x.copy_(buf_x, non_blocking=True) + + def get_tma_descriptor_kernel_param(self, name): + if HAS_TMA_DESC: + assert self.descriptors[name] is not None + return self.KernelParamWrapper(self.descriptors[name]) + else: + assert self.cuda_descriptors[name] is not None + return self.cuda_descriptors[name] + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if (c_ptr.dtype.element_ty == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, + "num_warps": 8 + }, torch.float16: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, + "num_warps": 8 + } + } + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + M, K = a.shape + K, N = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # + ) + return c + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # + ): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + pid_m = 0 + pid_n = 0 + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + if ki == k_tiles - 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if (c_ptr.dtype.element_ty == tl.float8e4nv): + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def matmul_persistent(a, b): + configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, + "num_warps": 8 + }, torch.float16: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, + "num_warps": 8 + } + } + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + matmul_kernel_persistent[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + NUM_SMS=NUM_SMS, # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # + ) + return c + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # + M, N, K, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + NUM_SMS: tl.constexpr): # + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + offs_k = ki * BLOCK_SIZE_K + + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + c = accumulator.to(dtype) + + tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def matmul_tma_persistent(a, b): + # Autotuner does not work with TMA. Use manual config. + configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, + "num_warps": 8 + }, torch.float16: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, + "num_warps": 8 + } + } + + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + desc_a = triton.tools.experimental_descriptor.create_2d_tma_descriptor(a.data_ptr(), M, K, + configs[dtype]["BLOCK_SIZE_M"], + configs[dtype]["BLOCK_SIZE_K"], + a.element_size()) + desc_b = triton.tools.experimental_descriptor.create_2d_tma_descriptor(b.data_ptr(), N, K, + configs[dtype]["BLOCK_SIZE_N"], + configs[dtype]["BLOCK_SIZE_K"], + b.element_size()) + desc_c = triton.tools.experimental_descriptor.create_2d_tma_descriptor(c.data_ptr(), M, N, + configs[dtype]["BLOCK_SIZE_M"], + configs[dtype]["BLOCK_SIZE_N"], + c.element_size()) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + matmul_kernel_tma_persistent[grid]( + desc_a, desc_b, desc_c, # + M, N, K, # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + FP8_OUTPUT=dtype == torch.float8_e4m3fn, # + NUM_SMS=NUM_SMS, # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # + ) + return c + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_device_tma_persistent(workspace_ptr, # + tiles_per_update: tl.constexpr, # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr): # + # Matmul using TMA and device-side descriptor creation + dtype = c_ptr.dtype.element_ty + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + TMA_SIZE: tl.constexpr = 128 + workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + c_desc_ptr = workspace_base + 2 * TMA_SIZE + + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K], + element_ty=a_ptr.dtype.element_ty) + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, + load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K], + element_ty=b_ptr.dtype.element_ty) + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N], + element_ty=c_ptr.dtype.element_ty) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + ni = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + ni += 1 + + # Simulate a grouped gemm + if ni == tiles_per_update: + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr, + load_size=[BLOCK_SIZE_M, + BLOCK_SIZE_K], global_size=[M, K], + element_ty=a_ptr.dtype.element_ty) + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr, + load_size=[BLOCK_SIZE_N, + BLOCK_SIZE_K], global_size=[N, K], + element_ty=b_ptr.dtype.element_ty) + tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr, + load_size=[BLOCK_SIZE_M, + BLOCK_SIZE_N], global_size=[M, N], + element_ty=c_ptr.dtype.element_ty) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + ni = 0 + + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + offs_k = ki * BLOCK_SIZE_K + + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + c = accumulator.to(dtype) + + tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def matmul_device_tma_persistent(a, b, tiles_per_update): + # Autotuner does not work with TMA. Use manual config. + configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, + "num_warps": 8 + }, torch.float16: { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, + "num_warps": 8 + } + } + + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + tma_size = 128 + workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda") + + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + matmul_kernel_device_tma_persistent[grid]( + workspace, # + tiles_per_update, # + a, b, c, # + M, N, K, # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + NUM_SMS=NUM_SMS, # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # + ) + return c + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "NUM_CONSUMER_GROUPS": 2, + }, + num_stages=2, + num_warps=4, + num_consumer_groups=2, + num_buffers_warp_spec=3, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=3, + num_warps=4, + num_consumer_groups=0, # disable warp specialization + num_buffers_warp_spec=3, + ), + ], + key=["M", "N", "K"], + use_cuda_graph=True, +) +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_persistent_tma_ws_cooperative_kernel( + a_desc_ptr, + b_desc_ptr, + c_desc_ptr, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + NUM_CONSUMER_GROUPS: tl.constexpr, +): + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)): + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl._experimental_descriptor_load( + a_desc_ptr, + [offs_am, offs_k], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype) + + accumulator = tl.dot(a, b.T, accumulator) + offs_k += BLOCK_SIZE_K + + c = accumulator.to(dtype) + tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + + +def matmul_persistent_tma_ws_cooperative(a, b): + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + + desc_helper = TmaAutoTuneHelper() + desc_helper.init_tma_descriptor("a") + desc_helper.init_tma_descriptor("b") + desc_helper.init_tma_descriptor("c") + + def grid(META): + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "a", + a.data_ptr(), + M, + K, + META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], + META["BLOCK_SIZE_K"], + a.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "b", + b.data_ptr(), + N, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + b.element_size(), + ) + desc_helper.fill_2d_tma_descriptor( + "c", + c.data_ptr(), + M, + N, + META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], + META["BLOCK_SIZE_N"], + c.element_size(), + ) + return (min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), ) + + desc_a = desc_helper.get_tma_descriptor_kernel_param("a") + desc_b = desc_helper.get_tma_descriptor_kernel_param("b") + desc_c = desc_helper.get_tma_descriptor_kernel_param("c") + + matmul_persistent_tma_ws_cooperative_kernel[grid]( + desc_a, desc_b, desc_c, # + M, N, K, # + FP8_OUTPUT=dtype == torch.float8_e4m3fn, # + ) + return c + + +def cublas_matmul(a, b): + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + M, K = a.shape + N, K = b.shape + dtype = a.dtype + c = torch.empty((M, N), device=a.device, dtype=dtype) + bytes_per_elem = a.element_size() + flops_str = f"flops{bytes_per_elem * 8}" + with proton.scope(f"cublas [M={M}, N={N}, K={K}]", + {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): + cublas.matmul(a, b, c) + return c + + +def torch_matmul(a, b): + M, K = a.shape + N, K = b.shape + bytes_per_elem = a.element_size() + flops_str = f"flops{bytes_per_elem * 8}" + with proton.scope(f"torch [M={M}, N={N}, K={K}]", + {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): + c = torch.matmul(a, b.T) + return c + + +def bench(K, dtype, tiles_per_update, reps=10): + M = 8192 + N = 8192 + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) + b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) + + b = b.T.contiguous() + + proton.activate(0) + + if cublas is not None: + for _ in range(reps): + cublas_matmul(a, b) + time.sleep(0.01) + if dtype == torch.float16: + for _ in range(reps): + torch_matmul(a, b) + time.sleep(0.01) + for _ in range(reps): + matmul(a, b.T) + time.sleep(0.01) + for _ in range(reps): + matmul_persistent(a, b.T) + time.sleep(0.01) + if supports_tma(): + for _ in range(reps): + matmul_tma_persistent(a, b) + time.sleep(0.01) + for _ in range(reps): + matmul_persistent_tma_ws_cooperative(a, b) + time.sleep(0.01) + with proton.scope( + f"matmul_kernel_device_tma_persistent [M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}]"): + for _ in range(reps): + matmul_device_tma_persistent(a, b, tiles_per_update) + time.sleep(0.01) + + proton.deactivate(0) + + +def validate(M, N, K, dtype, tiles_per_update): + a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) + b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) + b = b.T.contiguous() + + torch_result = torch_matmul(a, b) if dtype == torch.float16 else None + cublas_result = cublas_matmul(a, b) if cublas is not None else None + naive_result = matmul(a, b.T) + persistent_result = matmul_persistent(a, b.T) + tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None + device_tma_persistent_result = matmul_device_tma_persistent(a, b, tiles_per_update) if supports_tma() else None + matmul_persistent_tma_ws_cooperative_result = matmul_persistent_tma_ws_cooperative(a, b) if supports_tma() else None + + if torch_result is not None: + naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16), + atol=1.0) else "❌" + if cublas_result is not None: + naive_vs_cublas = "✅" if torch.allclose(naive_result.to(torch.float16), cublas_result.to(torch.float16), + atol=1.0) else "❌" + naive_vs_persistent = "✅" if torch.allclose(naive_result.to(torch.float16), persistent_result.to(torch.float16), + atol=1.0) else "❌" + if tma_persistent_result is not None: + naive_vs_tma_persistent = "✅" if torch.allclose(cublas_result.to(torch.float16), + tma_persistent_result.to(torch.float16), atol=1.0) else "❌" + if device_tma_persistent_result is not None: + naive_vs_device_tma_persistent = "✅" if torch.allclose(cublas_result.to( + torch.float16), device_tma_persistent_result.to(torch.float16), atol=1.0) else "❌" + if matmul_persistent_tma_ws_cooperative_result is not None: + naive_vs_matmul_persistent_tma_ws_cooperative = "✅" if torch.allclose( + cublas_result.to(torch.float16), matmul_persistent_tma_ws_cooperative_result.to(torch.float16), + atol=1.0) else "❌" + print(f"M={M}, N={N}, K={K} verification naive vs: ", end="") + if torch_result is not None: + print(f"torch: {naive_vs_torch} ", end="") + if cublas_result is not None: + print(f"cublas: {naive_vs_cublas} ", end="") + print(f"persistent: {naive_vs_persistent} ", end="") + if tma_persistent_result is not None: + print(f"TMA persistent: {naive_vs_tma_persistent} ", end="") + if device_tma_persistent_result is not None: + print(f"Device TMA persistent: {naive_vs_device_tma_persistent} ", end="") + if matmul_persistent_tma_ws_cooperative_result is not None: + print(f"TMA persistent with warp specialization: {naive_vs_matmul_persistent_tma_ws_cooperative} ", end="") + print() + + +def show_profile(precision, profile_name): + import triton.profiler.viewer as proton_viewer + metrics = ["time/ms"] + if precision == 'fp8': + metrics = ["tflop8/s"] + metrics + elif precision == 'fp16': + metrics = ["tflop16/s"] + metrics + file_name = f"{profile_name}.hatchet" + proton_viewer.parse(metrics, file_name, depth=100) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-K", type=int, required=False, default=512) + parser.add_argument("--K_range", type=int, nargs=2) + parser.add_argument("--K_step", type=int, default=512) + parser.add_argument( + "--tiles_per_update", + type=int, + default=1, + help= + "Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel", + ) + parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") + args = parser.parse_args() + + if args.prec == 'fp8' and (not hasattr(torch, "float8_e4m3fn") or not is_cuda()): + print("This example requires CUDA with fp8 support.") + exit(1) + + dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16 + + if args.K and args.K_range is None: + args.K_range = [args.K, args.K] + args.K_step = 1 # doesn't matter as long as it's not 0 + + torch.manual_seed(0) + + validate(32, 32, 32, dtype, args.tiles_per_update) + validate(8192, 8192, 512, dtype, args.tiles_per_update) + + proton.start("matmul", hook="triton") + for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): + bench(K, dtype, args.tiles_per_update) + proton.finalize() + show_profile(args.prec, "matmul") diff --git a/reports/v0.1.0/release_notes_v0.1.0.md b/reports/v0.1.0/release_notes_v0.1.0.md new file mode 100644 index 000000000..20abc9e15 --- /dev/null +++ b/reports/v0.1.0/release_notes_v0.1.0.md @@ -0,0 +1,35 @@ +[中文版](./release_notes_v0.1.0_cn.md) + +## FlagTree 0.1.0 Release + +### Highlights + +FlagTree's initial release is built on Triton 3.1, introducing support for diverse AI chip backends. In its early stage, the project aims to maintain compatibility with existing backend adaptation solutions while unifying the codebase to enable rapid single-version multi-backend support. + +### New features + +* Multi-Backend Support +Currently supported backends include iluvatar, xpu (klx), mthreads, and cambricon. + +* Dual Compilation Path Support +In this initial phase, the project provides basic compatibility for both TritonGPU dialect and Linalg dialect compilation paths. + +* Pluggable High-Variance Module Architecture +Enables chip-specific backend customization through a plugin architecture. These non-generic modules are maintained by respective chip vendors and maintain structural consistency with the FlagTree main repository through engineering practices. + +* Cross-Compilation and Rapid Validation Capabilities +For developer convenience, FlagTree supports compilation on any hardware platform and Python3 import functionality. Cross-compilation is possible when build and runtime environments are compatible (specifically matching or compatible versions of cpython, glibc, glibcxx, and cxxabi), allowing compiled artifacts to run across platforms with corresponding chip deployments. + +* CI/CD Integration +The project implements comprehensive CI/CD pipelines for iluvatar, xpu, mthreads, nvidia, and other backends, enabling end-to-end validation from compilation to testing correctness. + +* Quality Management Framework +Beyond CI/CD coverage for multiple backend chips, FlagTree implements quality and compliance assurance mechanisms including Contributor License Agreement (CLA) signing and security compliance scanning. + +### Known issues + +* Current lack of support for triton-opt, proton, and related tools. + +### Looking ahead + +FlagTree will continue investing in the Triton ecosystem, focusing on tracking Triton version updates, integrating AI chip backends, improving compilation efficiency, and enhancing cross-platform compatibility. Additionally, FlagTree will explore balancing general usability with chip-specific optimization requirements, providing compatible language-level unified abstractions and explicit specifications for hardware storage hierarchies, parallelism levels, and acceleration units. diff --git a/reports/v0.1.0/release_notes_v0.1.0_cn.md b/reports/v0.1.0/release_notes_v0.1.0_cn.md new file mode 100644 index 000000000..4cd86926c --- /dev/null +++ b/reports/v0.1.0/release_notes_v0.1.0_cn.md @@ -0,0 +1,36 @@ +[English](./release_notes_v0.1.0.md) + +## FlagTree 0.1.0 Release + +### Highlights + +FlagTree 首次发布,基于 Triton 3.1 版本接入多元 AI 芯片后端。项目当前处于初期,目标是兼容各芯片后端现有适配方案,统一代码仓库,打造代码共建平台,快速实现单版本多后端支持。 + +### New features + +* 多后端支持 +目前支持的后端包括 iluvatar、xpu (klx)、mthreads、cambricon。 + +* 两种编译路径支持 +项目初期,对 TritonGPU dialect 或 Linalg dialect 两种编译路径作简单快速兼容。 + +* 高差异度模块插件化能力 +支持芯片后端定制化的高差异度模块以插件形式提供,这些非通用模块的代码由对应的芯片提供商自行维护,并通过工程化手段与 FlagTree 主仓库可保持同构设计。 + +* 交叉编译与快速验证能力 +为方便开发者简单快速验证,FlagTree 可以实现在任意硬件平台上编译及在 python3 中导入。如果编译环境和运行环境一致(一般指 cpython、glibc、glibcxx、cxxabi 版本对齐或兼容),可以实现交叉编译,即编译结果能够在实际搭载对应芯片的环 +境中跨平台运行。 + +* CI/CD 能力 +项目为 iluvatar、xpu、mthreads、nvidia 等后端搭建了 CI/CD,可以完整验证从编译到测试正确性的全流程。 + +* 质量管理能力 +FlagTree 除了建设 CI/CD 覆盖多后端芯片外,还搭建了贡献者许可协议(CLA)签署、安全合规扫描等机制做质量与合规保障。 + +### Known issues + +* triton-opt、proton 等工具目前不支持。 + +### Looking ahead + +FlagTree 将持续投入发展 Triton 生态,包括跟进 Triton 版本更迭,接入 AI 芯片后端,提升编译效率,优化跨平台兼容性。同时,FlagTree 将对兼顾通用性和芯片极致优化需求进行探索,兼容式地在语言层提供硬件的存储层次、并行层次、加速单>元等关键特征的统一抽象表达和显示指定能力。 diff --git a/reports/v0.1.0/report_tests.md b/reports/v0.1.0/report_tests.md new file mode 100644 index 000000000..4de2f98a4 --- /dev/null +++ b/reports/v0.1.0/report_tests.md @@ -0,0 +1,14 @@ + + +## FlagTree Test-Report + +FlagTree tests are validated on different backends, but currently the tests consist of only unit tests, which we will refine in the future for smaller or larger scale tests. + +### 1. Python unit test: + +| | default | xpu (klx) | iluvatar | mthreads | +|----------------------|---------------------------|-------------------------------------------|------------------------------------------------|------------------------------------------------| +| Number of unit tests | 11353 items | 12623 items | 14808 items | 10392 items | +| Script location | flagtree/python/test/unit | flagtree/third_party/xpu/python/test/unit | flagtree/third_party/iluvatar/python/test/unit | flagtree/third_party/mthreads/python/test/unit | +| command | python3 -m pytest -s | python3 -m pytest -s | python3 -m pytest -s | python3 -m pytest -s | +| passing rate | 100% | 100% | 100% | 100% | diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 2f73e0880..4f3af58e8 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -41,7 +41,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: %0 -> %0 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, mutable> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return } @@ -49,40 +49,40 @@ tt.func @alloc(%A : !tt.ptr) { tt.func @alloc_init(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: %0 -> %0 - %cst1 = triton_gpu.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: %0 -> %0 - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: %1 -> %0 - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED> -> !tt.memdesc<32x16xf16, #A_SHARED_T> + %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> tt.return } // CHECK-LABEL: subview -tt.func @subview(%A : !tt.memdesc<1x16x16xf16, #A_SHARED>) { +tt.func @subview(%A : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { %index = arith.constant 0 : i32 // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %1 -> %0 - %cst1 = triton_gpu.memdesc_subview %a[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED> -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.memdesc_subview %a[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return } // CHECK-LABEL: if_alias tt.func @if_alias(%i1 : i1) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %2 -> %0,%1 - %cst2 = scf.if %i1 -> !tt.memdesc<16x16xf16, #A_SHARED> { - scf.yield %a : !tt.memdesc<16x16xf16, #A_SHARED> + %cst2 = scf.if %i1 -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> { + scf.yield %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } else { - scf.yield %b : !tt.memdesc<16x16xf16, #A_SHARED> + scf.yield %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } tt.return } @@ -90,11 +90,11 @@ tt.func @if_alias(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: %2 -> %2 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %c = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %arg6 -> %0 // CHECK-NEXT: %arg7 -> %1 // CHECK-NEXT: %arg8 -> %2 @@ -102,8 +102,8 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a, %b_shared = %b, %c_shared = %c) -> - (!tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED>) { - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED> + (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } tt.return } @@ -111,11 +111,11 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -123,14 +123,14 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { scf.if %i1 { %index = arith.constant 8 : i32 // CHECK-NEXT: %4 -> %0,%1 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<32xf16, #A_SHARED> + %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } tt.return } @@ -138,11 +138,11 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -150,23 +150,23 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-NEXT: %3#1 -> %1 // CHECK-NEXT: %3#2 -> %2,%6,%6 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { // CHECK-NEXT: %arg11 -> %2,%6,%6 // CHECK-NEXT: %4 -> %2,%6,%6 - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { // CHECK-NEXT: %5 -> %6,%6 - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } else { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } tt.return } @@ -175,29 +175,29 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr) { %idx = arith.constant 0 : i32 // CHECK: %0 -> %0 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %1 -> %1 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %2 -> %0 - %0 = triton_gpu.memdesc_subview %cst[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<128x32xf16, #A_SHARED> + %0 = triton_gpu.memdesc_subview %cst[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> gpu.barrier // CHECK-NEXT: %3 -> %3 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: %5 -> %0,%1,%3 // CHECK-NEXT: %6 -> %0,%1,%3 // CHECK-NEXT: %7 -> %0,%1,%3 - cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) -^bb1(%1: index, %2: !tt.memdesc<128x32xf16, #A_SHARED>, %3: !tt.memdesc<128x32xf16, #A_SHARED>, %4: !tt.memdesc<128x32xf16, #A_SHARED>): // 2 preds: ^bb0, ^bb2 + cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) +^bb1(%1: index, %2: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, %3: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, %4: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>): // 2 preds: ^bb0, ^bb2 %5 = arith.cmpi slt, %1, %arg1 : index cf.cond_br %5, ^bb2, ^bb3 ^bb2: // pred: ^bb1 gpu.barrier %8 = arith.addi %1, %arg2 : index - cf.br ^bb1(%8, %4, %2, %3 : index, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) + cf.br ^bb1(%8, %4, %2, %3 : index, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) ^bb3: // pred: ^bb1 gpu.barrier // CHECK-NEXT: %10 -> %0 - %9 = triton_gpu.memdesc_subview %0[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<128x32xf16, #A_SHARED> + %9 = triton_gpu.memdesc_subview %0[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return } diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index f157f7965..ebc5383b9 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -97,10 +97,12 @@ tt.func @sub() { %1 = arith.constant dense<1> : tensor<128xi32> // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = %2 = arith.subi %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %3 = arith.subi %1, %0 : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129 - %3 = arith.constant dense<129> : tensor<128xi32> + %4 = arith.constant dense<129> : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 - %4 = arith.subi %3, %1 : tensor<128xi32> + %5 = arith.subi %4, %1 : tensor<128xi32> tt.return } @@ -399,8 +401,10 @@ tt.func @select(%arg0 : i1, %arg1 : tensor<4xi1>) { // ----- -tt.func @shift() { - // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = +tt.func @shift(%arg0: i32 {tt.divisibility = 4 : i32}) { + // CHECK: contiguity = [1], divisibility = [4], constancy = [128], constant_value = + %s = tt.splat %arg0 : i32 -> tensor<128xi32> + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 %1 = arith.constant dense<8> : tensor<128xi32> @@ -412,6 +416,10 @@ tt.func @shift() { %4 = arith.shrsi %0, %2 : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 %5 = arith.shli %1, %2 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = + %6 = arith.shli %1, %s : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = + %7 = arith.shrsi %0, %s : tensor<128xi32> tt.return } @@ -827,3 +835,44 @@ tt.func @tensor_ptr(%arg0: !tt.ptr, 1>) { %0 = tt.load %arg0 : !tt.ptr, 1> tt.return } + + +// ----- + +// CHECK-LABEL: @chained_for +tt.func public @chained_for(%8: tensor<128x64x!tt.ptr> {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> + // CHECK: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16 + %c16_i32 = arith.constant 16 : i32 + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %c1_i32 = arith.constant 1 : i32 + // CHECK: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 + %c0_i32 = arith.constant 0 : i32 + // CHECK: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 64], constant_value = 64 + %cst_0 = arith.constant dense<64> : tensor<128x64xi32> + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + %9 = scf.for %arg7 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg8 = %8) -> (tensor<128x64x!tt.ptr>) : i32 { + %11 = tt.addptr %arg8, %cst_0 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + scf.yield %11 : tensor<128x64x!tt.ptr> + } + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = + %10 = scf.for %arg7 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg8 = %9) -> (tensor<128x64x!tt.ptr>) : i32 { + tt.store %arg8, %cst : tensor<128x64x!tt.ptr> + %11 = tt.addptr %arg8, %cst_0 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + scf.yield %11 : tensor<128x64x!tt.ptr> + } + tt.return +} + +// ----- + +// CHECK-LABEL: @int_min_does_not_underflow_in_analysis +module { + tt.func @int_min_does_not_underflow_in_analysis() -> i64 { + // CHECK: divisibility = [4611686018427387904] + %int_min = arith.constant -9223372036854775808 : i64 + tt.return %int_min : i64 + } +} diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index f7926de3a..a0719c974 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -12,6 +12,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// CHECK-LABEL: empty +tt.func @empty(%A : !tt.ptr) { + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> + tt.return + // CHECK: size = 0 +} + // CHECK-LABEL: matmul_loop tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> @@ -31,7 +39,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK: offset = 0, size = 4608 %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - // CHECK-NEXT: offset = 0, size = 4224 + // CHECK-NEXT: offset = 0, size = 4352 %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> @@ -59,14 +67,14 @@ tt.func @reusable(%A : !tt.ptr) { // CHECK-NEXT: offset = 0, size = 4608 %a1 = triton_gpu.convert_layout %a1_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %a2_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> - // CHECK-NEXT: offset = 0, size = 1152 + // CHECK-NEXT: offset = 0, size = 1088 %a2 = triton_gpu.convert_layout %a2_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> %a3_ = tt.load %a_ptr, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 4608 %a3 = triton_gpu.convert_layout %a3_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %c = tt.dot %a1, %a2, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> - // CHECK-NEXT: offset = 0, size = 1152 + // CHECK-NEXT: offset = 0, size = 1088 %a4 = triton_gpu.convert_layout %a4_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> %c1 = tt.dot %a3, %a4, %c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> tt.return @@ -80,47 +88,47 @@ tt.func @reusable(%A : !tt.ptr) { // CHECK-LABEL: preallocate tt.func @preallocate(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 4096, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 1024 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 1024 - %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 6144, size = 2048 - %e = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %a : !tt.memdesc<32x16xf16, #A_SHARED> + %e = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %a : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 2048 - %d = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %b : !tt.memdesc<32x16xf16, #A_SHARED> + %d = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %b : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 10240, size = 2048 - %f = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst4 : !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %c : !tt.memdesc<32x16xf16, #A_SHARED> + %f = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst4 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %c : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 2048 - %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> + %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 4096 - %g = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED> - triton_gpu.local_dealloc %e : !tt.memdesc<64x16xf16, #A_SHARED> + %g = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %e : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 4096 - %h = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED> - triton_gpu.local_dealloc %d : !tt.memdesc<64x16xf16, #A_SHARED> + %h = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %d : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 4096 - %i = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED> - triton_gpu.local_dealloc %f : !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst5 : !tt.memdesc<64x16xf16, #A_SHARED> + %i = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %f : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst5 : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 12288 } @@ -130,11 +138,11 @@ tt.func @preallocate(%A : !tt.ptr) { tt.func @unused(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK: size = 1024 } @@ -143,33 +151,33 @@ tt.func @unused(%A : !tt.ptr) { // CHECK-LABEL: longlive tt.func @longlive(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 512 - %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 512 - %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 512 - %cst6 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst6 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 1024 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst4 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst4 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 1024 - %d = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %d = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 4096 } @@ -178,43 +186,43 @@ tt.func @longlive(%A : !tt.ptr) { // CHECK-LABEL: multi_color tt.func @multi_color(%A : !tt.ptr) { // CHECK: offset = 0, size = 64 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED> + %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1536, size = 32 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED> + %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1664, size = 128 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED> + %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 - %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> - %1 = triton_gpu.local_load %cst : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> + %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = triton_gpu.local_load %cst : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 0, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<4x16xf16, #A_SHARED> - %2 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<4x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %2 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 - %3 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> + %3 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> // CHECK-NEXT: offset = 0, size = 256 - %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<4x32xf16, #A_SHARED> + %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 256, size = 64 - %cst_5 = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED> - %4 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> - %5 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> + %cst_5 = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %4 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x8xf16, #AL> + %5 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_6 = triton_gpu.local_alloc : () -> !tt.memdesc<8x32xf16, #A_SHARED> + %cst_6 = triton_gpu.local_alloc : () -> !tt.memdesc<8x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1792, size = 128 - %cst_7 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED> - %6 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %cst_7 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %6 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_8 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst_8 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 256, size = 32 - %cst_9 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED> + %cst_9 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst_10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> - %7 = triton_gpu.local_load %cst_1 : !tt.memdesc<16x4xf16, #A_SHARED> -> tensor<16x4xf16, #AL> - %8 = triton_gpu.local_load %cst_4 : !tt.memdesc<4x32xf16, #A_SHARED> -> tensor<4x32xf16, #AL> + %cst_10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %7 = triton_gpu.local_load %cst_1 : !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x4xf16, #AL> + %8 = triton_gpu.local_load %cst_4 : !tt.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x32xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 - %9 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> + %9 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> %cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL> - %10 = triton_gpu.local_load %cst_7 : !tt.memdesc<2x32xf16, #A_SHARED> -> tensor<2x32xf16, #AL> + %10 = triton_gpu.local_load %cst_7 : !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<2x32xf16, #AL> %cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL> %cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL> // CHECK-NEXT: size = 1920 @@ -225,25 +233,25 @@ tt.func @multi_color(%A : !tt.ptr) { // CHECK-LABEL: multi_color_multi_rounds tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK: offset = 0, size = 32 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED> + %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1280, size = 128 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED> + %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 8192 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<1024x4xf16, #A_SHARED> + %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 - %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> - %1 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %1 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1152, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED> - %2 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %2 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> - %3 = triton_gpu.local_load %cst_0 : !tt.memdesc<16x4xf16, #A_SHARED> -> tensor<16x4xf16, #AL> - %4 = triton_gpu.local_load %cst_1 : !tt.memdesc<1024x4xf16, #A_SHARED> -> tensor<1024x4xf16, #AL> + %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %3 = triton_gpu.local_load %cst_0 : !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x4xf16, #AL> + %4 = triton_gpu.local_load %cst_1 : !tt.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<1024x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 - %5 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> - %6 = triton_gpu.local_load %cst_3 : !tt.memdesc<2x32xf16, #A_SHARED> -> tensor<2x32xf16, #AL> + %5 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> + %6 = triton_gpu.local_load %cst_3 : !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<2x32xf16, #AL> // CHECK-NEXT: size = 10240 tt.return } @@ -252,10 +260,10 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 512 } @@ -264,10 +272,10 @@ tt.func @alloc(%A : !tt.ptr) { // CHECK-LABEL: dealloc tt.func @dealloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: offset = 1024, size = 1024 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<32x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 2048 } @@ -288,8 +296,8 @@ tt.func @scratch() { // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED> - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED> -> !tt.memdesc<32x16xf16, #A_SHARED_T> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> tt.return } @@ -297,37 +305,62 @@ tt.func @trans(%A : !tt.ptr) { // CHECK-LABEL: extract_slice tt.func @extract_slice(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> %index = arith.constant 0 : i32 - %cst1 = triton_gpu.memdesc_subview %cst0[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED> -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.memdesc_subview %cst0[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 512 } +// CHECK-LABEL: atomic_scalar +tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { + // CHECK: offset = 0, size = 8192 + // CHECK: scratch offset = 8192, size = 4 + // CHECK: size = 8196 + %c0_i32 = arith.constant 0 : i32 + %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + tt.return %4 : i32 +} + +// CHECK-LABEL: atomic_scalar_no_use +tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { + // CHECK: offset = 0, size = 8192 + // CHECK: size = 8192 + %c0_i32 = arith.constant 0 : i32 + %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + tt.return +} + // B0 -> (B1) -> B0 // Memory used by B1 can be reused by B0. // CHECK-LABEL: if tt.func @if(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 3072 } @@ -337,28 +370,28 @@ tt.func @if(%i1 : i1) { // CHECK-LABEL: if_else tt.func @if_else(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } else { // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 3072, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 4096, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 5120 } @@ -368,13 +401,13 @@ tt.func @if_else(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 24576 @@ -383,18 +416,18 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if_slice tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { scf.if %i1 { %index = arith.constant 8 : i32 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<32xf16, #A_SHARED> + %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 24576 @@ -404,16 +437,16 @@ tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr // CHECK-LABEL: for_use_ancestor tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c0 = tt.trans %c_shared_init {order=array} : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<32x128xf16, #A_SHARED_T> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + %c0 = tt.trans %c_shared_init {order=array} : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 24576, size = 8192 - %c1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %b_shared, %a_shared: !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %c1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %b_shared, %a_shared: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 32768 @@ -424,28 +457,28 @@ tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> { // CHECK-NEXT: offset = 24576, size = 8192 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } else { // CHECK-NEXT: offset = 32768, size = 8192 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst1 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + scf.yield %cst1 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } // CHECK-NEXT: offset = 0, size = 8192 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 40960 } @@ -457,7 +490,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: alloc1 tt.func @alloc1(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 512 } @@ -465,7 +498,7 @@ tt.func @alloc1(%A : !tt.ptr) { // CHECK-LABEL: alloc2 tt.func @alloc2(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return // CHECK-NEXT: size = 1024 } @@ -474,10 +507,10 @@ tt.func @alloc2(%A : !tt.ptr) { tt.func @alloc3(%cond : i1) { scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } else { // CHECK-NEXT: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } tt.return // CHECK-NEXT: size = 1024 @@ -499,7 +532,7 @@ tt.func @alloc4(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: single_call tt.func @single_call(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () @@ -510,7 +543,7 @@ tt.func @single_call(%A : !tt.ptr) { // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> @@ -525,9 +558,9 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 0, size = 1024 - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () } else { @@ -542,7 +575,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -558,7 +591,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc3(%cond) : (i1) -> () tt.return @@ -568,7 +601,7 @@ tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_2 tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc4(%A, %cond) : (!tt.ptr, i1) -> () tt.return diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 747a63959..2054853b3 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -5,7 +5,6 @@ #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> @@ -47,10 +46,10 @@ tt.func @raw_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -60,14 +59,14 @@ tt.func @war_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_alloc // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: %4 = triton_gpu.local_alloc - %4 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %4 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } @@ -77,25 +76,25 @@ tt.func @war_single_block_local_store(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: triton_gpu.local_alloc // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_store - triton_gpu.local_store %1, %2 : tensor<128x32xf16, #AL> -> !tt.memdesc<128x32xf16, #A_SHARED> + triton_gpu.local_store %1, %2 : tensor<128x32xf16, #AL> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return } // CHECK-LABEL: scratch tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load // CHECK: gpu.barrier // CHECK: tt.reduce - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> %2 = "tt.reduce" (%1) ({ ^bb0(%arg1: f16, %arg2: f16): %add = arith.addf %arg1, %arg2 : f16 @@ -106,34 +105,34 @@ tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { // CHECK-LABEL: async_wait tt.func @async_wait(%arg: tensor<32x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.async_wait triton_gpu.async_wait {num = 4 : i32} // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<32x16xf16, #A_SHARED> -> tensor<32x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<32x16xf16, #AL> tt.return } // CHECK-LABEL: subview tt.func @subview() { %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> - %a = triton_gpu.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> %index = arith.constant 0 : i32 - %0 = triton_gpu.memdesc_subview %a[%index, %index] : !tt.memdesc<32x16xf16, #A_SHARED> -> !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.memdesc_subview %a[%index, %index] : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: trans -tt.func @trans(%a: !tt.memdesc<16x32xf16, #A_SHARED>) { +tt.func @trans(%a: !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK-NOT: gpu.barrier - %b = tt.trans %a {order=array} : !tt.memdesc<16x32xf16, #A_SHARED> -> !tt.memdesc<32x16xf16, #A_SHARED_T> + %b = tt.trans %a {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> tt.return } @@ -143,31 +142,31 @@ tt.func @async_copy_global_to_local(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.splat %A : !tt.ptr -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : i1 -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, mutable> - %subview = triton_gpu.memdesc_subview %alloc[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, mutable> - %1 = triton_gpu.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !tt.memdesc<16x16xf16, #A_SHARED, mutable> + %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %subview = triton_gpu.memdesc_subview %alloc[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %subview : !tt.memdesc<16x16xf16, #A_SHARED, mutable> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %subview : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // If branch inserted a barrier for %cst0, but else didn't, then the barrier should be inserted in the parent region // CHECK-LABEL: multi_blocks tt.func @multi_blocks(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.yield } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -175,21 +174,21 @@ tt.func @multi_blocks(%i1 : i1) { // CHECK-LABEL: multi_blocks_join_barrier tt.func @multi_blocks_join_barrier(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK-NOT: gpu.barrier // CHECK: tt.return - %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -197,25 +196,25 @@ tt.func @multi_blocks_join_barrier(%i1 : i1) { // CHECK-LABEL: multi_blocks_yield tt.func @multi_blocks_yield(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED>) { + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %3 : !tt.memdesc<16x16xf16, #A_SHARED> + %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %3 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } - %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -223,27 +222,27 @@ tt.func @multi_blocks_yield(%i1 : i1) { // CHECK-LABEL: multi_blocks_entry_no_shared tt.func @multi_blocks_entry_no_shared(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED>) { + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %0 = triton_gpu.local_load %cst1 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %0 = triton_gpu.local_load %cst1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NOT: gpu.barrier // CHECK: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %cst1 : !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -251,16 +250,16 @@ tt.func @multi_blocks_entry_no_shared(%i1 : i1) { // CHECK-LABEL: multi_blocks_noelse tt.func @multi_blocks_noelse(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -268,39 +267,39 @@ tt.func @multi_blocks_noelse(%i1 : i1) { // CHECK-LABEL: multi_blocks_nested_scf tt.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { scf.if %i2 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } scf.yield } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %2 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a0 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -310,24 +309,24 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_alias tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a1 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a1 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -336,63 +335,63 @@ tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, % // CHECK-LABEL: for_reuse tt.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for_reuse_nested tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -400,25 +399,25 @@ tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -427,30 +426,30 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: for_if_for tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier - %c_blocked = triton_gpu.local_load %c_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %c_blocked = triton_gpu.local_load %c_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { - %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { + %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %c_blocked_next = triton_gpu.local_load %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %c_shared : !tt.memdesc<128x32xf16, #A_SHARED> + %c_blocked_next = triton_gpu.local_load %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_ : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_ : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK-NOT: gpu.barrier - %b_blocked_next = triton_gpu.local_load %b_shared: !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %a_shared, %b_shared, %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %b_blocked_next = triton_gpu.local_load %b_shared: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %a_shared, %b_shared, %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -458,63 +457,88 @@ tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: cf_if tt.func @cf_if(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } +// CHECK-LABEL: cf_if_else tt.func @cf_if_else(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - cf.br ^bb3(%1 : !tt.memdesc<16x16xf16, #A_SHARED>) + %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + cf.br ^bb3(%1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) ^bb2: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - cf.br ^bb3(%3 : !tt.memdesc<16x16xf16, #A_SHARED>) -^bb3(%arg: !tt.memdesc<16x16xf16, #A_SHARED>): // 2 preds: ^bb1, ^bb2 + %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + cf.br ^bb3(%3 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) +^bb3(%arg: !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>): // 2 preds: ^bb1, ^bb2 cf.br ^bb4 ^bb4: // pred: ^bb3 // CHECK: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %5 = triton_gpu.local_load %arg : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %5 = triton_gpu.local_load %arg : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } +// CHECK-LABEL: cf_if_else_return tt.func @cf_if_else_return(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %b = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %b = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return ^bb2: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: atomic_scalar +tt.func @atomic_scalar(%arg3: !tt.ptr) -> i32 { + // CHECK-NOT: gpu.barrier + %c0_i32 = arith.constant 0 : i32 + %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + tt.return %4 : i32 +} + +// CHECK-LABEL: atomic_scalar_no_use +tt.func @atomic_scalar_no_use(%arg3: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %1 = arith.constant dense<1.0> : tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %4 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.local_load + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -525,38 +549,38 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // CHECK-LABEL: convert_layout1 tt.func @convert_layout1(%A : !tt.ptr) { // CHECK-NOT: gpu.barrier - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout2 tt.func @convert_layout2(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> - %1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK: triton_gpu.local_load - %3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %4 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout3 tt.func @convert_layout3(%cond : i1) { scf.if %cond { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf16, #A_SHARED> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: triton_gpu.local_load // CHECK-NOT: gpu.barrier - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x64xf16, #A_SHARED> -> tensor<16x64xf16, #AL> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x64xf16, #AL> } else { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> } tt.return } @@ -595,7 +619,7 @@ tt.func @single_call_no_sync(%A : !tt.ptr) { // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.call @convert_layout1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> tt.call @convert_layout2(%A) : (!tt.ptr) -> () @@ -607,12 +631,12 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { scf.if %cond { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %cst_ = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call // CHECK-NEXT: gpu.barrier tt.call @convert_layout1(%A) : (!tt.ptr) -> () - %cst1 = triton_gpu.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK: tt.call @@ -625,7 +649,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -641,7 +665,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> // CHECK: gpu.barrier + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call tt.call @convert_layout3(%cond) : (i1) -> () tt.return @@ -653,8 +677,132 @@ tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { tt.call @convert_layout4(%A, %cond) : (!tt.ptr, i1) -> () // CHECK: tt.call // CHECK-NEXT: gpu.barrier - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - tt.return + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + tt.return +} + +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { + tt.func public @kernel(%arg3: !tt.ptr, %arg4: !tt.ptr, %arg12: tensor<32x128xf16, #blocked>, %arg13: tensor<32x128xf32, #blocked>, %arg14: tensor<32x32xf16, #blocked1>) { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #blocked> + %37 = triton_gpu.local_alloc %arg14 {allocation.offset = 0 : i32} : (tensor<32x32xf16, #blocked1>) -> !tt.memdesc<32x32xf16, #shared, #triton_gpu.shared_memory> + %58 = triton_gpu.local_alloc %arg12 : (tensor<32x128xf16, #blocked>) -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory> + cf.br ^bb1 + ^bb1: // 2 preds: ^bb0, ^bb1 + %59 = tt.atomic_cas acq_rel, gpu, %arg3, %c0_i32, %c0_i32 : (!tt.ptr, i32, i32) -> i32 + %60 = arith.cmpi eq, %59, %c0_i32 : i32 + cf.cond_br %60, ^bb1, ^bb2 + ^bb2: // pred: ^bb1 + %72 = triton_gpu.convert_layout %arg13 : tensor<32x128xf32, #blocked> -> tensor<32x128xf32, #mma> + %73 = triton_gpu.local_load %37 : !tt.memdesc<32x32xf16, #shared, #triton_gpu.shared_memory> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %74 = triton_gpu.local_load %58 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %75 = tt.dot %73, %74, %72, inputPrecision = tf32 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x128xf32, #mma> + %76 = triton_gpu.convert_layout %75 {allocation.offset = 0 : i32} : tensor<32x128xf32, #mma> -> tensor<32x128xf32, #blocked> + %77 = arith.truncf %76 : tensor<32x128xf32, #blocked> to tensor<32x128xf16, #blocked> + %78 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + tt.store %78, %77 : tensor<32x128x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +// CHECK-LABEL: tma_special_cases +tt.func @tma_special_cases(%arg1: !tt.ptr) -> (tensor<256x64xf16, #blocked>){ + %true = arith.constant 1 : i1 + %c0 = arith.constant 0 : i32 + %barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> + %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + // CHECK: triton_nvidia_gpu.init_barrier + // CHECK-NEXT: triton_nvidia_gpu.init_barrier + triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %barrier, 1 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + // CHECK-NEXT: triton_gpu.local_load + %t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + + // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: triton_nvidia_gpu.inval_barrier + // CHECK-NEXT: triton_nvidia_gpu.inval_barrier + triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.inval_barrier %barrier : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + + tt.return %t : tensor<256x64xf16, #blocked> } +} + +// ----- +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 18944 : i32} { +// CHECK-LABEL: tma_special_cases_cf +tt.func @tma_special_cases_cf(%arg1: !tt.ptr, %i1 : i1, %arg2: tensor<256x64xf16, #blocked>) -> (tensor<256x64xf16, #blocked>){ + %true = arith.constant 1 : i1 + %c0 = arith.constant 0 : i32 + %barrier = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared1, #triton_gpu.shared_memory, mutable> + %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + // CHECK: cf.cond_br + scf.if %i1 { + // CHECK-NOT: gpu.barrier + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK-NEXT: triton_nvidia_gpu.barrier_expect + // CHECK-NEXT: triton_nvidia_gpu.wait_barrier + // CHECK-NEXT: cf.br + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : , <1xi64, #shared1, #triton_gpu.shared_memory, mutable> -> <256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.barrier_expect %barrier, 49152, %true : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %barrier, %c0 : <1xi64, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield + } else { + // CHECK-NOT: gpu.barrier + // CHECK: triton_gpu.local_store + // CHECK-NEXT: cf.br + triton_gpu.local_store %arg2, %alloc : tensor<256x64xf16, #blocked> -> !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> + scf.yield + } + // CHECK: gpu.barrier + // CHECK-NEXT: triton_gpu.local_load + %t = triton_gpu.local_load %alloc : !tt.memdesc<256x64xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x64xf16, #blocked> + tt.return %t : tensor<256x64xf16, #blocked> +} } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6ec6fc0ab..8028d099f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -17,6 +17,7 @@ set(TRITON_TEST_DEPENDS set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck") set(LIT_ARGS "-Dfilecheck=${FILECHECK_PATH}") + add_lit_testsuite(check-triton-lit-tests "Running the triton regression tests" ${CMAKE_CURRENT_BINARY_DIR} ARGS ${LIT_ARGS} diff --git a/test/Conversion/amd/amd-convert-builtin-func.mlir b/test/Conversion/amd/amd-convert-builtin-func.mlir deleted file mode 100644 index 9df0059a5..000000000 --- a/test/Conversion/amd/amd-convert-builtin-func.mlir +++ /dev/null @@ -1,63 +0,0 @@ -// RUN: triton-opt --convert-builtin-func-to-llvm %s | FileCheck %s - -// Trying to merge those blocks will cause a lot of duplication in the block arguments, which will cause -// an exponential growth of the argument length. Make sure we don't try to merge those blocks. -module { - llvm.func @rand() -> i1 - llvm.func @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(!llvm.ptr<1>, i32, i1) attributes {libname = "", libpath = ""} - - llvm.func @top(%arg0: i64, %1 : !llvm.ptr<1>, %2 : !llvm.ptr<1>, %3 : !llvm.ptr<1>, %4 : !llvm.ptr<1>) { - %0 = llvm.mlir.constant(0 : i64) : i64 - %10 = llvm.icmp "eq" %arg0, %0 : i64 - %true = llvm.mlir.constant(1 : i1) : i1 - %c = llvm.mlir.constant(1 : i32) : i32 - // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} - llvm.cond_br %10, ^bb1, ^bb14 - ^bb1: // pred: ^bb0 - %11 = llvm.call @rand() : () -> i1 - // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} - llvm.cond_br %11, ^bb2, ^bb3 - ^bb2: // pred: ^bb1 - llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%1, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () - llvm.br ^bb4 - ^bb3: // pred: ^bb1 - llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%2, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () - llvm.br ^bb4 - ^bb4: // 2 preds: ^bb2, ^bb3 - %14 = llvm.call @rand() : () -> i1 - // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} - llvm.cond_br %14, ^bb5, ^bb6 - ^bb5: // pred: ^bb4 - llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%3, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () - llvm.br ^bb13 - ^bb6: // pred: ^bb4 - llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%4, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () - llvm.br ^bb13 - ^bb13: // 2 preds: ^bb11, ^bb12 - llvm.br ^bb27 - ^bb14: // pred: ^bb0 - %23 = llvm.call @rand() : () -> i1 - // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} - llvm.cond_br %23, ^bb15, ^bb16 - ^bb15: // pred: ^bb14 - llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%4, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () - llvm.br ^bb17 - ^bb16: // pred: ^bb14 - llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%3, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () - llvm.br ^bb17 - ^bb17: // 2 preds: ^bb15, ^bb16 - %26 = llvm.call @rand() : () -> i1 - // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} - llvm.cond_br %26, ^bb18, ^bb19 - ^bb18: // pred: ^bb17 - llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%2, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () - llvm.br ^bb26 - ^bb19: // pred: ^bb17 - llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%1, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () - llvm.br ^bb26 - ^bb26: // 2 preds: ^bb24, ^bb25 - llvm.br ^bb27 - ^bb27: // 2 preds: ^bb13, ^bb26 - llvm.return - } -} diff --git a/test/Conversion/amd/buffer_load_store.mlir b/test/Conversion/amd/buffer_load_store.mlir new file mode 100644 index 000000000..209c7065d --- /dev/null +++ b/test/Conversion/amd/buffer_load_store.mlir @@ -0,0 +1,178 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load + tt.func @buffer_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { + // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 + // CHECK: %[[offset:.*]] = llvm.select %[[c_mask]] + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]] + %ret = amdgpu.buffer_load %arg0[%offset] : tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_mask + tt.func @buffer_load_mask(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + // CHECK: %[[mask:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // CHECK: %[[offset:.*]] = llvm.select %[[mask]] + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]] + %ret = amdgpu.buffer_load %arg0[%offset], %7: tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_mask_other + tt.func @buffer_load_mask_other(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + %other = arith.constant dense<0.00e+00> : tensor<128xf32, #blocked0> + // CHECK: %[[mask:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // CHECK: %[[offset:.*]] = llvm.select %[[mask]] + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]] + // CHECK: llvm.select + %ret = amdgpu.buffer_load %arg0[%offset], %7, %other: tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_store + tt.func @buffer_store(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}) { + // CHECK: %[[c_mask:.*]] = llvm.mlir.constant(true) : i1 + // CHECK: %[[offset:.*]] = llvm.select %[[c_mask]] + // CHECK: rocdl.raw.ptr.buffer.store {{.*}}, {{.*}}, %[[offset]] + amdgpu.buffer_store %value, %arg0[%offset] : tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_store_mask + tt.func @buffer_store_mask(%value : tensor<128xf32, #blocked0>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0> {tt.divisibility=16:i32}, %N : i32 {tt.divisibility = 16 : i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<128xi32, #blocked0> + %5 = tt.splat %N: i32 -> tensor<128xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<128xi32, #blocked0> + // CHECK: %[[mask0:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i1, i1, i1, i1)> + // CHECK: %[[mask1:.*]] = llvm.and %[[mask0]], {{.*}} + // CHECK: %[[offset:.*]] = llvm.select %[[mask1]] + // CHECK: rocdl.raw.ptr.buffer.store {{.*}}, {{.*}}, %[[offset]] + amdgpu.buffer_store %value, %arg0[%offset], %7: tensor<128xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec4 + tt.func @buffer_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + // Load 8 elements from A with two vectorized load instructions + // CHECK-COUNT-2: rocdl.raw.ptr.buffer.load {{.*}} : vector<4xf32> + %9 = amdgpu.buffer_load %arg0[%4] : tensor<256xf32, #blocked0> + // Load 8 elements from B with two vectorized load instructions + // CHECK-COUNT-2: rocdl.raw.ptr.buffer.load {{.*}} : vector<4xf32> + %10 = amdgpu.buffer_load %arg1[%4] : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + // Store 8 elements into C with two vectorized store instructions + // CHECK-COUNT-2: rocdl.raw.ptr.buffer.store {{.*}} : vector<4xf32> + amdgpu.buffer_store %11, %arg2[%4]: tensor<256xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec1 + tt.func @buffer_load_store_vec1(%arg0: !tt.ptr , %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg3 : i32 -> tensor<256xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<256xi32, #blocked0> + // Load 8 elements from A with eight scalar load instructions + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.load {{.*}} : f32 + %9 = amdgpu.buffer_load %arg0[%4], %7 : tensor<256xf32, #blocked0> + // Load 8 elements from B with two scalar load instructions + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.load {{.*}} : f32 + %10 = amdgpu.buffer_load %arg1[%4], %7 : tensor<256xf32, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + // Store 8 elements into C with two scalar store instructions + // CHECK-COUNT-8: rocdl.raw.ptr.buffer.store {{.*}} : f32 + amdgpu.buffer_store %11, %arg2[%4], %7 : tensor<256xf32, #blocked0> + tt.return + } +} + +// ----- + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec2 + tt.func @buffer_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr{tt.divisibility = 4 : i32}, %arg2: !tt.ptr{tt.divisibility = 4: i32}, %arg3: i32{tt.divisibility = 4: i32}) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg3 : i32 -> tensor<256xi32, #blocked0> + %7 = arith.cmpi slt, %4, %5: tensor<256xi32, #blocked0> + // Load 8 fp16 elements from A with four i32 scalar load instructions + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}} : i32 + %9 = amdgpu.buffer_load %arg0[%4], %7 : tensor<256xf16, #blocked0> + // Load 8 fp16 elements from B with four i32 scalar load instructions + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.load {{.*}} : i32 + %10 = amdgpu.buffer_load %arg1[%4], %7 : tensor<256xf16, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf16, #blocked0> + // Store 8 fp16 elements into C with four i32 scalar store instructionss + // CHECK-COUNT-4: rocdl.raw.ptr.buffer.store {{.*}} : i32 + amdgpu.buffer_store %11, %arg2[%4], %7 : tensor<256xf16, #blocked0> + tt.return + } +} diff --git a/test/Conversion/amd/builtin_func_to_llvm.mlir b/test/Conversion/amd/builtin_func_to_llvm.mlir new file mode 100644 index 000000000..06ef06c54 --- /dev/null +++ b/test/Conversion/amd/builtin_func_to_llvm.mlir @@ -0,0 +1,12 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm="ftz=True" | FileCheck %s --check-prefix=LLVM_FTZ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm="ftz=False" | FileCheck %s --check-prefix=LLVM_NO_FTZ + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @test_fast_expf(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.amdgcn.exp2.f32 + // LLVM_NO_FTZ: llvm.exp2.f32 + %0 = tt.extern_elementwise %arg0 {libname = "libdevice", libpath = "", pure = true, symbol = "__triton_hip_fast_expf"} : (tensor<64xf32, #blocked>) -> tensor<64xf32, #blocked> + tt.return + } +} diff --git a/test/Conversion/amd/compute-base-ptr.mlir b/test/Conversion/amd/compute-base-ptr.mlir new file mode 100644 index 000000000..809e5a869 --- /dev/null +++ b/test/Conversion/amd/compute-base-ptr.mlir @@ -0,0 +1,19 @@ +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 544 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @local_load_offset + tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) { + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1) + %1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> loc(#loc2) + // This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type. + // CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0 + %2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3) + tt.return + } +} +#loc1 = loc("conert_layout":1:0) +#loc2 = loc("local_alloc":2:0) +#loc3 = loc("local_load":3:0) diff --git a/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir b/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir new file mode 100644 index 000000000..f30e0aa6d --- /dev/null +++ b/test/Conversion/amd/decompose-unsupported-conversions-cdna.mlir @@ -0,0 +1,33 @@ +// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx942 | FileCheck %s + +// CHECK-DAG: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK: large_tensor_conversion +#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = false}> +#dst = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @large_tensor_conversion(%arg0: tensor<128x128xf32, #src>) { + // CHECK: %[[TMP:.*]] = triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #[[SRC_ENC]]> -> tensor<128x128xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = triton_gpu.convert_layout %[[TMP]] : tensor<128x128xf32, #[[TMP_ENC]]> -> tensor<128x128xf32, #[[DST_ENC]]> + %0 = triton_gpu.convert_layout %arg0 : tensor<128x128xf32, #src> -> tensor<128x128xf32, #dst> + tt.return + } +} + +// ----- + +// CHECK-DAG: #[[DST_ENC:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK-DAG: #[[SRC_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK-DAG: #[[TMP_ENC:.+]] = #triton_gpu.amd_mfma<{{.*}}> +// CHECK: large_tensor_3d_conversion +#src = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 1, 2], instrShape = [32, 32], isTransposed = false}> +#dst = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 64, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @large_tensor_3d_conversion(%arg0: tensor<2x128x64xf32, #src>) { + // CHECK: %[[TMP:.*]] = triton_gpu.convert_layout {{.*}} : tensor<2x128x64xf32, #[[SRC_ENC]]> -> tensor<2x128x64xf32, #[[TMP_ENC]]> + // CHECK: {{.*}} = triton_gpu.convert_layout %[[TMP]] : tensor<2x128x64xf32, #[[TMP_ENC]]> -> tensor<2x128x64xf32, #[[DST_ENC]]> + %0 = triton_gpu.convert_layout %arg0 : tensor<2x128x64xf32, #src> -> tensor<2x128x64xf32, #dst> + tt.return + } +} diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index b5fc5b72c..1bd288449 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -1,15 +1,105 @@ -// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx942 | FileCheck %s +// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s -// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> -// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}> -#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK-LABEL: wmma_to_wmma_dot_op +#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { - // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]]> - // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>> + // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return } } + +// ----- + +// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}> +// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}> +// CHECK-LABEL: wmma_to_wmma_dot3d_op +#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) { + // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory> + // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> + %0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.local_alloc + // CHECK: triton_gpu.convert_layout + // CHECK-NOT: triton_gpu.local_alloc + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot_op_shortcut_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.local_alloc + // CHECK: triton_gpu.convert_layout + // CHECK-NOT: triton_gpu.local_alloc + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: triton_gpu.local_alloc + // CHECK: triton_gpu.local_load + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: triton_gpu.local_alloc + // CHECK: triton_gpu.local_load + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx940 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: triton_gpu.local_alloc + // CHECK: triton_gpu.local_load + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + tt.return + } +} diff --git a/test/Conversion/amd/fp_to_fp.mlir b/test/Conversion/amd/fp_to_fp.mlir index ce27b6fbb..aaa70564f 100644 --- a/test/Conversion/amd/fp_to_fp.mlir +++ b/test/Conversion/amd/fp_to_fp.mlir @@ -9,3 +9,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +// CHECK-LABEL: bf16_to_f32 +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @bf16_to_f32(%arg0: tensor<8x8xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>) { + // CHECK-COUNT-8: llvm.bitcast + %0 = tt.fp_to_fp %arg0 : tensor<8x8xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> + tt.return + } +} diff --git a/test/Conversion/amd/load_store.mlir b/test/Conversion/amd/load_store.mlir index c71aa56c5..93796439b 100644 --- a/test/Conversion/amd/load_store.mlir +++ b/test/Conversion/amd/load_store.mlir @@ -15,10 +15,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // Load 8 elements from A with two vectorized load instruction - // CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr -> vector<4xf32> + // CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32> %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #blocked0> // Load 8 elements from B with two vectorized load instruction - // CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr -> vector<4xf32> + // CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32> %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #blocked0> %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index f3cffa707..ef6733845 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -9,8 +9,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.br // CHECK: rocdl.barrier // CHECK: llvm.load - // CHECK: rocdl.barrier - // CHECK: llvm.store + // CHECK: llvm.intr.masked.store %0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (!tt.ptr, f32, i1) -> f32 tt.store %arg0, %0 : !tt.ptr tt.return @@ -26,12 +25,40 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.cond_br // CHECK: llvm.atomicrmw // CHECK: llvm.atomicrmw - // CHECK: %[[ADDR1:.*]] = llvm.extractvalue - // CHECK: %[[ADDR2:.*]] = llvm.extractvalue - // CHECK: llvm.store %{{.*}}, %[[ADDR1]] - // CHECK: llvm.store %{{.*}}, %[[ADDR2]] + // CHECK: %[[ADDR1:.*]] = llvm.addrspacecast + // CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR1]] + // CHECK: %[[ADDR2:.*]] = llvm.addrspacecast + // CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR2]] %0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> tt.store %arg0, %0 : tensor<256x!tt.ptr, #blocked0> tt.return } } + +// ----- + +// Smoke test to check that mfma 32 and dot operand layouts can work with small tensors, for example with shape 16x16 +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> +#dotop1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: small_mfma_tensor_conversions + tt.func public @small_mfma_tensor_conversions(%arg0: tensor<16x16xf16, #mfma>, %arg1: tensor<16x16x!tt.ptr, #mfma>) { + // CHECK-NOT: triton_gpu.convert_layout + %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + // CHECK-4: store {{.*}} vector<4xf16> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop0> + // CHECK-2: load {{.*}} vector<4xf16> + %2 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop1> + // CHECK-8: load {{.*}} vector<1xf16> + %3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #mfma> + // CHECK-4: load {{.*}} vector<4xf16> + %4 = tt.fp_to_fp %3 : tensor<16x16xf16, #mfma> -> tensor<16x16xf32, #mfma> + + %5 = tt.dot %1, %2, %4 : tensor<16x16xf16, #dotop0> * tensor<16x16xf16, #dotop1> -> tensor<16x16xf32, #mfma> + // Store result to prevent DCE from removing all conversion related code + %6 = triton_gpu.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> + tt.return + } +} diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index 5a4ada339..5eb856bb9 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -1,34 +1,45 @@ // RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> -#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}> +#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#mma2 = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: wmma_dot_operand - tt.func @wmma_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) { + // CHECK-LABEL: wmma1_dot_operand + tt.func @wmma1_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) { // 2 CTA * 4 rep * load_per_thread_per_instr // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + %0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-128: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + %1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> tt.return } - // CHECK-LABEL: wmma_dot - tt.func @wmma_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma>) { + // CHECK-LABEL: wmma2_dot_operand + tt.func @wmma2_dot_operand(%arg0: !tt.memdesc<64x64xf16, #shared>) { + // 2 CTA * 4 rep * load_per_thread_per_instr + // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> + %0 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> + // CHECK-COUNT-64: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> + %1 = triton_gpu.local_load %arg0 : !tt.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> + tt.return + } + + // CHECK-LABEL: wmma1_dot + tt.func @wmma1_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) { // CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK: llvm.mlir.undef : vector<16xf16> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16> // CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xf16, #mma> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1> // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16> // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> tt.return } - // CHECK-LABEL: wmma_dot_bf16 - tt.func @wmma_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma>) { + // CHECK-LABEL: wmma1_dot_bf16 + tt.func @wmma1_dot_bf16(%arg0: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xbf16, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> // CHECK: llvm.bitcast %{{.*}} : vector<16xbf16> to vector<16xi16> // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)> @@ -37,12 +48,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.undef : vector<16xbf16> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16> // CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xbf16, #mma> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1> tt.return } - // CHECK-LABEL: wmma_dot_int8_32 - tt.func @wmma_dot_int8_32(%arg0: tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) { + // CHECK-LABEL: wmma1_dot_int8_32 + tt.func @wmma1_dot_int8_32(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8> // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> @@ -51,13 +62,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xui8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } - // CHECK-LABEL: wmma_dot_int4_32 - tt.func @wmma_dot_int4_32(%arg0: tensor<16x16xui4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<16x16xui4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma>) { + // CHECK-LABEL: wmma1_dot_int4_32 + tt.func @wmma1_dot_int4_32(%arg0: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) { // CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4, i4)> // CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4> // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> @@ -66,28 +77,44 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32> // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: rocdl.wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> - %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xui4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<16x16xui4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<16x16xi32, #mma> + %0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1> // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> tt.return } + + // CHECK-LABEL: wmma2_dot + tt.func @wmma2_dot(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>>, %arg2: tensor<16x16xf16, #mma2>) { + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK: llvm.mlir.undef : vector<8xf16> + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK: llvm.mlir.undef : vector<8xf16> + // CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK: llvm.mlir.undef : vector<8xf16> + // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16.v8f16"{{.*}} : (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> -> tensor<16x16xf16, #mma2> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf16> + // CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> + tt.return + } } // ----- #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> -#mma = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 1, 4]}> +#mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_dot_operand3d tt.func @wmma_dot_operand3d(%arg0: !tt.memdesc<4x16x32xf16, #shared>) { // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + %0 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + %1 = triton_gpu.local_load %arg0 : !tt.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> tt.return } // CHECK-LABEL: wmma_dot3d - tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma>) { + tt.func @wmma_dot3d(%arg0: tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<2x16x16xf16, #mma1>) { // CHECK-COUNT-32: llvm.extractvalue %arg0 // CHECK-COUNT-32: llvm.insertelement // CHECK-COUNT-32: llvm.extractvalue %arg1 @@ -95,7 +122,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-COUNT-8: llvm.extractvalue %arg2 // CHECK-COUNT-8: llvm.insertelement // CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<2x16x16xf16, #mma> + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1> // CHECK-COUNT-8: llvm.extractelement // CHECK-COUNT-8: llvm.insertvalue tt.return diff --git a/test/Conversion/nvgpu_to_llvm.mlir b/test/Conversion/nvgpu_to_llvm.mlir new file mode 100644 index 000000000..beaa4c952 --- /dev/null +++ b/test/Conversion/nvgpu_to_llvm.mlir @@ -0,0 +1,90 @@ +// RUN: triton-opt %s --convert-nv-gpu-to-llvm -split-input-file | FileCheck %s + +// CHECK-LABEL: @nvvm_syncs +llvm.func @nvvm_syncs() { + // CHECK: wgmma.fence.sync.aligned; + nvgpu.wgmma_fence + + // CHECK: wgmma.commit_group.sync.aligned; + nvgpu.wgmma_commit_group + + // CHECK: barrier.cluster.wait.aligned; + nvgpu.cluster_wait + + // CHECK: fence.proxy.async.shared::cta; + nvgpu.fence_async_shared {bCluster = false} + // CHECK: fence.proxy.async.shared::cluster; + nvgpu.fence_async_shared {bCluster = true} + + // CHECK: barrier.cluster.arrive.aligned; + nvgpu.cluster_arrive {relaxed = false} + // CHECK: barrier.cluster.arrive.relaxed.aligned; + nvgpu.cluster_arrive {relaxed = true} + + llvm.return +} + +// CHECK-LABEL: @cluster_id +llvm.func @cluster_id() -> i32 { + // CHECK: %cluster_ctaid.x; + // CHECK-SAME: %cluster_ctaid.y; + // CHECK-SAME: %cluster_ctaid.z; + // CHECK-SAME: %cluster_nctaid.x; + // CHECK-SAME: %cluster_nctaid.y; + %id = nvgpu.cluster_id + llvm.return %id : i32 +} + +// ----- + +// CHECK-LABEL: @st_matrix +llvm.func @st_matrix(%i: i32, %ptr: !llvm.ptr<3>) { + // CHECK: stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4}; + nvgpu.stmatrix %ptr, %i, %i, %i, %i : !llvm.ptr<3>, i32, i32, i32, i32 + llvm.return +} + +// ----- + +!struct_128xf32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32 +)> + +!struct_64xf32 = !llvm.struct<( + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32 +)> + +// CHECK-LABEL: @wgmma +llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) { +// CHECK: wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 +%false = llvm.mlir.constant(false) : i1 +%acc0 = nvgpu.wgmma %desc, %desc, %false { + eltTypeA = 3 : i32, + eltTypeB = 3 : i32, + eltTypeC = 7 : i32, + layoutA = 0 : i32, + layoutB = 1 : i32, + m = 64 : i32, + n = 256 : i32, + k = 32 : i32 +} : (i64, i64, i1) -> !struct_128xf32 + + // CHECK: // wait for regs: $0,$1,$2,{{.*}},$127 + // CHECK: wgmma.wait_group.sync.aligned 0; + %out = nvgpu.wgmma_wait_group %in {pendings = 0 : i32} : !struct_64xf32 + llvm.return +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 3bf06a836..34573f773 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -3,7 +3,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>) // Here the 128 comes from the 4 in module attribute multiples 32 - // CHECK: nvvm.kernel = 1 : ui1, nvvm.maxntid = array + // CHECK: nvvm.kernel = 1 : ui1, nvvm.reqntid = array tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { // CHECK: llvm.return tt.return @@ -355,7 +355,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.undef // CHECK: %[[T0:.*]] = llvm.extractvalue // CHECK: %[[T1:.*]] = llvm.extractvalue - %0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2> + %0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2> // CHECK: llvm.mlir.undef // CHECK: llvm.insertvalue %[[T0]] // CHECK: llvm.insertvalue %[[T1]] @@ -445,7 +445,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.addressof @global_smem // CHECK-NEXT: llvm.getelementptr // CHECK-NEXT: llvm.mlir.constant - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #shared0> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -475,8 +475,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: llvm.getelementptr %index = arith.constant 1 : i32 %zero = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x16x32xf32, #shared0> - %1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !tt.memdesc<128x16x32xf32, #shared0> -> !tt.memdesc<16x32xf32, #shared0> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !tt.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x32xf32, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -496,7 +496,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 8], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #slice1d0 = #triton_gpu.slice<{dim = 0, parent = #blocked1}> -#shared = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}> +#shared1D = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}> +#shared2D = #triton_gpu.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { // CHECK-LABEL: basic_insert_slice_async_1d tt.func @basic_insert_slice_async_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { @@ -506,7 +507,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0> %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> - %71 = triton_gpu.local_alloc : () -> !tt.memdesc<2x64xi64, #shared> + %71 = triton_gpu.local_alloc : () -> !tt.memdesc<2x64xi64, #shared2D, #triton_gpu.shared_memory, mutable> + %subview = triton_gpu.memdesc_subview %71[%c0_i32, %c0_i32] : + !tt.memdesc<2x64xi64, #shared2D, #triton_gpu.shared_memory, mutable> -> + !tt.memdesc<64xi64, #shared1D, #triton_gpu.shared_memory, mutable> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 @@ -517,7 +521,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.commit_group - %73 = triton_gpu.async_copy_global_to_local %66, %71 : tensor<64x!tt.ptr, #slice1d0> -> !tt.memdesc<2x64xi64, #shared> + %73 = triton_gpu.async_copy_global_to_local %66, %subview : tensor<64x!tt.ptr, #slice1d0> -> !tt.memdesc<64xi64, #shared1D, #triton_gpu.shared_memory, mutable> triton_gpu.async_commit_group %73 tt.return } @@ -550,16 +554,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x64x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL>, tensor<16x64xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf32, #A> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf32, #A, #triton_gpu.shared_memory, mutable> %index = arith.constant 1 : i32 - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10 - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;" + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;" // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !tt.memdesc<16x64xf32, #A> + %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !tt.memdesc<16x64xf32, #A, #triton_gpu.shared_memory, mutable> triton_gpu.async_commit_group tt.return } @@ -592,7 +594,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL>, tensor<16x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf32, #A> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf32, #A, #triton_gpu.shared_memory, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm @@ -605,7 +607,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !tt.memdesc<16x32xf32, #A> + %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !tt.memdesc<16x32xf32, #A, #triton_gpu.shared_memory, mutable> triton_gpu.async_commit_group tt.return } @@ -637,15 +639,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #A> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #A, #triton_gpu.shared_memory, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.mlir.constant(16 : i32) : i32 // CHECK: llvm.mul // CHECK: llvm.add - // CHECK: llvm.inline_asm - // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4;" // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm @@ -662,7 +663,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !tt.memdesc<32x32xf32, #A> + %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !tt.memdesc<32x32xf32, #A, #triton_gpu.shared_memory, mutable> triton_gpu.async_commit_group tt.return } @@ -706,39 +707,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_layout_blocked_blocked tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: nvvm.barrier0 - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> + // CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared + // CHECK-: nvvm.barrier0 + // CHECK-COUNT-8: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -753,15 +724,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_layout_blocked_blocked_vec tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: st.shared + // CHECK: llvm.inline_asm + // CHECK: st.shared // CHECK: nvvm.barrier0 // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -776,21 +745,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: st.shared // CHECK: nvvm.barrier0 // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> // CHECK: nvvm.barrier0 - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK: st.shared // CHECK: nvvm.barrier0 // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -806,14 +771,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { - %AA = triton_gpu.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0> - %BB = triton_gpu.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0> + %AA = triton_gpu.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> + %BB = triton_gpu.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> // CHECK: llvm.inline_asm // CHECK: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %AA_DOT = triton_gpu.local_load %AA : !tt.memdesc<16x16xf16, #shared0> -> tensor<16x16xf16, #dot_operand_a> - %BB_DOT = triton_gpu.local_load %BB : !tt.memdesc<16x16xf16, #shared0> -> tensor<16x16xf16, #dot_operand_b> + %AA_DOT = triton_gpu.local_load %AA : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = triton_gpu.local_load %BB : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> // CHECK: llvm.inline_asm @@ -843,13 +808,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_mmav2_block tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> - // CHECK: llvm.store - // CHECK-SAME: !llvm.ptr<3> + // CHECK: llvm.inline_asm + // CHECK-SAME: st.shared + // CHECK: llvm.inline_asm + // CHECK-SAME: st.shared // CHECK: nvvm.barrier0 // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0> tt.return } @@ -887,9 +851,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_mmav3_transpose tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) { - // CHECK-COUNT-128: llvm.store %{{.*}} : vector<1xi8>, !llvm.ptr<3> + // CHECK-COUNT-128: st.shared.b8 // CHECK: nvvm.barrier0 - // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8> + // CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32> %0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked> tt.return } @@ -906,7 +870,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-SAME: !llvm.ptr<3> // CHECK: llvm.store // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -918,7 +882,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice0 tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { - // CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xi32> + // CHECK: llvm.load {{.*}} -> vector<4xi32> %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> tt.return } @@ -931,7 +895,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { - // CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr<3> + // CHECK-COUNT-8: llvm.load {{.*}} -> i32 %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> tt.return } @@ -945,7 +909,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_blocked_to_blocked_ptr tt.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { // CHECK: llvm.ptrtoint - // CHECK: llvm.store + // CHECK: inline_asm{{.*}}st.shared // CHECK: nvvm.barrier0 // CHECK: llvm.inttoptr // CHECK-COUNT-4: llvm.insertvalue @@ -963,11 +927,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<128x32xf16, #shared>, %b:!tt.memdesc<32x256xf16, #shared>) { + %a:!tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = triton_gpu.local_load %a : !tt.memdesc<128x32xf16, #shared> -> tensor<128x32xf16, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<32x256xf16, #shared> -> tensor<32x256xf16, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x32xf16, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<32x256xf16, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> @@ -989,11 +953,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x64xf16, #shared0>, %b:!tt.memdesc<64x64xf16, #shared1>) { + %a:!tt.memdesc<32x64xf16, #shared0, #triton_gpu.shared_memory>, %b:!tt.memdesc<64x64xf16, #shared1, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x64xf16, #shared0> -> tensor<32x64xf16, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<64x64xf16, #shared1> -> tensor<64x64xf16, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x64xf16, #shared0, #triton_gpu.shared_memory> -> tensor<32x64xf16, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<64x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf16, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<32x64xf32, #mma> -> tensor<32x64xf32, #blocked> @@ -1012,11 +976,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x16xf32, #shared>, %b:!tt.memdesc<16x32xf32, #shared>) { + %a:!tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> // CHECK: llvm.intr.fmuladd - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = ieee : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> @@ -1036,7 +1000,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32dot tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x16xf32, #shared>, %b:!tt.memdesc<16x32xf32, #shared>) { + %a:!tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 @@ -1044,8 +1008,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 // CHECK-SAME: (i32, i32, i32, i32) - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 @@ -1068,7 +1032,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1082,7 +1046,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // CHECK: llvm.icmp "eq" @@ -1096,7 +1060,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm @@ -1110,6 +1074,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_nomask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} + +// ----- + #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32 @@ -1240,8 +1232,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-LABEL: test_base_index_cache tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -1253,10 +1245,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-LABEL: test_index_cache_different_block tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> cf.cond_br %arg1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 tt.return @@ -1507,6 +1499,44 @@ module attributes {"triton_gpu.target" = "cuda:70", "triton_gpu.num-ctas" = 1 : // ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @i16_mma_layout(%f16_inp: tensor<16x16xf16, #blocked0>, %i16_inp: tensor<16x16xi16, #blocked0>) { + // CHECK-LABEL: @i16_mma_layout + + %f16_shared = triton_gpu.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> + %i16_shared = triton_gpu.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !tt.memdesc<16x16xi16, #shared0, #triton_gpu.shared_memory> + + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 + + %f16_dot = triton_gpu.local_load %f16_shared : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_a> + %i16_dot = triton_gpu.local_load %i16_shared : !tt.memdesc<16x16xi16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xi16, #dot_operand_b> + + // CHECK: llvm.sitofp %{{.*}} : i16 to f16 + + %converted_i16 = arith.sitofp %i16_dot : tensor<16x16xi16, #dot_operand_b> to tensor<16x16xf16, #dot_operand_b> + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + + %out = tt.dot %f16_dot, %converted_i16, %cst0 : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma> + + tt.return + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> module attributes {"triton_gpu.target" = "cuda:75", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { @@ -1550,8 +1580,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: llvm.load // CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<8xi8> // CHECK-NOT: llvm.load - tt.func public @vectorize_shmem_load(%shmem : !tt.memdesc<16x16xi8, #shared>) { - %0 = triton_gpu.local_load %shmem : !tt.memdesc<16x16xi8, #shared> -> tensor<16x16xi8, #blocked> + tt.func public @vectorize_shmem_load(%shmem : !tt.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory>) { + %0 = triton_gpu.local_load %shmem : !tt.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory> -> tensor<16x16xi8, #blocked> tt.return } } @@ -1566,7 +1596,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK-SAME: {alignment = 64 : i64} : vector<16xi32>, !llvm.ptr<3> // CHECK-NOT: llvm.store tt.func public @vectorize_shmem_store(%block : tensor<64x64xi32, #blocked>) { - %0 = triton_gpu.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !tt.memdesc<64x64xi32, #shared> + %0 = triton_gpu.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !tt.memdesc<64x64xi32, #shared, #triton_gpu.shared_memory> tt.return } } @@ -1588,12 +1618,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_load_bf16 - // CHECK: llvm.extractelement {{.*}} : vector<8xi16> + // CHECK: llvm.extractelement {{.*}} : vector<8xbf16> tt.func public @test_local_load_bf16() { %c0_i32 = arith.constant 0 : i32 - %19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x1x2048xbf16, #shared, mutable> - %22 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x1x2048xbf16, #shared, mutable> -> !tt.memdesc<1x2048xbf16, #shared, mutable> - %39 = triton_gpu.local_load %22 : !tt.memdesc<1x2048xbf16, #shared, mutable> -> tensor<1x2048xbf16, #blocked> + %19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> + %22 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> + %39 = triton_gpu.local_load %22 : !tt.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<1x2048xbf16, #blocked> %40 = arith.extf %39 : tensor<1x2048xbf16, #blocked> to tensor<1x2048xf32, #blocked> tt.return } @@ -1607,8 +1637,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.store tt.func public @test_local_store(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, mutable> - triton_gpu.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, mutable> + %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> tt.return } } @@ -1621,9 +1651,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.store tt.func public @test_local_store_subview(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, mutable> - %sv = triton_gpu.memdesc_subview %0[%c0_i32] : !tt.memdesc<1xf32, #shared, mutable> -> !tt.memdesc<1xf32, #shared, mutable> - triton_gpu.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, mutable> + %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + %sv = triton_gpu.memdesc_subview %0[%c0_i32] : !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> tt.return } } @@ -1635,7 +1665,95 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: print_ptr // CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 tt.func @print_ptr(%arg0 : tensor<256x!tt.ptr, #blocked0>) { - tt.print "ptr: " {hex = false} : %arg0 : tensor<256x!tt.ptr, #blocked0> + tt.print "ptr: " {hex = false, isSigned = array} : %arg0 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // Test that %u format specifier is used if isSigned is false + // CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %u{{.*}}") + // CHECK-LABEL: print_int32_tensor_issigned_off + // CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 + tt.func @print_int32_tensor_issigned_off(%arg0 : i32) { + tt.print "int32 tensor: " {hex = false, isSigned = array} : %arg0 : i32 + tt.return + } +} + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // Test that %i format specifier is used if isSigned is true + // CHECK: llvm.mlir.global internal constant @printfFormat_0("{{.*}}int32 tensor: %i{{.*}}") + // CHECK-LABEL: print_int32_tensor_issigned_on + // CHECK: llvm.call @vprintf(%{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 + tt.func @print_int32_tensor_issigned_on(%arg0 : i32) { + tt.print "int32 tensor: " {hex = false, isSigned = array} : %arg0 : i32 + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func @int32_to_bf16(%arg0: tensor<256xi32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: @int32_to_bf16 + // CHECK: llvm.sitofp %{{.*}} : i32 to bf16 + %a = arith.sitofp %arg0 : tensor<256xi32, #blocked> to tensor<256xbf16, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func @bf16_to_int32(%arg0: tensor<256xbf16, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: @bf16_to_int32 + // CHECK: llvm.fptosi %{{.*}} : bf16 to i32 + %a = arith.fptosi %arg0 : tensor<256xbf16, #blocked> to tensor<256xi32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32} +// CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32} +// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32} +// CHECK: llvm.call @__assertfail +// CHECK: nvvm.barrier0 +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) { + tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5) + tt.return + } +} +#loc1 = loc("outer_call":33:8) +#loc2 = loc("top_func":47:8) +#loc3 = loc("inner_call":29:28) +#loc4 = loc(callsite(#loc3 at #loc1)) +#loc5 = loc(callsite(#loc4 at #loc2)) + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @log1pf_scan(%39: tensor<32x16xf32, #blocked>) attributes {noinline = false} { + // CHECK: log1pf_scan + // non-speculatable ops will introduce a cond_br; extern_elementwise with pure = true should be considered speculatable. + // CHECK-NOT: llvm.cond_br + %40 = "tt.scan"(%39) <{axis = 1 : i32, reverse = false}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %43 = tt.extern_elementwise %arg5 {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (f32) -> f32 + %44 = arith.addf %43, %43 : f32 + tt.scan.return %44 : f32 + }) : (tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked> tt.return } } diff --git a/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir new file mode 100644 index 000000000..49128064a --- /dev/null +++ b/test/Conversion/tritongpu_to_llvm_block_dot_shortcut.mlir @@ -0,0 +1,47 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s + +// CHECK-LABEL: blocked_to_dot_op_shortcut_warp32 +#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp32(%arg0: tensor<32x32xf16, #blocked>, %arg1: tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot_op_shortcut_warp64 +#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [2, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot_op_shortcut_warp64(%arg0: tensor<32x32xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp32 +#blocked = #triton_gpu.blocked<{sizePerThread = [2, 32, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 1, 2], order = [1, 2, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @blocked_to_dot3d_op_shortcut_warp32(%arg0: tensor<8x32x32xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} + +// ----- + +// CHECK-LABEL: blocked_to_dot3d_op_shortcut_warp64 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32, 1], threadsPerWarp = [1, 2, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @blocked_to_dot3d_op_shortcut_warp64(%arg0: tensor<8x32x32xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<8x32x32xf16, #blocked> -> tensor<8x32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK-NOT: load + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 743d554a3..113ec3cf6 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' 2>&1 | FileCheck %s #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> #shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> @@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-COUNT-128: llvm.fadd // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd - %m = triton_nvidia_gpu.dot_async %a, %b, %c + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 32 : i32, inputPrecision = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return @@ -38,7 +38,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: llvm.return - %m = triton_nvidia_gpu.dot_async %a, %b, %c + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 129 : i32, inputPrecision = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return @@ -62,7 +62,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd // CHECK: llvm.return - %m = triton_nvidia_gpu.dot_async %a, %b, %c + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> tt.return @@ -80,7 +80,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} { tt.func @dot_zero_acc(%a: !tt.memdesc<128x64xf16, #shared>, %b: !tt.memdesc<64x64xf16, #shared1>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %m = triton_nvidia_gpu.dot_async %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x64xf16, #shared1> -> tensor<128x64xf32, #mma> tt.return } @@ -93,12 +93,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A // Generate a wgmma where the first operand is a struct. - // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - %m = tt.dot %opA, %b, %cst, inputPrecision = tf32 : + %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } @@ -112,11 +112,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A_fp8 // Generate a wgmma where the first operand is a struct. - // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> - %m = tt.dot %a, %b, %cst, inputPrecision = tf32 {maxNumImpreciseAcc = 1073741824 : i32} : + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> tt.return } @@ -129,24 +129,24 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_fp8_to_f16_conversion tt.func @test_fp8_to_f16_conversion( - %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FNUZ, #blocked>, + %in0: tensor<128xf8E5M2, #blocked>, %in1: tensor<128xf8E4M3FN, #blocked>, %in2: tensor<128xf16, #blocked>, %in3: tensor<128xf32, #blocked>) { // CHECK-COUNT-2: cvt.rn.f16x2.e5m2x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> %out0 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xf16, #blocked> // CHECK-COUNT-2: cvt.rn.f16x2.e4m3x2 {{.*}} "=r,h" %{{.*}} : (i16) -> vector<2xf16> - %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FNUZ, #blocked> -> tensor<128xf16, #blocked> + %out1 = tt.fp_to_fp %in1 : tensor<128xf8E4M3FN, #blocked> -> tensor<128xf16, #blocked> // CHECK-COUNT-2: mul.rn.bf16x2 %out2 = tt.fp_to_fp %in0 : tensor<128xf8E5M2, #blocked> -> tensor<128xbf16, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> %out3 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E5M2, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f16x2 {{.*}} "=h,r" %{{.*}} : (i32) -> vector<2xi8> - %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + %out4 = tt.fp_to_fp %in2, rounding = rtne : tensor<128xf16, #blocked> -> tensor<128xf8E4M3FN, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e5m2x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> %out5 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E5M2, #blocked> // CHECK-COUNT-2: cvt.rn.satfinite.e4m3x2.f32 {{.*}} "=h,r,r" %{{.*}}, %{{.*}} : (i32, i32) -> vector<2xi8> - %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FNUZ, #blocked> + %out6 = tt.fp_to_fp %in3, rounding = rtne : tensor<128xf32, #blocked> -> tensor<128xf8E4M3FN, #blocked> tt.return } } @@ -208,7 +208,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-COUNT-128: llvm.fadd tt.func @dot_zero_acc_operand(%a: !tt.memdesc<128x128xf8E5M2, #shared>, %b: !tt.memdesc<128x128xf8E5M2, #shared1>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> - %m = tt.dot %a, %b, %cst, inputPrecision = tf32 {maxNumImpreciseAcc = 64 : i32} : + %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : !tt.memdesc<128x128xf8E5M2, #shared> * !tt.memdesc<128x128xf8E5M2, #shared1> -> tensor<128x128xf32, #mma> tt.return } @@ -228,3 +228,54 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @fp8_const(%arg0: tensor<1024xi1, #blocked>, %arg1: tensor<1024xf8E4M3FNUZ, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: @fp8_const + // CHECK: llvm.mlir.constant(0.000000e+00 : f8E4M3FNUZ) : i8 + %cst = arith.constant dense<0.000000e+00> : tensor<1024xf8E4M3FNUZ, #blocked> + %a = arith.select %arg0, %arg1, %cst : tensor<1024xi1, #blocked>, tensor<1024xf8E4M3FNUZ, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_nomask + // CHECK: atom.global.gpu.acq_rel.add.v4.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_withmask + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + // CHECK: atom.global.gpu.acq_rel.add.v2.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + // CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} diff --git a/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir b/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir new file mode 100644 index 000000000..906c61002 --- /dev/null +++ b/test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir @@ -0,0 +1,44 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=80' 2>&1 | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_nomask + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f32_withmask + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + // CHECK: atom.global.gpu.acq_rel.add.f32 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} { + // CHECK-LABEL: atomic_add_f16_withmask + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + // CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2 + %0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked> + tt.return + } +} diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 6ba8add3b..0bcab369f 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -35,8 +35,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: "@$0 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4}], [$5];", "b,r,l,r,r,r" {{.*}} : (i1, !llvm.ptr<3>, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void // CHECK-NOT: cp.async.bulk.tensor.2d.shared // CHECK: return - tt.func @tma_copy_global_to_local(%tma: !tt.ptr, %alloc: !tt.memdesc<128x128xf32, #shared1>, %x: i32, %barrier: !tt.memdesc<1xi64, #shared0>, %pred: i1) { - triton_nvidia_gpu.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.ptr, !tt.memdesc<1xi64, #shared0> -> !tt.memdesc<128x128xf32, #shared1> + tt.func @tma_copy_global_to_local(%tma: !tt.ptr, %alloc: !tt.memdesc<128x128xf32, #shared1, mutable>, %x: i32, %barrier: !tt.memdesc<1xi64, #shared0>, %pred: i1) { + triton_nvidia_gpu.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.ptr, !tt.memdesc<1xi64, #shared0> -> !tt.memdesc<128x128xf32, #shared1, mutable> tt.return } } @@ -79,3 +79,81 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: byval_tma_desc + // CHECK: llvm.align = 64 + // CHECK: llvm.byval = !llvm.array<128 x i8> + // CHECK: nvvm.grid_constant + tt.func @byval_tma_desc(%desc: !tt.ptr {tt.nv_tma_desc = 1 : i32}) { + tt.return + } +} + +// ----- + +// CHECK-LABEL: device_tensormap_create1d +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @device_tensormap_create1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c256_i32 = arith.constant 256 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + // CHECK: st.shared.b32 + // CHECK: bar.warp.sync + // CHECK: tensormap.replace.tile.global_address.shared::cta.b1024.b64 [ $0 + 0 ], $1; + // CHECK: tensormap.replace.tile.rank.shared::cta.b1024.b32 [ $0 + 0 ], 0x0; + // CHECK: tensormap.replace.tile.box_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.element_stride.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.elemtype.shared::cta.b1024.b32 [ $0 + 0 ], 0x3; + // CHECK: tensormap.replace.tile.interleave_layout.shared::cta.b1024.b32 [ $0 + 0 ], 0x0; + // CHECK: tensormap.replace.tile.swizzle_mode.shared::cta.b1024.b32 [ $0 + 0 ], 0x2; + // CHECK: tensormap.replace.tile.fill_mode.shared::cta.b1024.b32 [ $0 + 0 ], 0x1; + // CHECK: tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [ $0 + 0 ], [ $1 + 0 ], 0x80; + tt.experimental_tensormap_create %arg1, %arg0, [%c256_i32], [%arg2], [], [%c1_i32] {elem_type = 3 : i32, fill_mode = 1 : i32, interleave_layout = 0 : i32, swizzle_mode = 2 : i32, allocation.offset = 0 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32) -> () + tt.return + } +} + +// ----- + +// CHECK-LABEL: device_tensormap_create2d +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @device_tensormap_create2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c256_i32 = arith.constant 256 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1024_i64 = arith.constant 1024 : i64 + // CHECK: st.shared.b32 + // CHECK: bar.warp.sync + // CHECK: tensormap.replace.tile.global_address.shared::cta.b1024.b64 [ $0 + 0 ], $1; + // CHECK: tensormap.replace.tile.rank.shared::cta.b1024.b32 [ $0 + 0 ], 0x1; + // CHECK: tensormap.replace.tile.box_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.box_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x1, $1; + // CHECK: tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [ $0 + 0 ], 0x1, $1; + // CHECK: tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.element_stride.shared::cta.b1024.b32 [ $0 + 0 ], 0x0, $1; + // CHECK: tensormap.replace.tile.element_stride.shared::cta.b1024.b32 [ $0 + 0 ], 0x1, $1; + // CHECK: tensormap.replace.tile.elemtype.shared::cta.b1024.b32 [ $0 + 0 ], 0x3; + // CHECK: tensormap.replace.tile.interleave_layout.shared::cta.b1024.b32 [ $0 + 0 ], 0x0; + // CHECK: tensormap.replace.tile.swizzle_mode.shared::cta.b1024.b32 [ $0 + 0 ], 0x2; + // CHECK: tensormap.replace.tile.fill_mode.shared::cta.b1024.b32 [ $0 + 0 ], 0x1; + // CHECK: tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [ $0 + 0 ], [ $1 + 0 ], 0x80; + tt.experimental_tensormap_create %arg1, %arg0, [%c256_i32, %c256_i32], [%arg2, %arg2], [%c1024_i64], [%c1_i32, %c1_i32] {elem_type = 3 : i32, fill_mode = 1 : i32, interleave_layout = 0 : i32, swizzle_mode = 2 : i32, allocation.offset = 0 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () + tt.return + } +} + +// ----- + +// CHECK-LABEL: tensormap_fenceproxy_acquire +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @tensormap_fenceproxy_acquire(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + // CHECK: fence.proxy.tensormap::generic.acquire.gpu [ $0 + 0 ], 0x80; + tt.experimental_tensormap_fenceproxy_acquire %arg0 : !tt.ptr + tt.return + } +} diff --git a/test/NVGPU/test_cga.mlir b/test/NVGPU/test_cga.mlir deleted file mode 100644 index f67e49fcc..000000000 --- a/test/NVGPU/test_cga.mlir +++ /dev/null @@ -1,13 +0,0 @@ -// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s -#SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { - tt.func @test_mbarrier() { - %ptr = llvm.mlir.zero : !llvm.ptr<3> - - // CHECK: llvm.inline_asm - %v = nvgpu.cluster_id - llvm.store %v, %ptr : i32, !llvm.ptr<3> - - tt.return - } -} // end module diff --git a/test/NVGPU/test_wgmma.mlir b/test/NVGPU/test_wgmma.mlir deleted file mode 100644 index 05c6ad597..000000000 --- a/test/NVGPU/test_wgmma.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: triton-opt %s -split-input-file --convert-nv-gpu-to-llvm | FileCheck %s - -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { - tt.func @wgmma_no_acc(%descA: i64, %descB: i64) { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127}, $128, $129, 0, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l" %{{.*}}, %{{.*}} : (i64, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> - %acc0 = nvgpu.wgmma %descA, %descB - {eltTypeA = 3 : i32, eltTypeB = 3 : i32, eltTypeC = 7 : i32, k = 32 : i32, layoutA = 0 : i32, layoutB = 1 : i32, m = 64 : i32, n = 256 : i32} : - (i64, i64) -> - !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> - tt.return - } -} - -// ----- - -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { - tt.func @wgmma_wait(%in: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) { - // CHECK: // wait for regs: $0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63 - // CHECK: wgmma.wait_group.sync.aligned 0; - %out = nvgpu.wgmma_wait_group %in {pendings = 0 : i32} : - !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> - tt.return - } -} diff --git a/test/Tools/tensor_layout_print.mlir b/test/Tools/tensor_layout_print.mlir new file mode 100644 index 000000000..80c019593 --- /dev/null +++ b/test/Tools/tensor_layout_print.mlir @@ -0,0 +1,58 @@ +// RUN: triton-tensor-layout -i %s -alias-names="blocked" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-BLOCKED + +// RUN: triton-tensor-layout -i %s -alias-names="mfma" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA + +// RUN: triton-tensor-layout -l "#triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>" -t "tensor<16x16xf16>" | FileCheck %s --check-prefix=CHECK-MFMA + +// RUN: triton-tensor-layout -i %s -alias-names="mfma" -t "tensor<16x16xf16>" -use-hw-view | FileCheck %s --check-prefix=CHECK-HW + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +tt.func @print(%A : !tt.ptr) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #blocked> + %cst1 = arith.constant dense<0.00e+00> : tensor<16x16xf16, #mfma> + tt.return +} + +// CHECK-BLOCKED: Print layout attribute: #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-BLOCKED: T0:0| T4:0, T0:1| T4:1, T0:2| T4:2, T0:3| T4:3, T1:0| T5:0, T1:1| T5:1, T1:2| T5:2, T1:3| T5:3, T2:0| T6:0, T2:1| T6:1, T2:2| T6:2, T2:3| T6:3, T3:0| T7:0, T3:1| T7:1, T3:2| T7:2, T3:3| T7:3 +// CHECK-BLOCKED: T8:0| T12:0, T8:1| T12:1, T8:2| T12:2, T8:3| T12:3, T9:0| T13:0, T9:1| T13:1, T9:2| T13:2, T9:3| T13:3, T10:0| T14:0, T10:1| T14:1, T10:2| T14:2, T10:3| T14:3, T11:0| T15:0, T11:1| T15:1, T11:2| T15:2, T11:3| T15:3 +// CHECK-BLOCKED: T16:0| T20:0, T16:1| T20:1, T16:2| T20:2, T16:3| T20:3, T17:0| T21:0, T17:1| T21:1, T17:2| T21:2, T17:3| T21:3, T18:0| T22:0, T18:1| T22:1, T18:2| T22:2, T18:3| T22:3, T19:0| T23:0, T19:1| T23:1, T19:2| T23:2, T19:3| T23:3 +// CHECK-BLOCKED: T24:0| T28:0, T24:1| T28:1, T24:2| T28:2, T24:3| T28:3, T25:0| T29:0, T25:1| T29:1, T25:2| T29:2, T25:3| T29:3, T26:0| T30:0, T26:1| T30:1, T26:2| T30:2, T26:3| T30:3, T27:0| T31:0, T27:1| T31:1, T27:2| T31:2, T27:3| T31:3 +// CHECK-BLOCKED: T32:0| T36:0, T32:1| T36:1, T32:2| T36:2, T32:3| T36:3, T33:0| T37:0, T33:1| T37:1, T33:2| T37:2, T33:3| T37:3, T34:0| T38:0, T34:1| T38:1, T34:2| T38:2, T34:3| T38:3, T35:0| T39:0, T35:1| T39:1, T35:2| T39:2, T35:3| T39:3 +// CHECK-BLOCKED: T40:0| T44:0, T40:1| T44:1, T40:2| T44:2, T40:3| T44:3, T41:0| T45:0, T41:1| T45:1, T41:2| T45:2, T41:3| T45:3, T42:0| T46:0, T42:1| T46:1, T42:2| T46:2, T42:3| T46:3, T43:0| T47:0, T43:1| T47:1, T43:2| T47:2, T43:3| T47:3 +// CHECK-BLOCKED: T48:0| T52:0, T48:1| T52:1, T48:2| T52:2, T48:3| T52:3, T49:0| T53:0, T49:1| T53:1, T49:2| T53:2, T49:3| T53:3, T50:0| T54:0, T50:1| T54:1, T50:2| T54:2, T50:3| T54:3, T51:0| T55:0, T51:1| T55:1, T51:2| T55:2, T51:3| T55:3 +// CHECK-BLOCKED: T56:0| T60:0, T56:1| T60:1, T56:2| T60:2, T56:3| T60:3, T57:0| T61:0, T57:1| T61:1, T57:2| T61:2, T57:3| T61:3, T58:0| T62:0, T58:1| T62:1, T58:2| T62:2, T58:3| T62:3, T59:0| T63:0, T59:1| T63:1, T59:2| T63:2, T59:3| T63:3 +// CHECK-BLOCKED: T64:0| T68:0, T64:1| T68:1, T64:2| T68:2, T64:3| T68:3, T65:0| T69:0, T65:1| T69:1, T65:2| T69:2, T65:3| T69:3, T66:0| T70:0, T66:1| T70:1, T66:2| T70:2, T66:3| T70:3, T67:0| T71:0, T67:1| T71:1, T67:2| T71:2, T67:3| T71:3 +// CHECK-BLOCKED: T72:0| T76:0, T72:1| T76:1, T72:2| T76:2, T72:3| T76:3, T73:0| T77:0, T73:1| T77:1, T73:2| T77:2, T73:3| T77:3, T74:0| T78:0, T74:1| T78:1, T74:2| T78:2, T74:3| T78:3, T75:0| T79:0, T75:1| T79:1, T75:2| T79:2, T75:3| T79:3 +// CHECK-BLOCKED: T80:0| T84:0, T80:1| T84:1, T80:2| T84:2, T80:3| T84:3, T81:0| T85:0, T81:1| T85:1, T81:2| T85:2, T81:3| T85:3, T82:0| T86:0, T82:1| T86:1, T82:2| T86:2, T82:3| T86:3, T83:0| T87:0, T83:1| T87:1, T83:2| T87:2, T83:3| T87:3 +// CHECK-BLOCKED: T88:0| T92:0, T88:1| T92:1, T88:2| T92:2, T88:3| T92:3, T89:0| T93:0, T89:1| T93:1, T89:2| T93:2, T89:3| T93:3, T90:0| T94:0, T90:1| T94:1, T90:2| T94:2, T90:3| T94:3, T91:0| T95:0, T91:1| T95:1, T91:2| T95:2, T91:3| T95:3 +// CHECK-BLOCKED: T96:0|T100:0, T96:1|T100:1, T96:2|T100:2, T96:3|T100:3, T97:0|T101:0, T97:1|T101:1, T97:2|T101:2, T97:3|T101:3, T98:0|T102:0, T98:1|T102:1, T98:2|T102:2, T98:3|T102:3, T99:0|T103:0, T99:1|T103:1, T99:2|T103:2, T99:3|T103:3 +// CHECK-BLOCKED: T104:0|T108:0, T104:1|T108:1, T104:2|T108:2, T104:3|T108:3, T105:0|T109:0, T105:1|T109:1, T105:2|T109:2, T105:3|T109:3, T106:0|T110:0, T106:1|T110:1, T106:2|T110:2, T106:3|T110:3, T107:0|T111:0, T107:1|T111:1, T107:2|T111:2, T107:3|T111:3 +// CHECK-BLOCKED: T112:0|T116:0, T112:1|T116:1, T112:2|T116:2, T112:3|T116:3, T113:0|T117:0, T113:1|T117:1, T113:2|T117:2, T113:3|T117:3, T114:0|T118:0, T114:1|T118:1, T114:2|T118:2, T114:3|T118:3, T115:0|T119:0, T115:1|T119:1, T115:2|T119:2, T115:3|T119:3 +// CHECK-BLOCKED: T120:0|T124:0, T120:1|T124:1, T120:2|T124:2, T120:3|T124:3, T121:0|T125:0, T121:1|T125:1, T121:2|T125:2, T121:3|T125:3, T122:0|T126:0, T122:1|T126:1, T122:2|T126:2, T122:3|T126:3, T123:0|T127:0, T123:1|T127:1, T123:2|T127:2, T123:3|T127:3 + + +// CHECK-MFMA: Print layout attribute: {{.*}}#triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +// CHECK-MFMA: T0:0| T64:0|T128:0|T192:0, T0:1| T64:1|T128:1|T192:1, T0:2| T64:2|T128:2|T192:2, T0:3| T64:3|T128:3|T192:3, T16:0| T80:0|T144:0|T208:0, T16:1| T80:1|T144:1|T208:1, T16:2| T80:2|T144:2|T208:2, T16:3| T80:3|T144:3|T208:3, T32:0| T96:0|T160:0|T224:0, T32:1| T96:1|T160:1|T224:1, T32:2| T96:2|T160:2|T224:2, T32:3| T96:3|T160:3|T224:3, T48:0|T112:0|T176:0|T240:0, T48:1|T112:1|T176:1|T240:1, T48:2|T112:2|T176:2|T240:2, T48:3|T112:3|T176:3|T240:3 +// CHECK-MFMA: T1:0| T65:0|T129:0|T193:0, T1:1| T65:1|T129:1|T193:1, T1:2| T65:2|T129:2|T193:2, T1:3| T65:3|T129:3|T193:3, T17:0| T81:0|T145:0|T209:0, T17:1| T81:1|T145:1|T209:1, T17:2| T81:2|T145:2|T209:2, T17:3| T81:3|T145:3|T209:3, T33:0| T97:0|T161:0|T225:0, T33:1| T97:1|T161:1|T225:1, T33:2| T97:2|T161:2|T225:2, T33:3| T97:3|T161:3|T225:3, T49:0|T113:0|T177:0|T241:0, T49:1|T113:1|T177:1|T241:1, T49:2|T113:2|T177:2|T241:2, T49:3|T113:3|T177:3|T241:3 +// CHECK-MFMA: T2:0| T66:0|T130:0|T194:0, T2:1| T66:1|T130:1|T194:1, T2:2| T66:2|T130:2|T194:2, T2:3| T66:3|T130:3|T194:3, T18:0| T82:0|T146:0|T210:0, T18:1| T82:1|T146:1|T210:1, T18:2| T82:2|T146:2|T210:2, T18:3| T82:3|T146:3|T210:3, T34:0| T98:0|T162:0|T226:0, T34:1| T98:1|T162:1|T226:1, T34:2| T98:2|T162:2|T226:2, T34:3| T98:3|T162:3|T226:3, T50:0|T114:0|T178:0|T242:0, T50:1|T114:1|T178:1|T242:1, T50:2|T114:2|T178:2|T242:2, T50:3|T114:3|T178:3|T242:3 +// CHECK-MFMA: T3:0| T67:0|T131:0|T195:0, T3:1| T67:1|T131:1|T195:1, T3:2| T67:2|T131:2|T195:2, T3:3| T67:3|T131:3|T195:3, T19:0| T83:0|T147:0|T211:0, T19:1| T83:1|T147:1|T211:1, T19:2| T83:2|T147:2|T211:2, T19:3| T83:3|T147:3|T211:3, T35:0| T99:0|T163:0|T227:0, T35:1| T99:1|T163:1|T227:1, T35:2| T99:2|T163:2|T227:2, T35:3| T99:3|T163:3|T227:3, T51:0|T115:0|T179:0|T243:0, T51:1|T115:1|T179:1|T243:1, T51:2|T115:2|T179:2|T243:2, T51:3|T115:3|T179:3|T243:3 +// CHECK-MFMA: T4:0| T68:0|T132:0|T196:0, T4:1| T68:1|T132:1|T196:1, T4:2| T68:2|T132:2|T196:2, T4:3| T68:3|T132:3|T196:3, T20:0| T84:0|T148:0|T212:0, T20:1| T84:1|T148:1|T212:1, T20:2| T84:2|T148:2|T212:2, T20:3| T84:3|T148:3|T212:3, T36:0|T100:0|T164:0|T228:0, T36:1|T100:1|T164:1|T228:1, T36:2|T100:2|T164:2|T228:2, T36:3|T100:3|T164:3|T228:3, T52:0|T116:0|T180:0|T244:0, T52:1|T116:1|T180:1|T244:1, T52:2|T116:2|T180:2|T244:2, T52:3|T116:3|T180:3|T244:3 +// CHECK-MFMA: T5:0| T69:0|T133:0|T197:0, T5:1| T69:1|T133:1|T197:1, T5:2| T69:2|T133:2|T197:2, T5:3| T69:3|T133:3|T197:3, T21:0| T85:0|T149:0|T213:0, T21:1| T85:1|T149:1|T213:1, T21:2| T85:2|T149:2|T213:2, T21:3| T85:3|T149:3|T213:3, T37:0|T101:0|T165:0|T229:0, T37:1|T101:1|T165:1|T229:1, T37:2|T101:2|T165:2|T229:2, T37:3|T101:3|T165:3|T229:3, T53:0|T117:0|T181:0|T245:0, T53:1|T117:1|T181:1|T245:1, T53:2|T117:2|T181:2|T245:2, T53:3|T117:3|T181:3|T245:3 +// CHECK-MFMA: T6:0| T70:0|T134:0|T198:0, T6:1| T70:1|T134:1|T198:1, T6:2| T70:2|T134:2|T198:2, T6:3| T70:3|T134:3|T198:3, T22:0| T86:0|T150:0|T214:0, T22:1| T86:1|T150:1|T214:1, T22:2| T86:2|T150:2|T214:2, T22:3| T86:3|T150:3|T214:3, T38:0|T102:0|T166:0|T230:0, T38:1|T102:1|T166:1|T230:1, T38:2|T102:2|T166:2|T230:2, T38:3|T102:3|T166:3|T230:3, T54:0|T118:0|T182:0|T246:0, T54:1|T118:1|T182:1|T246:1, T54:2|T118:2|T182:2|T246:2, T54:3|T118:3|T182:3|T246:3 +// CHECK-MFMA: T7:0| T71:0|T135:0|T199:0, T7:1| T71:1|T135:1|T199:1, T7:2| T71:2|T135:2|T199:2, T7:3| T71:3|T135:3|T199:3, T23:0| T87:0|T151:0|T215:0, T23:1| T87:1|T151:1|T215:1, T23:2| T87:2|T151:2|T215:2, T23:3| T87:3|T151:3|T215:3, T39:0|T103:0|T167:0|T231:0, T39:1|T103:1|T167:1|T231:1, T39:2|T103:2|T167:2|T231:2, T39:3|T103:3|T167:3|T231:3, T55:0|T119:0|T183:0|T247:0, T55:1|T119:1|T183:1|T247:1, T55:2|T119:2|T183:2|T247:2, T55:3|T119:3|T183:3|T247:3 +// CHECK-MFMA: T8:0| T72:0|T136:0|T200:0, T8:1| T72:1|T136:1|T200:1, T8:2| T72:2|T136:2|T200:2, T8:3| T72:3|T136:3|T200:3, T24:0| T88:0|T152:0|T216:0, T24:1| T88:1|T152:1|T216:1, T24:2| T88:2|T152:2|T216:2, T24:3| T88:3|T152:3|T216:3, T40:0|T104:0|T168:0|T232:0, T40:1|T104:1|T168:1|T232:1, T40:2|T104:2|T168:2|T232:2, T40:3|T104:3|T168:3|T232:3, T56:0|T120:0|T184:0|T248:0, T56:1|T120:1|T184:1|T248:1, T56:2|T120:2|T184:2|T248:2, T56:3|T120:3|T184:3|T248:3 +// CHECK-MFMA: T9:0| T73:0|T137:0|T201:0, T9:1| T73:1|T137:1|T201:1, T9:2| T73:2|T137:2|T201:2, T9:3| T73:3|T137:3|T201:3, T25:0| T89:0|T153:0|T217:0, T25:1| T89:1|T153:1|T217:1, T25:2| T89:2|T153:2|T217:2, T25:3| T89:3|T153:3|T217:3, T41:0|T105:0|T169:0|T233:0, T41:1|T105:1|T169:1|T233:1, T41:2|T105:2|T169:2|T233:2, T41:3|T105:3|T169:3|T233:3, T57:0|T121:0|T185:0|T249:0, T57:1|T121:1|T185:1|T249:1, T57:2|T121:2|T185:2|T249:2, T57:3|T121:3|T185:3|T249:3 +// CHECK-MFMA: T10:0| T74:0|T138:0|T202:0, T10:1| T74:1|T138:1|T202:1, T10:2| T74:2|T138:2|T202:2, T10:3| T74:3|T138:3|T202:3, T26:0| T90:0|T154:0|T218:0, T26:1| T90:1|T154:1|T218:1, T26:2| T90:2|T154:2|T218:2, T26:3| T90:3|T154:3|T218:3, T42:0|T106:0|T170:0|T234:0, T42:1|T106:1|T170:1|T234:1, T42:2|T106:2|T170:2|T234:2, T42:3|T106:3|T170:3|T234:3, T58:0|T122:0|T186:0|T250:0, T58:1|T122:1|T186:1|T250:1, T58:2|T122:2|T186:2|T250:2, T58:3|T122:3|T186:3|T250:3 +// CHECK-MFMA: T11:0| T75:0|T139:0|T203:0, T11:1| T75:1|T139:1|T203:1, T11:2| T75:2|T139:2|T203:2, T11:3| T75:3|T139:3|T203:3, T27:0| T91:0|T155:0|T219:0, T27:1| T91:1|T155:1|T219:1, T27:2| T91:2|T155:2|T219:2, T27:3| T91:3|T155:3|T219:3, T43:0|T107:0|T171:0|T235:0, T43:1|T107:1|T171:1|T235:1, T43:2|T107:2|T171:2|T235:2, T43:3|T107:3|T171:3|T235:3, T59:0|T123:0|T187:0|T251:0, T59:1|T123:1|T187:1|T251:1, T59:2|T123:2|T187:2|T251:2, T59:3|T123:3|T187:3|T251:3 +// CHECK-MFMA: T12:0| T76:0|T140:0|T204:0, T12:1| T76:1|T140:1|T204:1, T12:2| T76:2|T140:2|T204:2, T12:3| T76:3|T140:3|T204:3, T28:0| T92:0|T156:0|T220:0, T28:1| T92:1|T156:1|T220:1, T28:2| T92:2|T156:2|T220:2, T28:3| T92:3|T156:3|T220:3, T44:0|T108:0|T172:0|T236:0, T44:1|T108:1|T172:1|T236:1, T44:2|T108:2|T172:2|T236:2, T44:3|T108:3|T172:3|T236:3, T60:0|T124:0|T188:0|T252:0, T60:1|T124:1|T188:1|T252:1, T60:2|T124:2|T188:2|T252:2, T60:3|T124:3|T188:3|T252:3 +// CHECK-MFMA: T13:0| T77:0|T141:0|T205:0, T13:1| T77:1|T141:1|T205:1, T13:2| T77:2|T141:2|T205:2, T13:3| T77:3|T141:3|T205:3, T29:0| T93:0|T157:0|T221:0, T29:1| T93:1|T157:1|T221:1, T29:2| T93:2|T157:2|T221:2, T29:3| T93:3|T157:3|T221:3, T45:0|T109:0|T173:0|T237:0, T45:1|T109:1|T173:1|T237:1, T45:2|T109:2|T173:2|T237:2, T45:3|T109:3|T173:3|T237:3, T61:0|T125:0|T189:0|T253:0, T61:1|T125:1|T189:1|T253:1, T61:2|T125:2|T189:2|T253:2, T61:3|T125:3|T189:3|T253:3 +// CHECK-MFMA: T14:0| T78:0|T142:0|T206:0, T14:1| T78:1|T142:1|T206:1, T14:2| T78:2|T142:2|T206:2, T14:3| T78:3|T142:3|T206:3, T30:0| T94:0|T158:0|T222:0, T30:1| T94:1|T158:1|T222:1, T30:2| T94:2|T158:2|T222:2, T30:3| T94:3|T158:3|T222:3, T46:0|T110:0|T174:0|T238:0, T46:1|T110:1|T174:1|T238:1, T46:2|T110:2|T174:2|T238:2, T46:3|T110:3|T174:3|T238:3, T62:0|T126:0|T190:0|T254:0, T62:1|T126:1|T190:1|T254:1, T62:2|T126:2|T190:2|T254:2, T62:3|T126:3|T190:3|T254:3 +// CHECK-MFMA: T15:0| T79:0|T143:0|T207:0, T15:1| T79:1|T143:1|T207:1, T15:2| T79:2|T143:2|T207:2, T15:3| T79:3|T143:3|T207:3, T31:0| T95:0|T159:0|T223:0, T31:1| T95:1|T159:1|T223:1, T31:2| T95:2|T159:2|T223:2, T31:3| T95:3|T159:3|T223:3, T47:0|T111:0|T175:0|T239:0, T47:1|T111:1|T175:1|T239:1, T47:2|T111:2|T175:2|T239:2, T47:3|T111:3|T175:3|T239:3, T63:0|T127:0|T191:0|T255:0, T63:1|T127:1|T191:1|T255:1, T63:2|T127:2|T191:2|T255:2, T63:3|T127:3|T191:3|T255:3 + + +// CHECK-HW: Warp0: +// CHECK-HW: Warp1: +// CHECK-HW: Warp2: +// CHECK-HW: Warp3: diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir index f04f03e4b..8888271e3 100644 --- a/test/Triton/canonicalize.mlir +++ b/test/Triton/canonicalize.mlir @@ -25,6 +25,16 @@ tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) { tt.return %c, %d : tensor<128x1xi32>, tensor<1xi32> } +// CHECK-LABEL: fold_advance +tt.func @fold_advance(%arg: !tt.ptr>) -> (!tt.ptr>) { + %c0_i32 = arith.constant 0 : i32 + %0 = tt.advance %arg, [%c0_i32, %c0_i32] : > + // CHECK-NOT: tt.advance + // CHECK: tt.return %arg + tt.return %0 : !tt.ptr> +} + + // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 2eb87772f..41a3ba15a 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -211,8 +211,8 @@ tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr tensor<8x2xf32> { // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> - %const = arith.constant dense<1.0> : tensor<8xf32> - %bst_out = tt.broadcast %const : tensor<8xf32> -> tensor<8x2xf32> + %const = arith.constant dense<1.0> : tensor<8x1xf32> + %bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32> // CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32> tt.return %bst_out : tensor<8x2xf32> @@ -292,15 +292,15 @@ tt.func @test_canonicalize_expand_dims(%arg0: tensor, %arg1: tensor<1xf32>) // CHECK-LABEL: @test_canonicalize_view tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) { - %view0 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<2x4xf32> - // CHECK: %{{.*}} = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<4x2xf32> - %view1 = tt.reshape %view0 {allow_reorder = true} : tensor<2x4xf32> -> tensor<4x2xf32> + %view0 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x4xf32> + // CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<4x2xf32> + %view1 = tt.reshape %view0 allow_reorder : tensor<2x4xf32> -> tensor<4x2xf32> %splat = tt.splat %arg1 : tensor -> tensor<8xf32> // CHECK: %{{.*}} = tt.splat %arg1 : tensor -> tensor<2x2x2xf32> - %view2 = tt.reshape %splat {allow_reorder = true} : tensor<8xf32> -> tensor<2x2x2xf32> + %view2 = tt.reshape %splat allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32> - %view3 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<8xf32> + %view3 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<8xf32> // CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32> %add = arith.addf %view3, %arg0 : tensor<8xf32> @@ -329,7 +329,7 @@ tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x %a = arith.constant dense<1.0> : tensor<1x128xf32> // CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32> - %b = tt.reshape %a {allow_reorder = true} : tensor<1x128xf32> -> tensor<16x8xf32> + %b = tt.reshape %a allow_reorder : tensor<1x128xf32> -> tensor<16x8xf32> // CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x128xf32> %c = tt.broadcast %a : tensor<1x128xf32> -> tensor<16x128xf32> diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index 82258b5bc..c7df02322 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -1,5 +1,23 @@ // RUN: triton-opt --split-input-file %s --verify-diagnostics +tt.func @fn(%v: i32) { + %b = tt.splat %v : i32 -> tensor<128xi32> + // expected-error @+1 {{rank of source must be same as rank of result}} + %c = tt.broadcast %b : tensor<128xi32> -> tensor<128x32xi32> + tt.return +} + +// ----- + +tt.func @fn(%v: i32) { + %b = tt.splat %v : i32 -> tensor<2x32xi32> + // expected-error @+1 {{Different dimensions at index 0 between source and result. Broadcast requires the source dimension to be 1.}} + %c = tt.broadcast %b : tensor<2x32xi32> -> tensor<128x32xi32> + tt.return +} + +// ----- + tt.func public @fn(%arg0: tensor<128xf32>) { // expected-error @+1 {{packed_element}} %a = tt.elementwise_inline_asm "" @@ -20,7 +38,7 @@ tt.func public @fn(%arg0: tensor<128xf32>, %arg1: tensor<64xf32>) { tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) { // expected-error @+1 {{number of src and dst elements of reshape must be the same}} - %a = tt.reshape %arg0 {allow_reorder = false} : tensor<32x128xf16> -> tensor<64x32xf16> + %a = tt.reshape %arg0 : tensor<32x128xf16> -> tensor<64x32xf16> tt.return } @@ -90,6 +108,19 @@ tt.func public @fn(%v: tensor<4x128xf64>) { // ----- +tt.func @reduce_different_input_shapes(%arg0: tensor<32x32x64xf32>, %arg1: tensor<16x32x64xf32>) -> (tensor<32x64xf32>, tensor<16x64xf32>) { + // expected-error @below {{op requires the same shape for all operands}} + %0:2 = "tt.reduce" (%arg0, %arg1) <{axis = 1 : i32}> ({ + ^bb0(%acc0: f32, %acc1: f32, %cur0: f32, %cur1: f32): + %1 = arith.addf %acc0, %cur0 : f32 + %2 = arith.addf %acc1, %cur1 : f32 + tt.reduce.return %1, %2 : f32, f32 + }) : (tensor<32x32x64xf32>, tensor<16x32x64xf32>) -> (tensor<32x64xf32>, tensor<16x64xf32>) + tt.return %0#0, %0#1 : tensor<32x64xf32>, tensor<16x64xf32> +} + +// ----- + tt.func public @fn(%v: tensor<4x128xf32>) { // expected-error @+1 {{requires the same shape}} %a = "tt.scan" (%v) ({ @@ -184,21 +215,6 @@ tt.func public @fn(%arg0: tensor<2xf32>) { // ----- -// Bad order; should start with 2. -#blocked = #triton_gpu.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [1,2,0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}> - -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { -tt.func public @fn(%arg0: tensor<2x2x2xf32, #blocked>) { - // expected-error @+2 {{last dimension}} - // expected-error @+1 {{op failed to infer returned types}} - %a, %b = tt.split %arg0 : tensor<2x2x2xf32, #blocked> -> tensor<2x2xf32, #blocked1> - tt.return -} -} // end module - -// ----- - #blocked = #triton_gpu.blocked<{sizePerThread = [1,1,2], threadsPerWarp = [1,32,1], warpsPerCTA = [1,1,1], order = [2,0,1]}> // Bad order, should be [1,0]. #blocked1 = #triton_gpu.blocked<{sizePerThread = [1,1], threadsPerWarp = [1,32], warpsPerCTA = [1,1], order = [1,0]}> diff --git a/test/Triton/loop-unroll.mlir b/test/Triton/loop-unroll.mlir new file mode 100644 index 000000000..916663028 --- /dev/null +++ b/test/Triton/loop-unroll.mlir @@ -0,0 +1,45 @@ +// RUN: triton-opt --split-input-file %s -triton-loop-unroll | FileCheck %s + +tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr>, %arg1: i32) { + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tt.splat %c1_i32 : i32 -> tensor<256xi32> + %1 = tt.splat %cst : f32 -> tensor<256xf32> + // Check the loop is unrolled by factor of 2 and is followed by a reminder loop. + // CHECK-LABEL: add_kernel_unroll + // CHECK: scf.for + // CHECK-COUNT-2: tt.load + // CHECK-NOT: tt.load + // CHECK: scf.for + // CHECK: tt.load + // CHECK-NOT: tt.load + %2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr>) : i32 { + %3 = tt.load %arg5 : tensor<256x!tt.ptr> + %4 = arith.addf %arg4, %3 : tensor<256xf32> + %5 = tt.addptr %arg5, %0 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %4, %5 : tensor<256xf32>, tensor<256x!tt.ptr> + } {tt.loop_unroll_factor = 2 : i32} + tt.return +} + +// ----- + +tt.func @add_kernel_nounroll(%arg0: tensor<256x!tt.ptr>, %arg1: i32) { + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tt.splat %c1_i32 : i32 -> tensor<256xi32> + %1 = tt.splat %cst : f32 -> tensor<256xf32> + // Check the loop is not unrolled. + // CHECK-LABEL: add_kernel_nounroll + // CHECK: scf.for + // CHECK-COUNT-1: tt.load + // CHECK-NOT: tt.load + // CHECK-NOT: scf.for + %2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr>) : i32 { + %3 = tt.load %arg5 : tensor<256x!tt.ptr> + %4 = arith.addf %arg4, %3 : tensor<256xf32> + %5 = tt.addptr %arg5, %0 : tensor<256x!tt.ptr>, tensor<256xi32> + scf.yield %4, %5 : tensor<256xf32>, tensor<256x!tt.ptr> + } + tt.return +} diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index bd5627145..c3b92b7ee 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -186,7 +186,7 @@ tt.func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { // CHECK-LABEL: @print_no_arg tt.func @print_no_arg(%arg0: !tt.ptr) { // CHECK: tt.print "test" - tt.print "test" { hex = false } + tt.print "test" { hex = false, isSigned = array} %0 = tt.load %arg0 : !tt.ptr tt.store %arg0, %0 : !tt.ptr tt.return @@ -225,8 +225,14 @@ tt.func @inline_asm_scalar(%0: i32) { // CHECK-LABEL: reshape tt.func @reshape(%0: tensor<512xi32>) { - // CHECK: tt.reshape %{{.+}} {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32> - %1 = tt.reshape %0 {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} : tensor<512xi32> -> tensor<16x32xi32> + %1 = tt.reshape %0 : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} allow_reorder : tensor<512xi32> -> tensor<16x32xi32> + %2 = tt.reshape %0 allow_reorder : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + %3 = tt.reshape %0 allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + %4 = tt.reshape %0 efficient_layout : tensor<512xi32> -> tensor<16x32xi32> tt.return } diff --git a/test/Triton/reproducer.mlir b/test/Triton/reproducer.mlir index ea4579e3e..f2c3a0f8e 100644 --- a/test/Triton/reproducer.mlir +++ b/test/Triton/reproducer.mlir @@ -9,7 +9,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : {-# external_resources: { mlir_reproducer: { - pipeline: "builtin.module(any(convert-scf-to-cf,convert-index-to-llvm{index-bitwidth=0},convert-triton-gpu-to-llvm{compute-capability=90},convert-nv-gpu-to-llvm,convert-arith-to-llvm{index-bitwidth=0},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true},cse,symbol-dce,enable-line-info))", + pipeline: "builtin.module(any(convert-scf-to-cf,convert-index-to-llvm{index-bitwidth=0},convert-triton-gpu-to-llvm{compute-capability=90},convert-nv-gpu-to-llvm,convert-arith-to-llvm{index-bitwidth=0},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse,symbol-dce,enable-line-info))", disable_threading: false, verify_each: false } diff --git a/test/Triton/rewrite-tensor-pointer.mlir b/test/Triton/rewrite-tensor-pointer.mlir index f48211885..eb39dcac0 100644 --- a/test/Triton/rewrite-tensor-pointer.mlir +++ b/test/Triton/rewrite-tensor-pointer.mlir @@ -1,83 +1,218 @@ -// RUN: triton-opt %s -triton-rewrite-tensor-pointer | FileCheck %s -tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { - %c31_i32 = arith.constant 31 : i32 - %c127_i32 = arith.constant 127 : i32 - %c1 = arith.constant 1 : index +// RUN: triton-opt %s -triton-rewrite-tensor-pointer -split-input-file | FileCheck %s + +tt.func public @rewrite_load(%arg0: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16> + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + %load = tt.load %0 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @rewrite_load( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C32_I64:.*]] = arith.constant 32 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16> +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[SPLAT0:.*]] = tt.splat %[[ARG0]] : !tt.ptr -> tensor<128x32x!tt.ptr> +// CHECK: %[[SPLAT1:.*]] = tt.splat %[[EXTSI0]] : i64 -> tensor<128xi64> +// CHECK: %[[MAKE_RANGE0:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +// CHECK: %[[EXTSI2:.*]] = arith.extsi %[[MAKE_RANGE0]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[ADDI0:.*]] = arith.addi %[[SPLAT1]], %[[EXTSI2]] : tensor<128xi64> +// CHECK: %[[EXPAND_DIMS0:.*]] = tt.expand_dims %[[ADDI0]] {axis = 1 : i32} : tensor<128xi64> -> tensor<128x1xi64> +// CHECK: %[[SPLAT2:.*]] = tt.splat %[[C1_I64]] : i64 -> tensor<128x1xi64> +// CHECK: %[[MULI0:.*]] = arith.muli %[[EXPAND_DIMS0]], %[[SPLAT2]] : tensor<128x1xi64> +// CHECK: %[[BROADCAST0:.*]] = tt.broadcast %[[MULI0]] : tensor<128x1xi64> -> tensor<128x32xi64> +// CHECK: %[[ADDPTR0:.*]] = tt.addptr %[[SPLAT0]], %[[BROADCAST0]] : tensor<128x32x!tt.ptr>, tensor<128x32xi64> +// CHECK: %[[SPLAT3:.*]] = tt.splat %[[EXTSI1]] : i64 -> tensor<32xi64> +// CHECK: %[[MAKE_RANGE1:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> +// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[MAKE_RANGE1]] : tensor<32xi32> to tensor<32xi64> +// CHECK: %[[ADDI1:.*]] = arith.addi %[[SPLAT3]], %[[EXTSI3]] : tensor<32xi64> +// CHECK: %[[EXPAND_DIMS1:.*]] = tt.expand_dims %[[ADDI1]] {axis = 0 : i32} : tensor<32xi64> -> tensor<1x32xi64> +// CHECK: %[[SPLAT4:.*]] = tt.splat %[[C1_I64]] : i64 -> tensor<1x32xi64> +// CHECK: %[[MULI1:.*]] = arith.muli %[[EXPAND_DIMS1]], %[[SPLAT4]] : tensor<1x32xi64> +// CHECK: %[[BROADCAST1:.*]] = tt.broadcast %[[MULI1]] : tensor<1x32xi64> -> tensor<128x32xi64> +// CHECK: %[[ADDPTR1:.*]] = tt.addptr %[[ADDPTR0]], %[[BROADCAST1]] : tensor<128x32x!tt.ptr>, tensor<128x32xi64> +// CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK: %[[SPLAT5:.*]] = tt.splat %[[C0_I64]] : i64 -> tensor<1x32xi64> +// CHECK: %[[CMP0:.*]] = arith.cmpi sge, %[[EXPAND_DIMS1]], %[[SPLAT5]] : tensor<1x32xi64> +// CHECK: %[[SPLAT6:.*]] = tt.splat %[[C32_I64]] : i64 -> tensor<1x32xi64> +// CHECK: %[[CMPI:.*]] = arith.cmpi slt, %[[EXPAND_DIMS1]], %[[SPLAT6]] : tensor<1x32xi64> +// CHECK: %[[ANDI:.*]] = arith.andi %[[CMP0]], %[[CMPI]] : tensor<1x32xi1> +// CHECK: %[[BROADCAST2:.*]] = tt.broadcast %[[ANDI]] : tensor<1x32xi1> -> tensor<128x32xi1> +// CHECK: %[[OTHER:.*]] = arith.constant 0x7E00 : f16 +// CHECK: %[[SPLAT7:.*]] = tt.splat %[[OTHER]] : f16 -> tensor<128x32xf16> +// CHECK: %[[LOAD:.*]] = tt.load %[[ADDPTR1]], %[[BROADCAST2]], %[[SPLAT7]] : tensor<128x32x!tt.ptr> +// CHECK: tt.return + +// ----- +tt.func public @rewrite_store(%arg0: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16> + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + tt.store %0, %cst: !tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @rewrite_store( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C32_I64:.*]] = arith.constant 32 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16> +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[SPLAT0:.*]] = tt.splat %[[ARG0]] : !tt.ptr -> tensor<128x32x!tt.ptr> +// CHECK: %[[SPLAT1:.*]] = tt.splat %[[EXTSI0]] : i64 -> tensor<128xi64> +// CHECK: %[[MAKE_RANGE0:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> +// CHECK: %[[EXTSI2:.*]] = arith.extsi %[[MAKE_RANGE0]] : tensor<128xi32> to tensor<128xi64> +// CHECK: %[[ADDI0:.*]] = arith.addi %[[SPLAT1]], %[[EXTSI2]] : tensor<128xi64> +// CHECK: %[[EXPAND_DIMS0:.*]] = tt.expand_dims %[[ADDI0]] {axis = 1 : i32} : tensor<128xi64> -> tensor<128x1xi64> +// CHECK: %[[SPLAT2:.*]] = tt.splat %[[C1_I64]] : i64 -> tensor<128x1xi64> +// CHECK: %[[MULI0:.*]] = arith.muli %[[EXPAND_DIMS0]], %[[SPLAT2]] : tensor<128x1xi64> +// CHECK: %[[BROADCAST0:.*]] = tt.broadcast %[[MULI0]] : tensor<128x1xi64> -> tensor<128x32xi64> +// CHECK: %[[ADDPTR0:.*]] = tt.addptr %[[SPLAT0]], %[[BROADCAST0]] : tensor<128x32x!tt.ptr>, tensor<128x32xi64> +// CHECK: %[[SPLAT3:.*]] = tt.splat %[[EXTSI1]] : i64 -> tensor<32xi64> +// CHECK: %[[MAKE_RANGE1:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> +// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[MAKE_RANGE1]] : tensor<32xi32> to tensor<32xi64> +// CHECK: %[[ADDI1:.*]] = arith.addi %[[SPLAT3]], %[[EXTSI3]] : tensor<32xi64> +// CHECK: %[[EXPAND_DIMS1:.*]] = tt.expand_dims %[[ADDI1]] {axis = 0 : i32} : tensor<32xi64> -> tensor<1x32xi64> +// CHECK: %[[SPLAT4:.*]] = tt.splat %[[C1_I64]] : i64 -> tensor<1x32xi64> +// CHECK: %[[MULI1:.*]] = arith.muli %[[EXPAND_DIMS1]], %[[SPLAT4]] : tensor<1x32xi64> +// CHECK: %[[BROADCAST1:.*]] = tt.broadcast %[[MULI1]] : tensor<1x32xi64> -> tensor<128x32xi64> +// CHECK: %[[ADDPTR1:.*]] = tt.addptr %[[ADDPTR0]], %[[BROADCAST1]] : tensor<128x32x!tt.ptr>, tensor<128x32xi64> +// CHECK: tt.store %[[ADDPTR1]], %[[CST]] : tensor<128x32x!tt.ptr> +// CHECK: tt.return + +// ----- +tt.func public @rewrite_for(%arg0: !tt.ptr, %arg1: !tt.ptr) { %c0 = arith.constant 0 : index - %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32> + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16> + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + %1:2 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %cst, %arg4 = %0) -> (tensor<128x32xf16>, !tt.ptr>) { + %3 = tt.load %arg4 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> + %4 = arith.addf %arg3, %3 : tensor<128x32xf16> + %5 = tt.advance %arg4, [%c32_i32, %c0_i32] : !tt.ptr> + scf.yield %4, %5 : tensor<128x32xf16>, !tt.ptr> + } {tt.num_stages = 3 : i32} + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x32x!tt.ptr> + tt.store %2, %1#0 : tensor<128x32x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @rewrite_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C32_I32:.*]] = arith.constant 32 : i32 +// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C32_I64:.*]] = arith.constant 32 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16> +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[FOR:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C32]] step %[[C1]] +// CHECK-SAME: iter_args(%[[ARG3:.*]] = %[[CST]], %[[ARG4:.*]] = %[[EXTSI0]], %[[ARG5:.*]] = %[[EXTSI1]]) -> (tensor<128x32xf16>, i64, i64) +// CHECK: %[[EXTSI2:.*]] = arith.extsi %[[C32_I32]] : i32 to i64 +// CHECK: %[[ADDI0:.*]] = arith.addi %[[ARG4]], %[[EXTSI2]] : i64 +// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[ADDI1:.*]] = arith.addi %[[ARG5]], %[[EXTSI3]] : i64 +// CHECK: scf.yield %{{.*}}, %[[ADDI0]], %[[ADDI1]] : tensor<128x32xf16>, i64, i64 +// CHECK: tt.num_stages = 3 + +// ----- +tt.func public @rewrite_if(%arg0: !tt.ptr, %arg1: i1, %arg2: tensor<128x32xf32>) -> tensor<128x32xf16> { + %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %0 = tt.get_program_id x : i32 - %1 = tt.get_program_id y : i32 - %2 = arith.addi %arg3, %c127_i32 : i32 - %3 = arith.divsi %2, %c128_i32 : i32 - %4 = arith.addi %arg4, %c31_i32 : i32 - %5 = arith.divsi %4, %c32_i32 : i32 - %6 = arith.muli %5, %c8_i32 : i32 - %7 = arith.divsi %0, %6 : i32 - %8 = arith.muli %7, %c8_i32 : i32 - %9 = arith.subi %3, %8 : i32 - %10 = arith.cmpi slt, %9, %c8_i32 : i32 - %11 = arith.select %10, %9, %c8_i32 : i32 - %12 = arith.remsi %0, %11 : i32 - %13 = arith.addi %8, %12 : i32 - %14 = arith.remsi %0, %6 : i32 - %15 = arith.divsi %14, %11 : i32 - %16 = arith.muli %13, %c128_i32 : i32 - %17 = arith.muli %1, %c32_i32 : i32 - %18 = arith.extsi %arg3 : i32 to i64 - %19 = arith.extsi %arg5 : i32 to i64 - %20 = arith.extsi %arg6 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %21 = tt.make_tensor_ptr %arg0, [%18, %19], [%20, %c1_i64], [%16, %17] {order = array} : !tt.ptr> - %22 = arith.muli %15, %c32_i32 : i32 - %23 = arith.extsi %arg4 : i32 to i64 - %24 = arith.extsi %arg7 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %25 = tt.make_tensor_ptr %arg1, [%19, %23], [%24, %c1_i64], [%17, %22] {order = array} : !tt.ptr> - %26 = arith.addi %arg5, %c31_i32 : i32 - %27 = arith.divsi %26, %c32_i32 : i32 - %28 = arith.index_cast %27 : i32 to index - %29:3 = scf.for %arg9 = %c0 to %28 step %c1 iter_args(%arg10 = %cst, %arg11 = %21, %arg12 = %25) -> (tensor<128x32xf32>, !tt.ptr>, !tt.ptr>) { - // CHECK: tt.load %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : tensor<128x32x!tt.ptr> - %55 = tt.load %arg11 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> - // CHECK: tt.load %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : tensor<32x32x!tt.ptr> - %56 = tt.load %arg12 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> - %57 = tt.dot %55, %56, %arg10 : tensor<128x32xf16> * tensor<32x32xf16> -> tensor<128x32xf32> - // CHECK-NOT: tt.advance - %58 = tt.advance %arg11, [%c0_i32, %c32_i32] : !tt.ptr> - // CHECK-NOT: tt.advance - %59 = tt.advance %arg12, [%c32_i32, %c0_i32] : !tt.ptr> - scf.yield %57, %58, %59 : tensor<128x32xf32>, !tt.ptr>, !tt.ptr> + %c1_i64 = arith.constant 1 : i64 + %c32_i64 = arith.constant 32 : i64 + %c128_i64 = arith.constant 128 : i64 + %0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c32_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + %1:2 = scf.if %arg1 -> (tensor<128x32xf16>, !tt.ptr>) { + %2 = tt.advance %0, [%c32_i32, %c0_i32] : !tt.ptr> + %3 = arith.truncf %arg2 : tensor<128x32xf32> to tensor<128x32xf16> + scf.yield %3, %2 : tensor<128x32xf16>, !tt.ptr> + } else { + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16> + scf.yield %cst, %0 : tensor<128x32xf16>, !tt.ptr> + } + %4 = tt.load %1#1 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> + %5 = arith.addf %1#0, %4 : tensor<128x32xf16> + tt.return %5 : tensor<128x32xf16> +} + +// CHECK-LABEL: tt.func public @rewrite_if( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: i1 +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<128x32xf32> +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C32_I32:.*]] = arith.constant 32 : i32 +// CHECK-DAG: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK-DAG: %[[C32_I64:.*]] = arith.constant 32 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[IF:.*]]:3 = scf.if %[[ARG1]] -> (tensor<128x32xf16>, i64, i64) { +// CHECK: %[[EXTSI2:.*]] = arith.extsi %[[C32_I32]] : i32 to i64 +// CHECK: %[[ADDI0:.*]] = arith.addi %[[EXTSI0]], %[[EXTSI2]] : i64 +// CHECK: %[[EXTSI3:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[ADDI1:.*]] = arith.addi %[[EXTSI1]], %[[EXTSI3]] : i64 +// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[ARG2]] : tensor<128x32xf32> to tensor<128x32xf16> +// CHECK: scf.yield %[[TRUNCF]], %[[ADDI0]], %[[ADDI1]] : tensor<128x32xf16>, i64, i64 +// CHECK: } else { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor<128x32xf16> +// CHECK: scf.yield %[[CST]], %[[EXTSI0]], %[[EXTSI1]] : tensor<128x32xf16>, i64, i64 +// CHECK: } +// CHECK: %{{.*}} = tt.splat %[[IF]]#1 : i64 -> tensor<128xi64> +// CHECK: %{{.*}} = tt.splat %[[IF]]#2 : i64 -> tensor<32xi64> +// CHECK: %{{.*}} = arith.addf %[[IF]]#0, %{{.*}} : tensor<128x32xf16> + + +// ----- +tt.func public @asm_in_loop(%arg0: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i64 = arith.constant 0 : i64 + %c128_i64 = arith.constant 128 : i64 + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %1 = tt.make_tensor_ptr %arg0, [%c128_i64, %c128_i64], [%c128_i64, %c0_i64], [%c0_i32, %c0_i32] {order = array} : !tt.ptr> + %2:1 = scf.for %arg1 = %c0_i32 to %c1_i32 step %c1_i32 iter_args(%arg2 = %1) -> (!tt.ptr>) : i32 { + %3:2 = tt.elementwise_inline_asm "asm_multiple_results" {constraints = "=r,=r,r", packed_element = 1 : i32, pure = true} %0 : tensor<16xi32> -> tensor<16xi16>, tensor<16xi16> + %4 = tt.advance %arg2, [%c0_i32, %c0_i32] : !tt.ptr> + scf.yield %4 : !tt.ptr> } - %30 = arith.truncf %29#0 : tensor<128x32xf32> to tensor<128x32xf16> - %31 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %32 = tt.splat %16 : i32 -> tensor<128xi32> - %33 = arith.addi %32, %31 : tensor<128xi32> - %34 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %35 = tt.splat %22 : i32 -> tensor<32xi32> - %36 = arith.addi %35, %34 : tensor<32xi32> - %37 = tt.expand_dims %33 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %38 = tt.splat %arg8 : i32 -> tensor<128x1xi32> - %39 = arith.muli %37, %38 : tensor<128x1xi32> - %40 = tt.expand_dims %36 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> - %41 = tt.broadcast %39 : tensor<128x1xi32> -> tensor<128x32xi32> - %42 = tt.broadcast %40 : tensor<1x32xi32> -> tensor<128x32xi32> - %43 = arith.addi %41, %42 : tensor<128x32xi32> - %44 = tt.splat %arg2 : !tt.ptr -> tensor<128x32x!tt.ptr> - %45 = tt.addptr %44, %43 : tensor<128x32x!tt.ptr>, tensor<128x32xi32> - %46 = tt.splat %arg3 : i32 -> tensor<128xi32> - %47 = arith.cmpi slt, %33, %46 : tensor<128xi32> - %48 = tt.expand_dims %47 {axis = 1 : i32} : tensor<128xi1> -> tensor<128x1xi1> - %49 = tt.splat %arg4 : i32 -> tensor<32xi32> - %50 = arith.cmpi slt, %36, %49 : tensor<32xi32> - %51 = tt.expand_dims %50 {axis = 0 : i32} : tensor<32xi1> -> tensor<1x32xi1> - %52 = tt.broadcast %48 : tensor<128x1xi1> -> tensor<128x32xi1> - %53 = tt.broadcast %51 : tensor<1x32xi1> -> tensor<128x32xi1> - %54 = arith.andi %52, %53 : tensor<128x32xi1> - tt.store %45, %30, %54 : tensor<128x32x!tt.ptr> tt.return } + +// CHECK-LABEL: tt.func public @asm_in_loop( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: !tt.ptr +// CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C0_I64:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[C128_I64:.*]] = arith.constant 128 : i64 +// CHECK: %[[RANGE:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> +// CHECK: %[[EXTSI0:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[EXTSI1:.*]] = arith.extsi %[[C0_I32]] : i32 to i64 +// CHECK: %[[FOR:.*]]:2 = scf.for %[[ARG1:.*]] = %[[C0_I32]] to %[[C1_I32]] step %[[C1_I32]] +// CHECK-SAME: iter_args(%[[ARG2:.*]] = %[[EXTSI0]], %[[ARG3:.*]] = %[[EXTSI1]]) -> (i64, i64) +// CHECK: %[[ASM:.*]]:2 = tt.elementwise_inline_asm "asm_multiple_results" {{.*}} %[[RANGE]] : tensor<16xi32> -> tensor<16xi16>, tensor<16xi16> diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 8c4e85aa0..2f9793e52 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -23,8 +23,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2> // CHECK: scf.for - // CHECK: tt.dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> - // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 { %172 = tt.dot %170, %171, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> %178 = triton_gpu.convert_layout %172 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> @@ -32,8 +32,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : scf.yield %180 : tensor<128x64xf16, #blocked1> } // CHECK: scf.for - // CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> - // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> %149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 { %166 = tt.dot %164, %165, %cst_2 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> %172 = triton_gpu.convert_layout %166 : tensor<128x32xf16, #blocked2> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> @@ -73,6 +73,33 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- +// CHECK: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> +// CHECK: #mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: chained_dot + tt.func public @chained_dot_wgmma( + %arg0: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<64x64xf32, #mma> + %d = tt.dot %arg0, %arg1, %cst_0 : + tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked> + %t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked> + %c = triton_gpu.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + // CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<64x128xf32, #mma1> + %r = tt.dot %c, %arg2, %cst_1 : + tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1> + tt.return %r : tensor<64x128xf32, #blocked1> + } +} + +// ----- + // CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> @@ -101,6 +128,7 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 : #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [0, 1, 2]}> #blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 2, 2], threadsPerWarp = [1, 4, 8], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK: kernel_ tt.func public @kernel_() attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<2x16x16xf32, #blocked> %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked1> @@ -119,7 +147,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: tt.dot {{.*}} -> tensor<2x16x16xf32, #[[MMA1]]> %11 = tt.dot %8, %9, %10, inputPrecision = tf32 : tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> * tensor<2x16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<2x16x16xf32, #blocked3> %12 = triton_gpu.convert_layout %11 : tensor<2x16x16xf32, #blocked3> -> tensor<2x16x16xf32, #blocked> - tt.print ": " {hex = false} : %12 : tensor<2x16x16xf32, #blocked> + tt.print ": " {hex = false, isSigned = array} : %12 : tensor<2x16x16xf32, #blocked> tt.return } } @@ -129,8 +157,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, {{.*}}, instrShape = [16, 32, 16]}> #blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: check_instrShape_per_warps tt.func @check_instrShape_per_warps(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - // CHECK-LABEL: check_instrShape_per_warps %mask = arith.constant dense : tensor<128x128xi1, #blocked> %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %a = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> @@ -142,3 +170,42 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return } } + + +// ----- + +// Verify that we use mmav2 when the k dim is too small for mmav3. +// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: small_k_size + tt.func @small_k_size( + %a: tensor<128x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %b: tensor<16x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) + -> tensor<128x128xf32, #blocked> { + %zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %result = tt.dot %a, %b, %zero_f32 : tensor<128x16xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } +} + +// ----- + +// Verify that dot_scaled (mxfp8 x fp8) decomposes as expected +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: dot_scaled + tt.func @dot_scaled( + %a: tensor<128x64xi8, #blocked2>, + %scale: tensor<128x2xi8, #blocked1>, + %b: tensor<64x128xi8, #blocked>) + -> tensor<128x128xf32, #blocked> { + // CHECK: triton_gpu.upcast_mxfp + // CHECK: tt.dot + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a, %scale, %b, %cst lhs = e4m3 rhs = e4m3 : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xi8, #blocked> -> tensor<128x128xf32, #blocked> + tt.return %result : tensor<128x128xf32, #blocked> + } +} diff --git a/test/TritonGPU/accumulator-init.mlir b/test/TritonGPU/accumulator-init.mlir new file mode 100644 index 000000000..ef5dd9165 --- /dev/null +++ b/test/TritonGPU/accumulator-init.mlir @@ -0,0 +1,367 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-accumulator-init | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + +// CHECK-LABEL: @constant_init +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @constant_init_integer +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init_integer(%A: !tt.memdesc<128x64xi8, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xi8, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xi32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0> : tensor<128x16xi32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xi32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %cst_2 : !tt.memdesc<128x64xi8, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xi8, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xi32, #mma1> + scf.yield %acc: tensor<128x16xi32, #mma1> + } + tt.return %17 : tensor<128x16xi32, #mma1> + } + +// CHECK-LABEL: @if_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @if_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_after_mma_invert +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[TRUE]], %[[FALSE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @if_after_mma_invert(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %acc : tensor<128x16xf32, #mma1> + } else { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_before_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_before_mma_invert +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[USE_ACC]], %[[FALSE]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_mma_invert(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } else { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @sel_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @sel_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @sel_before_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @sel_before_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xf32, #mma1> + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + + +// Check that we look only at the zeroing directly preceding the mma + +// CHECK-LABEL: @if_before_and_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[C0_TENSOR]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_and_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_0 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + scf.yield %acc_1: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @two_ifs_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[C0_TENSOR]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_CND]] +// CHECK: else +// CHECK: scf.yield %[[ACC_CND]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @two_ifs_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc_0 : tensor<128x16xf32, #mma1> + } + scf.yield %acc_1: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// Check that we bail out in unsupported cases + +// CHECK-LABEL: @non_zero_init +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @non_zero_init(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @zero_init_dist_2 +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @zero_init_dist_2(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %cst_2) -> (tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg5 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_, %arg4: tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_defines_alternative +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @if_defines_alternative(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + %acc_alt = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> + scf.yield %acc_alt : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @non_cond_override +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @non_cond_override(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } +} diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir new file mode 100644 index 000000000..7854a4eed --- /dev/null +++ b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir @@ -0,0 +1,20 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx940 matrix-instruction-size=0' | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> +// CHECK-LABEL: mfma_dot_fp8e5m2 +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_fp8e5m2( + %arg0: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<128x256x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> + // CHECK: %[[A0:.+]] = triton_gpu.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + // CHECK: %[[B0:.+]] = triton_gpu.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E5M2, {{.*}} -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + // CHECK: tt.dot %[[A1]], %[[B1]] + %1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + tt.store %arg2, %1 : tensor<128x256x!tt.ptr, #blocked> + tt.return + } +} diff --git a/test/TritonGPU/amd/accelerate-amd-matmul.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir similarity index 84% rename from test/TritonGPU/amd/accelerate-amd-matmul.mlir rename to test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir index 591632f2d..7d3e8c23b 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir @@ -1,8 +1,7 @@ // RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx1100 matrix-instruction-size=0' | FileCheck %s // CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> -// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}> -// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{warpsPerCTA = [2, 2]}> +// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf32( @@ -27,6 +26,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.store %2, %4 : tensor<128x256x!tt.ptr, #blocked> tt.return } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_cf16( // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> %0: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, @@ -49,6 +56,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_ab8_cf16( // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> %0: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, @@ -75,6 +90,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.store %2, %4 : tensor<32x64x!tt.ptr, #blocked> tt.return } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @wmma_dot_i8_i32( // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> %0: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, @@ -97,6 +120,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> tt.return } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @fma_dot_i16_i16( // CHECK: %[[DOT3_ARG_A:.+]]: tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> %0: tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir new file mode 100644 index 000000000..a8683a5d3 --- /dev/null +++ b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir @@ -0,0 +1,123 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx1200 matrix-instruction-size=0' | FileCheck %s + +// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_cf32( + // CHECK: %[[DOT0_ARG_A:.+]]: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT0_ARG_B:.+]]: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<128x256x!tt.ptr, #blocked>) { + // CHECK: %[[DOT0_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT0_OP_C:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_C]] + // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] + %3 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> + // CHECK: %[[DOT0_OP_A:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_A]] + // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_OP_B:.+]] = triton_gpu.convert_layout %[[DOT0_ARG_B]] + // CHECK-SAME: -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT0_WMMA_RES:.+]] = tt.dot %[[DOT0_OP_A]], %[[DOT0_OP_B]], %[[DOT0_OP_C]] + // CHECK-SAME: -> tensor<128x256xf32, #[[WMMA_0]] + %4 = tt.dot %0, %1, %3 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> + // CHECK: triton_gpu.convert_layout %[[DOT0_WMMA_RES]] + // CHECK-SAME: -> tensor<128x256xf32, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<128x256x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_cf16( + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<32x32x!tt.ptr, #blocked>) { + // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] + %3 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked> + // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] + // CHECK-SAME: -> tensor<32x32xf16, #[[WMMA_1]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf16, #blocked> + // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + // CHECK-SAME: -> tensor<32x32xf16, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[WMMA_0:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_ab8_cf16( + // CHECK: %[[DOT2_ARG_A:.+]]: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT2_ARG_B:.+]]: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<32x64x!tt.ptr, #blocked>) { + // CHECK: %[[DOT2_ARG_C:.+]] = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT2_OP_C:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_C]] + // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] + %3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #blocked> + // CHECK: %[[DOT2_OP_A_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_A]] + // CHECK-SAME: -> tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_A_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_A_F8]] + // CHECK-SAME: -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 8}>> + // CHECK: %[[DOT2_OP_B_F8:.+]] = triton_gpu.convert_layout %[[DOT2_ARG_B]] + // CHECK-SAME: -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]] + // CHECK: %[[DOT2_OP_B_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_B_F8]] + // CHECK-SAME: -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 8}>> + // CHECK: %[[DOT2_WMMA_RES:.+]] = tt.dot %[[DOT2_OP_A_F16]], %[[DOT2_OP_B_F16]], %[[DOT2_OP_C]] + // CHECK-SAME: -> tensor<32x64xf16, #[[WMMA_0]] + %4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked> + // CHECK: triton_gpu.convert_layout %[[DOT2_WMMA_RES]] + // CHECK-SAME: -> tensor<32x64xf16, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<32x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[DOT_OP_PARENT:.+]] = #triton_gpu.blocked<{{.*}}> +// CHECK: #[[WMMA_1:.+]] = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_i8_i32( + // CHECK: %[[DOT1_ARG_A:.+]]: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]]}>> + %0: tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + // CHECK-SAME: %[[DOT1_ARG_B:.+]]: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> + %1: tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %2: tensor<32x32x!tt.ptr, #blocked>) { + // CHECK: %[[DOT1_ARG_C:.+]] = arith.constant dense<0> : tensor<32x32xi32, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT1_OP_C:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_C]] + // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] + %3 = arith.constant dense<0> : tensor<32x32xi32, #blocked> + // CHECK: %[[DOT1_OP_A:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_A]] + // CHECK-SAME: -> tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_OP_B:.+]] = triton_gpu.convert_layout %[[DOT1_ARG_B]] + // CHECK-SAME: -> tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[WMMA_1]] + // CHECK: %[[DOT1_WMMA_RES:.+]] = tt.dot %[[DOT1_OP_A]], %[[DOT1_OP_B]], %[[DOT1_OP_C]] + // CHECK-SAME: -> tensor<32x32xi32, #[[WMMA_1]] + %4 = tt.dot %0, %1, %3 : tensor<32x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xi32, #blocked> + // CHECK: triton_gpu.convert_layout %[[DOT1_WMMA_RES]] + // CHECK-SAME: -> tensor<32x32xi32, #[[DOT_OP_PARENT]]> + tt.store %2, %4 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir new file mode 100644 index 000000000..6c3e2ac42 --- /dev/null +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -0,0 +1,735 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers | FileCheck %s +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @conversion1 + tt.func @conversion1(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.muli{{.*}} : i32 + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] : !tt.ptr, i32 + // CHECK: %[[offset_32bit:.*]] = arith.trunci %{{.*}} : tensor<1024xi64, #blocked> to tensor<1024xi32, #blocked> + // CHECK: %[[basePtr:.*]] = tt.splat %[[scalarPtr]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[basePtr]], %[[offset_32bit]] + // CHECK: tt.load %[[newPtr]] + %6 = tt.addptr %5, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> + tt.return %7 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @conversion2 + tt.func @conversion2(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[baseOffset064bit:.*]] = tt.splat {{.*}} : i64 + // CHECK: %[[newScalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[offset064bit:.*]] = arith.extsi {{.*}} + // CHECK: %[[offset164bit:.*]] = arith.addi %[[offset064bit]], %[[baseOffset064bit]] + // CHECK: %[[offset132bit:.*]] = arith.trunci %[[offset164bit]] : tensor<1024xi64, #blocked> to tensor<1024xi32, #blocked> + // CHECK: %[[basePtr:.*]] = tt.splat %[[newScalarPtr]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[basePtr]], %[[offset132bit]] + // CHECK: tt.load %[[newPtr]] + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> + tt.return %7 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @conversion3 + tt.func @conversion3(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + + //CHECK: %0 = tt.get_program_id x : i32 + //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 + //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 + //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 + //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i64 -> tensor<1024xi64, #blocked> + //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset3ext:.*]] = arith.extsi %[[tensorOffset3]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3ext]], %[[zero]] : tensor<1024xi64, #blocked> + //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset1ext:.*]] = arith.extsi %[[tensorOffset1]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1ext]], %[[tensorOffset0]]: tensor<1024xi64, #blocked> + //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> + //CHECK: tt.load %[[newPtr]] + + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // + // This is the same as conversion3, but now the `arith.extsi` operations + // disappeared and all the offsets are 32 bits. + // + // CHECK-LABEL: tt.func @conversion4 + tt.func @conversion4(%arg0: !tt.ptr{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + + //CHECK: %0 = tt.get_program_id x : i32 + //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 + //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 + //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 + //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked> + //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked> + //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + //CHECK: tt.load %[[newPtr]] + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forOp + tt.func @forOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffsetLoop:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor + // CHECK: %[[scalarOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor + // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] + // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %{{.*}} : tensor<1024xi64, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[loop:.*]]:4 = scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[loopScalarPtr:.*]] = %{{.*}}, %[[loopOffset:.*]] = %[[offset1]]) -> {{.*}} { + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + // CHECK: %[[scalarPtrUpdateLoop:.*]] = tt.addptr %[[loopScalarPtr]], %[[scalarOffsetLoop]] + // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset1]] + // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] + // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdateLoop]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] + // CHECK: tt.load %[[newPtr]] + // CHECK: scf.yield {{.*}}, {{.*}}, %[[scalarPtrUpdateLoop]], %[[offset_i]] + %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> + scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + // CHECK: tt.addptr %[[loop]]#2, %[[scalarOffset1]] : !tt.ptr, i32 + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forOp2 + tt.func @forOp2(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[variableOffset0:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[finalScalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[forOut:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr]], %[[scalarOffset]] + // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset0]] + // CHECK: %[[ext_offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] + // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[ext_offset_i]] + // CHECK: tt.load %[[newPtr]] + %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> + scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + // CHECK: %[[scalarPtrFinalUpdate:.*]] = tt.addptr %[[forOut]]#2, %[[finalScalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset1]] + // CHECK: %[[tailOffset:.*]] = arith.addi %[[ext_offset0]], %[[forOut]]#3 + // CHECK: %[[tail_base_ptr:.*]] = tt.splat %[[scalarPtrFinalUpdate]] + // CHECK: %[[tailPtr:.*]] = tt.addptr %[[tail_base_ptr]], %[[tailOffset]] + // CHECK: tt.load %[[tailPtr]] + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forNested + tt.func @forNested(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + + // CHECK: %[[forOut0:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr0:.*]] = %arg0, %[[loopOffset0:.*]] = %[[base_offset]]){{.*}}{ + // CHECK: %[[forOut1:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr1:.*]] = %[[scalarPtr0]], %[[loopOffset1:.*]] = %[[loopOffset0]]){{.*}}{ + // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr1]], %{{.*}} + // CHECK: %[[ext_loop_offset1:.*]] = arith.extsi %[[variableOffset]] + // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_loop_offset1]], %[[loopOffset1]] + // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] + // CHECK: tt.load %[[newPtr]] + // CHECK: scf.yield %{{.*}}, {{.*}}, %[[scalarPtrUpdate]], %[[offset_i]] + // CHECK: scf.yield %{{.*}}, {{.*}}, %[[forOut1]]#2, %[[forOut1]]#3 + + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + %53:2 = scf.for %arg10 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg1, %arg4 = %arg2) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + %11 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + %10 = arith.addf %9, %arg4 : tensor<1024xf32, #blocked> + scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + scf.yield %53#0, %53#1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @ifOp + tt.func @ifOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[baseOffsetVariable:.*]] = tt.splat {{.*}} : i64 -> tensor<1024xi64, #blocked> + // CHECK: %[[ifOut:.*]]:3 = scf.if {{.*}} -> (tensor<1024x!tt.ptr, #blocked>, !tt.ptr, tensor<1024xi64, #blocked>) + %6 = scf.if %cond -> (tensor<1024x!tt.ptr, #blocked>){ + // CHECK: %[[scalarOffsetUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] + // CHECK: %[[if_offset:.*]] = arith.addi %[[ext_offset0]], %[[baseOffsetVariable]] + // CHECK: scf.yield %{{.*}}, %[[scalarOffsetUpdate]], %[[if_offset]] + %true = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + scf.yield %true : tensor<1024x!tt.ptr, #blocked> + } else { + // CHECK: %[[new_scalar_ptr:.*]] = tt.addptr %arg0, {{.*}} + // CHECK: scf.yield %{{.*}}, %[[new_scalar_ptr]], %[[baseOffsetVariable]] + %false = tt.addptr %5, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + scf.yield %false : tensor<1024x!tt.ptr, #blocked> + } + // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[ifOut]]#2 + // CHECK: %[[base_ptr:.*]] = tt.splat %[[ifOut]]#1 + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] + // CHECK: tt.load %[[newPtr]] + %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @whileOp + tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) + %6 = scf.while (%arg1 = %5, %arg2 = %cond) : (tensor<1024x!tt.ptr, #blocked>, i1) -> (tensor<1024x!tt.ptr, #blocked>) { + // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]] + scf.condition(%arg2) %arg1 : tensor<1024x!tt.ptr, #blocked> + } do { + // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>): + ^bb0(%arg1: tensor<1024x!tt.ptr, #blocked>): + // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]] + scf.yield %arg1, %cond : tensor<1024x!tt.ptr, #blocked>, i1 + } + // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2 + // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1 + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] + // CHECK: tt.load %[[newPtr]] + %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @condBranch + tt.func @condBranch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: cf.cond_br {{.*}}, ^bb1(%{{.*}}, %arg0, %[[base_offset]] : {{.*}}), ^bb2(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) + cf.cond_br %i1, ^bb1(%5 : tensor<1024x!tt.ptr, #blocked>), ^bb2(%6 : tensor<1024x!tt.ptr, #blocked>) + // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: !tt.ptr, %[[block1Offset:.*]]: tensor<1024xi64, #blocked>) + ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): + // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] + // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] + // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] + // CHECK: tt.load %[[newPtr1]] + %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> + tt.return %out1 : tensor<1024xf32, #blocked> + // CHECK: ^bb2({{.*}}, %[[block2ScalarPtr:.*]]: !tt.ptr, %[[block2Offset:.*]]: tensor<1024xi64, #blocked>) + ^bb2(%arg2 : tensor<1024x!tt.ptr, #blocked>): // 2 preds: ^bb0, ^bb1 + // CHECK: %[[trunc_offset_2:.*]] = arith.trunci %[[block2Offset]] + // CHECK: %[[basePtr2:.*]] = tt.splat %[[block2ScalarPtr]] + // CHECK: %[[newPtr2:.*]] = tt.addptr %[[basePtr2]], %[[trunc_offset_2]] + // CHECK: tt.load %[[newPtr2]] + %out2 = tt.load %arg2 : tensor<1024x!tt.ptr, #blocked> + tt.return %out2 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @branch + tt.func @branch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: cf.br ^bb1(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) + // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: {{.*}}, %[[block1Offset:.*]]: {{.*}}) + cf.br ^bb1(%6 : tensor<1024x!tt.ptr, #blocked>) + ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): + // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] + // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] + // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] + // CHECK: tt.load %[[newPtr1]] + %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> + tt.return %out1 : tensor<1024xf32, #blocked> + } +} + +// ----- + +// The following is a simple case of a tile offset like: (A*B + C + D) where B,C are Uniform and A,D are not. So +// we expect that the Uniform offset (which can be added to the scalar pointer) will be simply C and the NonUniform +// offset will be A*B+D +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @tile_offset + tt.func @tile_offset(%arg1: !tt.ptr, %arg5: i32 , %arg7: i32 ) { + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %1 = tt.get_program_id x : i32 + %20 = arith.muli %1, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %20 : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %26 = arith.addi %24, %22 : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %38 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked> + %39 = arith.muli %37, %38 : tensor<16x1xi32, #blocked> + %41 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %42 = tt.broadcast %39 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> + %43 = tt.broadcast %41 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> + %44 = arith.addi %42, %43 : tensor<16x256xi32, #blocked> + %45 = tt.splat %arg1 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> + %46 = tt.addptr %45, %44 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> + // CHECK: %[[uniformOffset1:.*]] = arith.muli %c0_i32_0, %arg2 : i32 + // CHECK: {{.*}} = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + // CHECK: {{.*}} = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> + // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> + // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %{{.*}} : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> + // CHECK: %[[tensorOffset5:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> + // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}}: i32 + // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset5]] : tensor<16x256xi32, #blocked> + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 + // CHECK: %[[tensorOffset2ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<16x256xi32, #blocked> to tensor<16x256xi64, #blocked> + // CHECK: %[[tensorOffset1:.*]] = arith.addi %[[tensorOffset2ext]], %{{.*}} : tensor<16x256xi64, #blocked> + // CHECK: %[[tensorOffset:.*]] = arith.trunci %[[tensorOffset1:.*]] : tensor<16x256xi64, #blocked> to tensor<16x256xi32, #blocked> + // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> + // CHECK: tt.addptr %[[ptr]], %[[tensorOffset]] : tensor<16x256x!tt.ptr, #block + %61 = tt.load %46 : tensor<16x256x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// The following is a more complex case where also a multiplication is involved. It's useful to walk through the case. +// We have that the offset to the pointer is the following: +// %12 = %10 + 11 +// This can be transformed in: +// = %7 + %9 +// = %5*%6 + %8 +// = %4*%arg1 + %8 +// = (%3+%2)*%arg1 + %8 +// = (%1 + %2) * %arg1 + %8 +// = (U + N)*U + N +// Where U means uniform (e.g., a splat) and N means NonUniform (e.g., a make_range) +// The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8) +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func public @matmul_kernel + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %4 = arith.addi %3, %2 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> + %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked> + %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %10 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + %11 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + %12 = arith.addi %10, %11 : tensor<128x16xi32, #blocked> + %13 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %14 = tt.addptr %13, %12 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %15 = tt.load %14 : tensor<128x16x!tt.ptr, #blocked> + // CHECK: %[[pid:.*]] = tt.get_program_id x : i32 + // CHECK: %[[uniformOffset3:.*]] = arith.muli %[[pid]], %{{.*}} : i32 + // CHECK: %[[uniformOffset2:.*]] = arith.addi %[[uniformOffset3]], %{{.*}} : i32 + // CHECK: %[[uniformOffset1:.*]] = arith.muli %[[uniformOffset2]], %arg1 : i32 + // CHECK: %[[makerange:.*]] = tt.make_range + // CHECK: %{{.*}} = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + // CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : tensor<128x16xi32, #blocked> + // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}} : i32 + // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset4]] : tensor<128x16xi32, #blocked> + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 + // CHECK: %[[tensorOffset1Ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<128x16xi32, #blocked> to tensor<128x16xi64, #blocked> + // CHECK: %[[tensorOffset:.*]] = arith.addi %[[tensorOffset1Ext]], %{{.*}} : tensor<128x16xi64, #blocked> + // CHECK: %[[tensorOffsetTrunc:.*]] = arith.trunci %[[tensorOffset]] : tensor<128x16xi64, #blocked> to tensor<128x16xi32, #blocked> + // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + // CHECK: tt.addptr %[[ptr]], %[[tensorOffsetTrunc]] : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + tt.return + } +} + + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @select + tt.func @select(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[baseOffset:.*]] = tt.splat %{{.*}} : i64 + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[extVariableOffset:.*]] = arith.extsi %[[variableOffset]] + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[offset2:.*]] = arith.addi %[[extVariableOffset]], %[[baseOffset]] + // CHECK: %[[scalarPtr1:.*]] = arith.select %arg1, %arg0, %[[scalarPtr]] + // CHECK: %[[offset0:.*]] = arith.select %arg1, {{.*}}, %[[offset2]] + // CHECK: %[[offset1:.*]] = arith.trunci %[[offset0]] + // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr1]] + // CHECK: tt.addptr %[[ptr]], %[[offset1]] + %7 = arith.select %i1, %5 , %6 : tensor<1024x!tt.ptr, #blocked> + %out = tt.load %7: tensor<1024x!tt.ptr, #blocked> + tt.return %out : tensor<1024xf32, #blocked> + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1100", "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @where_kernel + tt.func @where_kernel(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}){ + %c0_i8 = arith.constant 0 : i8 + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %9 = arith.cmpi ne, %c0_i8, %c0_i8 : i8 + %10 = arith.select %9, %arg1, %arg2 : !tt.ptr + // CHECK: %[[selectPtr:.*]] = arith.select {{.*}} : !tt.ptr + %11 = tt.splat %10: !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %13 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[selectPtr0:.*]] = tt.addptr %[[selectPtr]] + // CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]] + // CHECK: tt.addptr %[[tensorPtr]] + %14 = tt.load %13 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forOpWithHints + tt.func @forOpWithHints(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c0 = arith.constant 0: index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128: index + %0 = tt.get_program_id x : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %0 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + %9 = tt.load %arg1: tensor<1024x!tt.ptr, #blocked> + // CHECK: tt.addptr {{.*}}, {{.*}} {tt.divisibility = dense<16> : tensor<1xi32>} + %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.addptr %11, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> + scf.yield %12, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } {"tt.divisibility_arg1"=dense<[16]> : tensor<1xi32>} + // CHECK: tt.divisibility_arg1 + // CHECK-SAME: tt.divisibility_arg4 + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: scalar_pointers + tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i64 = arith.constant 0 : i64 + %c10_i64 = arith.constant 10 : i64 + %c100_i32 = arith.constant 100 : i32 + %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + // CHECK: arith.constant 0 : i64 + // CHECK: arith.constant 0 : i64 + // CHECK: %[[offset0:.*]] = arith.constant 0 : i64 + // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[ptr1:.*]] = %[[ptr0]], %[[offset1:.*]] = %[[offset0]]) + %10:1 = scf.for %arg3 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %5) -> (!tt.ptr) : i32 { + // CHECK: tt.store %[[ptr1]] + tt.store %arg4, %c0_i64 : !tt.ptr + // CHECK: tt.addptr %[[ptr1]] + %11 = tt.addptr %arg4, %c1_i32 : !tt.ptr, i32 + scf.yield %11 : !tt.ptr + } + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: @scalar_if + tt.func @scalar_if(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)->f32{ + %0 = tt.get_program_id x : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i64 = arith.constant 0 : i64 + %c10_i64 = arith.constant 10 : i64 + %c100_i32 = arith.constant 100 : i32 + %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} + // CHECK: scf.if {{.*}} -> ({{.*}}, !tt.ptr, i64) + %6 = scf.if %cond -> (!tt.ptr){ + %true = tt.addptr %5, %c1_i32 : !tt.ptr, i32 + // CHECK: %[[ptr1:.*]] = tt.addptr %[[ptr0]] + // CHECK: scf.yield {{.*}}, %[[ptr1]] + scf.yield %true : !tt.ptr + } else { + %false = tt.addptr %5, %c100_i32 : !tt.ptr, i32 + // CHECK: %[[ptr2:.*]] = tt.addptr %[[ptr0]] + // CHECK: scf.yield {{.*}}, %[[ptr2]] + scf.yield %false : !tt.ptr + } + %11 = tt.load %6 : !tt.ptr + tt.return %11 : f32 + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @scalar_while + tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1)->f32{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} + // CHECK: scf.while ({{.*}}, {{.*}} = %arg2, %[[ptr1:.*]] = %[[ptr0]], {{.*}}) + %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %6 = scf.while (%arg1 = %2, %arg2 = %cond) : (!tt.ptr, i1) -> (!tt.ptr) { + // CHECK: scf.condition({{.*}}) {{.*}}, %[[ptr1]] + scf.condition(%arg2) %arg1 : !tt.ptr + } do { + // CHECK: ^bb0({{.*}}: !tt.ptr, %[[ptr2:.*]]: !tt.ptr, {{.*}}) + // CHECK: scf.yield %{{.*}}, {{.*}} %[[ptr2]], {{.*}}, {{.*}} + ^bb0(%arg1: !tt.ptr): + scf.yield %arg1, %cond : !tt.ptr, i1 + } + %11 = tt.load %6 : !tt.ptr + tt.return %11 : f32 + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @scalar_cond_branch + tt.func @scalar_cond_branch(%arg0 : !tt.ptr, %i1 : i1) -> f32{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %6 = tt.addptr %arg0, %0 : !tt.ptr, i32 + // CHECK: %[[ptr0:.*]] = tt.addptr %arg0 + // CHECK: cf.cond_br %arg1, ^bb1(%{{.*}}, %[[ptr0]], {{.*}}), ^bb2(%{{.*}}, %arg0, {{.*}}) + cf.cond_br %i1, ^bb1(%6 : !tt.ptr), ^bb2(%arg0 : !tt.ptr) + // CHECK: ^bb1({{.*}}, %[[ptr1:.*]]: !tt.ptr, {{.*}}): + ^bb1(%arg1 : !tt.ptr): + // CHECK: tt.load %[[ptr1]] + %out1 = tt.load %arg1 : !tt.ptr + tt.return %out1 : f32 + // CHECK: ^bb2({{.*}}, %[[ptr2:.*]]: !tt.ptr, {{.*}}): + ^bb2(%arg2 : !tt.ptr): // 2 preds: ^bb0, ^bb1 + // CHECK: tt.load %[[ptr2]] + %out2 = tt.load %arg2 : !tt.ptr + tt.return %out2 : f32 + } +} diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir new file mode 100644 index 000000000..25897f2a9 --- /dev/null +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -0,0 +1,124 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops | FileCheck %s + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: simple + tt.func @simple(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 :i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + // CHECK: %[[offset:.*]] = arith.addi + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + // CHECK: buffer_load %arg0[%[[offset]]] + %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> + // CHECK: buffer_load %arg1[%[[offset]]] + %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> + // CHECK: %[[data:.*]] = arith.addf + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + // CHECK: buffer_store %[[data]], %arg2[%[[offset]]] + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_positive_offset + tt.func @assume_positive_offset(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %sub = arith.subi %1, %c128_i32 : i32 + %cmp = arith.cmpi sgt, %sub, %c0_i32 : i32 + llvm.intr.assume %cmp : i1 + %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[offset:.*]] = arith.addi + %4 = arith.addi %2, %3 : tensor<1024xi32, #blocked> + // CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: buffer_load %[[scalar_ptr]][%[[offset]]] + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: offset_64_bits + tt.func @offset_64_bits(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> { + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %sub = arith.subi %1, %c128_i32 : i32 + %2 = tt.splat %sub : i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %ext2 = arith.extsi %2 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %ext3 = arith.extsi %3 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %4 = arith.addi %ext2, %ext3 : tensor<1024xi64, #blocked> + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> + // CHECK: tt.load + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: offset_64_bits_narrow + tt.func public @offset_64_bits_narrow(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) -> tensor<1024xf32, #blocked> { + %c1024_i32 = arith.constant 1024 : i32 + %c128_i32 = arith.constant 128 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.splat %1: i32 -> tensor<1024xi32, #blocked> + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %ext2 = arith.extsi %2 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %ext3 = arith.extsi %3 : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + %4 = arith.addi %ext2, %ext3 : tensor<1024xi64, #blocked> + // CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 + %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[offset_32_bit:.*]] = arith.trunci + %narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked> + %9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]] + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: non_canonical_ptr + tt.func @non_canonical_ptr(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: tensor<1024xi32, #blocked>) -> tensor<1024xf32, #blocked>{ + %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %9 = tt.addptr %8, %arg1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: tt.load + %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> + tt.return %10 : tensor<1024xf32, #blocked> + } +} diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index cb565d1f0..686e5a24e 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -1,25 +1,924 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s -// Check that we order load, local_alloc and local_load one after another. This is useful -// for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers +// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands +// in cases where local_alloc is in the loop but it's operand is not. +// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers // throughout the computation. -// CHECK-LABEL: order_load_alloc_local_load -// CHECK: %[[LOAD:.+]] = tt.load -// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]] -// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] + +// CHECK-LABEL: hoist_q_out_of_the_loop +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]] +// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]] +// CHECK: scf.for +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + tt.return + } +} + + +// ----- +// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both +// local_alloc and it's src tensor defining op are in the loop. +// CHECK-LABEL: no_hoist_q_type_reordering +// CHECK: scf.for +// CHECK: %[[TRUNCF:.+]] = arith.truncf +// CHECK-NEXT: arith.constant +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c0_i64 = arith.constant 0 : i64 + %1 = tt.get_program_id y : i32 + %2 = arith.muli %1, %arg7 : i32 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i32 + %12 = tt.splat %3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %41 = tt.load %12 : tensor<256x128x!tt.ptr, #blocked1> + %42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1> + %43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1> + %44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1> + %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { + %45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> + %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> + %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> + %75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> + %76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> + %78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + %107 = arith.addi %arg26, %c128_i64 : i64 + scf.yield %107 : i64 + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + tt.return + } +} + +// ----- #blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> + +// CHECK-LABEL: order_load_alloc_local_load_local_store +// CHECK: %[[LOAD:.+]] = tt.load +// CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc +// CHECK: triton_gpu.local_store %[[LOAD]], %[[ALLOC]] +// CHECK: triton_gpu.local_load %[[ALLOC]] module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @order_load_alloc_local_load(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + tt.func public @order_load_alloc_local_load_local_store(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared> + %10 = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #shared, mutable> + triton_gpu.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared, mutable> %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> %13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> tt.return } } + +// ----- +// Move loads (and independent local_stores) as early as possible. +// For example in the matmul_loop below, the scf.for loop looks like this after pipeliner: +// scf.for ... { +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0 +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } +// +// Solution for num_stages=2 : +// scf.for ... { +// // stage 0.a +// %aptr = tt.addptr %aptr, %k +// %a_next = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next = tt.load %bptr +// // stage 1 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// // stage 0.b +// tt.local_store %a_next +// tt.local_store %b_next +// yield +// } +// +// Solution for num_stages=3 (double-buffered) : +// scf.for ... { +// // stage 1 +// tt.local_store %a_next_1 +// tt.local_store %b_next_1 +// // stage 0 +// %aptr = tt.addptr %aptr, %k +// %a_next_2 = tt.load %aptr +// %bptr = tt.addptr %bptr, %k +// %b_next_2 = tt.load %bptr +// // stage 2 +// %a = tt.local_load %a_tile +// %b = tt.local_load %b_tile +// tt.dot %c, %a, %b +// yield +// } + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +#shared3 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#shared4 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32, triton_gpu.target = "hip:gfx942"} { + +// CHECK-LABEL: tt.func @matmul_loop +// CHECK: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) +// Stage 0.a +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] +// CHECK: %[[ADDPTR_25:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_25]], %[[SPLAT_26]] +// Stage 1 +// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} +// CHECK: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %[[ARG8]] +// Stage 0.b +// CHECK: %[[ADDI_32:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ADDI_32]], %{{.*}} +// CHECK: %[[SELECT_34:.*]] = arith.select %[[CMPI_33]], %[[ADDI_32]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]] +// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: scf.yield %[[ADDPTR_20]], %[[ADDPTR_25]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]] +// CHECK: } + + tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %11 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %12 = arith.cmpi slt, %arg0, %arg1 : index + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) { + %20 = arith.subi %arg1, %arg2 : index + %21 = arith.cmpi slt, %arg5, %20 : index + %22 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %23 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = arith.mulf %23, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %26 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %27 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %28 = tt.splat %21 : i1 -> tensor<128x32xi1, #blocked1> + %29 = tt.load %26, %28 : tensor<128x32x!tt.ptr, #blocked1> + %30 = tt.splat %21 : i1 -> tensor<32x128xi1, #blocked> + %31 = tt.load %27, %30, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %32 = arith.addi %arg9, %c1_i32 : i32 + %33 = arith.cmpi slt, %32, %c1_i32 : i32 + %34 = arith.select %33, %32, %c0_i32 : i32 + %35 = triton_gpu.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %36 = triton_gpu.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + } + triton_gpu.local_dealloc %10 : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %11 : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + tt.return %19#2 : tensor<128x128xf32, #mma> + } + + +// This example tests that tt.load overlaps with independent ttg.local_store which +// overlaps with independent tt.dot. +// num_stages == 3, double buffered + +// CHECK-LABEL: tt.func @matmul_loop_mb +// CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// Stage 0 +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}} +// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]] +// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]] +// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]] +// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]] +// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]] +// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]] +// Stage 1 +// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}} +// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}} +// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]] +// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]] +// Stage 2 +// CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]] +// CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}} +// CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]] +// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]] +// CHECK: } + + tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + %c2 = arith.constant 2 : index + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = triton_gpu.local_alloc : () -> !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %11 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %12 = arith.cmpi slt, %arg0, %arg1 : index + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = arith.addi %arg0, %arg2 : index + %18 = arith.cmpi slt, %17, %arg1 : index + %19 = tt.addptr %4, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %20 = tt.addptr %9, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %21 = tt.splat %18 : i1 -> tensor<128x32xi1, #blocked1> + %22 = tt.load %19, %21 : tensor<128x32x!tt.ptr, #blocked1> + %23 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked> + %24 = tt.load %20, %23, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %25 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %26 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { + %28 = arith.muli %arg2, %c2 : index + %29 = arith.subi %arg1, %28 : index + %30 = arith.cmpi slt, %arg5, %29 : index + %31 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %33 = arith.mulf %32, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %35 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %36 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %37 = tt.splat %30 : i1 -> tensor<128x32xi1, #blocked1> + %38 = tt.load %35, %37 : tensor<128x32x!tt.ptr, #blocked1> + %39 = tt.splat %30 : i1 -> tensor<32x128xi1, #blocked> + %40 = tt.load %36, %39, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %41 = arith.addi %arg9, %c1_i32 : i32 + %42 = arith.cmpi slt, %41, %c2_i32 : i32 + %43 = arith.select %42, %41, %c0_i32 : i32 + %44 = triton_gpu.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + %45 = triton_gpu.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> + } + triton_gpu.local_dealloc %10 : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %11 : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> + tt.return %27#2 : tensor<128x128xf32, #mma> + } + +// This example shows dependent loads and verifies all are moved early. +// CHECK-LABEL: tt.func @indirect_bmm_vector +// CHECK: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// Stage 0 +// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_21]] +// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]] +// Stage 1.a +// CHECK: %[[EXPAND_DIMS_25:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} +// CHECK: %[[BROADCAST_26:.*]] = tt.broadcast %[[EXPAND_DIMS_25]] +// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %[[BROADCAST_26]] +// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}, %[[MULI_27]] +// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_22]] +// CHECK: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]] +// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// CHECK: %[[SUBI_32:.*]] = arith.subi %{{.*}}, %{{.*}} +// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_32]] +// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_33]] +// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_34]] +// Stage 2 +// CHECK: %[[LOCAL_LOAD_36:.*]] = triton_gpu.local_load %[[ARG11]] +// CHECK: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG12]] +// CHECK: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_36]], %[[LOCAL_LOAD_37]], %[[ARG7]] +// Stage 1.b +// CHECK: %[[ADDI_39:.*]] = arith.addi %[[ARG10]], %{{.*}} +// CHECK: %[[CMPI_40:.*]] = arith.cmpi slt, %[[ADDI_39]], %{{.*}} +// CHECK: %[[SELECT_41:.*]] = arith.select %[[CMPI_40]], %[[ADDI_39]], %{{.*}} +// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]] +// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}] +// CHECK: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]] +// CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_20]], %[[ADDPTR_31]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_35]] +// CHECK: } + + tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> { + %c2 = arith.constant 2 : index + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %2 = arith.cmpi sgt, %arg1, %c0 : index + %3 = tt.splat %2 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = arith.cmpi sgt, %arg1, %c1 : index + %6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked1> + %8 = tt.load %arg2, %7 : tensor<16x16x!tt.ptr, #blocked1> + %9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %10 = tt.broadcast %9 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %11 = arith.muli %arg0, %10 : tensor<16x16xi64, #blocked> + %12 = tt.addptr %arg5, %11 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %13 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked> + %14 = tt.load %12, %13 : tensor<16x16x!tt.ptr, #blocked> + %15 = tt.splat %5 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %17 = triton_gpu.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %18 = triton_gpu.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) { + %20 = arith.subi %arg1, %c2 : index + %21 = arith.cmpi slt, %arg6, %20 : index + %22 = arith.subi %arg1, %c1 : index + %23 = arith.cmpi slt, %arg6, %22 : index + %24 = triton_gpu.local_load %arg11 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %25 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %27 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> + %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %29 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked1> + %30 = tt.load %27, %29 : tensor<16x16x!tt.ptr, #blocked1> + %31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked> + %32 = tt.broadcast %31 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked> + %33 = arith.muli %arg0, %32 : tensor<16x16xi64, #blocked> + %34 = tt.addptr %arg5, %33 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi64, #blocked> + %35 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked> + %36 = tt.load %34, %35 : tensor<16x16x!tt.ptr, #blocked> + %37 = tt.splat %21 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %38 = tt.load %28, %37 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %39 = arith.addi %arg10, %c1_i32 : i32 + %40 = arith.cmpi slt, %39, %c1_i32 : i32 + %41 = arith.select %40, %39, %c0_i32 : i32 + %42 = triton_gpu.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + %43 = triton_gpu.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + } + triton_gpu.local_dealloc %0 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %1 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> + tt.return %19#0 : tensor<16x16xf32, #mma> + } +} + +// ----- +// This test ensures that loads will not be moved across `for` loops. + +// CHECK-LABEL: tt.func public @_attn_bwd +// CHECK: tt.load +// CHECK: tt.load +// CHECK: scf.for +// CHECK: } +// CHECK: scf.for +// CHECK: } +// Moved before the independent `tt.store` ops but not before the `for` ops. +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.load +// CHECK: tt.store +// CHECK: tt.store +// CHECK: scf.for +// CHECK: } +// CHECK: scf.for +// CHECK: } +// CHECK: tt.store + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#mma1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> +#shared3 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @_attn_bwd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma> + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %c32_i32 = arith.constant 32 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %cst_2 = arith.constant dense<0.693147182> : tensor<128x64xf32, #mma> + %0 = tt.get_program_id z : i32 + %1 = arith.muli %0, %arg14 : i32 + %2 = arith.extsi %1 : i32 to i64 + %3 = arith.remsi %0, %arg13 : i32 + %4 = arith.muli %arg11, %3 : i32 + %5 = arith.divsi %0, %arg13 : i32 + %6 = arith.muli %arg10, %5 : i32 + %7 = arith.addi %4, %6 : i32 + %8 = arith.extsi %7 : i32 to i64 + %9 = tt.get_program_id x : i32 + %10 = tt.addptr %arg0, %8 : !tt.ptr, i64 + %11 = tt.addptr %arg1, %8 : !tt.ptr, i64 + %12 = tt.addptr %arg2, %8 : !tt.ptr, i64 + %13 = tt.addptr %arg4, %8 : !tt.ptr, i64 + %14 = tt.addptr %arg5, %8 : !tt.ptr, i64 + %15 = tt.addptr %arg6, %8 : !tt.ptr, i64 + %16 = tt.addptr %arg7, %8 : !tt.ptr, i64 + %17 = tt.addptr %arg8, %2 : !tt.ptr, i64 + %18 = tt.addptr %arg9, %2 : !tt.ptr, i64 + %19 = arith.muli %9, %c128_i32 : i32 + %20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %25 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %26 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %27 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %28 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %29 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %30 = arith.addi %25, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %31 = arith.addi %26, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %32 = arith.addi %27, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %33 = arith.addi %28, %23 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %34 = arith.addi %29, %24 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %35 = tt.expand_dims %30 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> + %36 = tt.expand_dims %31 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %37 = tt.expand_dims %32 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1> + %38 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #mma> + %39 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #blocked> + %40 = arith.muli %35, %38 : tensor<128x1xi32, #mma> + %41 = arith.muli %36, %39 : tensor<128x1xi32, #blocked> + %42 = tt.splat %11 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %43 = tt.addptr %42, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %46 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %47 = tt.expand_dims %44 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma> + %48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %49 = tt.expand_dims %46 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %50 = tt.broadcast %43 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %51 = tt.broadcast %47 : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma> + %52 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> + %53 = tt.addptr %50, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %54 = tt.load %53 : tensor<128x64x!tt.ptr, #blocked> + %55 = tt.splat %12 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %56 = tt.addptr %55, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %57 = tt.broadcast %56 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %58 = tt.addptr %57, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %59 = tt.load %58 : tensor<128x64x!tt.ptr, #blocked> + %60 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %61 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %62 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %63 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %64 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %65 = arith.addi %63, %60 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %66 = arith.addi %64, %62 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %67 = tt.expand_dims %65 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi32, #blocked2> + %68 = tt.splat %arg12 : i32 -> tensor<1x16xi32, #blocked2> + %69 = arith.muli %67, %68 : tensor<1x16xi32, #blocked2> + %70 = tt.splat %10 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> + %71 = tt.addptr %70, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> + %72 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %73 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %74 = tt.expand_dims %72 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1xi32, #blocked2> + %75 = tt.expand_dims %73 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi32, #blocked3> + %76 = tt.broadcast %71 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> + %77 = tt.broadcast %74 : tensor<64x1xi32, #blocked2> -> tensor<64x16xi32, #blocked2> + %78 = tt.addptr %76, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %79 = tt.expand_dims %66 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> + %80 = tt.splat %arg12 : i32 -> tensor<16x1xi32, #blocked1> + %81 = arith.muli %79, %80 : tensor<16x1xi32, #blocked1> + %82 = tt.splat %13 : !tt.ptr -> tensor<16x1x!tt.ptr, #blocked1> + %83 = tt.addptr %82, %81 : tensor<16x1x!tt.ptr, #blocked1>, tensor<16x1xi32, #blocked1> + %84 = tt.broadcast %83 : tensor<16x1x!tt.ptr, #blocked1> -> tensor<16x64x!tt.ptr, #blocked1> + %85 = tt.broadcast %49 : tensor<1x64xi32, #blocked1> -> tensor<16x64xi32, #blocked1> + %86 = tt.addptr %84, %85 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> + %87 = tt.splat %17 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %88 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> + %89 = tt.splat %18 : !tt.ptr -> tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %90 = arith.muli %arg12, %c16_i32 : i32 + %91 = tt.splat %90 : i32 -> tensor<64x16xi32, #blocked2> + %92 = tt.splat %90 : i32 -> tensor<16x64xi32, #blocked1> + %93:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %cst_1, %arg18 = %19, %arg19 = %78, %arg20 = %86) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1>) : i32 { + %206 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> + %207 = tt.splat %arg18 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %208 = arith.addi %207, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %209 = tt.addptr %87, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %210 = tt.load %209 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %213 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %215 = tt.dot %212, %214, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> + %217 = tt.broadcast %216 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> + %218 = arith.subf %215, %217 : tensor<128x16xf32, #mma1> + %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> + %220 = tt.expand_dims %208 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> + %221 = tt.broadcast %220 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> + %222 = arith.cmpi sge, %221, %88 : tensor<128x16xi32, #mma1> + %223 = arith.select %222, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %224 = tt.load %arg20 : tensor<16x64x!tt.ptr, #blocked1> + %225 = arith.truncf %223 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %226 = triton_gpu.local_alloc %225 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> + %227 = triton_gpu.local_load %226 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %228 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> + %229 = triton_gpu.local_load %228 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %230 = tt.dot %227, %229, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %231 = tt.addptr %89, %208 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %232 = tt.load %231 : tensor<16x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %233 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %234 = tt.trans %233 {order = array} : !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %235 = triton_gpu.local_load %234 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %236 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %237 = triton_gpu.local_load %236 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %238 = tt.dot %237, %235, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %239 = tt.expand_dims %232 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1> + %240 = tt.broadcast %239 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1> + %241 = arith.subf %238, %240 : tensor<128x16xf32, #mma1> + %242 = arith.mulf %223, %241 : tensor<128x16xf32, #mma1> + %243 = arith.truncf %242 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %244 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> + %245 = tt.trans %244 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> + %246 = triton_gpu.local_load %245 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %247 = triton_gpu.local_alloc %243 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> + %248 = triton_gpu.local_load %247 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %249 = tt.dot %248, %246, %arg17 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %250 = arith.addi %arg18, %c16_i32 : i32 + %251 = tt.addptr %arg19, %91 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %252 = tt.addptr %arg20, %92 : tensor<16x64x!tt.ptr, #blocked1>, tensor<16x64xi32, #blocked1> + scf.yield %230, %249, %250, %251, %252 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<16x64x!tt.ptr, #blocked1> + } + %94 = arith.addi %19, %c128_i32 : i32 + %95 = arith.subi %arg14, %94 : i32 + %96 = arith.divsi %95, %c32_i32 : i32 + %97 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %98 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %99 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %100 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %101 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %102 = arith.addi %100, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %103 = arith.addi %101, %99 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %104 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %105 = tt.splat %arg12 : i32 -> tensor<1x32xi32, #blocked3> + %106 = arith.muli %104, %105 : tensor<1x32xi32, #blocked3> + %107 = tt.splat %10 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> + %108 = tt.addptr %107, %106 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> + %109 = tt.broadcast %108 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> + %110 = tt.broadcast %75 : tensor<64x1xi32, #blocked3> -> tensor<64x32xi32, #blocked3> + %111 = tt.addptr %109, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %112 = tt.expand_dims %103 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %113 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked> + %114 = arith.muli %112, %113 : tensor<32x1xi32, #blocked> + %115 = tt.splat %13 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %116 = tt.addptr %115, %114 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %117 = tt.broadcast %116 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x64x!tt.ptr, #blocked> + %118 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked> + %119 = tt.addptr %117, %118 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + %120 = tt.splat %17 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %121 = tt.splat %18 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %122 = arith.muli %arg12, %c32_i32 : i32 + %123 = tt.splat %122 : i32 -> tensor<64x32xi32, #blocked3> + %124 = tt.splat %122 : i32 -> tensor<32x64xi32, #blocked> + %125:5 = scf.for %arg15 = %c0_i32 to %96 step %c1_i32 iter_args(%arg16 = %93#0, %arg17 = %93#1, %arg18 = %94, %arg19 = %111, %arg20 = %119) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked>) : i32 { + %206 = tt.load %arg19 : tensor<64x32x!tt.ptr, #blocked3> + %207 = tt.splat %arg18 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %208 = arith.addi %207, %98 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %209 = tt.addptr %120, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %210 = tt.load %209 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %213 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %215 = tt.dot %212, %214, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> + %217 = tt.broadcast %216 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> + %218 = arith.subf %215, %217 : tensor<128x32xf32, #mma> + %219 = math.exp2 %218 : tensor<128x32xf32, #mma> + %220 = tt.load %arg20 : tensor<32x64x!tt.ptr, #blocked> + %221 = arith.truncf %219 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> + %222 = triton_gpu.convert_layout %221 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %223 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> + %224 = triton_gpu.local_load %223 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %225 = tt.dot %222, %224, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %226 = tt.addptr %121, %208 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %227 = tt.load %226 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %228 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> + %229 = tt.trans %228 {order = array} : !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> + %230 = triton_gpu.local_load %229 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %231 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %232 = triton_gpu.local_load %231 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %233 = tt.dot %232, %230, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %234 = tt.expand_dims %227 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma> + %235 = tt.broadcast %234 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma> + %236 = arith.subf %233, %235 : tensor<128x32xf32, #mma> + %237 = arith.mulf %219, %236 : tensor<128x32xf32, #mma> + %238 = arith.truncf %237 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> + %239 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> + %240 = tt.trans %239 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> + %241 = triton_gpu.local_load %240 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %242 = triton_gpu.convert_layout %238 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %243 = tt.dot %242, %241, %arg17 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %244 = arith.addi %arg18, %c32_i32 : i32 + %245 = tt.addptr %arg19, %123 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %246 = tt.addptr %arg20, %124 : tensor<32x64x!tt.ptr, #blocked>, tensor<32x64xi32, #blocked> + scf.yield %225, %243, %244, %245, %246 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr, #blocked3>, tensor<32x64x!tt.ptr, #blocked> + } + %126 = tt.splat %16 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> + %127 = tt.addptr %126, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> + %128 = tt.broadcast %127 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> + %129 = tt.addptr %128, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> + %130 = arith.truncf %125#0 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + tt.store %129, %130 : tensor<128x64x!tt.ptr, #mma> + %131 = tt.splat %arg3 : f32 -> tensor<128x64xf32, #mma> + %132 = arith.mulf %125#1, %131 : tensor<128x64xf32, #mma> + %133 = tt.splat %15 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> + %134 = tt.addptr %133, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> + %135 = tt.broadcast %134 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> + %136 = tt.addptr %135, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> + %137 = arith.truncf %132 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + tt.store %136, %137 : tensor<128x64x!tt.ptr, #mma> + %138 = tt.splat %10 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %139 = tt.addptr %138, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %140 = tt.broadcast %139 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %141 = tt.addptr %140, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %142 = tt.load %141 : tensor<128x64x!tt.ptr, #blocked> + %143 = tt.splat %13 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %144 = tt.addptr %143, %41 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %145 = tt.broadcast %144 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x64x!tt.ptr, #blocked> + %146 = tt.addptr %145, %52 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %147 = tt.load %146 : tensor<128x64x!tt.ptr, #blocked> + %148 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %149 = tt.splat %17 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %150 = tt.addptr %148, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %151 = tt.addptr %149, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %152 = tt.load %150 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %153 = tt.load %151 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %154 = tt.expand_dims %152 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> + %155 = tt.expand_dims %153 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %156 = tt.splat %11 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> + %157 = tt.addptr %156, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> + %158 = tt.broadcast %157 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> + %159 = tt.addptr %158, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %160 = tt.splat %12 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked2> + %161 = tt.addptr %160, %69 : tensor<1x16x!tt.ptr, #blocked2>, tensor<1x16xi32, #blocked2> + %162 = tt.broadcast %161 : tensor<1x16x!tt.ptr, #blocked2> -> tensor<64x16x!tt.ptr, #blocked2> + %163 = tt.addptr %162, %77 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %164 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %165 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %166 = tt.addptr %164, %33 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %167 = tt.addptr %165, %34 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %168 = tt.load %166 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma1}>> + %169 = tt.load %167 : tensor<128x!tt.ptr, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %170 = tt.broadcast %154 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> + %171 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1> + %172 = tt.expand_dims %168 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1> + %173 = tt.broadcast %172 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1> + %174 = arith.muli %arg12, %c16_i32 : i32 + %175 = tt.splat %174 : i32 -> tensor<64x16xi32, #blocked2> + %176 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + %177:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %19, %arg18 = %159, %arg19 = %163, %arg20 = %c-1_i32) -> (tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32) : i32 { + %206 = arith.addi %arg20, %c1_i32 : i32 + %207 = arith.cmpi slt, %206, %c1_i32 : i32 + %208 = arith.select %207, %206, %c0_i32 : i32 + %209 = tt.load %arg18 : tensor<64x16x!tt.ptr, #blocked2> + %210 = tt.load %arg19 : tensor<64x16x!tt.ptr, #blocked2> + %211 = triton_gpu.memdesc_subview %176[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %210, %211 : tensor<64x16xf16, #blocked2> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + %212 = triton_gpu.local_load %211 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %215 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %216 = triton_gpu.local_load %215 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %217 = tt.dot %214, %216, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %218 = arith.subf %217, %170 : tensor<128x16xf32, #mma1> + %219 = math.exp2 %218 : tensor<128x16xf32, #mma1> + %220 = tt.splat %arg17 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %221 = arith.addi %220, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + %222 = tt.expand_dims %221 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1> + %223 = tt.broadcast %222 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1> + %224 = arith.cmpi sge, %171, %223 : tensor<128x16xi32, #mma1> + %225 = arith.select %224, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %226 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %227 = triton_gpu.local_load %226 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %228 = tt.dot %227, %212, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1> + %229 = arith.subf %228, %173 : tensor<128x16xf32, #mma1> + %230 = arith.mulf %225, %229 : tensor<128x16xf32, #mma1> + %231 = arith.truncf %230 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> + %232 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> + %233 = tt.trans %232 {order = array} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> + %234 = triton_gpu.local_load %233 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %235 = triton_gpu.local_alloc %231 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> + %236 = triton_gpu.local_load %235 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %237 = tt.dot %236, %234, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %238 = arith.addi %arg17, %c16_i32 : i32 + %239 = tt.addptr %arg18, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + %240 = tt.addptr %arg19, %175 : tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16xi32, #blocked2> + scf.yield %237, %238, %239, %240, %208 : tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr, #blocked2>, tensor<64x16x!tt.ptr, #blocked2>, i32 + } + triton_gpu.local_dealloc %176 : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> + %178 = arith.divsi %19, %c32_i32 : i32 + %179 = arith.muli %178, %c32_i32 : i32 + %180 = arith.subi %19, %179 : i32 + %181 = tt.splat %180 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %182 = arith.addi %181, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %183 = tt.expand_dims %182 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3> + %184 = arith.muli %183, %105 : tensor<1x32xi32, #blocked3> + %185 = tt.splat %11 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> + %186 = tt.addptr %185, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> + %187 = tt.broadcast %186 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> + %188 = tt.addptr %187, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %189 = tt.splat %12 : !tt.ptr -> tensor<1x32x!tt.ptr, #blocked3> + %190 = tt.addptr %189, %184 : tensor<1x32x!tt.ptr, #blocked3>, tensor<1x32xi32, #blocked3> + %191 = tt.broadcast %190 : tensor<1x32x!tt.ptr, #blocked3> -> tensor<64x32x!tt.ptr, #blocked3> + %192 = tt.addptr %191, %110 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %193 = tt.broadcast %155 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> + %194 = tt.expand_dims %169 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %195 = tt.broadcast %194 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma> + %196 = arith.muli %arg12, %c32_i32 : i32 + %197 = tt.splat %196 : i32 -> tensor<64x32xi32, #blocked3> + %198 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + %199:4 = scf.for %arg15 = %c0_i32 to %178 step %c1_i32 iter_args(%arg16 = %177#0, %arg17 = %188, %arg18 = %192, %arg19 = %c-1_i32) -> (tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32) : i32 { + %206 = arith.addi %arg19, %c1_i32 : i32 + %207 = arith.cmpi slt, %206, %c1_i32 : i32 + %208 = arith.select %207, %206, %c0_i32 : i32 + %209 = tt.load %arg17 : tensor<64x32x!tt.ptr, #blocked3> + %210 = tt.load %arg18 : tensor<64x32x!tt.ptr, #blocked3> + %211 = triton_gpu.memdesc_subview %198[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %210, %211 : tensor<64x32xf16, #blocked3> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + %212 = triton_gpu.local_load %211 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %215 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> + %216 = triton_gpu.local_load %215 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %217 = tt.dot %214, %216, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %218 = arith.subf %217, %193 : tensor<128x32xf32, #mma> + %219 = math.exp2 %218 : tensor<128x32xf32, #mma> + %220 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %221 = triton_gpu.local_load %220 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %222 = tt.dot %221, %212, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma> + %223 = arith.subf %222, %195 : tensor<128x32xf32, #mma> + %224 = arith.mulf %219, %223 : tensor<128x32xf32, #mma> + %225 = arith.truncf %224 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma> + %226 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> + %227 = tt.trans %226 {order = array} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> + %228 = triton_gpu.local_load %227 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %229 = triton_gpu.convert_layout %225 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %230 = tt.dot %229, %228, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma> + %231 = tt.addptr %arg17, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + %232 = tt.addptr %arg18, %197 : tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32xi32, #blocked3> + scf.yield %230, %231, %232, %208 : tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked3>, tensor<64x32x!tt.ptr, #blocked3>, i32 + } + triton_gpu.local_dealloc %198 : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> + %200 = tt.splat %14 : !tt.ptr -> tensor<128x1x!tt.ptr, #mma> + %201 = tt.addptr %200, %40 : tensor<128x1x!tt.ptr, #mma>, tensor<128x1xi32, #mma> + %202 = tt.broadcast %201 : tensor<128x1x!tt.ptr, #mma> -> tensor<128x64x!tt.ptr, #mma> + %203 = tt.addptr %202, %51 : tensor<128x64x!tt.ptr, #mma>, tensor<128x64xi32, #mma> + %204 = arith.mulf %199#0, %cst_2 : tensor<128x64xf32, #mma> + %205 = arith.truncf %204 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> + tt.store %203, %205 : tensor<128x64x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +// CHECK-LABEL: sink_convert_dealloc +// CHECK-COUNT-2: triton_gpu.local_dealloc %{{.+}} : !tt.memdesc<4x128x64xf16, #shared, mutable> +// CHECK: triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> + %2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> + triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> + triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> + %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> + tt.return + } +} + +// ----- + +// CHECK-LABEL: anchor_barrier +// CHECK: gpu.barrier +// CHECK: tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @anchor_barrier(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> + gpu.barrier + %2 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> + %1 = triton_gpu.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !tt.memdesc<4x128x64xf16, #shared, mutable> + triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> + triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-stream-pipeline.mlir b/test/TritonGPU/amd/amd-stream-pipeline.mlir deleted file mode 100644 index 4b2de3336..000000000 --- a/test/TritonGPU/amd/amd-stream-pipeline.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-stream-pipeline | FileCheck %s - -// CHECK-LABEL: @check_stream_pipeline_epilogue -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [1, 0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { - tt.func public @check_stream_pipeline_epilogue(%Aptr: tensor<32x32x!tt.ptr, #blocked>, %Bptr : tensor<32x32x!tt.ptr, #blocked>, %arg4 : i32, %arg5 : i1) { - %cst_0 = arith.constant dense<16> : tensor<32x32xi32, #blocked> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> - %cst_5 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - // CHECK: scf.for {{.*}} = %[[LB:.*]] to %[[UB:.*]] step %[[STEP:.*]] iter_args({{.*}}) - %36:3 = scf.for %arg9 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg10 = %cst_5, %arg12 = %Aptr, %arg13 = %Bptr) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked>) : i32 { - %61 = arith.muli %arg9, %arg4 : i32 - %62 = arith.cmpi slt, %arg4, %61 : i32 - %63 = tt.splat %62 : i1 -> tensor<32x32xi1, #blocked> - // This load will not be pipelined - %66 = tt.load %arg12, %63 : tensor<32x32x!tt.ptr, #blocked> - // This load will be pipelined - %70 = tt.load %arg13 : tensor<32x32x!tt.ptr, #blocked> - %71 = triton_gpu.convert_layout %66 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.convert_layout %70 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %73 = tt.dot %71, %72, %arg10 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> - // This scf.if will make load at %66 non-pipelineable - %74 = scf.if %arg5 -> (tensor<32x32xf32, #blocked>){ - scf.yield %66 : tensor<32x32xf32, #blocked> - } else { - scf.yield %cst_2: tensor<32x32xf32, #blocked> - } - %75 = tt.addptr %arg12, %cst_0 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - %76 = tt.addptr %arg13, %cst_0 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> - scf.yield %73, %75, %76 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked> - } - // CHECK: %[[C1:.*]] = arith.constant 1 : i32 - // CHECK: %[[t0:.*]] = arith.subi %[[UB:.*]], %[[C1]] - // CHECK: %[[t1:.*]] = arith.subi %[[t0]], %[[LB]] - // CHECK: %[[t2:.*]] = arith.divui %[[t1]], %[[STEP]] - // CHECK: %[[t3:.*]] = arith.muli %[[t2]], %[[STEP]] - // CHECK: %[[PPLUB:.*]] = arith.addi %[[LB]], %[[t3]] - // CHECK: arith.muli %[[PPLUB]], {{.*}} - tt.return - } -} diff --git a/test/TritonGPU/amd/optimize-lds-usage.mlir b/test/TritonGPU/amd/optimize-lds-usage.mlir new file mode 100644 index 000000000..5cd34aab2 --- /dev/null +++ b/test/TritonGPU/amd/optimize-lds-usage.mlir @@ -0,0 +1,139 @@ +// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a | FileCheck %s +// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a -optimize-amd-lds-usage=lds-limit=32768 | FileCheck %s --check-prefix=CHECK-32KLIMIT + +// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS +// CHECK-LABEL: alloc_convert_load +// CHECK-32KLIMIT-LABEL: alloc_convert_load +// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} { + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> + %2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> + %3 = triton_gpu.local_load %1 : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return + } +} + +// ----- + +// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS +// in case of relatively small scratch buffer +// CHECK-LABEL: alloc_convert_small_load +// CHECK-32KLIMIT-LABEL: alloc_convert_small_load +// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) attributes {noinline = false} { + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> + %2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma> + %3 = triton_gpu.local_load %1 : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return + } +} + +// ----- + +// Check that optimization works with 3d tensors +// in case of relatively small scratch buffer +// CHECK-LABEL: alloc_convert_3d_load +// CHECK-32KLIMIT-LABEL: alloc_convert_3d_load +// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma +// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#mma{{.*}}#mma1 +// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x128x128xf16, #blocked>) attributes {noinline = false} { + %1 = triton_gpu.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !tt.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory> + %2 = triton_gpu.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma> + %3 = triton_gpu.local_load %1 : !tt.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<1x128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + tt.return + } +} + +// ----- + +// Check that optimization triggers with custom LDS limit and do not triggers with default one +// CHECK-LABEL: alloc_convert_32k_limit +// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma +// CHECK: %2 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK-32KLIMIT-LABEL: alloc_convert_32k_limit +// CHECK-32KLIMIT: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared +// CHECK-32KLIMIT: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1 +// CHECK-32KLIMIT: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma +// CHECK-32KLIMIT: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<64x128xf16, #blocked>) attributes {noinline = false} { + %1 = triton_gpu.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> + %2 = triton_gpu.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma> + %3 = triton_gpu.local_load %1 : !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>> + tt.return + } +} + +// ----- + +// Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion) +// CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}> +// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> + +// CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}}) +// CHECK: [[ALLOC:%[0-9]+]] = triton_gpu.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !tt.memdesc<128x128xf16, [[SHARED]], #triton_gpu.shared_memory> +// CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = triton_gpu.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]> +// CHECK: [[CONVERT_1:%[0-9]+]] = triton_gpu.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]> +// CHECK: [[CONVERT_2:%[0-9]+]] = triton_gpu.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>> +// CHECK: [[LOAD:%[0-9]+]] = triton_gpu.local_load [[ALLOC]] : !tt.memdesc<128x128xf16, [[SHARED]], #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma1 = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> +#mma2 = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#dotop1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma1, kWidth=4}> +#dotop2 = #triton_gpu.dot_op<{opIdx=0, parent=#mma2, kWidth=4}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_shortcut(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>, %arg2: tensor<256x128xf16, #mma2>) attributes {noinline = false} { + %alloc = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> + %convert_1 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma1> + %convert_2 = triton_gpu.convert_layout %arg2 : tensor<256x128xf16, #mma2> -> tensor<256x128xf16, #dotop2> + %load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #dotop1> + tt.return + } +} + +// ----- + +// Checks that optimization do not crash on 1d tensor +// CHECK-LABEL: convert_1d +// CHECK: triton_gpu.local_alloc +// CHECK-NEXT: triton_gpu.convert_layout +// CHECK-NEXT: triton_gpu.local_load +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @convert_1d(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} { + %alloc = triton_gpu.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> + %1 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked> + %load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma> + tt.return + } +} diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index 565fe0aba..9422bb0f8 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: @test_canonicalize_convert_view // CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 // CHECK-NOT: triton_gpu.convert_layout -// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] {allow_reorder = true} +// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder // CHECK: tt.return %[[V]] #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> @@ -13,7 +13,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { %c = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2> - %r = tt.reshape %c {allow_reorder = true} : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> + %r = tt.reshape %c allow_reorder : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } } // end module @@ -25,7 +25,7 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> // CHECK-LABEL: @test_canonicalize_convert_expensive_view // CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32 // CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]] -// CHECK: %[[V:.+]] = tt.reshape %[[C]] {allow_reorder = true} +// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder // CHECK: tt.return %[[V]] #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> @@ -33,7 +33,7 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { %c = triton_gpu.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2> - %r = tt.reshape %c {allow_reorder = true} : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1> + %r = tt.reshape %c allow_reorder : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } } // end module @@ -57,3 +57,79 @@ tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>) tt.return %2 : tensor<512xi32, #blocked2> } } // end module + +// ----- + +// CHECK-LABEL: @test_canonicalize_convert_local_load +// CHECK-NOT: gpu.barrier +// CHECK: %[[V:.+]] = triton_gpu.local_load +// CHECK-NEXT: gpu.barrier +// CHECK-NEXT: tt.return %[[V]] + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.compute-capability" = 80} { +tt.func @test_canonicalize_convert_local_load() -> tensor<256xi32, #blocked1> { + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<256xi32, #shared, mutable> + %1 = triton_gpu.local_load %0 : !tt.memdesc<256xi32, #shared, mutable> -> tensor<256xi32, #blocked> + gpu.barrier + %2 = triton_gpu.convert_layout %1 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> + tt.return %2 : tensor<256xi32, #blocked1> +} +} // end module + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: local_alloc_nofold1 + tt.func @local_alloc_nofold1(%arg0: tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> { + // CHECK: %[[ARG:.+]] = triton_gpu.local_alloc + // CHECK-NEXT: %[[ARG2:.+]] = triton_gpu.local_load %[[ARG]] + // CHECK-NEXT: %[[ARG3:.+]] = triton_gpu.local_alloc %[[ARG2]] + // CHECK-NEXT: tt.return %[[ARG3]] + %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #blocked> + %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + tt.return %2 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + } +} // end module + + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: local_alloc_nofold2 + tt.func @local_alloc_nofold2(%arg0: tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared1, #triton_gpu.shared_memory> { + // CHECK: %[[ARG:.+]] = triton_gpu.local_alloc + // CHECK-NEXT: %[[ARG2:.+]] = triton_gpu.local_load %[[ARG]] + // CHECK-NEXT: %[[ARG3:.+]] = triton_gpu.local_alloc %[[ARG2]] + // CHECK-NEXT: tt.return %[[ARG3]] + %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #blocked> + %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared1, #triton_gpu.shared_memory> + tt.return %2 : !tt.memdesc<16x16xf16, #shared1, #triton_gpu.shared_memory> + } +} // end module + + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + tt.func @local_alloc_fold(%arg0: tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> { + // CHECK-LABEL: local_alloc_fold + // CHECK-NEXT: %[[ARG:.+]] = triton_gpu.local_alloc + // CHECK-NEXT: tt.return %[[ARG]] + %0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #blocked> + %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + tt.return %2 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + } +} // end module diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index cf93c37b8..5d35f43e9 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -131,3 +131,32 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war tt.return } } + +// ----- + +// COM: Reproducer for issue #5122 +// CHECK-LABEL: @test_5122 +module { + tt.func public @test_5122(%arg0: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %0 = arith.cmpi sgt, %arg0, %c1_i32 : i32 + scf.if %0 { + %1 = scf.if %0 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %2 = arith.cmpi sgt, %1, %c1_i32 : i32 + %3 = scf.if %2 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %4 = scf.for %arg1 = %1 to %1 step %c1_i32 iter_args(%arg2 = %3) -> (i32) : i32 { + %5 = arith.addi %arg2, %c1_i32 : i32 + scf.yield %5 : i32 + } + } + tt.return + } +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index b951f5041..3b727b4e9 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1475,7 +1475,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %23 = tt.splat %arg0 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> %24 = tt.splat %arg1 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> %25:3 = scf.for %arg7 = %c0_i32 to %22 step %c1_i32 iter_args(%arg8 = %cst_1, %arg9 = %11, %arg10 = %20) -> (f32, tensor<128x32xi32, #blocked1>, tensor<32x128xi32, #blocked>) : i32 { - tt.print "a_offsets: " { hex = false } : %arg9 : tensor<128x32xi32, #blocked1> + tt.print "a_offsets: " { hex = false, isSigned = array } : %arg9 : tensor<128x32xi32, #blocked1> %27 = tt.addptr %23, %arg9 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %28 = triton_gpu.convert_layout %27 : tensor<128x32x!tt.ptr, #blocked1> -> tensor<128x32x!tt.ptr, #blocked4> %29 = tt.load %28 : tensor<128x32x!tt.ptr, #blocked4> @@ -1936,7 +1936,8 @@ tt.func public @yield_outside_loop2(%arg0: i32, %arg1: i32) -> (i32, i32) { // ----- -// Check that we handle corner cases when hoisting convert on top of extf. For complex slices we may hoist convert on top of extf while the source of extf has multiple uses in the slice. +// Check that we handle corner cases when hoisting conversions on top of extf because conversion operations on a smaller type are faster. +// For complex slices we may hoist convert on top of extf while the source of extf has multiple uses in the slice. // In this case we want to make sure we don't replace other uses of extf source. #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -1946,7 +1947,7 @@ tt.func public @yield_outside_loop2(%arg0: i32, %arg1: i32) -> (i32, i32) { #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> #shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { -// CHECK: [[$BLOCKED:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: [[$BLOCKED:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: [[$MMA:#.*]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> // CHECK-LABEL: @hoist_convert_above_extf_and_remat @@ -2014,8 +2015,10 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %27 = tt.load %26 : tensor<1x256x!tt.ptr, #blocked2> %28 = tt.broadcast %27 : tensor<1x256xf16, #blocked2> -> tensor<32x256xf16, #blocked2> %29 = arith.addf %20, %28 : tensor<32x256xf16, #blocked2> -// CHECK: %[[C:.+]] = triton_gpu.convert_layout %29 : tensor<32x256xf16, [[$MMA]]> -> tensor<32x256xf16, [[$BLOCKED]]> -// CHECK: arith.extf %[[C]] : tensor<32x256xf16, [[$BLOCKED]]> to tensor<32x256xf32, [[$BLOCKED]]> +// CHECK: %[[A:.+]] = triton_gpu.convert_layout {{.*}} : tensor<1x256xf16, [[$BLOCKED]]> -> tensor<1x256xf16, [[$MMA]]> +// CHECK: %[[B:.+]] = tt.broadcast %[[A]] +// CHECK: %[[C:.+]] = arith.addf %[[B:.+]], {{.*}} +// CHECK: arith.extf %[[C]] : tensor<32x256xf16, [[$MMA]]> to tensor<32x256xf32, [[$MMA]]> %30 = arith.extf %29 : tensor<32x256xf16, #blocked2> to tensor<32x256xf32, #blocked2> %31 = "tt.reduce"(%30) <{axis = 1 : i32}> ({ ^bb0(%arg7: f32, %arg8: f32): @@ -2094,7 +2097,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.func public @reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked3> { // CHECK-NOT: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = triton_gpu.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3> tt.return %c : tensor<32xf32, #blocked3> } @@ -2113,7 +2116,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: tt.reshape // CHECK: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> tt.return %b : tensor<32xf32, #blocked2> } } @@ -2130,7 +2133,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-NOT: triton_gpu.convert_layout // CHECK: arith.truncf // CHECK: triton_gpu.convert_layout - %a = tt.reshape %arg0 {allow_reorder = true, efficient_layout} : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> + %a = tt.reshape %arg0 allow_reorder efficient_layout : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> %b = triton_gpu.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = arith.truncf %b : tensor<32xf32, #blocked2> to tensor<32xf16, #blocked2> tt.return %c : tensor<32xf16, #blocked2> @@ -2304,7 +2307,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : tt.func @assertop(%ptr: tensor<1024x!tt.ptr, #blocked>) { %0 = tt.load %ptr : tensor<1024x!tt.ptr, #blocked> %1 = triton_gpu.convert_layout %0 : tensor<1024xi1, #blocked> -> tensor<1024xi1, #blocked1> - tt.assert %1, "cond must be true ", "unknown", "unknown", 0 : tensor<1024xi1, #blocked1> + tt.assert %1, "cond must be true " : tensor<1024xi1, #blocked1> tt.return } } @@ -2316,11 +2319,11 @@ tt.func @assertop(%ptr: tensor<1024x!tt.ptr, #blocked>) { #blocked3 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @dot_wait_propagate - tt.func public @dot_wait_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<16x2xf32, #blocked> { + // CHECK-LABEL: @warp_group_dot_wait_propagate + tt.func public @warp_group_dot_wait_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<16x2xf32, #blocked> { // CHECK-NOT: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = triton_nvidia_gpu.dot_wait %a {pendings = 0 : i32} : tensor<16x2xf32, #blocked1> + %b = triton_nvidia_gpu.warp_group_dot_wait %a {pendings = 0 : i32} : tensor<16x2xf32, #blocked1> %c = triton_gpu.convert_layout %b : tensor<16x2xf32, #blocked1> -> tensor<16x2xf32, #blocked> tt.return %c : tensor<16x2xf32, #blocked> } @@ -2468,3 +2471,217 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %5#1, %5#2 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> } } + +// ----- + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked7 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // Regression test: + // The while loop use the result of the for loop as an argument. + // When propagating the layout, we should only "forward" propagate the layout to the argument and the result of the while loop + // CHECK-LABEL: @while_use_for + tt.func public @while_use_for(%arg0: !tt.ptr, %arg3: !tt.ptr, %arg6: i32) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %c0_i1 = arith.constant 1 : i1 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #blocked1> + %1000 = tt.splat %arg0 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked2> + %1001 = tt.splat %arg0 : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked1> + %1002 = tt.splat %arg0 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %1003 = tt.splat %arg3 : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %74 = tt.load %1000 : tensor<256x64x!tt.ptr, #blocked2> + %67:2 = scf.for %arg11 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg12 = %cst_0, %arg14 = %1001) -> (tensor<256x128xf32, #blocked1>, tensor<64x128x!tt.ptr, #blocked1>) : i32 { + %76 = tt.load %arg14 : tensor<64x128x!tt.ptr, #blocked1> + %78 = triton_gpu.convert_layout %74 : tensor<256x64xf16, #blocked2> -> tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> + %79 = triton_gpu.convert_layout %76 : tensor<64x128xf16, #blocked1> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> + %80 = triton_gpu.convert_layout %arg12 : tensor<256x128xf32, #blocked1> -> tensor<256x128xf32, #blocked7> + %81 = tt.dot %78, %79, %80, inputPrecision = tf32 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked7}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked7}>> -> tensor<256x128xf32, #blocked7> + %82 = triton_gpu.convert_layout %81 : tensor<256x128xf32, #blocked7> -> tensor<256x128xf32, #blocked1> + scf.yield %82, %arg14 : tensor<256x128xf32, #blocked1>, tensor<64x128x!tt.ptr, #blocked1> + } + %68:2 = scf.while (%arg11 = %67#0, %arg12 = %c1_i32) : (tensor<256x128xf32, #blocked1>, i32) -> (tensor<256x128xf32, #blocked1>, i32) { + scf.condition(%c0_i1) %arg11, %arg12 : tensor<256x128xf32, #blocked1>, i32 + } do { + ^bb0(%arg11: tensor<256x128xf32, #blocked1>, %arg12: i32): + %80 = triton_gpu.convert_layout %1003 : tensor<256x128x!tt.ptr, #blocked1> -> tensor<256x128x!tt.ptr, #blocked1> + %81 = tt.load %80 : tensor<256x128x!tt.ptr, #blocked1> + %82 = arith.addf %arg11, %81 : tensor<256x128xf32, #blocked1> + %83 = arith.addi %arg12, %c1_i32 : i32 + scf.yield %82, %83 : tensor<256x128xf32, #blocked1>, i32 + } + %69 = arith.truncf %68#0 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1> + %71 = triton_gpu.convert_layout %69 : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #blocked1> + tt.store %1002, %71 : tensor<256x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- +// Minimized reproducer for https://github.com/pytorch/pytorch/issues/130101 +// Check that backward rematerialization bails out when the same tensor requires two different layouts + +// CHECK-LABEL: double_remat +// CHECK: %[[res:.*]] = triton_gpu.convert_layout +// CHECK-NEXT: tt.return %[[res]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 2], order = [2, 1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:86", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @double_remat() -> tensor<1x256xi32, #blocked> attributes {noinline = false} { + %cst = arith.constant dense<0> : tensor<1x256xi32, #blocked1> + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> + %2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> + %3 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<1x2x128xi32, #blocked2> + %4 = tt.reshape %3 : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1> + %5 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<2x2x64xi32, #blocked2> + %6 = tt.reshape %5 : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1> + %7 = arith.cmpi ne, %4, %cst : tensor<1x256xi32, #blocked1> + %8 = arith.select %7, %6, %cst : tensor<1x256xi1, #blocked1>, tensor<1x256xi32, #blocked1> + %9 = triton_gpu.convert_layout %8 : tensor<1x256xi32, #blocked1> -> tensor<1x256xi32, #blocked> + tt.return %9 : tensor<1x256xi32, #blocked> + } +} + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @if_condition_not_dead_inside_loop + // CHECK: scf.if + // CHECK-NOT: convert_layout + tt.func public @if_condition_not_dead_inside_loop(%arg0: i32) -> (tensor<32xf32, #blocked>, tensor<32xf32, #blocked>) { + %true = arith.constant true + %cst = arith.constant dense<1.000000e+00> : tensor<32xf32, #blocked1> + %cst_0 = arith.constant dense<2.000000e+00> : tensor<32xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %1:3 = scf.for %arg10 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg1 = %cst, %arg3 = %cst_0, %arg4 = %true) -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, i1) : i32 { + %3:2 = scf.if %arg4 -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>) { + scf.yield %cst, %cst_0 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> + } else { + %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> + %5 = triton_gpu.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + %6 = arith.mulf %arg3, %5 : tensor<32xf32, #blocked> + scf.yield %4, %6 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked> + } + %119 = arith.cmpi eq, %arg10, %arg0 : i32 + scf.yield %3#0, %3#1, %119 : tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, i1 + } + %7 = triton_gpu.convert_layout %1#0 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> + tt.return %7, %1#1 : tensor<32xf32, #blocked>, tensor<32xf32, #blocked> + } +} + +// ----- +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @dot_wait + tt.func public @dot_wait(%arg0: tensor<64x64xf32, #mma>, %arg1: tensor<64x128xf32, #mma1>) -> (tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1>) { + %0:2 = triton_nvidia_gpu.warp_group_dot_wait %arg0, %arg1 {pendings = 0 : i32} : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> + tt.return %0#0, %0#1 : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> + // CHECK: %[[W:.+]]:2 = triton_nvidia_gpu.warp_group_dot_wait + // CHECK: tt.return %[[W]]#0, %[[W]]#1 : tensor<64x64xf32, #mma>, tensor<64x128xf32, #mma1> + } +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [1, 32, 1], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @split_propagation + // CHECK-SAME: (%[[ARG:.+]]: tensor<128x64x2xf32 + // CHECK: %[[S:.+]], %{{.+}} = tt.split %[[ARG]] + // CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[S]] + // CHECK: tt.return %[[C]] + tt.func public @split_propagation(%arg0: tensor<128x64x2xf32, #blocked>) -> tensor<128x64xf32, #blocked1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<128x64x2xf32, #blocked> -> tensor<128x64x2xf32, #blocked2> + %outLHS, %outRHS = tt.split %0 : tensor<128x64x2xf32, #blocked2> -> tensor<128x64xf32, #blocked1> + tt.return %outLHS : tensor<128x64xf32, #blocked1> + } +} + +// ----- + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#CL = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { + // CHECK-LABEL: matmul_add + tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %C : !tt.ptr) { + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %c_ptr_init = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr, #CL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #CL> + %cst = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL> + // CHECK: %[[T0:.*]] = tt.dot + // CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma> + %t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: scf.yield + scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL> + } + + // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked + tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr, #CL> + tt.return + } +} + +// ----- + +// Minimized reproducer for compiler crash during remove layouts conversions pass: +// If dot result transformed into tensor with shape smaller than one MFMA instruction size, it triggers various asserts. +// This is a smoke test that checks that compiler do not crash. +// +// CHECK-LABEL: small_tensor_mfma + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 1], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}> +#mma1 = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @small_tensor_mfma(%arg0: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + %cst_3 = arith.constant dense<1.230000e+02> : tensor<32x16xf32, #mma1> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %1 = triton_gpu.convert_layout %0 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> + %2 = "tt.reduce" (%1) ({ + ^bb0(%arg1: f32, %arg2: f32): + %3 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %3 : f32 + }) {axis = 1 : i32} : (tensor<32x32xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %4 = tt.expand_dims %2 {axis = 1 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xf32, #blocked> + %5 = tt.broadcast %4 : tensor<32x1xf32, #blocked> -> tensor<32x16xf32, #blocked> + %6 = triton_gpu.convert_layout %5 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> + %7 = tt.dot %cst_2, %6, %cst_3 : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<32x16xf32, #mma1> + %addr = tt.splat %arg0 : !tt.ptr -> tensor<32x16x!tt.ptr, #blocked> + %8 = triton_gpu.convert_layout %7 : tensor<32x16xf32, #mma1> -> tensor<32x16xf32, #blocked> + tt.store %addr, %8 : tensor<32x16x!tt.ptr, #blocked> + tt.return + } +} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index ed24e5f58..82fc1ddf7 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -165,10 +165,10 @@ tt.func @update_kwidth_slice( module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A // CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !tt.memdesc<128x64xf16, #shared1> - %r = tt.dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } @@ -181,10 +181,10 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdes module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A_fp8 // CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1> - %r = tt.dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> + %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index bf15adbdb..9ed3646d9 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -12,7 +12,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %0 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> %1 = triton_gpu.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !tt.memdesc<128x64xf16, #shared1> // CHECK: triton_nvidia_gpu.fence_async_shared - %2 = tt.dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = triton_nvidia_gpu.warp_group_dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> tt.return } } @@ -36,10 +36,10 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: triton_nvidia_gpu.fence_async_shared // CHECK: scf.for // CHECK-NOT: triton_nvidia_gpu.fence_async_shared - // CHECK: tt.dot + // CHECK: triton_nvidia_gpu.warp_group_dot scf.for %iv0 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { scf.for %iv1 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { - %2 = tt.dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = triton_nvidia_gpu.warp_group_dot %0, %1, %cst : !tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> } } tt.return diff --git a/test/TritonGPU/invalid-attributes.mlir b/test/TritonGPU/invalid-attributes.mlir index abf18381f..c8b3c2ef6 100644 --- a/test/TritonGPU/invalid-attributes.mlir +++ b/test/TritonGPU/invalid-attributes.mlir @@ -36,16 +36,26 @@ // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter supports only 16 for WMMA parent}} -#wmma = #triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}> +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter supports only 16 for WMMA parent}} -#wmma = #triton_gpu.amd_wmma<{warpsPerCTA = [1, 4]}> +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}> +// ----- +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 16}> + +// ----- +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}} +#wmma = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}> +#dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #wmma, kWidth = 4}> + // ----- // expected-error@+1 {{major version must be in the [0, 3] range}} diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index 53cfd3d90..f9e265f3e 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -71,6 +71,19 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // ----- +#mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32>) { + // expected-error@+1 {{miss encoding of C operand}} + %D = tt.dot %A, %B, %C : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32> + tt.return + } +} + +// ----- + #mma0 = #triton_gpu.nvidia_mma<{versionMajor=2, warpsPerCTA=[1,1]}> #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=1}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> diff --git a/test/TritonGPU/loop-pipeline-cuda.mlir b/test/TritonGPU/loop-pipeline-cuda.mlir new file mode 100644 index 000000000..b6610c0a6 --- /dev/null +++ b/test/TritonGPU/loop-pipeline-cuda.mlir @@ -0,0 +1,199 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// CHECK-LABEL: tt.func @load_two_users + tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} + // CHECK: scf.for + // CHECK: tt.dot + // CHECK: tt.dot + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} + // CHECK: scf.yield + // CHECK: triton_gpu.async_wait {num = 0 : i32} + + %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } + tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- + +// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.get_program_id y : i32 + %3 = tt.load %arg3 : !tt.ptr + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> + %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> + %11 = arith.extsi %arg5 : i32 to i64 + %12 = tt.splat %11 : i64 -> tensor<64x1xi64, #blocked> + %13 = arith.muli %10, %12 : tensor<64x1xi64, #blocked> + %14 = arith.muli %2, %arg5 : i32 + %15 = arith.extsi %14 : i32 to i64 + %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> + %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> + %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> + %25 = arith.muli %21, %23 : tensor<1x64xi32, #blocked1> + %26 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> + %27 = arith.extsi %24 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> + %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> + %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> + %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> + %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> + %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> + %36 = tt.splat %11 : i64 -> tensor<32x1xi64, #blocked1> + %37 = arith.muli %35, %36 : tensor<32x1xi64, #blocked1> + %38 = tt.splat %15 : i64 -> tensor<32x1xi64, #blocked1> + %39 = arith.addi %37, %38 : tensor<32x1xi64, #blocked1> + %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> + %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> + %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> + %50 = arith.muli %46, %48 : tensor<1x32xi32, #blocked> + %51 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %52 = arith.extsi %49 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %53 = arith.extsi %50 : tensor<1x32xi32, #blocked> to tensor<1x32xi64, #blocked> + %54 = tt.broadcast %52 : tensor<1x32xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %55 = arith.addi %51, %54 : tensor<32x32xi64, #blocked1> + %56 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %57 = tt.addptr %56, %30 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi64, #blocked> + %58 = tt.splat %arg1 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked1> + %59 = tt.addptr %58, %42 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi64, #blocked1> + %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %55 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi64, #blocked1> + %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> + %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { + %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> + %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> + %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> + %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> + %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + scf.yield %79 : tensor<64x32xf32, #mma> + } + %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> + %65 = tt.broadcast %53 : tensor<1x32xi64, #blocked> -> tensor<64x32xi64, #blocked> + %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> + %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> + %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> + %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> + tt.return + } +} // end module + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// CHECK-LABEL: @matmul_tma +// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x128x64xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x64x256xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3xi64, #{{.+}}, #triton_gpu.shared_memory, mutable> +// CHECK-COUNT-3: triton_nvidia_gpu.init_barrier +// CHECK-COUNT-4: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.wait_barrier +// CHECK-NOT: triton_nvidia_gpu.wait_barrier +// CHECK-COUNT-2: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK: scf.yield + tt.func public @matmul_tma(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x256xf32, #mma> { + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { + %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.ptr -> tensor<128x64xf16, #blocked> + %2 = triton_gpu.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.ptr -> tensor<64x256xf16, #blocked1> + %4 = triton_gpu.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> + %5 = triton_nvidia_gpu.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + %6 = arith.addi %arg5, %c64_i32 : i32 + scf.yield %5, %6 : tensor<128x256xf32, #mma>, i32 + } + tt.return %0#0 : tensor<128x256xf32, #mma> + } +} diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir new file mode 100644 index 000000000..3abcc581b --- /dev/null +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -0,0 +1,265 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @load_two_users + tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: triton_gpu.local_store + // CHECK: scf.for + // CHECK: tt.dot + // CHECK: tt.dot + // CHECK: tt.load + // CHECK: triton_gpu.local_store + // CHECK: scf.yield + %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { + %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> + %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> + %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> + scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } + tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> + } +} + +// ----- + +// CHECK-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de +// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> + %c64_i32 = arith.constant 64 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c64_i32 : i32 + %2 = tt.get_program_id y : i32 + %3 = tt.load %arg3 : !tt.ptr + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> + %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> + %11 = arith.extsi %arg5 : i32 to i64 + %12 = tt.splat %11 : i64 -> tensor<64x1xi64, #blocked> + %13 = arith.muli %10, %12 : tensor<64x1xi64, #blocked> + %14 = arith.muli %2, %arg5 : i32 + %15 = arith.extsi %14 : i32 to i64 + %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> + %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> + %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> + %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> + %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> + %25 = arith.muli %21, %23 : tensor<1x64xi32, #blocked1> + %26 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> + %27 = arith.extsi %24 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> + %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> + %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> + %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> + %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> + %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> + %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> + %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> + %36 = tt.splat %11 : i64 -> tensor<32x1xi64, #blocked1> + %37 = arith.muli %35, %36 : tensor<32x1xi64, #blocked1> + %38 = tt.splat %15 : i64 -> tensor<32x1xi64, #blocked1> + %39 = arith.addi %37, %38 : tensor<32x1xi64, #blocked1> + %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> + %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> + %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> + %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> + %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> + %50 = arith.muli %46, %48 : tensor<1x32xi32, #blocked> + %51 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %52 = arith.extsi %49 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> + %53 = arith.extsi %50 : tensor<1x32xi32, #blocked> to tensor<1x32xi64, #blocked> + %54 = tt.broadcast %52 : tensor<1x32xi64, #blocked1> -> tensor<32x32xi64, #blocked1> + %55 = arith.addi %51, %54 : tensor<32x32xi64, #blocked1> + %56 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %57 = tt.addptr %56, %30 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi64, #blocked> + %58 = tt.splat %arg1 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked1> + %59 = tt.addptr %58, %42 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi64, #blocked1> + %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %55 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi64, #blocked1> + %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> + %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { + %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> + %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> + %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> + %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> + scf.yield %79 : tensor<64x32xf32, #mma> + } + %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> + %65 = tt.broadcast %53 : tensor<1x32xi64, #blocked> -> tensor<64x32xi64, #blocked> + %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> + %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> + %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> + %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> + tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> + tt.return + } +} // end module + +// ----- + +// CHECK-LABEL: tt.func public @add_barrier_kernel +// CHECK: tt.load +// CHECK: scf.for +// CHECK: gpu.barrier +// CHECK: tt.store +// CHECK: tt.load +// CHECK: scf.yield +// CHECK: gpu.barrier +// CHECK: tt.store + +#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + tt.func public @add_barrier_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %cval_f32 = arith.constant dense<0.3> : tensor<1024xf32, #blocked> + %c1016800_i32 = arith.constant 1016800 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + scf.for %arg4 = %c0_i32 to %arg3 step %c1024_i32 : i32 { + %7 = arith.addi %1, %arg4 : i32 + %8 = tt.splat %7 : i32 -> tensor<1024xi32, #blocked> + %9 = arith.addi %8, %2 : tensor<1024xi32, #blocked> + %11 = tt.addptr %4, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + gpu.barrier + %16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %15 = arith.addf %12, %cval_f32 : tensor<1024xf32, #blocked> + tt.store %16, %15 : tensor<1024x!tt.ptr, #blocked> + } {tt.num_stages = 2 : i32} + tt.return + } +} // end module + +// ----- + +// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] +// CHECK: #triton_gpu.shared<{{.*}} order = [2, 1, 0] +// CHECK-NOT: #triton_gpu.shared<{{.*}} order = [2, 0, 1] + +// CHECK-LABEL: tt.func public @slowest_dim_is_batch +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @slowest_dim_is_batch(%arg0: tensor<1x512x!tt.ptr, #blocked2>, %arg1: tensor<64x8x32x!tt.ptr, #blocked1>, %arg2: tensor<64x1x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<64x1x32xf32, #blocked> + %cst_0 = arith.constant dense<512> : tensor<1x512xi32, #blocked2> + %cst_1 = arith.constant dense<128> : tensor<64x8x32xi32, #blocked1> + %c1_i32 = arith.constant 1 : i32 + %c5_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %arg0, %arg10 = %arg1) -> (tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1>) : i32 { + %39 = tt.load %arg9 : tensor<1x512x!tt.ptr, #blocked2> + %40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr, #blocked1> + %41 = tt.reshape %39 allow_reorder : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> + %43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> + %46 = tt.addptr %arg9, %cst_0 : tensor<1x512x!tt.ptr, #blocked2>, tensor<1x512xi32, #blocked2> + %47 = tt.addptr %arg10, %cst_1 : tensor<64x8x32x!tt.ptr, #blocked1>, tensor<64x8x32xi32, #blocked1> + scf.yield %45, %46, %47 : tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1> + } + tt.store %arg2, %33#0 : tensor<64x1x32x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// Check that the stream pipeliner updates the resulting memory layout of transpose ops to mutable if immutable local buffers are replaced +// CHECK-LABEL: loop_with_dot_and_transpose +// CHECK: triton_gpu.local_alloc {{.*}}, mutable> +// CHECK: tt.trans {{.*}}, mutable> -> {{.*}}, mutable> + +#blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1201", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @loop_with_dot_and_transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i32, %arg4: tensor<32x32x!tt.ptr, #blocked1>, %arg5: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> + %0 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %cst) -> (tensor<32x32xf32, #blocked>) : i32 { + %2 = tt.load %arg4 : tensor<32x32x!tt.ptr, #blocked1> + %3 = triton_gpu.local_alloc %2 : (tensor<32x32xf32, #blocked1>) -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory> + %4 = tt.trans %3 {order = array} : !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory> + %5 = triton_gpu.local_load %4 : !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %6 = triton_gpu.convert_layout %2 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %7 = tt.dot %6, %5, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf32, #blocked> + scf.yield %7 : tensor<32x32xf32, #blocked> + } + tt.store %arg5, %0 : tensor<32x32x!tt.ptr, #blocked> + tt.return + } +} diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 48fd5f22e..d391be688 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -1,4 +1,5 @@ // RUN: triton-opt %s -split-input-file -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 | FileCheck %s --check-prefix=CHECK-NOCANON // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 @@ -18,11 +19,11 @@ // CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !tt.memdesc<2x128x32xf16, #shared, mutable> -> !tt.memdesc<128x32xf16, #shared, mutable> -// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, mutable> +// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> +// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #triton_gpu.shared_memory, mutable> // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK-DAG: %[[BSUB:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, mutable> +// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] @@ -303,8 +304,8 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, //// C-HECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_0]] // // C-HECK: %[[MBARRIER_AB_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] // // C-HECK: triton_nvidia_gpu.mbarrier_wait %[[MBARRIER_AB_ITER]], {{.*}} -// // C-HECK: triton_nvidia_gpu.dot_async %[[arg_a0]], %[[arg_b0]], {{.*}} -// // C-HECK: triton_nvidia_gpu.dot_wait {{.*}} +// // C-HECK: triton_nvidia_gpu.warp_group_dot %[[arg_a0]], %[[arg_b0]], {{.*}} +// // C-HECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} // // C-HECK: %[[EMPTY_BARRIER_B_ITER_ARRIVE:.*]] = triton_nvidia_gpu.extract_mbarrier %[[EMPTY_BARRIER_B]][{{.*}}] // // C-HECK: triton_nvidia_gpu.mbarrier_arrive %[[EMPTY_BARRIER_B_ITER_ARRIVE]] // // C-HECK: %[[MBARRIER_AB_NEXT_ITER:.*]] = triton_nvidia_gpu.extract_mbarrier %[[MBARRIER_AB]][{{.*}}] @@ -332,9 +333,9 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // %a = tt.load %a_tileptr : !tt.ptr, 1> // %b = tt.load %b_tileptr : !tt.ptr, 1> // -// %sa = triton_gpu.local_alloc %a : (tensor<128x32xf16, #BA>) -> !tt.memdesc<128x32xf16, #SA> -// %sb = triton_gpu.local_alloc %b : (tensor<32x128xf16, #BB>) -> !tt.memdesc<32x128xf16, #SB> -// %c = triton_gpu_nvidia.dot_async %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> +// %sa = triton_gpu.local_alloc %a : (tensor<128x32xf16, #BA>) -> !tt.memdesc<128x32xf16, #SA, #triton_gpu.shared_memory> +// %sb = triton_gpu.local_alloc %b : (tensor<32x128xf16, #BB>) -> !tt.memdesc<32x128xf16, #SB, #triton_gpu.shared_memory> +// %c = triton_nvidia_gpu.warp_group_dot %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> // // %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> // %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr, 1> @@ -384,21 +385,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: scf.for // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} - // CHECK: triton_nvidia_gpu.dot_async - // CHECK-NEXT: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 0 : i32} - // CHECK: triton_nvidia_gpu.dot_async + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.async_commit_group // CHECK: scf.yield %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>) : i32 { %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %21 = tt.dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> + %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %25 = tt.dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -431,22 +432,22 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> // CHECK: scf.for // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} - // CHECK: triton_nvidia_gpu.dot_async - // CHECK-NEXT: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 1 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 1 : i32} // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.async_commit_group // CHECK: scf.if - // CHECK: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} // CHECK: arith.mulf // CHECK: scf.yield // CHECK: scf.yield - // CHECK: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -501,24 +502,24 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> // CHECK: %[[ALLOC1:.+]] = triton_gpu.local_alloc // CHECK: %[[ALLOC2:.+]] = triton_gpu.local_alloc // CHECK: %[[R:.+]]:{{.+}} = scf.for - // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.dot_async{{.*}} + // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} // CHECK: %[[TRANS:.+]] = tt.trans{{.*}} : !tt.memdesc - // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.dot_async{{.*}} %[[TRANS]] - // CHECK: triton_nvidia_gpu.dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} + // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} %[[TRANS]] + // CHECK: triton_nvidia_gpu.warp_group_dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} // CHECK: scf.yield - // CHECK: %{{.*}}:2 = triton_nvidia_gpu.dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> + // CHECK: %{{.*}}:2 = triton_nvidia_gpu.warp_group_dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>) : i32 { - %21 = tt.dot %19, %20, %arg6 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %arg6 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> - %25 = tt.dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -576,13 +577,13 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %22:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %21) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { %35 = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked> %36 = tt.load %arg6 : tensor<64x256x!tt.ptr, #blocked1> - %37 = triton_gpu.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !tt.memdesc<128x64xf8E5M2, #shared> - %38 = triton_gpu.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !tt.memdesc<64x256xf8E5M2, #shared1> + %37 = triton_gpu.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !tt.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> + %38 = triton_gpu.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !tt.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_alloc // CHECK: scf.for - // CHECK: triton_nvidia_gpu.dot_async - // CHECK-NEXT: triton_nvidia_gpu.dot_wait - %39 = tt.dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x64xf8E5M2, #shared> * !tt.memdesc<64x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait + %39 = triton_nvidia_gpu.warp_group_dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> %40 = tt.addptr %arg5, %cst_6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> %41 = tt.addptr %arg6, %cst_5 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> scf.yield %39, %40, %41 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> @@ -656,35 +657,35 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> // CHECK: %[[LOOP:[^ :]+]]{{.*}} scf.for {{.*}} iter_args(%[[PREV_DOT2:[^ ]+]] - // CHECK-NOT: triton_nvidia_gpu.dot_wait - // CHECK: %[[DOT0:.+]] = triton_nvidia_gpu.dot_async - // CHECK-NOT: triton_nvidia_gpu.dot_wait - // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.dot_async - // CHECK-NEXT: triton_nvidia_gpu.dot_wait + // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK: %[[DOT0:.+]] = triton_nvidia_gpu.warp_group_dot + // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait + // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait // CHECK-DAG-SAME: %[[DOT0]] // CHECK-DAG-SAME: %[[DOT1]] // CHECK-DAG-SAME: %[[PREV_DOT2]] // CHECK-SAME: {pendings = 0 : i32} - // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.dot_async - // CHECK-NOT: triton_nvidia_gpu.dot_wait + // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.warp_group_dot + // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait // CHECK: scf.yield %[[DOT2]] - // CHECK: triton_nvidia_gpu.dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} %17:4 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%prev_dot2 = %cst_3, %arg5 = %16, %prev_dot1 = %cst_2, %prev_dot0 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { // This one can be async. - %dot0 = tt.dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %dot0 = triton_nvidia_gpu.warp_group_dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> // This can't be async because its result is modified before it's yielded. - %dot1 = tt.dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %dot1 = triton_nvidia_gpu.warp_group_dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %dot1.1 = arith.addf %dot1, %dot1 : tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> + %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = tt.dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } @@ -711,6 +712,32 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } } +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tma_multiple_store_pipeline + tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %c0_i32 = arith.constant 0 : i32 + // CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + // CHECK: scf.for + scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { + %1 = arith.divsi %arg4, %arg2 : i32 + %2 = arith.divsi %arg2, %arg4 : i32 + // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: triton_gpu.local_store %{{.+}}, %[[ALLOC]] + // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared + // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] + // CHECK: triton_nvidia_gpu.async_tma_store_wait {pendings = 0 : i32} + // CHECK-NEXT: triton_gpu.local_store %{{.+}}, %[[ALLOC]] + // CHECK-NEXT: triton_nvidia_gpu.fence_async_shared + // CHECK-NEXT: triton_nvidia_gpu.async_tma_copy_local_to_global %{{.*}} %[[ALLOC]] + tt.experimental_descriptor_store %arg1[%1], %arg0 : !tt.ptr, tensor<1xf32, #blocked> + tt.experimental_descriptor_store %arg1[%2], %arg0 : !tt.ptr, tensor<1xf32, #blocked> + } + tt.return + } +} + // ----- @@ -766,3 +793,141 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.return } } + +// ----- + +// Pipeline the if ops at the beginning and the end of the loop +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // COMMON-LABEL: dot_prologue_epilogue + // COMMON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} + tt.func @dot_prologue_epilogue(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // COMMON: %[[C0:.*]] = arith.constant 0 : i32 + // COMMON: scf.for %[[IND_VAR:.*]] = %[[C0]] + // COMMON-NOT: load + // COMMON: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] + // COMMON: scf.if %[[CND]] + // COMMON: dot + // COMMON: scf.if %[[CND]] + // COMMON: arith.mulf + // COMMON: scf.yield + // COMMON-NOT: tt.addptr + // COMMON: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { + %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %inc_ptr = scf.if %cnd -> tensor<64x16x!tt.ptr, #blocked> { + %ptr = tt.addptr %arg5, %inc : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %ptr : tensor<64x16x!tt.ptr, #blocked> + } else { + scf.yield %arg5 : tensor<64x16x!tt.ptr, #blocked> + } + %18 = tt.load %inc_ptr : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> + scf.yield %acc_zero : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + %22 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %23 = tt.addptr %arg6, %cst2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %acc_, %22, %23 : tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1> + } + tt.return %17#0 : tensor<128x16xf32, #mma1> + } +} + +// ----- + +// Verify that uses of the ops scheduled in partucular place of the loop (like epilogue if) are correctly scheduled too. +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-NOCANON-LABEL: pipeline_downstream_dependencies + // CHECK-NOCANON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} + tt.func @pipeline_downstream_dependencies(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst1 = arith.constant dense<1> : tensor<64x16xi32, #blocked> + %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK-NOCANON: %[[C0:.*]] = arith.constant 0 : i32 + // CHECK-NOCANON: scf.for %[[IND_VAR:.*]] = %[[C0]] + // CHECK-NOCANON-NOT load + // CHECK-NOCANON: dot + // CHECK-NOCANON: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] + // CHECK-NOCANON: %[[IFRET:.*]]:2 = scf.if %[[CND]] + // CHECK-NOCANON: arith.mulf + // CHECK-NOCANON: scf.yield + // CHECK-NOCANON: tt.addptr {{.*}}, %[[IFRET]]#1 + // CHECK-NOCANON: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { + %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> + %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = triton_nvidia_gpu.warp_group_dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %if_ret:2 = scf.if %cnd -> (tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked>) { + %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> + scf.yield %acc_zero, %cst : tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked> + } else { + scf.yield %acc, %cst1 : tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked> + } + %22 = tt.addptr %arg5, %if_ret#1 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + %23 = tt.addptr %arg6, %cst2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + scf.yield %if_ret#0, %22, %23 : tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1> + } + tt.return %17#0 : tensor<128x16xf32, #mma1> + } +} diff --git a/test/TritonGPU/loop-pipeline-indirect-load.mlir b/test/TritonGPU/loop-pipeline-indirect-load.mlir new file mode 100644 index 000000000..74794b949 --- /dev/null +++ b/test/TritonGPU/loop-pipeline-indirect-load.mlir @@ -0,0 +1,90 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=2 | FileCheck %s +// CHECK-LABEL: @indirect_load_two_stages +// CHECK: scf.for +// CHECK: tt.dot +// CHECK: tt.load +// CHECK: async_copy_global_to_local +// CHECK: async_copy_global_to_local + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @indirect_load_two_stages(%arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked> + + %0 = tt.get_program_id y : i32 + %1 = tt.addptr %arg3, %0 : !tt.ptr, i32 + %2 = tt.load %1 : !tt.ptr + + %7 = tt.get_program_id x : i32 + %8 = arith.muli %7, %c16_i32 : i32 + %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %15 = tt.splat %8 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %18 = arith.addi %15, %10 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + + %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %34 = arith.extsi %arg12 : i32 to i64 + %35 = arith.muli %2, %34 : i64 + %36 = tt.addptr %arg2, %35 : !tt.ptr, i64 + + %47 = tt.splat %arg4 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %48 = tt.addptr %47, %20 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + + %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %61 = arith.extsi %59 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %63 = tt.expand_dims %61 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi64, #blocked3> + + %85 = arith.extsi %22 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> to tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %107 = tt.splat %36 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked3> + %108 = tt.splat %34 : i64 -> tensor<32x1xi64, #blocked3> + %109 = tt.broadcast %63 : tensor<1x128xi64, #blocked3> -> tensor<32x128xi64, #blocked3> + + %101 = tt.splat %arg5 : !tt.ptr -> tensor<16x32x!tt.ptr, #blocked1> + %111:1 = scf.for %arg28 = %arg18 to %arg19 step %c32_i32 iter_args(%arg29 = %cst) -> (tensor<16x128xf32, #blocked>) : i32 { + %129 = tt.splat %arg28 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %160 = tt.addptr %48, %129 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %161 = tt.load %160 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %162 = tt.expand_dims %161 {axis = 0 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi64, #blocked1> + %163 = tt.broadcast %162 : tensor<1x32xi64, #blocked1> -> tensor<16x32xi64, #blocked1> + %182 = tt.addptr %101, %163 : tensor<16x32x!tt.ptr, #blocked1>, tensor<16x32xi64, #blocked1> + %183 = tt.load %182 : tensor<16x32x!tt.ptr, #blocked1> + + %197 = arith.extsi %arg28 : i32 to i64 + %198 = tt.splat %197 : i64 -> tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %199 = arith.addi %198, %85 : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %200 = tt.expand_dims %199 {axis = 1 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi64, #blocked3> + %201 = arith.muli %200, %108 : tensor<32x1xi64, #blocked3> + %202 = tt.broadcast %201 : tensor<32x1xi64, #blocked3> -> tensor<32x128xi64, #blocked3> + %203 = arith.addi %202, %109 : tensor<32x128xi64, #blocked3> + %204 = tt.addptr %107, %203 : tensor<32x128x!tt.ptr, #blocked3>, tensor<32x128xi64, #blocked3> + %209 = tt.load %204 : tensor<32x128x!tt.ptr, #blocked3> + + %210 = triton_gpu.convert_layout %183 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %211 = triton_gpu.convert_layout %209 : tensor<32x128xf32, #blocked3> -> tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %212 = tt.dot %210, %211, %arg29 : tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x128xf32, #blocked> + scf.yield %212 : tensor<16x128xf32, #blocked> + } + %112 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<16x1xi32, #blocked3> + %113 = tt.splat %2 : i64 -> tensor<16x1xi64, #blocked3> + %114 = arith.extsi %112 : tensor<16x1xi32, #blocked3> to tensor<16x1xi64, #blocked3> + %115 = arith.addi %113, %114 : tensor<16x1xi64, #blocked3> + %116 = arith.extsi %arg17 : i32 to i64 + %117 = tt.splat %116 : i64 -> tensor<16x1xi64, #blocked3> + %118 = arith.muli %115, %117 : tensor<16x1xi64, #blocked3> + %119 = tt.expand_dims %59 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> + %120 = tt.broadcast %118 : tensor<16x1xi64, #blocked3> -> tensor<16x128xi64, #blocked3> + %121 = arith.extsi %119 : tensor<1x128xi32, #blocked3> to tensor<1x128xi64, #blocked3> + %122 = tt.broadcast %121 : tensor<1x128xi64, #blocked3> -> tensor<16x128xi64, #blocked3> + %123 = arith.addi %120, %122 : tensor<16x128xi64, #blocked3> + %124 = tt.splat %arg7 : !tt.ptr -> tensor<16x128x!tt.ptr, #blocked3> + %125 = tt.addptr %124, %123 : tensor<16x128x!tt.ptr, #blocked3>, tensor<16x128xi64, #blocked3> + %128 = triton_gpu.convert_layout %111#0 : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #blocked3> + tt.store %125, %128 : tensor<16x128x!tt.ptr, #blocked3> + tt.return + } +} diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 20d8093b0..3d215a635 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1,5 +1,5 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s -// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 | FileCheck %s --check-prefix=CHECK-NOCANON +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 @@ -55,7 +55,51 @@ // CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]], // CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { + +// AMD-LABEL: tt.func @matmul_loop +// AMD-DAG: %[[CM1:.*]] = arith.constant -1 : index +// AMD-DAG: %[[C1:.*]] = arith.constant 1 : index +// AMD-DAG: %[[C0:.*]] = arith.constant 0 : index +// AMD: %[[UB1:.*]] = arith.subi %[[UB:.*]], %arg2 : index +// AMD: %[[FOR:.*]]:6 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UB1]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}) +// AMD: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG10]] +// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[MULF_34:.*]] = arith.mulf %[[LOCAL_LOAD_33]], %{{.*}} +// AMD: %[[DOT_35:.*]] = tt.dot %[[LOCAL_LOAD_32]], %[[MULF_34]], %[[ARG8]] +// AMD: %[[ADDPTR_36:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[ADDPTR_37:.*]] = tt.addptr %[[ARG7]], %{{.*}} +// AMD: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_36]] +// AMD: %[[LOAD_39:.*]] = tt.load %[[ADDPTR_37]] +// AMD: %[[ADDI_40:.*]] = arith.addi %[[ARG9]], %{{.*}} +// AMD: %[[CMPI_41:.*]] = arith.cmpi slt, %[[ADDI_40]], %{{.*}} +// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_41]], %[[ADDI_40]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_42]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_38]], %[[MEMDESC_SUBVIEW_43]] +// AMD: %[[MEMDESC_SUBVIEW_44:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_42]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_39]], %[[MEMDESC_SUBVIEW_44]] +// AMD: scf.yield %[[ADDPTR_36]], %[[ADDPTR_37]], %[[DOT_35]], %[[SELECT_42]], %[[MEMDESC_SUBVIEW_43]], %[[MEMDESC_SUBVIEW_44]] +// AMD: } +// AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]] +// AMD: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[C1]], %[[CM1]] +// AMD: %[[SUBI_23:.*]] = arith.subi %[[UB]], %[[LB]] +// AMD: %[[ADDI_24:.*]] = arith.addi %[[SUBI_23]], %[[STEP]] +// AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]] +// AMD: %[[DIVSI_26:.*]] = arith.divsi %[[ADDI_25]], %[[STEP]] +// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %[[DIVSI_26]], %{{.*}} +// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}} +// AMD: %[[IF_31:.*]] = scf.if %[[CMPI_27]] +// AMD: %[[DOT_33:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %{{.*}}#2 +// AMD: scf.yield %[[DOT_33]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#2 +// AMD: } +// AMD: %[[SELECT_32:.*]] = arith.select %[[CMPI_27]], %[[IF_31]], %{{.*}}#2 +// AMD: triton_gpu.local_dealloc %{{.*}} +// AMD: triton_gpu.local_dealloc %{{.*}} + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -146,6 +190,33 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] // CHECK: triton_gpu.async_copy_global_to_local // CHECK scf.yield + +// AMD-LABEL: tt.func @matmul_loop_nested +// AMD: scf.for +// AMD-COUNT-2: triton_gpu.local_alloc +// AMD-COUNT-2: tt.load +// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: %[[FOR:.*]]:6 = scf.for +// AMD-COUNT-2: triton_gpu.local_load +// AMD: tt.dot +// AMD-COUNT-2: tt.addptr +// AMD-COUNT-2: tt.load +// AMD: %[[SUBVIEW0:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW0]] +// AMD: %[[SUBVIEW1:.*]] = triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] +// AMD: scf.yield +// AMD-COUNT-2: triton_gpu.local_load +// AMD: %[[IF1:.*]] = scf.if +// AMD: %[[DOT1:.*]] = tt.dot +// AMD: scf.yield %[[DOT1]] +// AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[IF1]], %[[FOR]]#2 +// AMD-COUNT-2: triton_gpu.local_dealloc +// AMD: scf.yield %[[SEL1]] + tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{ @@ -216,6 +287,30 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[NEXT_B:.*]] = triton_gpu.memdesc_subview %{{.+}}[%[[EXT_IDX_3]] // CHECK-DAG: triton_gpu.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_B]] + +// AMD-LABEL: tt.func @matmul_loop_single_pipeline +// AMD: %[[LOAD_10:.*]] = tt.load %{{.*}} +// AMD: %[[CONVERT_LAYOUT_11:.*]] = triton_gpu.convert_layout %[[LOAD_10]] +// AMD: %[[LOCAL_ALLOC_12:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_13:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_13]] +// AMD: %[[LOAD_15:.*]] = tt.load %{{.*}}, %[[SPLAT_14]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_16:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_15]], %[[MEMDESC_SUBVIEW_16]] +// AMD: %[[SUBI_17:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:4 = scf.for %[[ARG5:.*]] = %{{.*}} to %[[SUBI_17]] step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[MEMDESC_SUBVIEW_16]]) +// AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %[[ARG9]] +// AMD: %[[DOT_31:.*]] = tt.dot %[[CONVERT_LAYOUT_11]], %[[LOCAL_LOAD_30]], %[[ARG7]] +// AMD: %[[ADDPTR_32:.*]] = tt.addptr %[[ARG6]], %{{.*}} +// AMD: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_32]] +// AMD: %[[ADDI_34:.*]] = arith.addi %[[ARG8]], %{{.*}} +// AMD: %[[CMPI_35:.*]] = arith.cmpi slt, %[[ADDI_34]], %{{.*}} +// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_35]], %[[ADDI_34]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_12]][%[[SELECT_36]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_33]], %[[MEMDESC_SUBVIEW_37]] +// AMD: scf.yield %[[ADDPTR_32]], %[[DOT_31]], %[[SELECT_36]], %[[MEMDESC_SUBVIEW_37]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_12]] + tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -268,6 +363,86 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_2]] // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} + +// AMD-LABEL: tt.func @indirect_bmm_scalar +// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[LOAD_5:.*]] = tt.load %{{.*}}, %[[CMPI_2]] +// AMD: %[[MULI_6:.*]] = arith.muli %{{.*}}, %[[LOAD_5]] +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[MULI_6]] +// AMD: %[[ADDPTR_8:.*]] = tt.addptr %{{.*}}, %[[SPLAT_7]] +// AMD: %[[SPLAT_9:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_10:.*]] = tt.load %[[ADDPTR_8]], %[[SPLAT_9]] +// AMD: %[[CMPI_11:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_13:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_14:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_15:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_14]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_13]], %[[CMPI_11]] +// AMD: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[LOAD_16]] +// AMD: %[[SPLAT_18:.*]] = tt.splat %[[MULI_17]] +// AMD: %[[ADDPTR_19:.*]] = tt.addptr %{{.*}}, %[[SPLAT_18]] +// AMD: %[[SPLAT_20:.*]] = tt.splat %[[CMPI_11]] +// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_19]], %[[SPLAT_20]] +// AMD: %[[MEMDESC_SUBVIEW_22:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_4]], %[[MEMDESC_SUBVIEW_22]] +// AMD: %[[MEMDESC_SUBVIEW_23:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_10]], %[[MEMDESC_SUBVIEW_23]] +// AMD: %[[SUBI_24:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:8 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_24]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %[[ADDPTR_12]], %[[ARG9:.*]] = %[[ADDPTR_13]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_22]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_23]], %[[ARG13:.*]] = %[[LOAD_15]], %[[ARG14:.*]] = %[[LOAD_21]]) +// AMD: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[ARG7]] +// AMD: %[[ADDPTR_46:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_47:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_48:.*]] = tt.load %[[ADDPTR_46]] +// AMD: %[[LOAD_49:.*]] = tt.load %[[ADDPTR_47]] +// AMD: %[[MULI_50:.*]] = arith.muli %{{.*}}, %[[LOAD_49]] +// AMD: %[[SPLAT_51:.*]] = tt.splat %[[MULI_50]] +// AMD: %[[ADDPTR_52:.*]] = tt.addptr %{{.*}}, %[[SPLAT_51]] +// AMD: %[[LOAD_53:.*]] = tt.load %[[ADDPTR_52]] +// AMD: %[[ADDI_54:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_55:.*]] = arith.cmpi slt, %[[ADDI_54]], %{{.*}} +// AMD: %[[SELECT_56:.*]] = arith.select %[[CMPI_55]], %[[ADDI_54]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_57:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_56]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_57]] +// AMD: %[[MEMDESC_SUBVIEW_58:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_56]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[ARG14]], %[[MEMDESC_SUBVIEW_58]] +// AMD: scf.yield %[[DOT_45]], %[[ADDPTR_46]], %[[ADDPTR_47]], %[[SELECT_56]], %[[MEMDESC_SUBVIEW_57]], %[[MEMDESC_SUBVIEW_58]], %[[LOAD_48]], %[[LOAD_53]] +// AMD: } +// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_27:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]] +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[LOCAL_LOAD_29]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_41]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: %[[ADDI_31:.*]] = arith.addi %{{.*}}#3, %{{.*}} +// AMD: %[[CMPI_32:.*]] = arith.cmpi slt, %[[ADDI_31]], %{{.*}} +// AMD: %[[SELECT_33:.*]] = arith.select %[[CMPI_32]], %[[ADDI_31]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_34:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_33]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %{{.*}}#6, %[[MEMDESC_SUBVIEW_34]] +// AMD: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_33]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %{{.*}}#7, %[[MEMDESC_SUBVIEW_35]] +// AMD: %[[SELECT_36:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_34]] +// AMD: %[[LOCAL_LOAD_38:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_35]] +// AMD: %[[IF_39:.*]] = scf.if %[[CMPI_27]] +// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_37]], %[[LOCAL_LOAD_38]], %[[SELECT_36]] +// AMD: scf.yield %[[DOT_41]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_36]] +// AMD: } +// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_27]], %[[IF_39]], %[[SELECT_36]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] +// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] + tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -293,7 +468,7 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } @@ -313,6 +488,15 @@ tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[IND_BUFFER_0]] + +// AMD-LABEL: tt.func @indirect_bmm_scalar_dist_one +// AMD-COUNT-4: tt.load +// AMD: scf.for +// AMD: tt.dot +// AMD: tt.load +// AMD: triton_gpu.local_store +// AMD: scf.yield + tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -365,6 +549,52 @@ tt.func @indirect_bmm_scalar_dist_one(%77: i64 {tt.divisibility=16: i32}, // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} // CHECK: scf.yield + +// AMD-LABEL: tt.func @indirect_bmm_vector +// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %[[CMPI_2:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_3:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_4:.*]] = tt.load %{{.*}}, %[[SPLAT_3]] +// AMD: %[[CMPI_5:.*]] = arith.cmpi sgt, %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_6:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_7:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_8:.*]] = tt.load %{{.*}}, %[[SPLAT_7]] +// AMD: %[[EXPAND_DIMS_9:.*]] = tt.expand_dims %[[LOAD_4]] {axis = 1 : i32} +// AMD: %[[BROADCAST_10:.*]] = tt.broadcast %[[EXPAND_DIMS_9]] +// AMD: %[[MULI_11:.*]] = arith.muli %{{.*}}, %[[BROADCAST_10]] +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %[[MULI_11]] +// AMD: %[[SPLAT_13:.*]] = tt.splat %[[CMPI_2]] +// AMD: %[[LOAD_14:.*]] = tt.load %[[ADDPTR_12]], %[[SPLAT_13]] +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[CMPI_5]] +// AMD: %[[LOAD_16:.*]] = tt.load %[[ADDPTR_6]], %[[SPLAT_15]] +// AMD: %[[MEMDESC_SUBVIEW_17:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_8]], %[[MEMDESC_SUBVIEW_17]] +// AMD: %[[MEMDESC_SUBVIEW_18:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%{{.*}}, %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_14]], %[[MEMDESC_SUBVIEW_18]] +// AMD: %[[SUBI_19:.*]] = arith.subi %{{.*}}, %{{.*}} +// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %[[SUBI_19]] step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %[[ADDPTR_6]], %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %[[MEMDESC_SUBVIEW_17]], %[[ARG12:.*]] = %[[MEMDESC_SUBVIEW_18]], %[[ARG13:.*]] = %[[LOAD_16]]) +// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG7]] +// AMD: %[[ADDPTR_50:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_51:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_52:.*]] = tt.load %[[ADDPTR_50]] +// AMD: %[[EXPAND_DIMS_53:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} +// AMD: %[[BROADCAST_54:.*]] = tt.broadcast %[[EXPAND_DIMS_53]] +// AMD: %[[MULI_55:.*]] = arith.muli %{{.*}}, %[[BROADCAST_54]] +// AMD: %[[ADDPTR_56:.*]] = tt.addptr %{{.*}}, %[[MULI_55]] +// AMD: %[[LOAD_57:.*]] = tt.load %[[ADDPTR_56]] +// AMD: %[[LOAD_58:.*]] = tt.load %[[ADDPTR_51]] +// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} +// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_52]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]] + tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -392,16 +622,16 @@ tt.func @indirect_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } -// CHECK-LABEL: tt.func @post_load_inv -// CHECK: scf.for -// CHECK-DAG: %[[IV:.*]] = arith.index_cast -// CHECK: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 -// CHECK: arith.index_cast -// CHECK-NOT: arith.addi %[[NEXT_IV]] +// COMMON-LABEL: tt.func @post_load_inv +// COMMON: scf.for +// COMMON-DAG: %[[IV:.*]] = arith.index_cast +// COMMON: %[[NEXT_IV:.*]] = arith.addi %[[IV]], %c1_i32 : i32 +// COMMON: arith.index_cast +// COMMON-NOT: arith.addi %[[NEXT_IV]] tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -452,11 +682,12 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %85#0 : tensor<32x32xf32, #C> } -// CHECK-LABEL: tt.func @cross_iter_dep +// COMMON-LABEL: tt.func @cross_iter_dep // TODO: enable pipelining with distance of 2 -// CHECK-NOT: triton_gpu.async_commit_group -// CHECK: scf.for -// CHECK: scf.yield +// COMMON-NOT: triton_gpu.async_commit_group +// COMMON: scf.for +// COMMON: scf.yield + tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, @@ -509,14 +740,14 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return %119#0 : tensor<32x32xf32, #C> } -// CHECK-LABEL: tt.func @dep_arg_two_uses -// CHECK: tt.expand_dims -// CHECK: tt.expand_dims -// CHECK: tt.expand_dims %arg5 -// CHECK-NEXT: tt.expand_dims %arg5 -// CHECK: %[[PTR0:.*]] = tt.splat %arg6 -// CHECK: %[[PTR1:.*]] = tt.addptr %[[PTR0]] -// CHECK-NEXT: tt.load %[[PTR1]] +// COMMON-LABEL: tt.func @dep_arg_two_uses +// COMMON: tt.expand_dims +// COMMON: tt.expand_dims +// COMMON: tt.expand_dims %arg5 +// COMMON-NEXT: tt.expand_dims %arg5 +// COMMON: %[[PTR0:.*]] = tt.splat %arg6 +// COMMON: %[[PTR1:.*]] = tt.addptr %[[PTR0]] +// COMMON-NEXT: tt.load %[[PTR1]] tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> { @@ -577,74 +808,13 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, // ----- -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { -// CHECK-LABEL: tt.func @load_two_users - tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { - %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> - %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> - %c0_i64 = arith.constant 0 : i64 - %c0_i32 = arith.constant 0 : i32 - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %c1_i32 = arith.constant 1 : i32 - %c8_i32 = arith.constant 8 : i32 - %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 - %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 - %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %3 = tt.addptr %2, %cst_0 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> - %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> - %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> - %11 = tt.addptr %10, %cst : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> - %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> - %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} - // CHECK: scf.for - // CHECK: tt.dot - // CHECK: tt.dot - // CHECK: triton_gpu.async_copy_global_to_local - // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} - // CHECK: scf.yield - // CHECK: triton_gpu.async_wait {num = 0 : i32} - - %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { - %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %20 = triton_gpu.convert_layout %18 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> - %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> - %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared> -> !tt.memdesc<16x64xf16, #shared1> - %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> - scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> - } - tt.return %17#0, %17#1 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> - } -} - -// ----- - #blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> #shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { -// CHECK-LABEL: tt.func @load_two_users_incompatible_layouts +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { +// COMMON-LABEL: tt.func @load_two_users_incompatible_layouts tt.func @load_two_users_incompatible_layouts(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> %cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> @@ -671,8 +841,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK-NOT: triton_gpu.insert_slice_async - // CHECK: scf.for + // check that the load didn't get pipelined. + // COMMON-NOT: alloc + // COMMON: scf.for %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> @@ -680,9 +851,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared> -> !tt.memdesc<16x64xf16, #shared1> - %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } @@ -704,6 +875,15 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.async_commit_group // CHECK: scf.yield + +// AMD-LABEL: tt.func public @nested_loops +// AMD: scf.for +// AMD: triton_gpu.local_alloc +// AMD-NOT: triton_gpu.local_alloc +// AMD: scf.for +// AMD: scf.yield +// AMD-DIS: scf.yield + // // The following code has the structure: // @@ -717,17 +897,12 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // } // ``` // -// Only the outer for should be pipelined. The regression this tests -// causes an assertion to fail while pipelining the outer `for`, in -// particular while predicating the operations scheduled to be emitted -// in the prologue. -// -// We check that there is no allocation before the first occurrence of -// scf.for because that would mean that the first load `%a = load()` -// would be pipelined. +// For CUDA, we pipeline the inner loop first then pipeline the outer +// loop to prefetch the async copy after the inner loop. +// For HIP, we only pipeline the inner loop for now. #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst_0 = arith.constant dense<320> : tensor<32x1xi32, #blocked> @@ -780,107 +955,6 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : } } // end module -// ----- - -// CHECK-NOT: triton_gpu.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1> - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> - %c64_i32 = arith.constant 64 : i32 - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c64_i32 : i32 - %2 = tt.get_program_id y : i32 - %3 = tt.load %arg3 : !tt.ptr - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %5 = tt.splat %1 : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %6 = arith.addi %5, %4 : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %8 = tt.splat %3 : i64 -> tensor<64x1xi64, #blocked> - %9 = arith.extsi %7 : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> - %10 = arith.addi %8, %9 : tensor<64x1xi64, #blocked> - %11 = arith.extsi %arg5 : i32 to i64 - %12 = tt.splat %11 : i64 -> tensor<64x1xi64, #blocked> - %13 = arith.muli %10, %12 : tensor<64x1xi64, #blocked> - %14 = arith.muli %2, %arg5 : i32 - %15 = arith.extsi %14 : i32 to i64 - %16 = tt.splat %15 : i64 -> tensor<64x1xi64, #blocked> - %17 = arith.addi %13, %16 : tensor<64x1xi64, #blocked> - %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %19 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - %21 = tt.expand_dims %19 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %22 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked> - %23 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> - %24 = arith.muli %20, %22 : tensor<1x64xi32, #blocked> - %25 = arith.muli %21, %23 : tensor<1x64xi32, #blocked1> - %26 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x64xi64, #blocked> - %27 = arith.extsi %24 : tensor<1x64xi32, #blocked> to tensor<1x64xi64, #blocked> - %28 = arith.extsi %25 : tensor<1x64xi32, #blocked1> to tensor<1x64xi64, #blocked1> - %29 = tt.broadcast %27 : tensor<1x64xi64, #blocked> -> tensor<64x64xi64, #blocked> - %30 = arith.addi %26, %29 : tensor<64x64xi64, #blocked> - %31 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> - %32 = tt.expand_dims %31 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> - %33 = tt.splat %3 : i64 -> tensor<32x1xi64, #blocked1> - %34 = arith.extsi %32 : tensor<32x1xi32, #blocked1> to tensor<32x1xi64, #blocked1> - %35 = arith.addi %33, %34 : tensor<32x1xi64, #blocked1> - %36 = tt.splat %11 : i64 -> tensor<32x1xi64, #blocked1> - %37 = arith.muli %35, %36 : tensor<32x1xi64, #blocked1> - %38 = tt.splat %15 : i64 -> tensor<32x1xi64, #blocked1> - %39 = arith.addi %37, %38 : tensor<32x1xi64, #blocked1> - %40 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x64xi64, #blocked1> - %41 = tt.broadcast %28 : tensor<1x64xi64, #blocked1> -> tensor<32x64xi64, #blocked1> - %42 = arith.addi %40, %41 : tensor<32x64xi64, #blocked1> - %43 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %44 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %45 = tt.expand_dims %43 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> - %46 = tt.expand_dims %44 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> - %47 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked1> - %48 = tt.splat %arg5 : i32 -> tensor<1x32xi32, #blocked> - %49 = arith.muli %45, %47 : tensor<1x32xi32, #blocked1> - %50 = arith.muli %46, %48 : tensor<1x32xi32, #blocked> - %51 = tt.broadcast %39 : tensor<32x1xi64, #blocked1> -> tensor<32x32xi64, #blocked1> - %52 = arith.extsi %49 : tensor<1x32xi32, #blocked1> to tensor<1x32xi64, #blocked1> - %53 = arith.extsi %50 : tensor<1x32xi32, #blocked> to tensor<1x32xi64, #blocked> - %54 = tt.broadcast %52 : tensor<1x32xi64, #blocked1> -> tensor<32x32xi64, #blocked1> - %55 = arith.addi %51, %54 : tensor<32x32xi64, #blocked1> - %56 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> - %57 = tt.addptr %56, %30 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi64, #blocked> - %58 = tt.splat %arg1 : !tt.ptr -> tensor<32x64x!tt.ptr, #blocked1> - %59 = tt.addptr %58, %42 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi64, #blocked1> - %60 = tt.splat %arg2 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> - %61 = tt.addptr %60, %55 : tensor<32x32x!tt.ptr, #blocked1>, tensor<32x32xi64, #blocked1> - %62 = tt.load %57 : tensor<64x64x!tt.ptr, #blocked> - %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { - %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> - %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared> - %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared> -> !tt.memdesc<64x32xf32, #shared1> - %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> - %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> - %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %78 = triton_gpu.convert_layout %76 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> - %79 = tt.dot %77, %78, %arg7 : tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> - scf.yield %79 : tensor<64x32xf32, #mma> - } - %64 = tt.broadcast %17 : tensor<64x1xi64, #blocked> -> tensor<64x32xi64, #blocked> - %65 = tt.broadcast %53 : tensor<1x32xi64, #blocked> -> tensor<64x32xi64, #blocked> - %66 = arith.addi %64, %65 : tensor<64x32xi64, #blocked> - %67 = tt.splat %arg4 : !tt.ptr -> tensor<64x32x!tt.ptr, #blocked> - %68 = tt.addptr %67, %66 : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi64, #blocked> - %69 = triton_gpu.convert_layout %63 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #blocked> - tt.store %68, %69 : tensor<64x32x!tt.ptr, #blocked> - tt.return - } -} // end module // ----- // CHECK: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> @@ -888,7 +962,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: scf.for // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview {{.*}} : !tt.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], mutable> -> !tt.memdesc<16xi64, #[[$SHARED_LAYOUT]], mutable> +// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview {{.*}} : !tt.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> // CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] @@ -896,13 +970,79 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[NEXT_BUFFER_0:.*]] = tt.addptr {{.*}}, %[[IND_BUFFER_4]] // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_0]] // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} + +// AMD-DIS: #[[$SHARED_LAYOUT:shared.*]] = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// AMD-LABEL: tt.func @indirect_load_shared_layout +// AMD: %[[LOCAL_ALLOC_0:.*]] = triton_gpu.local_alloc +// AMD: %[[LOCAL_ALLOC_1:.*]] = triton_gpu.local_alloc +// AMD: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}}) +// AMD: %[[LOCAL_LOAD_47:.*]] = triton_gpu.local_load %[[ARG11]] +// AMD: %[[LOCAL_LOAD_48:.*]] = triton_gpu.local_load %[[ARG12]] +// AMD: %[[DOT_49:.*]] = tt.dot %[[LOCAL_LOAD_47]], %[[LOCAL_LOAD_48]], %[[ARG7]] +// AMD: %[[ADDPTR_50:.*]] = tt.addptr %[[ARG8]], %{{.*}} +// AMD: %[[ADDPTR_51:.*]] = tt.addptr %[[ARG9]], %{{.*}} +// AMD: %[[LOAD_52:.*]] = tt.load %[[ADDPTR_50]] +// AMD: %[[EXPAND_DIMS_53:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32} +// AMD: %[[BROADCAST_54:.*]] = tt.broadcast %[[EXPAND_DIMS_53]] +// AMD: %[[MULI_55:.*]] = arith.muli %{{.*}}, %[[BROADCAST_54]] +// AMD: %[[ADDPTR_56:.*]] = tt.addptr %{{.*}}, %[[MULI_55]] +// AMD: %[[LOAD_57:.*]] = tt.load %[[ADDPTR_56]] +// AMD: %[[LOAD_58:.*]] = tt.load %[[ADDPTR_51]] +// AMD: %[[ADDI_59:.*]] = arith.addi %[[ARG10]], %{{.*}} +// AMD: %[[CMPI_60:.*]] = arith.cmpi slt, %[[ADDI_59]], %{{.*}} +// AMD: %[[SELECT_61:.*]] = arith.select %[[CMPI_60]], %[[ADDI_59]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_62:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_52]], %[[MEMDESC_SUBVIEW_62]] +// AMD: %[[MEMDESC_SUBVIEW_63:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_61]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_57]], %[[MEMDESC_SUBVIEW_63]] +// AMD: scf.yield %[[DOT_49]], %[[ADDPTR_50]], %[[ADDPTR_51]], %[[SELECT_61]], %[[MEMDESC_SUBVIEW_62]], %[[MEMDESC_SUBVIEW_63]], %[[LOAD_58]] +// AMD: } +// AMD: %[[CMPI_21:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[CMPI_22:.*]] = arith.cmpi sge, %{{.*}}, %{{.*}} +// AMD: %[[LOCAL_LOAD_23:.*]] = triton_gpu.local_load %{{.*}}#4 +// AMD: %[[LOCAL_LOAD_24:.*]] = triton_gpu.local_load %{{.*}}#5 +// AMD: %[[IF_25:.*]] = scf.if %[[CMPI_21]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_23]], %[[LOCAL_LOAD_24]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } +// AMD: %[[ADDPTR_26:.*]] = tt.addptr %{{.*}}#1, %{{.*}} +// AMD: %[[SPLAT_27:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_28:.*]] = tt.load %[[ADDPTR_26]], %[[SPLAT_27]] +// AMD: %[[EXPAND_DIMS_29:.*]] = tt.expand_dims %{{.*}}#6 {axis = 1 : i32} +// AMD: %[[BROADCAST_30:.*]] = tt.broadcast %[[EXPAND_DIMS_29]] +// AMD: %[[MULI_31:.*]] = arith.muli %{{.*}}, %[[BROADCAST_30]] +// AMD: %[[ADDPTR_32:.*]] = tt.addptr %{{.*}}, %[[MULI_31]] +// AMD: %[[SPLAT_33:.*]] = tt.splat %[[CMPI_22]] +// AMD: %[[LOAD_34:.*]] = tt.load %[[ADDPTR_32]], %[[SPLAT_33]] +// AMD: %[[ADDI_35:.*]] = arith.addi %{{.*}}#3, %{{.*}} +// AMD: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ADDI_35]], %{{.*}} +// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_36]], %[[ADDI_35]], %{{.*}} +// AMD: %[[MEMDESC_SUBVIEW_38:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_28]], %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[MEMDESC_SUBVIEW_39:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_37]], %{{.*}}, %{{.*}}] +// AMD: triton_gpu.local_store %[[LOAD_34]], %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[SELECT_40:.*]] = arith.select %[[CMPI_21]], %[[IF_25]], %{{.*}}#0 +// AMD: %[[LOCAL_LOAD_41:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_38]] +// AMD: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_39]] +// AMD: %[[IF_43:.*]] = scf.if %[[CMPI_22]] +// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_41]], %[[LOCAL_LOAD_42]], %[[SELECT_40]] +// AMD: scf.yield %[[DOT_45]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_40]] +// AMD: } +// AMD: %[[SELECT_44:.*]] = arith.select %[[CMPI_22]], %[[IF_43]], %[[SELECT_40]] +// AMD: triton_gpu.local_dealloc %{{.*}} +// AMD: triton_gpu.local_dealloc %{{.*}} + #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> #C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.target" = "cuda:86", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -930,7 +1070,7 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } } @@ -945,9 +1085,19 @@ tt.func @indirect_load_shared_layout(%77: tensor<16x16xi64, #BL> {tt.divisibilit // CHECK: triton_gpu.async_copy_global_to_local // CHECK: triton_gpu.memdesc_subview // CHECK: tt.return + +// AMD-LABEL: @kernel_yield_constant +// AMD: tt.load +// AMD: triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store +// AMD: scf.for +// AMD: tt.load +// AMD: triton_gpu.memdesc_subview +// AMD: triton_gpu.local_store +// AMD: tt.return #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:86", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @kernel_yield_constant(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst1 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #mma> @@ -1001,8 +1151,22 @@ module attributes {"triton_gpu.target" = "cuda:86", "triton_gpu.num-ctas" = 1 : // CHECK: %[[B1BUFFER:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]]] // CHECK: triton_gpu.async_copy_global_to_local {{.*}}, %[[B1BUFFER]] // CHECK: scf.for + +// AMD-LABEL: tt.func public @add_kernel +// AMD: %[[LOAD_11:.*]] = tt.load %{{.*}}, %{{.*}} +// AMD: %[[ADDPTR_12:.*]] = tt.addptr %{{.*}}, %{{.*}} +// AMD: %[[LOAD_13:.*]] = tt.load %[[ADDPTR_12]], %{{.*}} +// AMD: %[[ADDI_14:.*]] = arith.addi %{{.*}}, %{{.*}} +// AMD: %[[SPLAT_15:.*]] = tt.splat %[[ADDI_14]] +// AMD: %[[ADDI_16:.*]] = arith.addi %[[SPLAT_15]], %{{.*}} +// AMD: %[[CMPI_17:.*]] = arith.cmpi slt, %[[ADDI_16]], %{{.*}} +// AMD: %[[ADDPTR_18:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// AMD: %[[LOAD_19:.*]] = tt.load %[[ADDPTR_18]], %[[CMPI_17]] +// AMD: %[[ADDPTR_20:.*]] = tt.addptr %{{.*}}, %[[ADDI_16]] +// AMD: %[[LOAD_21:.*]] = tt.load %[[ADDPTR_20]], %[[CMPI_17]] +// AMD: scf.for #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -1026,7 +1190,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %15 = arith.addf %12, %14 : tensor<1024xf32, #blocked> %16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> tt.store %16, %15, %10 : tensor<1024x!tt.ptr, #blocked> - }{tt.num_stages = 3 : i32} + } {tt.num_stages = 3 : i32} tt.return } } @@ -1067,11 +1231,22 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: %[[COMMIT_2:.*]] = triton_gpu.async_commit_group %[[ASYNC_COPY_5]] // CHECK: scf.yield %[[COMMIT_1]], %[[COMMIT_2]] // CHECK: triton_gpu.local_dealloc %[[BUFFER_1]] + +// AMD-LABEL: tt.func public @nested_loops +// AMD-NOT: triton_gpu.local_alloc +// AMD: scf.for +// AMD: triton_gpu.local_alloc +// AMD: scf.for +// AMD: triton_gpu.local_load +// AMD: tt.dot +// AMD: triton_gpu.local_store +// AMD: scf.yield +// AMD: triton_gpu.local_dealloc #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> #shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> #shared1 = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> %c1_i32 = arith.constant 1 : i32 @@ -1090,9 +1265,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> - %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared> - %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared> -> !tt.memdesc<16x16xf32, #shared1> - %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> + %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> + %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %14 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> %15 = triton_gpu.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> @@ -1115,7 +1290,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : #blocked4 = #triton_gpu.blocked<{sizePerThread = [16, 2, 1], threadsPerWarp = [4, 1, 8], warpsPerCTA = [1, 1, 8], order = [1, 0, 2]}> #blocked5 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { tt.func public @int4_matmul_ampere( %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32} @@ -1146,14 +1321,12 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %51 = tt.addptr %50, %47 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> // Check that both loads in the loop are pipelined. - // TODO(jlebar): https://github.com/triton-lang/triton/pull/3472 disables the - // relevant optimization. Once we've reenabled it, we can uncomment this test. // CHECK: scf.for - // COM: CHECK-NOT: tt.load + // CHECK-NOT: tt.load + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK-NOT: tt.load // CHECK: triton_gpu.async_copy_global_to_local - // COM: CHECK-NOT: tt.load - // COM: CHECK: triton_gpu.async_copy_global_to_local - // COM: CHECK-NOT: tt.load + // CHECK-NOT: tt.load // CHECK: scf.yield %54:3 = scf.for %arg9 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %41, %arg12 = %51) -> (tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>) : i32 { %78 = tt.load %arg11 : tensor<16x128x!tt.ptr, #blocked1> @@ -1165,7 +1338,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %84 = arith.sitofp %82 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked> %85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3> %86 = tt.trans %85 {order = array} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4> - %87 = tt.reshape %86 {allow_reorder = false} : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> + %87 = tt.reshape %86 : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> %88 = triton_gpu.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %89 = triton_gpu.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma> @@ -1182,7 +1355,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // This test triggered some failure in the verifier, so we only // included a simple check for the kernel name. -// CHECK-LABEL: @load_convert_layout +// COMMON-LABEL: @load_convert_layout #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #ALs0 = #triton_gpu.slice<{parent=#AL, dim=0}> @@ -1192,7 +1365,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt.constancy=16: i32}, %76: index, %49: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -1223,7 +1396,7 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> - } + } {tt.num_stages = 3 : i32} tt.return %79#0 : tensor<16x16xf32, #C> } } @@ -1233,10 +1406,10 @@ tt.func @load_convert_layout(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i3 // This test captured some ICE in MatmulLoopPipeline pass, so we only // included a simple check for the kernel name. -// CHECK-LABEL: @matmul_indirect_pipeline +// COMMON-LABEL: @matmul_indirect_pipeline #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 2], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { tt.func public @matmul_indirect_pipeline(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %c1_i32 = arith.constant 1 : i32 @@ -1269,18 +1442,18 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %23 = tt.dot %21, %22, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %24 = triton_gpu.convert_layout %23 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %11, %24 : tensor<32x32x!tt.ptr, #blocked> - } + } {tt.num_stages = 3 : i32} tt.return } } // ----- -// CHECK-LABEL: @dont_pipeline_128x1 -// CHECK-NOT: local_load{{.*}}128x1 +// COMMON-LABEL: @dont_pipeline_128x1 +// COMMON-NOT: local_load{{.*}}128x1 #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @dont_pipeline_128x1(%arg6: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %c128_i32 = arith.constant 128 : i32 @@ -1319,8 +1492,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // Check that the dependencies across ops of different nesting does not cause crash or // incorrect schedule that fails to pipeline. -// CHECK-LABEL: @matmul_nested_ops -// CHECK: triton_gpu.local_load +// COMMON-LABEL: @matmul_nested_ops +// COMMON: triton_gpu.local_load #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -1331,7 +1504,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}, @@ -1382,144 +1555,6 @@ tt.func @matmul_nested_ops(%lb : index, %ub : index, %step : index, // ----- -// Pipeline the if ops at the beginning and the end of the loop -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: dot_prologue_epilogue - // CHECK: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} - tt.func @dot_prologue_epilogue(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { - %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> - %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> - %c0_i32 = arith.constant 0 : i32 - %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> - %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> - %c0_i64 = arith.constant 0 : i64 - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> - %c1_i32 = arith.constant 1 : i32 - %c8_i32 = arith.constant 8 : i32 - %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> - %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> - %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> - %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK: scf.for %[[IND_VAR:.*]] = %[[C0]] - // CHECK-NOT load - // CHECK: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] - // CHECK: scf.if %[[CND]] - // CHECK: dot - // CHECK: scf.if %[[CND]] - // CHECK: arith.mulf - // CHECK: scf.yield - // CHECK-NOT: tt.addptr - // CHECK: scf.yield - %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { - %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> - %cnd = arith.cmpi slt, %arg3, %ext : i32 - %inc_ptr = scf.if %cnd -> tensor<64x16x!tt.ptr, #blocked> { - %ptr = tt.addptr %arg5, %inc : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - scf.yield %ptr : tensor<64x16x!tt.ptr, #blocked> - } else { - scf.yield %arg5 : tensor<64x16x!tt.ptr, #blocked> - } - %18 = tt.load %inc_ptr : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> - %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { - %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> - scf.yield %acc_zero : tensor<128x16xf32, #mma1> - } else { - scf.yield %acc : tensor<128x16xf32, #mma1> - } - %22 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - %23 = tt.addptr %arg6, %cst2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - scf.yield %acc_, %22, %23 : tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1> - } - tt.return %17#0 : tensor<128x16xf32, #mma1> - } -} - -// ----- - -// Verify that uses of the ops scheduled in partucular place of the loop (like epilogue if) are correctly scheduled too. -#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-NOCANON-LABEL: pipeline_downstream_dependencies - // CHECK-NOCANON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} - tt.func @pipeline_downstream_dependencies(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { - %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> - %cst1 = arith.constant dense<1> : tensor<64x16xi32, #blocked> - %cst2 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> - %c0_i32 = arith.constant 0 : i32 - %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> - %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> - %c0_i64 = arith.constant 0 : i64 - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> - %c1_i32 = arith.constant 1 : i32 - %c8_i32 = arith.constant 8 : i32 - %2 = tt.splat %arg1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> - %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> - %6 = tt.broadcast %2 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> - %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> - %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - %10 = tt.splat %arg0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> - %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> - %14 = tt.broadcast %10 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> - %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> - %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK-NOCANON: %[[C0:.*]] = arith.constant 0 : i32 - // CHECK-NOCANON: scf.for %[[IND_VAR:.*]] = %[[C0]] - // CHECK-NOCANON-NOT load - // CHECK-NOCANON: dot - // CHECK-NOCANON: %[[CND:.*]] = arith.cmpi slt, %[[IND_VAR]], %[[EXT]] - // CHECK-NOCANON: %[[IFRET:.*]]:2 = scf.if %[[CND]] - // CHECK-NOCANON: arith.mulf - // CHECK-NOCANON: scf.yield - // CHECK-NOCANON: tt.addptr {{.*}}, %[[IFRET]]#1 - // CHECK-NOCANON: scf.yield - %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { - %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> - %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> - %cnd = arith.cmpi slt, %arg3, %ext : i32 - %if_ret:2 = scf.if %cnd -> (tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked>) { - %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> - scf.yield %acc_zero, %cst : tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked> - } else { - scf.yield %acc, %cst1 : tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked> - } - %22 = tt.addptr %arg5, %if_ret#1 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - %23 = tt.addptr %arg6, %cst2 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> - scf.yield %if_ret#0, %22, %23 : tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1> - } - tt.return %17#0 : tensor<128x16xf32, #mma1> - } -} - -// ----- - // CHECK-LABEL: @masked_add_kernel // CHECK: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> // CHECK: scf.for @@ -1528,8 +1563,20 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[B:.*]] = triton_gpu.local_load // CHECK: arith.select {{.*}}, %[[B]], %[[CONSTANT]] +// AMD-LABEL: @masked_add_kernel +// AMD: %[[CONSTANT:.*]] = arith.constant dense<0xFF800000> +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: scf.for +// AMD: arith.select +// AMD: arith.addf +// AMD: %[[A:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] +// AMD: %[[B:.*]] = tt.load {{.*}}, %{{.*}}, %[[CONSTANT]] + #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func public @masked_add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} { %c1024_i32 = arith.constant 1024 : i32 %c0_i32 = arith.constant 0 : i32 @@ -1558,41 +1605,3 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return } } - - -// ----- - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> -#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> -#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { -// CHECK-LABEL: @matmul_tma -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x128x64xf16, #{{.+}}, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x64x256xf16, #{{.+}}, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3xi64, #{{.+}}, mutable> -// CHECK-COUNT-3: triton_nvidia_gpu.init_barrier -// CHECK-COUNT-4: triton_nvidia_gpu.async_tma_copy_global_to_local -// CHECK: scf.for -// CHECK: triton_nvidia_gpu.wait_barrier -// CHECK-NOT: triton_nvidia_gpu.wait_barrier -// CHECK-COUNT-2: triton_nvidia_gpu.async_tma_copy_global_to_local -// CHECK: scf.yield - tt.func public @matmul_tma(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x256xf32, #mma> { - %c256_i32 = arith.constant 256 : i32 - %c0_i32 = arith.constant 0 : i32 - %c64_i32 = arith.constant 64 : i32 - %c1_i32 = arith.constant 1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { - %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.ptr -> tensor<128x64xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared> - %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.ptr -> tensor<64x256xf16, #blocked1> - %4 = triton_gpu.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared> - %5 = tt.dot %2, %4, %arg4, inputPrecision = tf32 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x256xf16, #shared> -> tensor<128x256xf32, #mma> - %6 = arith.addi %arg5, %c64_i32 : i32 - scf.yield %5, %6 : tensor<128x256xf32, #mma>, i32 - } - tt.return %0#0 : tensor<128x256xf32, #mma> - } -} diff --git a/test/TritonGPU/ops.mlir b/test/TritonGPU/ops.mlir index e3fd4d2a3..9184a5312 100644 --- a/test/TritonGPU/ops.mlir +++ b/test/TritonGPU/ops.mlir @@ -1,20 +1,35 @@ // RUN: triton-opt --split-input-file %s | FileCheck %s -// CHECK: #[[$WMMA:.*]] = #triton_gpu.amd_wmma +// CHECK: #[[$WMMA_GEN1:.*]] = #triton_gpu.amd_wmma<{{.*}}version = 1{{.*}}> +// CHECK: #[[$WMMA_GEN2:.*]] = #triton_gpu.amd_wmma<{{.*}}version = 2{{.*}}> #blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_layout tt.func @wmma_layout(%0: tensor<16x16xf16, #blocked>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #triton_gpu.amd_wmma<{warpsPerCTA = [1, 1]}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA]]> + %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>> + // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN1]]> tt.return } // CHECK-LABEL: wmma_dot_op_layout tt.func @wmma_dot_op_layout(%0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) { - %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.amd_wmma<{warpsPerCTA = [1, 1]}>, kWidth = 16}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$WMMA]], kWidth = 16}>> + %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [1, 1]}>, kWidth = 16}>> + // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN1]], kWidth = 16}>> + tt.return + } + + // CHECK-LABEL: wmma_gen2_layout + tt.func @wmma_gen2_layout(%0: tensor<16x16xf16, #blocked>) { + %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>> + // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #{{.+}}> -> tensor<16x16xf16, #[[$WMMA_GEN2]]> + tt.return + } + + // CHECK-LABEL: wmma_gen2_dot_op_layout + tt.func @wmma_gen2_dot_op_layout(%0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) { + %1 = triton_gpu.convert_layout %0 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [1, 1]}>, kWidth = 8}>> + // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #{{.+}}}>> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$WMMA_GEN2]], kWidth = 8}>> tt.return } } diff --git a/test/TritonGPU/optimize-locality.mlir b/test/TritonGPU/optimize-locality.mlir index 9504fe20e..544299867 100644 --- a/test/TritonGPU/optimize-locality.mlir +++ b/test/TritonGPU/optimize-locality.mlir @@ -4,7 +4,7 @@ // CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0.000000e+00> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.addf // CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] @@ -207,7 +207,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0xFF800000> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.maximumf // CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] @@ -314,7 +314,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[CST:.*]] = arith.constant dense<0x7F800000> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.minimumf // CHECK: arith.minimumf %[[FOR_ARG]], %[[REDUCE]] @@ -421,7 +421,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.mulf // CHECK: arith.mulf %[[FOR_ARG]], %[[REDUCE]] @@ -579,14 +579,14 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-DAG: #[[$BLOCK1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> // CHECK-DAG: #[[$BLOCK2:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> // CHECK-LABEL: optimize_view_layout -// CHECK: %[[R:.+]] = tt.reshape {{.*}} {allow_reorder = true, efficient_layout} : tensor<8x128xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK2]]> +// CHECK: %[[R:.+]] = tt.reshape {{.*}} allow_reorder efficient_layout : tensor<8x128xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK2]]> // CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK2]]> -> tensor<64x16xf32, #[[$BLOCK1]]> // CHECK: "tt.reduce"(%[[C]]) #blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @optimize_view_layout(%arg0: tensor<8x128xf32, #blocked>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> { - %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<8x128xf32, #blocked> -> tensor<64x16xf32, #blocked1> + %0 = tt.reshape %arg0 allow_reorder : tensor<8x128xf32, #blocked> -> tensor<64x16xf32, #blocked1> %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %2 = arith.maximumf %arg1, %arg2 : f32 @@ -595,3 +595,31 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : tt.return %1 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> } } + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#slice = #triton_gpu.slice<{dim = 1, parent = #blocked}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + tt.func public @reduce_for_arg(%arg: tensor<64x128xf32, #blocked>, %arg1: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %c128_i32 = arith.constant 128 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %cst_1 = arith.constant dense<1.000000e+00> : tensor<64x128xf32, #blocked> + %64:1 = scf.for %arg22 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg29 = %arg) -> (tensor<64x128xf32, #blocked>) : i32 { + %129 = "tt.reduce"(%arg29) <{axis = 1 : i32}> ({ + ^bb0(%arg31: f32, %arg32: f32): + %160 = arith.maxnumf %arg31, %arg32 : f32 + tt.reduce.return %160 : f32 + }) : (tensor<64x128xf32, #blocked>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %75 = triton_gpu.convert_layout %129 : tensor<64xf32, #slice> -> tensor<64xf32, #blocked1> + %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked1> + %80 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %81 = tt.addptr %80, %79 : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + tt.store %81, %75 : tensor<64x!tt.ptr, #blocked1> + %141 = arith.addf %arg29, %cst_1 : tensor<64x128xf32, #blocked> + scf.yield %141 : tensor<64x128xf32, #blocked> + } + tt.return + } +} diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index 1e3d4d967..74fd2e055 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -110,18 +110,18 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %112 = tt.load %111 : tensor<64x128x!tt.ptr, #blocked> %113 = triton_gpu.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> %114 = triton_gpu.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !tt.memdesc<128x64xf16, #shared1> - %115 = tt.dot %113, %114, %cst :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %115 = triton_nvidia_gpu.warp_group_dot %113, %114, %cst :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> %117 = triton_gpu.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared> %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> // The first dot gets converted to dot-async + wait. The second one // doesn't have a wait because the first wait is sufficient. - // CHECK: triton_nvidia_gpu.dot_async - // CHECK: triton_nvidia_gpu.dot_wait {{.*}} {pendings = 0 : i32} - // CHECK: triton_nvidia_gpu.dot_async - // CHECK-NOT: triton_nvidia_gpu.dot_wait + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait // CHECK: scf.yield - %119 = tt.dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %121 = arith.addf %120, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index e6b4d1188..9fbc540b9 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -26,12 +26,12 @@ // CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] // CHECK-DAG: %[[B_REM:.*]] = triton_gpu.local_load %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] // CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_A_PREFETCH_SMEM]] // CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] // CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] module attributes { "triton_gpu.num-warps" = 4 : i32 } { tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ @@ -70,3 +70,178 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr tt.return %loop#4 : tensor<128x128xf32, #C> } } // end module + +// 4 warps +// matmul: 128x16 @ 16x128 -> 128x128 +// CHECK: tt.func @matmul_loop_mixed +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.local_load %[[B0_PREFETCH_SMEM]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK: tt.dot %[[a0_prefetch]], %[[b0_prefetch]], {{.*}} +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] +module attributes { "triton_gpu.num-warps" = 4 : i32 } { +tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x16x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<16x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x16xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x16xf8E5M2, #AL> + %b_mask = arith.constant dense : tensor<16x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<16x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x16xi32, #AL> + %b_off = arith.constant dense<4> : tensor<16x128xi32, #BL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> + %a_init = triton_gpu.local_alloc %a_ : (tensor<128x16xf8E5M2, #AL>) -> !tt.memdesc<128x16xf8E5M2, #A> + %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> + %b_init = triton_gpu.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !tt.memdesc<16x128xf16, #B> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !tt.memdesc<128x16xf8E5M2, #A>, !tt.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C>) { + %a_op_ = triton_gpu.local_load %a : !tt.memdesc<128x16xf8E5M2, #A> -> tensor<128x16xf8E5M2, #A_OP> + %a_op = tt.fp_to_fp %a_op_ : tensor<128x16xf8E5M2, #A_OP> -> tensor<128x16xf16, #A_OP> + %b_op = triton_gpu.local_load %b : !tt.memdesc<16x128xf16, #B> -> tensor<16x128xf16, #B_OP> + %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x16xf16, #A_OP> * tensor<16x128xf16, #B_OP> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x16x!tt.ptr, #AL>, tensor<128x16xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<16x128x!tt.ptr, #BL>, tensor<16x128xi32, #BL> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> + %next_a = triton_gpu.local_alloc %next_a_ : (tensor<128x16xf8E5M2, #AL>) -> !tt.memdesc<128x16xf8E5M2, #A> + %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> + %next_b = triton_gpu.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !tt.memdesc<16x128xf16, #B> + + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !tt.memdesc<128x16xf8E5M2, #A>, !tt.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C> + } + tt.return %loop#4 : tensor<128x128xf32, #C> +} +} // end module + + +// CHECK: tt.func @matmul_loop_yield_no_operand +// CHECK: scf.for +// CHECK: scf.if +// CHECK: tt.store +// CHECK-NOT: scf.yield +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:86", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_yield_no_operand(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %c32_i32 = arith.constant 32 : i32 + %c31_i32 = arith.constant 31 : i32 + %c0_i32 = arith.constant 0 : i32 + %0 = arith.muli %arg9, %arg10 : i32 + %1 = arith.addi %arg8, %c31_i32 : i32 + %2 = arith.divsi %1, %c32_i32 : i32 + %3 = arith.addi %0, %c31_i32 : i32 + %4 = arith.divsi %3, %c32_i32 : i32 + %5 = arith.muli %1, %4 : i32 + %6 = tt.get_program_id x : i32 + %7 = tt.get_num_programs x : i32 + %8 = tt.splat %arg3 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + scf.for %arg11 = %6 to %5 step %7 : i32 { + %9 = arith.divsi %arg11, %4 : i32 + %10 = arith.remsi %9, %2 : i32 + %11 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> + %12 = tt.load %8 : tensor<32x32x!tt.ptr, #blocked> + %13 = triton_gpu.convert_layout %12 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %14 = triton_gpu.convert_layout %11 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %15 = tt.dot %13, %14, %cst, inputPrecision = tf32 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %16 = arith.cmpi sgt, %10, %c0_i32 : i32 + %17 = scf.if %16 -> (tensor<32x32xf32, #mma>) { + %21 = tt.dot %13, %14, %15, inputPrecision = tf32 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + scf.yield %21 : tensor<32x32xf32, #mma> + } else { + scf.yield %15 : tensor<32x32xf32, #mma> + } + %18 = tt.splat %arg5 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked1> + %19 = arith.truncf %17 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %20 = triton_gpu.convert_layout %19 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked1> + tt.store %18, %20 : tensor<32x32x!tt.ptr, #blocked1> + } + tt.return + } +} + +// ----- + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#C = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = false}> +#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> + +// CHECK: tt.func @matmul_loop_mixed_amd +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32 +// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[A0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.local_load %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] +// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview %[[B0:.*]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.local_load %[[B0_PREFETCH_SMEM]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_a0]][%[[C0]], %[[C16]]] +// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.local_load %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] +// CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.memdesc_subview %[[arg_b0]][%[[C16]], %[[C0]]] +// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.local_load %[[B_REM_SMEM]] +// CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} +// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] +// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.memdesc_subview {{.*}}[%[[C0]], %[[C0]]] +// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.local_load %[[NEXT_B_PREFETCH_SMEM]] +// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] +module attributes { "triton_gpu.num-warps" = 4 : i32 } { +tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf8E5M2, #AL> + %b_mask = arith.constant dense : tensor<32x128xi1, #BL> + %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %a_init = triton_gpu.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !tt.memdesc<128x32xf8E5M2, #A> + %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %b_init = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !tt.memdesc<32x128xf16, #B> + + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !tt.memdesc<128x32xf8E5M2, #A>, !tt.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C>) { + %a_op_ = triton_gpu.local_load %a : !tt.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> + %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> + %b_op = triton_gpu.local_load %b : !tt.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> + %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> + + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> + %next_a = triton_gpu.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !tt.memdesc<128x32xf8E5M2, #A> + %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> + %next_b = triton_gpu.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !tt.memdesc<32x128xf16, #B> + + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !tt.memdesc<128x32xf8E5M2, #A>, !tt.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C> + } + tt.return %loop#4 : tensor<128x128xf32, #C> +} +} // end module + +// ----- diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir index 7dd91df04..9fca92c9b 100644 --- a/test/TritonGPU/reduce-data-duplication.mlir +++ b/test/TritonGPU/reduce-data-duplication.mlir @@ -1,8 +1,8 @@ // RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s -// CHECK: #[[SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} -// CHECK: apply_swizzle -// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]]> +// CHECK: #[[$SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} +// CHECK-LABEL: apply_swizzle +// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[$SHARED]], #triton_gpu.shared_memory> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> @@ -12,3 +12,31 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : tt.return } } + +// ----- + +// CHECK-LABEL: conversion_shortcut_blocked_dotop_warp32 +// CHECK-NOT: triton_gpu.local_alloc +// CHECK: triton_gpu.convert_layout +// CHECK-NOT: triton_gpu.local_alloc +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @conversion_shortcut_blocked_dotop_warp32(%arg0: tensor<64x64xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} + +// ----- + +// CHECK-LABEL: conversion_shortcut_blocked_dotop_warp64 +// CHECK-NOT: triton_gpu.local_alloc +// CHECK: triton_gpu.convert_layout +// CHECK-NOT: triton_gpu.local_alloc +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [0, 1]}> +module attributes {"triton_gpu.target" = "hip:gfx940", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @conversion_shortcut_blocked_dotop_warp64(%arg0: tensor<64x64xf16, #blocked>) { + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + tt.return + } +} diff --git a/test/TritonGPU/reorder-instructions.mlir b/test/TritonGPU/reorder-instructions.mlir index 0499f4b44..dff1e6b60 100644 --- a/test/TritonGPU/reorder-instructions.mlir +++ b/test/TritonGPU/reorder-instructions.mlir @@ -26,20 +26,20 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // CHECK-LABEL: sink_convert_dealloc // CHECK: triton_gpu.async_wait {num = 0 : i32} -// CHECK: triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared> -// CHECK: triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared> +// CHECK: triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> +// CHECK: triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> // CHECK: %3 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable> %2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> triton_gpu.async_wait {num = 0 : i32} - triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared> - triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared> + triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable> + triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable> %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> tt.return } diff --git a/test/TritonGPU/repro/sort_with_index.mlir b/test/TritonGPU/repro/sort_with_index.mlir deleted file mode 100644 index a155ee800..000000000 --- a/test/TritonGPU/repro/sort_with_index.mlir +++ /dev/null @@ -1,1911 +0,0 @@ -// RUN: triton-opt %s -tritongpu-remove-layout-conversions | FileCheck %s -// Reproducer for https://github.com/pytorch/pytorch/issues/130101 -// This is difficult to minimize as it specifically happens when a long slice of -// operations are being rematerialized that reuses the same node with two different -// layouts. - -// CHECK: tt.return -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [0, 1]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> -#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 2, 2], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> -#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [4, 2, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> -#blocked4 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 2, 8], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> -#blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> -#blocked6 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 2, 1], order = [2, 1, 0]}> -#blocked7 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 2], order = [2, 1, 0]}> -#blocked8 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}> -#blocked9 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 2, 1], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:86", "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @triton_(%arg0: tensor<1x256xi8, #blocked>, %arg1: tensor<1x256xi8, #blocked>, %arg2: tensor<1x256xi8, #blocked>, %arg3: tensor<1x256xi8, #blocked>, %arg4: tensor<1x256xi8, #blocked>, %arg5: tensor<1x256xi8, #blocked>, %arg6: tensor<1x256xi8, #blocked>, %arg7: tensor<1x256xi8, #blocked>, %arg8: tensor<1x256xi8, #blocked>, %arg9: tensor<1x256xi8, #blocked>, %arg10: tensor<1x256xi8, #blocked>, %arg11: tensor<1x256xi8, #blocked>, %arg12: tensor<1x256xi8, #blocked>, %arg13: tensor<1x256xi8, #blocked>, %arg14: tensor<1x256xi8, #blocked>, %arg15: tensor<1x256xi8, #blocked>, %arg16: tensor<1x256xi8, #blocked>, %arg17: tensor<1x256xi8, #blocked>, %arg18: tensor<1x256xi8, #blocked>, %arg19: tensor<1x256xi8, #blocked>, %arg20: tensor<1x256xi8, #blocked>, %arg21: tensor<1x256xi8, #blocked>, %arg22: tensor<1x256xi8, #blocked>, %arg23: tensor<1x256xi8, #blocked>, %arg24: tensor<1x256xi8, #blocked>, %arg25: tensor<1x256xi8, #blocked>, %arg26: tensor<1x256xi8, #blocked>, %arg27: tensor<1x256xi8, #blocked>, %arg28: tensor<1x256xi8, #blocked>, %arg29: tensor<1x256xi8, #blocked>, %arg30: tensor<1x256xi8, #blocked>, %arg31: tensor<1x256xi8, #blocked>, %arg32: tensor<1x256xi8, #blocked>, %arg33: tensor<1x256xi8, #blocked>, %arg34: tensor<1x256xi8, #blocked>, %arg35: tensor<1x256xi8, #blocked>) -> tensor<1x256xi32, #blocked1> attributes {noinline = false} { - %cst = arith.constant dense<1> : tensor<1x2x1xi32, #blocked2> - %cst_0 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked3> - %cst_1 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked4> - %cst_2 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked5> - %cst_3 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked6> - %cst_4 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked7> - %cst_5 = arith.constant dense<0> : tensor<1x256xi32, #blocked8> - %cst_6 = arith.constant dense<1> : tensor<1x2x1xi32, #blocked9> - %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked8}>> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked8}>> -> tensor<1x256xi32, #blocked8> - %2 = tt.reshape %1 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<128x2x1xi32, #blocked9> - %3 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> - %6 = tt.broadcast %5 : tensor<1x2x1xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %7 = tt.reshape %6 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %8 = arith.trunci %5 : tensor<1x2x1xi32, #blocked2> to tensor<1x2x1xi8, #blocked2> - %9 = arith.extsi %8 : tensor<1x2x1xi8, #blocked2> to tensor<1x2x1xi32, #blocked2> - %10 = tt.broadcast %9 : tensor<1x2x1xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %11 = tt.broadcast %8 : tensor<1x2x1xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %12 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked3}>}>> - %13 = tt.expand_dims %12 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked3}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked3}>> - %14 = tt.expand_dims %13 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked3}>> -> tensor<1x2x1xi32, #blocked3> - %15 = tt.broadcast %14 : tensor<1x2x1xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %16 = tt.reshape %15 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %17 = arith.trunci %14 : tensor<1x2x1xi32, #blocked3> to tensor<1x2x1xi8, #blocked3> - %18 = arith.extsi %17 : tensor<1x2x1xi8, #blocked3> to tensor<1x2x1xi32, #blocked3> - %19 = tt.broadcast %18 : tensor<1x2x1xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %20 = tt.broadcast %17 : tensor<1x2x1xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %21 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked4}>}>> - %22 = tt.expand_dims %21 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked4}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked4}>> - %23 = tt.expand_dims %22 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked4}>> -> tensor<1x2x1xi32, #blocked4> - %24 = tt.broadcast %23 : tensor<1x2x1xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %25 = tt.reshape %24 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %26 = arith.trunci %23 : tensor<1x2x1xi32, #blocked4> to tensor<1x2x1xi8, #blocked4> - %27 = arith.extsi %26 : tensor<1x2x1xi8, #blocked4> to tensor<1x2x1xi32, #blocked4> - %28 = tt.broadcast %27 : tensor<1x2x1xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %29 = tt.broadcast %26 : tensor<1x2x1xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %30 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked5}>}>> - %31 = tt.expand_dims %30 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked5}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked5}>> - %32 = tt.expand_dims %31 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked5}>> -> tensor<1x2x1xi32, #blocked5> - %33 = tt.broadcast %32 : tensor<1x2x1xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %34 = tt.reshape %33 {allow_reorder = false} : tensor<8x2x16xi32, #blocked5> -> tensor<1x256xi32, #blocked8> - %35 = arith.trunci %32 : tensor<1x2x1xi32, #blocked5> to tensor<1x2x1xi8, #blocked5> - %36 = arith.extsi %35 : tensor<1x2x1xi8, #blocked5> to tensor<1x2x1xi32, #blocked5> - %37 = tt.broadcast %36 : tensor<1x2x1xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %38 = tt.broadcast %35 : tensor<1x2x1xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %39 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked6}>}>> - %40 = tt.expand_dims %39 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked6}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked6}>> - %41 = tt.expand_dims %40 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked6}>> -> tensor<1x2x1xi32, #blocked6> - %42 = tt.broadcast %41 : tensor<1x2x1xi32, #blocked6> -> tensor<4x2x32xi32, #blocked6> - %43 = tt.reshape %42 {allow_reorder = false} : tensor<4x2x32xi32, #blocked6> -> tensor<1x256xi32, #blocked8> - %44 = arith.trunci %41 : tensor<1x2x1xi32, #blocked6> to tensor<1x2x1xi8, #blocked6> - %45 = arith.extsi %44 : tensor<1x2x1xi8, #blocked6> to tensor<1x2x1xi32, #blocked6> - %46 = tt.broadcast %45 : tensor<1x2x1xi32, #blocked6> -> tensor<4x2x32xi32, #blocked6> - %47 = tt.broadcast %44 : tensor<1x2x1xi8, #blocked6> -> tensor<4x2x32xi8, #blocked6> - %48 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked7}>}>> - %49 = tt.expand_dims %48 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked7}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked7}>> - %50 = tt.expand_dims %49 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked7}>> -> tensor<1x2x1xi32, #blocked7> - %51 = tt.broadcast %50 : tensor<1x2x1xi32, #blocked7> -> tensor<2x2x64xi32, #blocked7> - %52 = tt.reshape %51 {allow_reorder = false} : tensor<2x2x64xi32, #blocked7> -> tensor<1x256xi32, #blocked8> - %53 = tt.broadcast %50 : tensor<1x2x1xi32, #blocked7> -> tensor<1x2x128xi32, #blocked7> - %54 = tt.reshape %53 {allow_reorder = false} : tensor<1x2x128xi32, #blocked7> -> tensor<1x256xi32, #blocked8> - %55 = arith.trunci %50 : tensor<1x2x1xi32, #blocked7> to tensor<1x2x1xi8, #blocked7> - %56 = arith.extsi %55 : tensor<1x2x1xi8, #blocked7> to tensor<1x2x1xi32, #blocked7> - %57 = tt.broadcast %56 : tensor<1x2x1xi32, #blocked7> -> tensor<2x2x64xi32, #blocked7> - %58 = tt.broadcast %56 : tensor<1x2x1xi32, #blocked7> -> tensor<1x2x128xi32, #blocked7> - %59 = tt.broadcast %55 : tensor<1x2x1xi8, #blocked7> -> tensor<2x2x64xi8, #blocked7> - %60 = tt.broadcast %55 : tensor<1x2x1xi8, #blocked7> -> tensor<1x2x128xi8, #blocked7> - %61 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked9}>}>> - %62 = tt.expand_dims %61 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked9}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked9}>> - %63 = tt.expand_dims %62 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked9}>> -> tensor<1x2x1xi32, #blocked9> - %64 = arith.trunci %63 : tensor<1x2x1xi32, #blocked9> to tensor<1x2x1xi8, #blocked9> - %65 = arith.extsi %64 : tensor<1x2x1xi8, #blocked9> to tensor<1x2x1xi32, #blocked9> - %66 = tt.broadcast %65 : tensor<1x2x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %67 = arith.muli %2, %66 : tensor<128x2x1xi32, #blocked9> - %68 = "tt.reduce"(%67) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %69 = tt.expand_dims %68 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %70 = tt.broadcast %69 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %71 = tt.reshape %70 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %72 = tt.broadcast %64 : tensor<1x2x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %73 = arith.subi %cst_6, %65 : tensor<1x2x1xi32, #blocked9> - %74 = arith.trunci %73 : tensor<1x2x1xi32, #blocked9> to tensor<1x2x1xi8, #blocked9> - %75 = tt.broadcast %74 : tensor<1x2x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %76 = arith.extsi %74 : tensor<1x2x1xi8, #blocked9> to tensor<1x2x1xi32, #blocked9> - %77 = tt.broadcast %76 : tensor<1x2x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %78 = arith.muli %2, %77 : tensor<128x2x1xi32, #blocked9> - %79 = "tt.reduce"(%78) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %80 = tt.expand_dims %79 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %81 = tt.broadcast %80 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %82 = tt.reshape %81 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %83 = arith.cmpi sgt, %82, %71 : tensor<1x256xi32, #blocked8> - %84 = arith.xori %82, %71 : tensor<1x256xi32, #blocked8> - %85 = arith.subi %cst_4, %56 : tensor<1x2x1xi32, #blocked7> - %86 = arith.trunci %85 : tensor<1x2x1xi32, #blocked7> to tensor<1x2x1xi8, #blocked7> - %87 = tt.broadcast %86 : tensor<1x2x1xi8, #blocked7> -> tensor<2x2x64xi8, #blocked7> - %88 = arith.extsi %86 : tensor<1x2x1xi8, #blocked7> to tensor<1x2x1xi32, #blocked7> - %89 = tt.broadcast %88 : tensor<1x2x1xi32, #blocked7> -> tensor<2x2x64xi32, #blocked7> - %90 = tt.broadcast %88 : tensor<1x2x1xi32, #blocked7> -> tensor<1x2x128xi32, #blocked7> - %91 = tt.broadcast %86 : tensor<1x2x1xi8, #blocked7> -> tensor<1x2x128xi8, #blocked7> - %92 = arith.subi %cst_3, %45 : tensor<1x2x1xi32, #blocked6> - %93 = arith.trunci %92 : tensor<1x2x1xi32, #blocked6> to tensor<1x2x1xi8, #blocked6> - %94 = tt.broadcast %93 : tensor<1x2x1xi8, #blocked6> -> tensor<4x2x32xi8, #blocked6> - %95 = arith.extsi %93 : tensor<1x2x1xi8, #blocked6> to tensor<1x2x1xi32, #blocked6> - %96 = tt.broadcast %95 : tensor<1x2x1xi32, #blocked6> -> tensor<4x2x32xi32, #blocked6> - %97 = arith.subi %cst_2, %36 : tensor<1x2x1xi32, #blocked5> - %98 = arith.trunci %97 : tensor<1x2x1xi32, #blocked5> to tensor<1x2x1xi8, #blocked5> - %99 = tt.broadcast %98 : tensor<1x2x1xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %100 = arith.extsi %98 : tensor<1x2x1xi8, #blocked5> to tensor<1x2x1xi32, #blocked5> - %101 = tt.broadcast %100 : tensor<1x2x1xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %102 = arith.subi %cst_1, %27 : tensor<1x2x1xi32, #blocked4> - %103 = arith.trunci %102 : tensor<1x2x1xi32, #blocked4> to tensor<1x2x1xi8, #blocked4> - %104 = tt.broadcast %103 : tensor<1x2x1xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %105 = arith.extsi %103 : tensor<1x2x1xi8, #blocked4> to tensor<1x2x1xi32, #blocked4> - %106 = tt.broadcast %105 : tensor<1x2x1xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %107 = arith.subi %cst_0, %18 : tensor<1x2x1xi32, #blocked3> - %108 = arith.trunci %107 : tensor<1x2x1xi32, #blocked3> to tensor<1x2x1xi8, #blocked3> - %109 = tt.broadcast %108 : tensor<1x2x1xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %110 = arith.extsi %108 : tensor<1x2x1xi8, #blocked3> to tensor<1x2x1xi32, #blocked3> - %111 = tt.broadcast %110 : tensor<1x2x1xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %112 = arith.subi %cst, %9 : tensor<1x2x1xi32, #blocked2> - %113 = arith.trunci %112 : tensor<1x2x1xi32, #blocked2> to tensor<1x2x1xi8, #blocked2> - %114 = tt.broadcast %113 : tensor<1x2x1xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %115 = arith.extsi %113 : tensor<1x2x1xi8, #blocked2> to tensor<1x2x1xi32, #blocked2> - %116 = tt.broadcast %115 : tensor<1x2x1xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %117 = triton_gpu.convert_layout %arg0 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %118 = tt.reshape %117 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<128x2x1xi8, #blocked9> - %119 = arith.muli %118, %75 : tensor<128x2x1xi8, #blocked9> - %120 = "tt.reduce"(%119) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %121 = tt.expand_dims %120 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %122 = tt.broadcast %121 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %123 = tt.reshape %122 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %124 = arith.muli %118, %72 : tensor<128x2x1xi8, #blocked9> - %125 = "tt.reduce"(%124) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %126 = tt.expand_dims %125 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %127 = tt.broadcast %126 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %128 = tt.reshape %127 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %129 = arith.cmpi slt, %123, %128 : tensor<1x256xi8, #blocked8> - %130 = arith.cmpi eq, %123, %128 : tensor<1x256xi8, #blocked8> - %131 = arith.andi %130, %83 : tensor<1x256xi1, #blocked8> - %132 = arith.ori %129, %131 : tensor<1x256xi1, #blocked8> - %133 = arith.extui %132 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %134 = arith.xori %133, %7 : tensor<1x256xi32, #blocked8> - %135 = arith.cmpi ne, %134, %cst_5 : tensor<1x256xi32, #blocked8> - %136 = arith.select %135, %84, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %137 = arith.xori %1, %136 : tensor<1x256xi32, #blocked8> - %138 = tt.reshape %137 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<64x2x2xi32, #blocked2> - %139 = arith.muli %138, %116 : tensor<64x2x2xi32, #blocked2> - %140 = "tt.reduce"(%139) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %141 = tt.expand_dims %140 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %142 = tt.broadcast %141 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %143 = tt.reshape %142 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %144 = arith.muli %138, %10 : tensor<64x2x2xi32, #blocked2> - %145 = "tt.reduce"(%144) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %146 = tt.expand_dims %145 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %147 = tt.broadcast %146 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %148 = tt.reshape %147 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %149 = arith.cmpi sgt, %143, %148 : tensor<1x256xi32, #blocked8> - %150 = arith.xori %143, %148 : tensor<1x256xi32, #blocked8> - %151 = triton_gpu.convert_layout %arg1 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %152 = tt.reshape %151 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<64x2x2xi8, #blocked2> - %153 = arith.muli %152, %114 : tensor<64x2x2xi8, #blocked2> - %154 = "tt.reduce"(%153) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %155 = tt.expand_dims %154 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %156 = tt.broadcast %155 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %157 = tt.reshape %156 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %158 = arith.muli %152, %11 : tensor<64x2x2xi8, #blocked2> - %159 = "tt.reduce"(%158) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %160 = tt.expand_dims %159 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %161 = tt.broadcast %160 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %162 = tt.reshape %161 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %163 = arith.cmpi slt, %157, %162 : tensor<1x256xi8, #blocked8> - %164 = arith.cmpi eq, %157, %162 : tensor<1x256xi8, #blocked8> - %165 = arith.andi %164, %149 : tensor<1x256xi1, #blocked8> - %166 = arith.ori %163, %165 : tensor<1x256xi1, #blocked8> - %167 = arith.extui %166 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %168 = arith.xori %167, %16 : tensor<1x256xi32, #blocked8> - %169 = arith.cmpi ne, %168, %cst_5 : tensor<1x256xi32, #blocked8> - %170 = arith.select %169, %150, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %171 = arith.xori %137, %170 : tensor<1x256xi32, #blocked8> - %172 = tt.reshape %171 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<128x2x1xi32, #blocked9> - %173 = arith.muli %172, %77 : tensor<128x2x1xi32, #blocked9> - %174 = "tt.reduce"(%173) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %175 = tt.expand_dims %174 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %176 = tt.broadcast %175 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %177 = tt.reshape %176 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %178 = arith.muli %172, %66 : tensor<128x2x1xi32, #blocked9> - %179 = "tt.reduce"(%178) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %180 = tt.expand_dims %179 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %181 = tt.broadcast %180 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %182 = tt.reshape %181 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %183 = arith.cmpi sgt, %177, %182 : tensor<1x256xi32, #blocked8> - %184 = arith.xori %177, %182 : tensor<1x256xi32, #blocked8> - %185 = triton_gpu.convert_layout %arg2 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %186 = tt.reshape %185 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<128x2x1xi8, #blocked9> - %187 = arith.muli %186, %75 : tensor<128x2x1xi8, #blocked9> - %188 = "tt.reduce"(%187) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %189 = tt.expand_dims %188 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %190 = tt.broadcast %189 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %191 = tt.reshape %190 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %192 = arith.muli %186, %72 : tensor<128x2x1xi8, #blocked9> - %193 = "tt.reduce"(%192) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %194 = tt.expand_dims %193 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %195 = tt.broadcast %194 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %196 = tt.reshape %195 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %197 = arith.cmpi slt, %191, %196 : tensor<1x256xi8, #blocked8> - %198 = arith.cmpi eq, %191, %196 : tensor<1x256xi8, #blocked8> - %199 = arith.andi %198, %183 : tensor<1x256xi1, #blocked8> - %200 = arith.ori %197, %199 : tensor<1x256xi1, #blocked8> - %201 = arith.extui %200 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %202 = arith.xori %201, %16 : tensor<1x256xi32, #blocked8> - %203 = arith.cmpi ne, %202, %cst_5 : tensor<1x256xi32, #blocked8> - %204 = arith.select %203, %184, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %205 = arith.xori %171, %204 : tensor<1x256xi32, #blocked8> - %206 = tt.reshape %205 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<32x2x4xi32, #blocked3> - %207 = arith.muli %206, %111 : tensor<32x2x4xi32, #blocked3> - %208 = "tt.reduce"(%207) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %209 = tt.expand_dims %208 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %210 = tt.broadcast %209 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %211 = tt.reshape %210 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %212 = arith.muli %206, %19 : tensor<32x2x4xi32, #blocked3> - %213 = "tt.reduce"(%212) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %214 = tt.expand_dims %213 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %215 = tt.broadcast %214 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %216 = tt.reshape %215 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %217 = arith.cmpi sgt, %211, %216 : tensor<1x256xi32, #blocked8> - %218 = arith.xori %211, %216 : tensor<1x256xi32, #blocked8> - %219 = triton_gpu.convert_layout %arg3 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %220 = tt.reshape %219 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<32x2x4xi8, #blocked3> - %221 = arith.muli %220, %109 : tensor<32x2x4xi8, #blocked3> - %222 = "tt.reduce"(%221) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %223 = tt.expand_dims %222 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %224 = tt.broadcast %223 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %225 = tt.reshape %224 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %226 = arith.muli %220, %20 : tensor<32x2x4xi8, #blocked3> - %227 = "tt.reduce"(%226) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %228 = tt.expand_dims %227 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %229 = tt.broadcast %228 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %230 = tt.reshape %229 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %231 = arith.cmpi slt, %225, %230 : tensor<1x256xi8, #blocked8> - %232 = arith.cmpi eq, %225, %230 : tensor<1x256xi8, #blocked8> - %233 = arith.andi %232, %217 : tensor<1x256xi1, #blocked8> - %234 = arith.ori %231, %233 : tensor<1x256xi1, #blocked8> - %235 = arith.extui %234 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %236 = arith.xori %235, %25 : tensor<1x256xi32, #blocked8> - %237 = arith.cmpi ne, %236, %cst_5 : tensor<1x256xi32, #blocked8> - %238 = arith.select %237, %218, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %239 = arith.xori %205, %238 : tensor<1x256xi32, #blocked8> - %240 = tt.reshape %239 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<64x2x2xi32, #blocked2> - %241 = arith.muli %240, %116 : tensor<64x2x2xi32, #blocked2> - %242 = "tt.reduce"(%241) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %243 = tt.expand_dims %242 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %244 = tt.broadcast %243 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %245 = tt.reshape %244 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %246 = arith.muli %240, %10 : tensor<64x2x2xi32, #blocked2> - %247 = "tt.reduce"(%246) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %248 = tt.expand_dims %247 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %249 = tt.broadcast %248 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %250 = tt.reshape %249 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %251 = arith.cmpi sgt, %245, %250 : tensor<1x256xi32, #blocked8> - %252 = arith.xori %245, %250 : tensor<1x256xi32, #blocked8> - %253 = triton_gpu.convert_layout %arg4 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %254 = tt.reshape %253 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<64x2x2xi8, #blocked2> - %255 = arith.muli %254, %114 : tensor<64x2x2xi8, #blocked2> - %256 = "tt.reduce"(%255) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %257 = tt.expand_dims %256 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %258 = tt.broadcast %257 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %259 = tt.reshape %258 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %260 = arith.muli %254, %11 : tensor<64x2x2xi8, #blocked2> - %261 = "tt.reduce"(%260) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %262 = tt.expand_dims %261 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %263 = tt.broadcast %262 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %264 = tt.reshape %263 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %265 = arith.cmpi slt, %259, %264 : tensor<1x256xi8, #blocked8> - %266 = arith.cmpi eq, %259, %264 : tensor<1x256xi8, #blocked8> - %267 = arith.andi %266, %251 : tensor<1x256xi1, #blocked8> - %268 = arith.ori %265, %267 : tensor<1x256xi1, #blocked8> - %269 = arith.extui %268 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %270 = arith.xori %269, %25 : tensor<1x256xi32, #blocked8> - %271 = arith.cmpi ne, %270, %cst_5 : tensor<1x256xi32, #blocked8> - %272 = arith.select %271, %252, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %273 = arith.xori %239, %272 : tensor<1x256xi32, #blocked8> - %274 = tt.reshape %273 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<128x2x1xi32, #blocked9> - %275 = arith.muli %274, %77 : tensor<128x2x1xi32, #blocked9> - %276 = "tt.reduce"(%275) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %277 = tt.expand_dims %276 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %278 = tt.broadcast %277 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %279 = tt.reshape %278 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %280 = arith.muli %274, %66 : tensor<128x2x1xi32, #blocked9> - %281 = "tt.reduce"(%280) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %282 = tt.expand_dims %281 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %283 = tt.broadcast %282 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %284 = tt.reshape %283 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %285 = arith.cmpi sgt, %279, %284 : tensor<1x256xi32, #blocked8> - %286 = arith.xori %279, %284 : tensor<1x256xi32, #blocked8> - %287 = triton_gpu.convert_layout %arg5 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %288 = tt.reshape %287 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<128x2x1xi8, #blocked9> - %289 = arith.muli %288, %75 : tensor<128x2x1xi8, #blocked9> - %290 = "tt.reduce"(%289) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %291 = tt.expand_dims %290 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %292 = tt.broadcast %291 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %293 = tt.reshape %292 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %294 = arith.muli %288, %72 : tensor<128x2x1xi8, #blocked9> - %295 = "tt.reduce"(%294) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %296 = tt.expand_dims %295 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %297 = tt.broadcast %296 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %298 = tt.reshape %297 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %299 = arith.cmpi slt, %293, %298 : tensor<1x256xi8, #blocked8> - %300 = arith.cmpi eq, %293, %298 : tensor<1x256xi8, #blocked8> - %301 = arith.andi %300, %285 : tensor<1x256xi1, #blocked8> - %302 = arith.ori %299, %301 : tensor<1x256xi1, #blocked8> - %303 = arith.extui %302 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %304 = arith.xori %303, %25 : tensor<1x256xi32, #blocked8> - %305 = arith.cmpi ne, %304, %cst_5 : tensor<1x256xi32, #blocked8> - %306 = arith.select %305, %286, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %307 = arith.xori %273, %306 : tensor<1x256xi32, #blocked8> - %308 = tt.reshape %307 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<16x2x8xi32, #blocked4> - %309 = arith.muli %308, %106 : tensor<16x2x8xi32, #blocked4> - %310 = "tt.reduce"(%309) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %311 = tt.expand_dims %310 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %312 = tt.broadcast %311 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %313 = tt.reshape %312 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %314 = arith.muli %308, %28 : tensor<16x2x8xi32, #blocked4> - %315 = "tt.reduce"(%314) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %316 = tt.expand_dims %315 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %317 = tt.broadcast %316 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %318 = tt.reshape %317 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %319 = arith.cmpi sgt, %313, %318 : tensor<1x256xi32, #blocked8> - %320 = arith.xori %313, %318 : tensor<1x256xi32, #blocked8> - %321 = triton_gpu.convert_layout %arg6 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %322 = tt.reshape %321 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<16x2x8xi8, #blocked4> - %323 = arith.muli %322, %104 : tensor<16x2x8xi8, #blocked4> - %324 = "tt.reduce"(%323) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %325 = tt.expand_dims %324 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %326 = tt.broadcast %325 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %327 = tt.reshape %326 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %328 = arith.muli %322, %29 : tensor<16x2x8xi8, #blocked4> - %329 = "tt.reduce"(%328) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %330 = tt.expand_dims %329 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %331 = tt.broadcast %330 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %332 = tt.reshape %331 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %333 = arith.cmpi slt, %327, %332 : tensor<1x256xi8, #blocked8> - %334 = arith.cmpi eq, %327, %332 : tensor<1x256xi8, #blocked8> - %335 = arith.andi %334, %319 : tensor<1x256xi1, #blocked8> - %336 = arith.ori %333, %335 : tensor<1x256xi1, #blocked8> - %337 = arith.extui %336 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %338 = arith.xori %337, %34 : tensor<1x256xi32, #blocked8> - %339 = arith.cmpi ne, %338, %cst_5 : tensor<1x256xi32, #blocked8> - %340 = arith.select %339, %320, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %341 = arith.xori %307, %340 : tensor<1x256xi32, #blocked8> - %342 = tt.reshape %341 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<32x2x4xi32, #blocked3> - %343 = arith.muli %342, %111 : tensor<32x2x4xi32, #blocked3> - %344 = "tt.reduce"(%343) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %345 = tt.expand_dims %344 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %346 = tt.broadcast %345 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %347 = tt.reshape %346 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %348 = arith.muli %342, %19 : tensor<32x2x4xi32, #blocked3> - %349 = "tt.reduce"(%348) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %350 = tt.expand_dims %349 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %351 = tt.broadcast %350 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %352 = tt.reshape %351 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %353 = arith.cmpi sgt, %347, %352 : tensor<1x256xi32, #blocked8> - %354 = arith.xori %347, %352 : tensor<1x256xi32, #blocked8> - %355 = triton_gpu.convert_layout %arg7 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %356 = tt.reshape %355 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<32x2x4xi8, #blocked3> - %357 = arith.muli %356, %109 : tensor<32x2x4xi8, #blocked3> - %358 = "tt.reduce"(%357) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %359 = tt.expand_dims %358 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %360 = tt.broadcast %359 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %361 = tt.reshape %360 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %362 = arith.muli %356, %20 : tensor<32x2x4xi8, #blocked3> - %363 = "tt.reduce"(%362) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %364 = tt.expand_dims %363 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %365 = tt.broadcast %364 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %366 = tt.reshape %365 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %367 = arith.cmpi slt, %361, %366 : tensor<1x256xi8, #blocked8> - %368 = arith.cmpi eq, %361, %366 : tensor<1x256xi8, #blocked8> - %369 = arith.andi %368, %353 : tensor<1x256xi1, #blocked8> - %370 = arith.ori %367, %369 : tensor<1x256xi1, #blocked8> - %371 = arith.extui %370 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %372 = arith.xori %371, %34 : tensor<1x256xi32, #blocked8> - %373 = arith.cmpi ne, %372, %cst_5 : tensor<1x256xi32, #blocked8> - %374 = arith.select %373, %354, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %375 = arith.xori %341, %374 : tensor<1x256xi32, #blocked8> - %376 = tt.reshape %375 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<64x2x2xi32, #blocked2> - %377 = arith.muli %376, %116 : tensor<64x2x2xi32, #blocked2> - %378 = "tt.reduce"(%377) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %379 = tt.expand_dims %378 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %380 = tt.broadcast %379 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %381 = tt.reshape %380 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %382 = arith.muli %376, %10 : tensor<64x2x2xi32, #blocked2> - %383 = "tt.reduce"(%382) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %384 = tt.expand_dims %383 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %385 = tt.broadcast %384 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %386 = tt.reshape %385 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %387 = arith.cmpi sgt, %381, %386 : tensor<1x256xi32, #blocked8> - %388 = arith.xori %381, %386 : tensor<1x256xi32, #blocked8> - %389 = triton_gpu.convert_layout %arg8 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %390 = tt.reshape %389 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<64x2x2xi8, #blocked2> - %391 = arith.muli %390, %114 : tensor<64x2x2xi8, #blocked2> - %392 = "tt.reduce"(%391) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %393 = tt.expand_dims %392 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %394 = tt.broadcast %393 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %395 = tt.reshape %394 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %396 = arith.muli %390, %11 : tensor<64x2x2xi8, #blocked2> - %397 = "tt.reduce"(%396) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %398 = tt.expand_dims %397 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %399 = tt.broadcast %398 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %400 = tt.reshape %399 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %401 = arith.cmpi slt, %395, %400 : tensor<1x256xi8, #blocked8> - %402 = arith.cmpi eq, %395, %400 : tensor<1x256xi8, #blocked8> - %403 = arith.andi %402, %387 : tensor<1x256xi1, #blocked8> - %404 = arith.ori %401, %403 : tensor<1x256xi1, #blocked8> - %405 = arith.extui %404 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %406 = arith.xori %405, %34 : tensor<1x256xi32, #blocked8> - %407 = arith.cmpi ne, %406, %cst_5 : tensor<1x256xi32, #blocked8> - %408 = arith.select %407, %388, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %409 = arith.xori %375, %408 : tensor<1x256xi32, #blocked8> - %410 = tt.reshape %409 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<128x2x1xi32, #blocked9> - %411 = arith.muli %410, %77 : tensor<128x2x1xi32, #blocked9> - %412 = "tt.reduce"(%411) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %413 = tt.expand_dims %412 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %414 = tt.broadcast %413 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %415 = tt.reshape %414 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %416 = arith.muli %410, %66 : tensor<128x2x1xi32, #blocked9> - %417 = "tt.reduce"(%416) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %418 = tt.expand_dims %417 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %419 = tt.broadcast %418 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %420 = tt.reshape %419 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %421 = arith.cmpi sgt, %415, %420 : tensor<1x256xi32, #blocked8> - %422 = arith.xori %415, %420 : tensor<1x256xi32, #blocked8> - %423 = triton_gpu.convert_layout %arg9 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %424 = tt.reshape %423 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<128x2x1xi8, #blocked9> - %425 = arith.muli %424, %75 : tensor<128x2x1xi8, #blocked9> - %426 = "tt.reduce"(%425) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %427 = tt.expand_dims %426 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %428 = tt.broadcast %427 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %429 = tt.reshape %428 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %430 = arith.muli %424, %72 : tensor<128x2x1xi8, #blocked9> - %431 = "tt.reduce"(%430) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %432 = tt.expand_dims %431 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %433 = tt.broadcast %432 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %434 = tt.reshape %433 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %435 = arith.cmpi slt, %429, %434 : tensor<1x256xi8, #blocked8> - %436 = arith.cmpi eq, %429, %434 : tensor<1x256xi8, #blocked8> - %437 = arith.andi %436, %421 : tensor<1x256xi1, #blocked8> - %438 = arith.ori %435, %437 : tensor<1x256xi1, #blocked8> - %439 = arith.extui %438 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %440 = arith.xori %439, %34 : tensor<1x256xi32, #blocked8> - %441 = arith.cmpi ne, %440, %cst_5 : tensor<1x256xi32, #blocked8> - %442 = arith.select %441, %422, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %443 = arith.xori %409, %442 : tensor<1x256xi32, #blocked8> - %444 = tt.reshape %443 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<8x2x16xi32, #blocked5> - %445 = arith.muli %444, %101 : tensor<8x2x16xi32, #blocked5> - %446 = "tt.reduce"(%445) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<8x2x16xi32, #blocked5>) -> tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %447 = tt.expand_dims %446 {axis = 1 : i32} : tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi32, #blocked5> - %448 = tt.broadcast %447 : tensor<8x1x16xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %449 = tt.reshape %448 {allow_reorder = false} : tensor<8x2x16xi32, #blocked5> -> tensor<1x256xi32, #blocked8> - %450 = arith.muli %444, %37 : tensor<8x2x16xi32, #blocked5> - %451 = "tt.reduce"(%450) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<8x2x16xi32, #blocked5>) -> tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %452 = tt.expand_dims %451 {axis = 1 : i32} : tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi32, #blocked5> - %453 = tt.broadcast %452 : tensor<8x1x16xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %454 = tt.reshape %453 {allow_reorder = false} : tensor<8x2x16xi32, #blocked5> -> tensor<1x256xi32, #blocked8> - %455 = arith.cmpi sgt, %449, %454 : tensor<1x256xi32, #blocked8> - %456 = arith.xori %449, %454 : tensor<1x256xi32, #blocked8> - %457 = triton_gpu.convert_layout %arg10 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %458 = tt.reshape %457 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<8x2x16xi8, #blocked5> - %459 = arith.muli %458, %99 : tensor<8x2x16xi8, #blocked5> - %460 = "tt.reduce"(%459) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<8x2x16xi8, #blocked5>) -> tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %461 = tt.expand_dims %460 {axis = 1 : i32} : tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi8, #blocked5> - %462 = tt.broadcast %461 : tensor<8x1x16xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %463 = tt.reshape %462 {allow_reorder = false} : tensor<8x2x16xi8, #blocked5> -> tensor<1x256xi8, #blocked8> - %464 = arith.muli %458, %38 : tensor<8x2x16xi8, #blocked5> - %465 = "tt.reduce"(%464) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<8x2x16xi8, #blocked5>) -> tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %466 = tt.expand_dims %465 {axis = 1 : i32} : tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi8, #blocked5> - %467 = tt.broadcast %466 : tensor<8x1x16xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %468 = tt.reshape %467 {allow_reorder = false} : tensor<8x2x16xi8, #blocked5> -> tensor<1x256xi8, #blocked8> - %469 = arith.cmpi slt, %463, %468 : tensor<1x256xi8, #blocked8> - %470 = arith.cmpi eq, %463, %468 : tensor<1x256xi8, #blocked8> - %471 = arith.andi %470, %455 : tensor<1x256xi1, #blocked8> - %472 = arith.ori %469, %471 : tensor<1x256xi1, #blocked8> - %473 = arith.extui %472 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %474 = arith.xori %473, %43 : tensor<1x256xi32, #blocked8> - %475 = arith.cmpi ne, %474, %cst_5 : tensor<1x256xi32, #blocked8> - %476 = arith.select %475, %456, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %477 = arith.xori %443, %476 : tensor<1x256xi32, #blocked8> - %478 = tt.reshape %477 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<16x2x8xi32, #blocked4> - %479 = arith.muli %478, %106 : tensor<16x2x8xi32, #blocked4> - %480 = "tt.reduce"(%479) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %481 = tt.expand_dims %480 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %482 = tt.broadcast %481 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %483 = tt.reshape %482 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %484 = arith.muli %478, %28 : tensor<16x2x8xi32, #blocked4> - %485 = "tt.reduce"(%484) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %486 = tt.expand_dims %485 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %487 = tt.broadcast %486 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %488 = tt.reshape %487 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %489 = arith.cmpi sgt, %483, %488 : tensor<1x256xi32, #blocked8> - %490 = arith.xori %483, %488 : tensor<1x256xi32, #blocked8> - %491 = triton_gpu.convert_layout %arg11 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %492 = tt.reshape %491 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<16x2x8xi8, #blocked4> - %493 = arith.muli %492, %104 : tensor<16x2x8xi8, #blocked4> - %494 = "tt.reduce"(%493) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %495 = tt.expand_dims %494 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %496 = tt.broadcast %495 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %497 = tt.reshape %496 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %498 = arith.muli %492, %29 : tensor<16x2x8xi8, #blocked4> - %499 = "tt.reduce"(%498) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %500 = tt.expand_dims %499 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %501 = tt.broadcast %500 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %502 = tt.reshape %501 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %503 = arith.cmpi slt, %497, %502 : tensor<1x256xi8, #blocked8> - %504 = arith.cmpi eq, %497, %502 : tensor<1x256xi8, #blocked8> - %505 = arith.andi %504, %489 : tensor<1x256xi1, #blocked8> - %506 = arith.ori %503, %505 : tensor<1x256xi1, #blocked8> - %507 = arith.extui %506 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %508 = arith.xori %507, %43 : tensor<1x256xi32, #blocked8> - %509 = arith.cmpi ne, %508, %cst_5 : tensor<1x256xi32, #blocked8> - %510 = arith.select %509, %490, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %511 = arith.xori %477, %510 : tensor<1x256xi32, #blocked8> - %512 = tt.reshape %511 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<32x2x4xi32, #blocked3> - %513 = arith.muli %512, %111 : tensor<32x2x4xi32, #blocked3> - %514 = "tt.reduce"(%513) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %515 = tt.expand_dims %514 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %516 = tt.broadcast %515 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %517 = tt.reshape %516 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %518 = arith.muli %512, %19 : tensor<32x2x4xi32, #blocked3> - %519 = "tt.reduce"(%518) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %520 = tt.expand_dims %519 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %521 = tt.broadcast %520 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %522 = tt.reshape %521 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %523 = arith.cmpi sgt, %517, %522 : tensor<1x256xi32, #blocked8> - %524 = arith.xori %517, %522 : tensor<1x256xi32, #blocked8> - %525 = triton_gpu.convert_layout %arg12 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %526 = tt.reshape %525 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<32x2x4xi8, #blocked3> - %527 = arith.muli %526, %109 : tensor<32x2x4xi8, #blocked3> - %528 = "tt.reduce"(%527) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %529 = tt.expand_dims %528 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %530 = tt.broadcast %529 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %531 = tt.reshape %530 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %532 = arith.muli %526, %20 : tensor<32x2x4xi8, #blocked3> - %533 = "tt.reduce"(%532) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %534 = tt.expand_dims %533 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %535 = tt.broadcast %534 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %536 = tt.reshape %535 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %537 = arith.cmpi slt, %531, %536 : tensor<1x256xi8, #blocked8> - %538 = arith.cmpi eq, %531, %536 : tensor<1x256xi8, #blocked8> - %539 = arith.andi %538, %523 : tensor<1x256xi1, #blocked8> - %540 = arith.ori %537, %539 : tensor<1x256xi1, #blocked8> - %541 = arith.extui %540 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %542 = arith.xori %541, %43 : tensor<1x256xi32, #blocked8> - %543 = arith.cmpi ne, %542, %cst_5 : tensor<1x256xi32, #blocked8> - %544 = arith.select %543, %524, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %545 = arith.xori %511, %544 : tensor<1x256xi32, #blocked8> - %546 = tt.reshape %545 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<64x2x2xi32, #blocked2> - %547 = arith.muli %546, %116 : tensor<64x2x2xi32, #blocked2> - %548 = "tt.reduce"(%547) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %549 = tt.expand_dims %548 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %550 = tt.broadcast %549 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %551 = tt.reshape %550 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %552 = arith.muli %546, %10 : tensor<64x2x2xi32, #blocked2> - %553 = "tt.reduce"(%552) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %554 = tt.expand_dims %553 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %555 = tt.broadcast %554 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %556 = tt.reshape %555 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %557 = arith.cmpi sgt, %551, %556 : tensor<1x256xi32, #blocked8> - %558 = arith.xori %551, %556 : tensor<1x256xi32, #blocked8> - %559 = triton_gpu.convert_layout %arg13 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %560 = tt.reshape %559 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<64x2x2xi8, #blocked2> - %561 = arith.muli %560, %114 : tensor<64x2x2xi8, #blocked2> - %562 = "tt.reduce"(%561) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %563 = tt.expand_dims %562 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %564 = tt.broadcast %563 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %565 = tt.reshape %564 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %566 = arith.muli %560, %11 : tensor<64x2x2xi8, #blocked2> - %567 = "tt.reduce"(%566) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %568 = tt.expand_dims %567 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %569 = tt.broadcast %568 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %570 = tt.reshape %569 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %571 = arith.cmpi slt, %565, %570 : tensor<1x256xi8, #blocked8> - %572 = arith.cmpi eq, %565, %570 : tensor<1x256xi8, #blocked8> - %573 = arith.andi %572, %557 : tensor<1x256xi1, #blocked8> - %574 = arith.ori %571, %573 : tensor<1x256xi1, #blocked8> - %575 = arith.extui %574 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %576 = arith.xori %575, %43 : tensor<1x256xi32, #blocked8> - %577 = arith.cmpi ne, %576, %cst_5 : tensor<1x256xi32, #blocked8> - %578 = arith.select %577, %558, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %579 = arith.xori %545, %578 : tensor<1x256xi32, #blocked8> - %580 = tt.reshape %579 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<128x2x1xi32, #blocked9> - %581 = arith.muli %580, %77 : tensor<128x2x1xi32, #blocked9> - %582 = "tt.reduce"(%581) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %583 = tt.expand_dims %582 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %584 = tt.broadcast %583 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %585 = tt.reshape %584 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %586 = arith.muli %580, %66 : tensor<128x2x1xi32, #blocked9> - %587 = "tt.reduce"(%586) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %588 = tt.expand_dims %587 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %589 = tt.broadcast %588 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %590 = tt.reshape %589 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %591 = arith.cmpi sgt, %585, %590 : tensor<1x256xi32, #blocked8> - %592 = arith.xori %585, %590 : tensor<1x256xi32, #blocked8> - %593 = triton_gpu.convert_layout %arg14 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %594 = tt.reshape %593 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<128x2x1xi8, #blocked9> - %595 = arith.muli %594, %75 : tensor<128x2x1xi8, #blocked9> - %596 = "tt.reduce"(%595) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %597 = tt.expand_dims %596 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %598 = tt.broadcast %597 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %599 = tt.reshape %598 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %600 = arith.muli %594, %72 : tensor<128x2x1xi8, #blocked9> - %601 = "tt.reduce"(%600) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %602 = tt.expand_dims %601 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %603 = tt.broadcast %602 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %604 = tt.reshape %603 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %605 = arith.cmpi slt, %599, %604 : tensor<1x256xi8, #blocked8> - %606 = arith.cmpi eq, %599, %604 : tensor<1x256xi8, #blocked8> - %607 = arith.andi %606, %591 : tensor<1x256xi1, #blocked8> - %608 = arith.ori %605, %607 : tensor<1x256xi1, #blocked8> - %609 = arith.extui %608 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %610 = arith.xori %609, %43 : tensor<1x256xi32, #blocked8> - %611 = arith.cmpi ne, %610, %cst_5 : tensor<1x256xi32, #blocked8> - %612 = arith.select %611, %592, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %613 = arith.xori %579, %612 : tensor<1x256xi32, #blocked8> - %614 = tt.reshape %613 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<4x2x32xi32, #blocked6> - %615 = arith.muli %614, %96 : tensor<4x2x32xi32, #blocked6> - %616 = "tt.reduce"(%615) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<4x2x32xi32, #blocked6>) -> tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %617 = tt.expand_dims %616 {axis = 1 : i32} : tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi32, #blocked6> - %618 = tt.broadcast %617 : tensor<4x1x32xi32, #blocked6> -> tensor<4x2x32xi32, #blocked6> - %619 = tt.reshape %618 {allow_reorder = false} : tensor<4x2x32xi32, #blocked6> -> tensor<1x256xi32, #blocked8> - %620 = arith.muli %614, %46 : tensor<4x2x32xi32, #blocked6> - %621 = "tt.reduce"(%620) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<4x2x32xi32, #blocked6>) -> tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %622 = tt.expand_dims %621 {axis = 1 : i32} : tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi32, #blocked6> - %623 = tt.broadcast %622 : tensor<4x1x32xi32, #blocked6> -> tensor<4x2x32xi32, #blocked6> - %624 = tt.reshape %623 {allow_reorder = false} : tensor<4x2x32xi32, #blocked6> -> tensor<1x256xi32, #blocked8> - %625 = arith.cmpi sgt, %619, %624 : tensor<1x256xi32, #blocked8> - %626 = arith.xori %619, %624 : tensor<1x256xi32, #blocked8> - %627 = triton_gpu.convert_layout %arg15 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %628 = tt.reshape %627 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<4x2x32xi8, #blocked6> - %629 = arith.muli %628, %94 : tensor<4x2x32xi8, #blocked6> - %630 = "tt.reduce"(%629) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<4x2x32xi8, #blocked6>) -> tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %631 = tt.expand_dims %630 {axis = 1 : i32} : tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi8, #blocked6> - %632 = tt.broadcast %631 : tensor<4x1x32xi8, #blocked6> -> tensor<4x2x32xi8, #blocked6> - %633 = tt.reshape %632 {allow_reorder = false} : tensor<4x2x32xi8, #blocked6> -> tensor<1x256xi8, #blocked8> - %634 = arith.muli %628, %47 : tensor<4x2x32xi8, #blocked6> - %635 = "tt.reduce"(%634) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<4x2x32xi8, #blocked6>) -> tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %636 = tt.expand_dims %635 {axis = 1 : i32} : tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi8, #blocked6> - %637 = tt.broadcast %636 : tensor<4x1x32xi8, #blocked6> -> tensor<4x2x32xi8, #blocked6> - %638 = tt.reshape %637 {allow_reorder = false} : tensor<4x2x32xi8, #blocked6> -> tensor<1x256xi8, #blocked8> - %639 = arith.cmpi slt, %633, %638 : tensor<1x256xi8, #blocked8> - %640 = arith.cmpi eq, %633, %638 : tensor<1x256xi8, #blocked8> - %641 = arith.andi %640, %625 : tensor<1x256xi1, #blocked8> - %642 = arith.ori %639, %641 : tensor<1x256xi1, #blocked8> - %643 = arith.extui %642 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %644 = arith.xori %643, %52 : tensor<1x256xi32, #blocked8> - %645 = arith.cmpi ne, %644, %cst_5 : tensor<1x256xi32, #blocked8> - %646 = arith.select %645, %626, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %647 = arith.xori %613, %646 : tensor<1x256xi32, #blocked8> - %648 = tt.reshape %647 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<8x2x16xi32, #blocked5> - %649 = arith.muli %648, %101 : tensor<8x2x16xi32, #blocked5> - %650 = "tt.reduce"(%649) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<8x2x16xi32, #blocked5>) -> tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %651 = tt.expand_dims %650 {axis = 1 : i32} : tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi32, #blocked5> - %652 = tt.broadcast %651 : tensor<8x1x16xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %653 = tt.reshape %652 {allow_reorder = false} : tensor<8x2x16xi32, #blocked5> -> tensor<1x256xi32, #blocked8> - %654 = arith.muli %648, %37 : tensor<8x2x16xi32, #blocked5> - %655 = "tt.reduce"(%654) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<8x2x16xi32, #blocked5>) -> tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %656 = tt.expand_dims %655 {axis = 1 : i32} : tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi32, #blocked5> - %657 = tt.broadcast %656 : tensor<8x1x16xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %658 = tt.reshape %657 {allow_reorder = false} : tensor<8x2x16xi32, #blocked5> -> tensor<1x256xi32, #blocked8> - %659 = arith.cmpi sgt, %653, %658 : tensor<1x256xi32, #blocked8> - %660 = arith.xori %653, %658 : tensor<1x256xi32, #blocked8> - %661 = triton_gpu.convert_layout %arg16 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %662 = tt.reshape %661 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<8x2x16xi8, #blocked5> - %663 = arith.muli %662, %99 : tensor<8x2x16xi8, #blocked5> - %664 = "tt.reduce"(%663) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<8x2x16xi8, #blocked5>) -> tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %665 = tt.expand_dims %664 {axis = 1 : i32} : tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi8, #blocked5> - %666 = tt.broadcast %665 : tensor<8x1x16xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %667 = tt.reshape %666 {allow_reorder = false} : tensor<8x2x16xi8, #blocked5> -> tensor<1x256xi8, #blocked8> - %668 = arith.muli %662, %38 : tensor<8x2x16xi8, #blocked5> - %669 = "tt.reduce"(%668) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<8x2x16xi8, #blocked5>) -> tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %670 = tt.expand_dims %669 {axis = 1 : i32} : tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi8, #blocked5> - %671 = tt.broadcast %670 : tensor<8x1x16xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %672 = tt.reshape %671 {allow_reorder = false} : tensor<8x2x16xi8, #blocked5> -> tensor<1x256xi8, #blocked8> - %673 = arith.cmpi slt, %667, %672 : tensor<1x256xi8, #blocked8> - %674 = arith.cmpi eq, %667, %672 : tensor<1x256xi8, #blocked8> - %675 = arith.andi %674, %659 : tensor<1x256xi1, #blocked8> - %676 = arith.ori %673, %675 : tensor<1x256xi1, #blocked8> - %677 = arith.extui %676 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %678 = arith.xori %677, %52 : tensor<1x256xi32, #blocked8> - %679 = arith.cmpi ne, %678, %cst_5 : tensor<1x256xi32, #blocked8> - %680 = arith.select %679, %660, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %681 = arith.xori %647, %680 : tensor<1x256xi32, #blocked8> - %682 = tt.reshape %681 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<16x2x8xi32, #blocked4> - %683 = arith.muli %682, %106 : tensor<16x2x8xi32, #blocked4> - %684 = "tt.reduce"(%683) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %685 = tt.expand_dims %684 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %686 = tt.broadcast %685 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %687 = tt.reshape %686 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %688 = arith.muli %682, %28 : tensor<16x2x8xi32, #blocked4> - %689 = "tt.reduce"(%688) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %690 = tt.expand_dims %689 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %691 = tt.broadcast %690 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %692 = tt.reshape %691 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %693 = arith.cmpi sgt, %687, %692 : tensor<1x256xi32, #blocked8> - %694 = arith.xori %687, %692 : tensor<1x256xi32, #blocked8> - %695 = triton_gpu.convert_layout %arg17 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %696 = tt.reshape %695 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<16x2x8xi8, #blocked4> - %697 = arith.muli %696, %104 : tensor<16x2x8xi8, #blocked4> - %698 = "tt.reduce"(%697) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %699 = tt.expand_dims %698 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %700 = tt.broadcast %699 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %701 = tt.reshape %700 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %702 = arith.muli %696, %29 : tensor<16x2x8xi8, #blocked4> - %703 = "tt.reduce"(%702) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %704 = tt.expand_dims %703 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %705 = tt.broadcast %704 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %706 = tt.reshape %705 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %707 = arith.cmpi slt, %701, %706 : tensor<1x256xi8, #blocked8> - %708 = arith.cmpi eq, %701, %706 : tensor<1x256xi8, #blocked8> - %709 = arith.andi %708, %693 : tensor<1x256xi1, #blocked8> - %710 = arith.ori %707, %709 : tensor<1x256xi1, #blocked8> - %711 = arith.extui %710 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %712 = arith.xori %711, %52 : tensor<1x256xi32, #blocked8> - %713 = arith.cmpi ne, %712, %cst_5 : tensor<1x256xi32, #blocked8> - %714 = arith.select %713, %694, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %715 = arith.xori %681, %714 : tensor<1x256xi32, #blocked8> - %716 = tt.reshape %715 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<32x2x4xi32, #blocked3> - %717 = arith.muli %716, %111 : tensor<32x2x4xi32, #blocked3> - %718 = "tt.reduce"(%717) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %719 = tt.expand_dims %718 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %720 = tt.broadcast %719 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %721 = tt.reshape %720 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %722 = arith.muli %716, %19 : tensor<32x2x4xi32, #blocked3> - %723 = "tt.reduce"(%722) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %724 = tt.expand_dims %723 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %725 = tt.broadcast %724 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %726 = tt.reshape %725 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %727 = arith.cmpi sgt, %721, %726 : tensor<1x256xi32, #blocked8> - %728 = arith.xori %721, %726 : tensor<1x256xi32, #blocked8> - %729 = triton_gpu.convert_layout %arg18 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %730 = tt.reshape %729 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<32x2x4xi8, #blocked3> - %731 = arith.muli %730, %109 : tensor<32x2x4xi8, #blocked3> - %732 = "tt.reduce"(%731) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %733 = tt.expand_dims %732 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %734 = tt.broadcast %733 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %735 = tt.reshape %734 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %736 = arith.muli %730, %20 : tensor<32x2x4xi8, #blocked3> - %737 = "tt.reduce"(%736) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %738 = tt.expand_dims %737 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %739 = tt.broadcast %738 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %740 = tt.reshape %739 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %741 = arith.cmpi slt, %735, %740 : tensor<1x256xi8, #blocked8> - %742 = arith.cmpi eq, %735, %740 : tensor<1x256xi8, #blocked8> - %743 = arith.andi %742, %727 : tensor<1x256xi1, #blocked8> - %744 = arith.ori %741, %743 : tensor<1x256xi1, #blocked8> - %745 = arith.extui %744 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %746 = arith.xori %745, %52 : tensor<1x256xi32, #blocked8> - %747 = arith.cmpi ne, %746, %cst_5 : tensor<1x256xi32, #blocked8> - %748 = arith.select %747, %728, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %749 = arith.xori %715, %748 : tensor<1x256xi32, #blocked8> - %750 = tt.reshape %749 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<64x2x2xi32, #blocked2> - %751 = arith.muli %750, %116 : tensor<64x2x2xi32, #blocked2> - %752 = "tt.reduce"(%751) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %753 = tt.expand_dims %752 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %754 = tt.broadcast %753 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %755 = tt.reshape %754 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %756 = arith.muli %750, %10 : tensor<64x2x2xi32, #blocked2> - %757 = "tt.reduce"(%756) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %758 = tt.expand_dims %757 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %759 = tt.broadcast %758 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %760 = tt.reshape %759 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %761 = arith.cmpi sgt, %755, %760 : tensor<1x256xi32, #blocked8> - %762 = arith.xori %755, %760 : tensor<1x256xi32, #blocked8> - %763 = triton_gpu.convert_layout %arg19 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %764 = tt.reshape %763 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<64x2x2xi8, #blocked2> - %765 = arith.muli %764, %114 : tensor<64x2x2xi8, #blocked2> - %766 = "tt.reduce"(%765) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %767 = tt.expand_dims %766 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %768 = tt.broadcast %767 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %769 = tt.reshape %768 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %770 = arith.muli %764, %11 : tensor<64x2x2xi8, #blocked2> - %771 = "tt.reduce"(%770) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %772 = tt.expand_dims %771 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %773 = tt.broadcast %772 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %774 = tt.reshape %773 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %775 = arith.cmpi slt, %769, %774 : tensor<1x256xi8, #blocked8> - %776 = arith.cmpi eq, %769, %774 : tensor<1x256xi8, #blocked8> - %777 = arith.andi %776, %761 : tensor<1x256xi1, #blocked8> - %778 = arith.ori %775, %777 : tensor<1x256xi1, #blocked8> - %779 = arith.extui %778 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %780 = arith.xori %779, %52 : tensor<1x256xi32, #blocked8> - %781 = arith.cmpi ne, %780, %cst_5 : tensor<1x256xi32, #blocked8> - %782 = arith.select %781, %762, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %783 = arith.xori %749, %782 : tensor<1x256xi32, #blocked8> - %784 = tt.reshape %783 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<128x2x1xi32, #blocked9> - %785 = arith.muli %784, %77 : tensor<128x2x1xi32, #blocked9> - %786 = "tt.reduce"(%785) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %787 = tt.expand_dims %786 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %788 = tt.broadcast %787 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %789 = tt.reshape %788 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %790 = arith.muli %784, %66 : tensor<128x2x1xi32, #blocked9> - %791 = "tt.reduce"(%790) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %792 = tt.expand_dims %791 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %793 = tt.broadcast %792 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %794 = tt.reshape %793 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %795 = arith.cmpi sgt, %789, %794 : tensor<1x256xi32, #blocked8> - %796 = arith.xori %789, %794 : tensor<1x256xi32, #blocked8> - %797 = triton_gpu.convert_layout %arg20 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %798 = tt.reshape %797 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<128x2x1xi8, #blocked9> - %799 = arith.muli %798, %75 : tensor<128x2x1xi8, #blocked9> - %800 = "tt.reduce"(%799) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %801 = tt.expand_dims %800 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %802 = tt.broadcast %801 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %803 = tt.reshape %802 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %804 = arith.muli %798, %72 : tensor<128x2x1xi8, #blocked9> - %805 = "tt.reduce"(%804) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %806 = tt.expand_dims %805 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %807 = tt.broadcast %806 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %808 = tt.reshape %807 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %809 = arith.cmpi slt, %803, %808 : tensor<1x256xi8, #blocked8> - %810 = arith.cmpi eq, %803, %808 : tensor<1x256xi8, #blocked8> - %811 = arith.andi %810, %795 : tensor<1x256xi1, #blocked8> - %812 = arith.ori %809, %811 : tensor<1x256xi1, #blocked8> - %813 = arith.extui %812 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %814 = arith.xori %813, %52 : tensor<1x256xi32, #blocked8> - %815 = arith.cmpi ne, %814, %cst_5 : tensor<1x256xi32, #blocked8> - %816 = arith.select %815, %796, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %817 = arith.xori %783, %816 : tensor<1x256xi32, #blocked8> - %818 = tt.reshape %817 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<2x2x64xi32, #blocked7> - %819 = arith.muli %818, %89 : tensor<2x2x64xi32, #blocked7> - %820 = "tt.reduce"(%819) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<2x2x64xi32, #blocked7>) -> tensor<2x64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %821 = tt.expand_dims %820 {axis = 1 : i32} : tensor<2x64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<2x1x64xi32, #blocked7> - %822 = tt.broadcast %821 : tensor<2x1x64xi32, #blocked7> -> tensor<2x2x64xi32, #blocked7> - %823 = tt.reshape %822 {allow_reorder = false} : tensor<2x2x64xi32, #blocked7> -> tensor<1x256xi32, #blocked8> - %824 = arith.muli %818, %57 : tensor<2x2x64xi32, #blocked7> - %825 = "tt.reduce"(%824) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<2x2x64xi32, #blocked7>) -> tensor<2x64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %826 = tt.expand_dims %825 {axis = 1 : i32} : tensor<2x64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<2x1x64xi32, #blocked7> - %827 = tt.broadcast %826 : tensor<2x1x64xi32, #blocked7> -> tensor<2x2x64xi32, #blocked7> - %828 = tt.reshape %827 {allow_reorder = false} : tensor<2x2x64xi32, #blocked7> -> tensor<1x256xi32, #blocked8> - %829 = arith.cmpi sgt, %823, %828 : tensor<1x256xi32, #blocked8> - %830 = arith.xori %823, %828 : tensor<1x256xi32, #blocked8> - %831 = triton_gpu.convert_layout %arg21 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %832 = tt.reshape %831 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<2x2x64xi8, #blocked7> - %833 = arith.muli %832, %87 : tensor<2x2x64xi8, #blocked7> - %834 = "tt.reduce"(%833) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<2x2x64xi8, #blocked7>) -> tensor<2x64xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %835 = tt.expand_dims %834 {axis = 1 : i32} : tensor<2x64xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<2x1x64xi8, #blocked7> - %836 = tt.broadcast %835 : tensor<2x1x64xi8, #blocked7> -> tensor<2x2x64xi8, #blocked7> - %837 = tt.reshape %836 {allow_reorder = false} : tensor<2x2x64xi8, #blocked7> -> tensor<1x256xi8, #blocked8> - %838 = arith.muli %832, %59 : tensor<2x2x64xi8, #blocked7> - %839 = "tt.reduce"(%838) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<2x2x64xi8, #blocked7>) -> tensor<2x64xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %840 = tt.expand_dims %839 {axis = 1 : i32} : tensor<2x64xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<2x1x64xi8, #blocked7> - %841 = tt.broadcast %840 : tensor<2x1x64xi8, #blocked7> -> tensor<2x2x64xi8, #blocked7> - %842 = tt.reshape %841 {allow_reorder = false} : tensor<2x2x64xi8, #blocked7> -> tensor<1x256xi8, #blocked8> - %843 = arith.cmpi slt, %837, %842 : tensor<1x256xi8, #blocked8> - %844 = arith.cmpi eq, %837, %842 : tensor<1x256xi8, #blocked8> - %845 = arith.andi %844, %829 : tensor<1x256xi1, #blocked8> - %846 = arith.ori %843, %845 : tensor<1x256xi1, #blocked8> - %847 = arith.extui %846 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %848 = arith.xori %847, %54 : tensor<1x256xi32, #blocked8> - %849 = arith.cmpi ne, %848, %cst_5 : tensor<1x256xi32, #blocked8> - %850 = arith.select %849, %830, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %851 = arith.xori %817, %850 : tensor<1x256xi32, #blocked8> - %852 = tt.reshape %851 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<4x2x32xi32, #blocked6> - %853 = arith.muli %852, %96 : tensor<4x2x32xi32, #blocked6> - %854 = "tt.reduce"(%853) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<4x2x32xi32, #blocked6>) -> tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %855 = tt.expand_dims %854 {axis = 1 : i32} : tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi32, #blocked6> - %856 = tt.broadcast %855 : tensor<4x1x32xi32, #blocked6> -> tensor<4x2x32xi32, #blocked6> - %857 = tt.reshape %856 {allow_reorder = false} : tensor<4x2x32xi32, #blocked6> -> tensor<1x256xi32, #blocked8> - %858 = arith.muli %852, %46 : tensor<4x2x32xi32, #blocked6> - %859 = "tt.reduce"(%858) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<4x2x32xi32, #blocked6>) -> tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %860 = tt.expand_dims %859 {axis = 1 : i32} : tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi32, #blocked6> - %861 = tt.broadcast %860 : tensor<4x1x32xi32, #blocked6> -> tensor<4x2x32xi32, #blocked6> - %862 = tt.reshape %861 {allow_reorder = false} : tensor<4x2x32xi32, #blocked6> -> tensor<1x256xi32, #blocked8> - %863 = arith.cmpi sgt, %857, %862 : tensor<1x256xi32, #blocked8> - %864 = arith.xori %857, %862 : tensor<1x256xi32, #blocked8> - %865 = triton_gpu.convert_layout %arg22 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %866 = tt.reshape %865 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<4x2x32xi8, #blocked6> - %867 = arith.muli %866, %94 : tensor<4x2x32xi8, #blocked6> - %868 = "tt.reduce"(%867) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<4x2x32xi8, #blocked6>) -> tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %869 = tt.expand_dims %868 {axis = 1 : i32} : tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi8, #blocked6> - %870 = tt.broadcast %869 : tensor<4x1x32xi8, #blocked6> -> tensor<4x2x32xi8, #blocked6> - %871 = tt.reshape %870 {allow_reorder = false} : tensor<4x2x32xi8, #blocked6> -> tensor<1x256xi8, #blocked8> - %872 = arith.muli %866, %47 : tensor<4x2x32xi8, #blocked6> - %873 = "tt.reduce"(%872) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<4x2x32xi8, #blocked6>) -> tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %874 = tt.expand_dims %873 {axis = 1 : i32} : tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi8, #blocked6> - %875 = tt.broadcast %874 : tensor<4x1x32xi8, #blocked6> -> tensor<4x2x32xi8, #blocked6> - %876 = tt.reshape %875 {allow_reorder = false} : tensor<4x2x32xi8, #blocked6> -> tensor<1x256xi8, #blocked8> - %877 = arith.cmpi slt, %871, %876 : tensor<1x256xi8, #blocked8> - %878 = arith.cmpi eq, %871, %876 : tensor<1x256xi8, #blocked8> - %879 = arith.andi %878, %863 : tensor<1x256xi1, #blocked8> - %880 = arith.ori %877, %879 : tensor<1x256xi1, #blocked8> - %881 = arith.extui %880 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %882 = arith.xori %881, %54 : tensor<1x256xi32, #blocked8> - %883 = arith.cmpi ne, %882, %cst_5 : tensor<1x256xi32, #blocked8> - %884 = arith.select %883, %864, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %885 = arith.xori %851, %884 : tensor<1x256xi32, #blocked8> - %886 = tt.reshape %885 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<8x2x16xi32, #blocked5> - %887 = arith.muli %886, %101 : tensor<8x2x16xi32, #blocked5> - %888 = "tt.reduce"(%887) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<8x2x16xi32, #blocked5>) -> tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %889 = tt.expand_dims %888 {axis = 1 : i32} : tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi32, #blocked5> - %890 = tt.broadcast %889 : tensor<8x1x16xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %891 = tt.reshape %890 {allow_reorder = false} : tensor<8x2x16xi32, #blocked5> -> tensor<1x256xi32, #blocked8> - %892 = arith.muli %886, %37 : tensor<8x2x16xi32, #blocked5> - %893 = "tt.reduce"(%892) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<8x2x16xi32, #blocked5>) -> tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %894 = tt.expand_dims %893 {axis = 1 : i32} : tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi32, #blocked5> - %895 = tt.broadcast %894 : tensor<8x1x16xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %896 = tt.reshape %895 {allow_reorder = false} : tensor<8x2x16xi32, #blocked5> -> tensor<1x256xi32, #blocked8> - %897 = arith.cmpi sgt, %891, %896 : tensor<1x256xi32, #blocked8> - %898 = arith.xori %891, %896 : tensor<1x256xi32, #blocked8> - %899 = triton_gpu.convert_layout %arg23 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %900 = tt.reshape %899 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<8x2x16xi8, #blocked5> - %901 = arith.muli %900, %99 : tensor<8x2x16xi8, #blocked5> - %902 = "tt.reduce"(%901) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<8x2x16xi8, #blocked5>) -> tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %903 = tt.expand_dims %902 {axis = 1 : i32} : tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi8, #blocked5> - %904 = tt.broadcast %903 : tensor<8x1x16xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %905 = tt.reshape %904 {allow_reorder = false} : tensor<8x2x16xi8, #blocked5> -> tensor<1x256xi8, #blocked8> - %906 = arith.muli %900, %38 : tensor<8x2x16xi8, #blocked5> - %907 = "tt.reduce"(%906) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<8x2x16xi8, #blocked5>) -> tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %908 = tt.expand_dims %907 {axis = 1 : i32} : tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi8, #blocked5> - %909 = tt.broadcast %908 : tensor<8x1x16xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %910 = tt.reshape %909 {allow_reorder = false} : tensor<8x2x16xi8, #blocked5> -> tensor<1x256xi8, #blocked8> - %911 = arith.cmpi slt, %905, %910 : tensor<1x256xi8, #blocked8> - %912 = arith.cmpi eq, %905, %910 : tensor<1x256xi8, #blocked8> - %913 = arith.andi %912, %897 : tensor<1x256xi1, #blocked8> - %914 = arith.ori %911, %913 : tensor<1x256xi1, #blocked8> - %915 = arith.extui %914 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %916 = arith.xori %915, %54 : tensor<1x256xi32, #blocked8> - %917 = arith.cmpi ne, %916, %cst_5 : tensor<1x256xi32, #blocked8> - %918 = arith.select %917, %898, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %919 = arith.xori %885, %918 : tensor<1x256xi32, #blocked8> - %920 = tt.reshape %919 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<16x2x8xi32, #blocked4> - %921 = arith.muli %920, %106 : tensor<16x2x8xi32, #blocked4> - %922 = "tt.reduce"(%921) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %923 = tt.expand_dims %922 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %924 = tt.broadcast %923 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %925 = tt.reshape %924 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %926 = arith.muli %920, %28 : tensor<16x2x8xi32, #blocked4> - %927 = "tt.reduce"(%926) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %928 = tt.expand_dims %927 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %929 = tt.broadcast %928 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %930 = tt.reshape %929 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %931 = arith.cmpi sgt, %925, %930 : tensor<1x256xi32, #blocked8> - %932 = arith.xori %925, %930 : tensor<1x256xi32, #blocked8> - %933 = triton_gpu.convert_layout %arg24 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %934 = tt.reshape %933 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<16x2x8xi8, #blocked4> - %935 = arith.muli %934, %104 : tensor<16x2x8xi8, #blocked4> - %936 = "tt.reduce"(%935) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %937 = tt.expand_dims %936 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %938 = tt.broadcast %937 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %939 = tt.reshape %938 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %940 = arith.muli %934, %29 : tensor<16x2x8xi8, #blocked4> - %941 = "tt.reduce"(%940) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %942 = tt.expand_dims %941 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %943 = tt.broadcast %942 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %944 = tt.reshape %943 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %945 = arith.cmpi slt, %939, %944 : tensor<1x256xi8, #blocked8> - %946 = arith.cmpi eq, %939, %944 : tensor<1x256xi8, #blocked8> - %947 = arith.andi %946, %931 : tensor<1x256xi1, #blocked8> - %948 = arith.ori %945, %947 : tensor<1x256xi1, #blocked8> - %949 = arith.extui %948 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %950 = arith.xori %949, %54 : tensor<1x256xi32, #blocked8> - %951 = arith.cmpi ne, %950, %cst_5 : tensor<1x256xi32, #blocked8> - %952 = arith.select %951, %932, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %953 = arith.xori %919, %952 : tensor<1x256xi32, #blocked8> - %954 = tt.reshape %953 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<32x2x4xi32, #blocked3> - %955 = arith.muli %954, %111 : tensor<32x2x4xi32, #blocked3> - %956 = "tt.reduce"(%955) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %957 = tt.expand_dims %956 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %958 = tt.broadcast %957 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %959 = tt.reshape %958 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %960 = arith.muli %954, %19 : tensor<32x2x4xi32, #blocked3> - %961 = "tt.reduce"(%960) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %962 = tt.expand_dims %961 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %963 = tt.broadcast %962 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %964 = tt.reshape %963 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %965 = arith.cmpi sgt, %959, %964 : tensor<1x256xi32, #blocked8> - %966 = arith.xori %959, %964 : tensor<1x256xi32, #blocked8> - %967 = triton_gpu.convert_layout %arg25 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %968 = tt.reshape %967 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<32x2x4xi8, #blocked3> - %969 = arith.muli %968, %109 : tensor<32x2x4xi8, #blocked3> - %970 = "tt.reduce"(%969) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %971 = tt.expand_dims %970 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %972 = tt.broadcast %971 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %973 = tt.reshape %972 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %974 = arith.muli %968, %20 : tensor<32x2x4xi8, #blocked3> - %975 = "tt.reduce"(%974) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %976 = tt.expand_dims %975 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %977 = tt.broadcast %976 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %978 = tt.reshape %977 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %979 = arith.cmpi slt, %973, %978 : tensor<1x256xi8, #blocked8> - %980 = arith.cmpi eq, %973, %978 : tensor<1x256xi8, #blocked8> - %981 = arith.andi %980, %965 : tensor<1x256xi1, #blocked8> - %982 = arith.ori %979, %981 : tensor<1x256xi1, #blocked8> - %983 = arith.extui %982 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %984 = arith.xori %983, %54 : tensor<1x256xi32, #blocked8> - %985 = arith.cmpi ne, %984, %cst_5 : tensor<1x256xi32, #blocked8> - %986 = arith.select %985, %966, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %987 = arith.xori %953, %986 : tensor<1x256xi32, #blocked8> - %988 = tt.reshape %987 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<64x2x2xi32, #blocked2> - %989 = arith.muli %988, %116 : tensor<64x2x2xi32, #blocked2> - %990 = "tt.reduce"(%989) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %991 = tt.expand_dims %990 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %992 = tt.broadcast %991 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %993 = tt.reshape %992 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %994 = arith.muli %988, %10 : tensor<64x2x2xi32, #blocked2> - %995 = "tt.reduce"(%994) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %996 = tt.expand_dims %995 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %997 = tt.broadcast %996 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %998 = tt.reshape %997 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %999 = arith.cmpi sgt, %993, %998 : tensor<1x256xi32, #blocked8> - %1000 = arith.xori %993, %998 : tensor<1x256xi32, #blocked8> - %1001 = triton_gpu.convert_layout %arg26 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1002 = tt.reshape %1001 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<64x2x2xi8, #blocked2> - %1003 = arith.muli %1002, %114 : tensor<64x2x2xi8, #blocked2> - %1004 = "tt.reduce"(%1003) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %1005 = tt.expand_dims %1004 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %1006 = tt.broadcast %1005 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %1007 = tt.reshape %1006 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %1008 = arith.muli %1002, %11 : tensor<64x2x2xi8, #blocked2> - %1009 = "tt.reduce"(%1008) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %1010 = tt.expand_dims %1009 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %1011 = tt.broadcast %1010 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %1012 = tt.reshape %1011 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %1013 = arith.cmpi slt, %1007, %1012 : tensor<1x256xi8, #blocked8> - %1014 = arith.cmpi eq, %1007, %1012 : tensor<1x256xi8, #blocked8> - %1015 = arith.andi %1014, %999 : tensor<1x256xi1, #blocked8> - %1016 = arith.ori %1013, %1015 : tensor<1x256xi1, #blocked8> - %1017 = arith.extui %1016 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %1018 = arith.xori %1017, %54 : tensor<1x256xi32, #blocked8> - %1019 = arith.cmpi ne, %1018, %cst_5 : tensor<1x256xi32, #blocked8> - %1020 = arith.select %1019, %1000, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1021 = arith.xori %987, %1020 : tensor<1x256xi32, #blocked8> - %1022 = tt.reshape %1021 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<128x2x1xi32, #blocked9> - %1023 = arith.muli %1022, %77 : tensor<128x2x1xi32, #blocked9> - %1024 = "tt.reduce"(%1023) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %1025 = tt.expand_dims %1024 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %1026 = tt.broadcast %1025 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %1027 = tt.reshape %1026 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %1028 = arith.muli %1022, %66 : tensor<128x2x1xi32, #blocked9> - %1029 = "tt.reduce"(%1028) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %1030 = tt.expand_dims %1029 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %1031 = tt.broadcast %1030 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %1032 = tt.reshape %1031 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %1033 = arith.cmpi sgt, %1027, %1032 : tensor<1x256xi32, #blocked8> - %1034 = arith.xori %1027, %1032 : tensor<1x256xi32, #blocked8> - %1035 = triton_gpu.convert_layout %arg27 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1036 = tt.reshape %1035 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<128x2x1xi8, #blocked9> - %1037 = arith.muli %1036, %75 : tensor<128x2x1xi8, #blocked9> - %1038 = "tt.reduce"(%1037) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %1039 = tt.expand_dims %1038 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %1040 = tt.broadcast %1039 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %1041 = tt.reshape %1040 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %1042 = arith.muli %1036, %72 : tensor<128x2x1xi8, #blocked9> - %1043 = "tt.reduce"(%1042) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %1044 = tt.expand_dims %1043 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %1045 = tt.broadcast %1044 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %1046 = tt.reshape %1045 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %1047 = arith.cmpi slt, %1041, %1046 : tensor<1x256xi8, #blocked8> - %1048 = arith.cmpi eq, %1041, %1046 : tensor<1x256xi8, #blocked8> - %1049 = arith.andi %1048, %1033 : tensor<1x256xi1, #blocked8> - %1050 = arith.ori %1047, %1049 : tensor<1x256xi1, #blocked8> - %1051 = arith.extui %1050 : tensor<1x256xi1, #blocked8> to tensor<1x256xi32, #blocked8> - %1052 = arith.xori %1051, %54 : tensor<1x256xi32, #blocked8> - %1053 = arith.cmpi ne, %1052, %cst_5 : tensor<1x256xi32, #blocked8> - %1054 = arith.select %1053, %1034, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1055 = arith.xori %1021, %1054 : tensor<1x256xi32, #blocked8> - %1056 = tt.reshape %1055 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<1x2x128xi32, #blocked7> - %1057 = arith.muli %1056, %90 : tensor<1x2x128xi32, #blocked7> - %1058 = "tt.reduce"(%1057) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<1x2x128xi32, #blocked7>) -> tensor<1x128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %1059 = tt.expand_dims %1058 {axis = 1 : i32} : tensor<1x128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<1x1x128xi32, #blocked7> - %1060 = tt.broadcast %1059 : tensor<1x1x128xi32, #blocked7> -> tensor<1x2x128xi32, #blocked7> - %1061 = tt.reshape %1060 {allow_reorder = false} : tensor<1x2x128xi32, #blocked7> -> tensor<1x256xi32, #blocked8> - %1062 = arith.muli %1056, %58 : tensor<1x2x128xi32, #blocked7> - %1063 = "tt.reduce"(%1062) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<1x2x128xi32, #blocked7>) -> tensor<1x128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %1064 = tt.expand_dims %1063 {axis = 1 : i32} : tensor<1x128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<1x1x128xi32, #blocked7> - %1065 = tt.broadcast %1064 : tensor<1x1x128xi32, #blocked7> -> tensor<1x2x128xi32, #blocked7> - %1066 = tt.reshape %1065 {allow_reorder = false} : tensor<1x2x128xi32, #blocked7> -> tensor<1x256xi32, #blocked8> - %1067 = arith.cmpi sgt, %1061, %1066 : tensor<1x256xi32, #blocked8> - %1068 = arith.xori %1061, %1066 : tensor<1x256xi32, #blocked8> - %1069 = triton_gpu.convert_layout %arg28 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1070 = tt.reshape %1069 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<1x2x128xi8, #blocked7> - %1071 = arith.muli %1070, %91 : tensor<1x2x128xi8, #blocked7> - %1072 = "tt.reduce"(%1071) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<1x2x128xi8, #blocked7>) -> tensor<1x128xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %1073 = tt.expand_dims %1072 {axis = 1 : i32} : tensor<1x128xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<1x1x128xi8, #blocked7> - %1074 = tt.broadcast %1073 : tensor<1x1x128xi8, #blocked7> -> tensor<1x2x128xi8, #blocked7> - %1075 = tt.reshape %1074 {allow_reorder = false} : tensor<1x2x128xi8, #blocked7> -> tensor<1x256xi8, #blocked8> - %1076 = arith.muli %1070, %60 : tensor<1x2x128xi8, #blocked7> - %1077 = "tt.reduce"(%1076) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<1x2x128xi8, #blocked7>) -> tensor<1x128xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %1078 = tt.expand_dims %1077 {axis = 1 : i32} : tensor<1x128xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<1x1x128xi8, #blocked7> - %1079 = tt.broadcast %1078 : tensor<1x1x128xi8, #blocked7> -> tensor<1x2x128xi8, #blocked7> - %1080 = tt.reshape %1079 {allow_reorder = false} : tensor<1x2x128xi8, #blocked7> -> tensor<1x256xi8, #blocked8> - %1081 = arith.cmpi slt, %1075, %1080 : tensor<1x256xi8, #blocked8> - %1082 = arith.cmpi eq, %1075, %1080 : tensor<1x256xi8, #blocked8> - %1083 = arith.andi %1082, %1067 : tensor<1x256xi1, #blocked8> - %1084 = arith.ori %1081, %1083 : tensor<1x256xi1, #blocked8> - %1085 = arith.select %1084, %1068, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1086 = arith.xori %1055, %1085 : tensor<1x256xi32, #blocked8> - %1087 = tt.reshape %1086 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<2x2x64xi32, #blocked7> - %1088 = arith.muli %1087, %89 : tensor<2x2x64xi32, #blocked7> - %1089 = "tt.reduce"(%1088) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<2x2x64xi32, #blocked7>) -> tensor<2x64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %1090 = tt.expand_dims %1089 {axis = 1 : i32} : tensor<2x64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<2x1x64xi32, #blocked7> - %1091 = tt.broadcast %1090 : tensor<2x1x64xi32, #blocked7> -> tensor<2x2x64xi32, #blocked7> - %1092 = tt.reshape %1091 {allow_reorder = false} : tensor<2x2x64xi32, #blocked7> -> tensor<1x256xi32, #blocked8> - %1093 = arith.muli %1087, %57 : tensor<2x2x64xi32, #blocked7> - %1094 = "tt.reduce"(%1093) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<2x2x64xi32, #blocked7>) -> tensor<2x64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %1095 = tt.expand_dims %1094 {axis = 1 : i32} : tensor<2x64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<2x1x64xi32, #blocked7> - %1096 = tt.broadcast %1095 : tensor<2x1x64xi32, #blocked7> -> tensor<2x2x64xi32, #blocked7> - %1097 = tt.reshape %1096 {allow_reorder = false} : tensor<2x2x64xi32, #blocked7> -> tensor<1x256xi32, #blocked8> - %1098 = arith.cmpi sgt, %1092, %1097 : tensor<1x256xi32, #blocked8> - %1099 = arith.xori %1092, %1097 : tensor<1x256xi32, #blocked8> - %1100 = triton_gpu.convert_layout %arg29 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1101 = tt.reshape %1100 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<2x2x64xi8, #blocked7> - %1102 = arith.muli %1101, %87 : tensor<2x2x64xi8, #blocked7> - %1103 = "tt.reduce"(%1102) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<2x2x64xi8, #blocked7>) -> tensor<2x64xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %1104 = tt.expand_dims %1103 {axis = 1 : i32} : tensor<2x64xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<2x1x64xi8, #blocked7> - %1105 = tt.broadcast %1104 : tensor<2x1x64xi8, #blocked7> -> tensor<2x2x64xi8, #blocked7> - %1106 = tt.reshape %1105 {allow_reorder = false} : tensor<2x2x64xi8, #blocked7> -> tensor<1x256xi8, #blocked8> - %1107 = arith.muli %1101, %59 : tensor<2x2x64xi8, #blocked7> - %1108 = "tt.reduce"(%1107) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<2x2x64xi8, #blocked7>) -> tensor<2x64xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> - %1109 = tt.expand_dims %1108 {axis = 1 : i32} : tensor<2x64xi8, #triton_gpu.slice<{dim = 1, parent = #blocked7}>> -> tensor<2x1x64xi8, #blocked7> - %1110 = tt.broadcast %1109 : tensor<2x1x64xi8, #blocked7> -> tensor<2x2x64xi8, #blocked7> - %1111 = tt.reshape %1110 {allow_reorder = false} : tensor<2x2x64xi8, #blocked7> -> tensor<1x256xi8, #blocked8> - %1112 = arith.cmpi slt, %1106, %1111 : tensor<1x256xi8, #blocked8> - %1113 = arith.cmpi eq, %1106, %1111 : tensor<1x256xi8, #blocked8> - %1114 = arith.andi %1113, %1098 : tensor<1x256xi1, #blocked8> - %1115 = arith.ori %1112, %1114 : tensor<1x256xi1, #blocked8> - %1116 = arith.select %1115, %1099, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1117 = arith.xori %1086, %1116 : tensor<1x256xi32, #blocked8> - %1118 = tt.reshape %1117 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<4x2x32xi32, #blocked6> - %1119 = arith.muli %1118, %96 : tensor<4x2x32xi32, #blocked6> - %1120 = "tt.reduce"(%1119) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<4x2x32xi32, #blocked6>) -> tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %1121 = tt.expand_dims %1120 {axis = 1 : i32} : tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi32, #blocked6> - %1122 = tt.broadcast %1121 : tensor<4x1x32xi32, #blocked6> -> tensor<4x2x32xi32, #blocked6> - %1123 = tt.reshape %1122 {allow_reorder = false} : tensor<4x2x32xi32, #blocked6> -> tensor<1x256xi32, #blocked8> - %1124 = arith.muli %1118, %46 : tensor<4x2x32xi32, #blocked6> - %1125 = "tt.reduce"(%1124) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<4x2x32xi32, #blocked6>) -> tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %1126 = tt.expand_dims %1125 {axis = 1 : i32} : tensor<4x32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi32, #blocked6> - %1127 = tt.broadcast %1126 : tensor<4x1x32xi32, #blocked6> -> tensor<4x2x32xi32, #blocked6> - %1128 = tt.reshape %1127 {allow_reorder = false} : tensor<4x2x32xi32, #blocked6> -> tensor<1x256xi32, #blocked8> - %1129 = arith.cmpi sgt, %1123, %1128 : tensor<1x256xi32, #blocked8> - %1130 = arith.xori %1123, %1128 : tensor<1x256xi32, #blocked8> - %1131 = triton_gpu.convert_layout %arg30 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1132 = tt.reshape %1131 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<4x2x32xi8, #blocked6> - %1133 = arith.muli %1132, %94 : tensor<4x2x32xi8, #blocked6> - %1134 = "tt.reduce"(%1133) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<4x2x32xi8, #blocked6>) -> tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %1135 = tt.expand_dims %1134 {axis = 1 : i32} : tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi8, #blocked6> - %1136 = tt.broadcast %1135 : tensor<4x1x32xi8, #blocked6> -> tensor<4x2x32xi8, #blocked6> - %1137 = tt.reshape %1136 {allow_reorder = false} : tensor<4x2x32xi8, #blocked6> -> tensor<1x256xi8, #blocked8> - %1138 = arith.muli %1132, %47 : tensor<4x2x32xi8, #blocked6> - %1139 = "tt.reduce"(%1138) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<4x2x32xi8, #blocked6>) -> tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> - %1140 = tt.expand_dims %1139 {axis = 1 : i32} : tensor<4x32xi8, #triton_gpu.slice<{dim = 1, parent = #blocked6}>> -> tensor<4x1x32xi8, #blocked6> - %1141 = tt.broadcast %1140 : tensor<4x1x32xi8, #blocked6> -> tensor<4x2x32xi8, #blocked6> - %1142 = tt.reshape %1141 {allow_reorder = false} : tensor<4x2x32xi8, #blocked6> -> tensor<1x256xi8, #blocked8> - %1143 = arith.cmpi slt, %1137, %1142 : tensor<1x256xi8, #blocked8> - %1144 = arith.cmpi eq, %1137, %1142 : tensor<1x256xi8, #blocked8> - %1145 = arith.andi %1144, %1129 : tensor<1x256xi1, #blocked8> - %1146 = arith.ori %1143, %1145 : tensor<1x256xi1, #blocked8> - %1147 = arith.select %1146, %1130, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1148 = arith.xori %1117, %1147 : tensor<1x256xi32, #blocked8> - %1149 = tt.reshape %1148 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<8x2x16xi32, #blocked5> - %1150 = arith.muli %1149, %101 : tensor<8x2x16xi32, #blocked5> - %1151 = "tt.reduce"(%1150) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<8x2x16xi32, #blocked5>) -> tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %1152 = tt.expand_dims %1151 {axis = 1 : i32} : tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi32, #blocked5> - %1153 = tt.broadcast %1152 : tensor<8x1x16xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %1154 = tt.reshape %1153 {allow_reorder = false} : tensor<8x2x16xi32, #blocked5> -> tensor<1x256xi32, #blocked8> - %1155 = arith.muli %1149, %37 : tensor<8x2x16xi32, #blocked5> - %1156 = "tt.reduce"(%1155) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<8x2x16xi32, #blocked5>) -> tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %1157 = tt.expand_dims %1156 {axis = 1 : i32} : tensor<8x16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi32, #blocked5> - %1158 = tt.broadcast %1157 : tensor<8x1x16xi32, #blocked5> -> tensor<8x2x16xi32, #blocked5> - %1159 = tt.reshape %1158 {allow_reorder = false} : tensor<8x2x16xi32, #blocked5> -> tensor<1x256xi32, #blocked8> - %1160 = arith.cmpi sgt, %1154, %1159 : tensor<1x256xi32, #blocked8> - %1161 = arith.xori %1154, %1159 : tensor<1x256xi32, #blocked8> - %1162 = triton_gpu.convert_layout %arg31 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1163 = tt.reshape %1162 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<8x2x16xi8, #blocked5> - %1164 = arith.muli %1163, %99 : tensor<8x2x16xi8, #blocked5> - %1165 = "tt.reduce"(%1164) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<8x2x16xi8, #blocked5>) -> tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %1166 = tt.expand_dims %1165 {axis = 1 : i32} : tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi8, #blocked5> - %1167 = tt.broadcast %1166 : tensor<8x1x16xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %1168 = tt.reshape %1167 {allow_reorder = false} : tensor<8x2x16xi8, #blocked5> -> tensor<1x256xi8, #blocked8> - %1169 = arith.muli %1163, %38 : tensor<8x2x16xi8, #blocked5> - %1170 = "tt.reduce"(%1169) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<8x2x16xi8, #blocked5>) -> tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> - %1171 = tt.expand_dims %1170 {axis = 1 : i32} : tensor<8x16xi8, #triton_gpu.slice<{dim = 1, parent = #blocked5}>> -> tensor<8x1x16xi8, #blocked5> - %1172 = tt.broadcast %1171 : tensor<8x1x16xi8, #blocked5> -> tensor<8x2x16xi8, #blocked5> - %1173 = tt.reshape %1172 {allow_reorder = false} : tensor<8x2x16xi8, #blocked5> -> tensor<1x256xi8, #blocked8> - %1174 = arith.cmpi slt, %1168, %1173 : tensor<1x256xi8, #blocked8> - %1175 = arith.cmpi eq, %1168, %1173 : tensor<1x256xi8, #blocked8> - %1176 = arith.andi %1175, %1160 : tensor<1x256xi1, #blocked8> - %1177 = arith.ori %1174, %1176 : tensor<1x256xi1, #blocked8> - %1178 = arith.select %1177, %1161, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1179 = arith.xori %1148, %1178 : tensor<1x256xi32, #blocked8> - %1180 = tt.reshape %1179 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<16x2x8xi32, #blocked4> - %1181 = arith.muli %1180, %106 : tensor<16x2x8xi32, #blocked4> - %1182 = "tt.reduce"(%1181) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %1183 = tt.expand_dims %1182 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %1184 = tt.broadcast %1183 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %1185 = tt.reshape %1184 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %1186 = arith.muli %1180, %28 : tensor<16x2x8xi32, #blocked4> - %1187 = "tt.reduce"(%1186) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<16x2x8xi32, #blocked4>) -> tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %1188 = tt.expand_dims %1187 {axis = 1 : i32} : tensor<16x8xi32, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi32, #blocked4> - %1189 = tt.broadcast %1188 : tensor<16x1x8xi32, #blocked4> -> tensor<16x2x8xi32, #blocked4> - %1190 = tt.reshape %1189 {allow_reorder = false} : tensor<16x2x8xi32, #blocked4> -> tensor<1x256xi32, #blocked8> - %1191 = arith.cmpi sgt, %1185, %1190 : tensor<1x256xi32, #blocked8> - %1192 = arith.xori %1185, %1190 : tensor<1x256xi32, #blocked8> - %1193 = triton_gpu.convert_layout %arg32 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1194 = tt.reshape %1193 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<16x2x8xi8, #blocked4> - %1195 = arith.muli %1194, %104 : tensor<16x2x8xi8, #blocked4> - %1196 = "tt.reduce"(%1195) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %1197 = tt.expand_dims %1196 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %1198 = tt.broadcast %1197 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %1199 = tt.reshape %1198 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %1200 = arith.muli %1194, %29 : tensor<16x2x8xi8, #blocked4> - %1201 = "tt.reduce"(%1200) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<16x2x8xi8, #blocked4>) -> tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> - %1202 = tt.expand_dims %1201 {axis = 1 : i32} : tensor<16x8xi8, #triton_gpu.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1x8xi8, #blocked4> - %1203 = tt.broadcast %1202 : tensor<16x1x8xi8, #blocked4> -> tensor<16x2x8xi8, #blocked4> - %1204 = tt.reshape %1203 {allow_reorder = false} : tensor<16x2x8xi8, #blocked4> -> tensor<1x256xi8, #blocked8> - %1205 = arith.cmpi slt, %1199, %1204 : tensor<1x256xi8, #blocked8> - %1206 = arith.cmpi eq, %1199, %1204 : tensor<1x256xi8, #blocked8> - %1207 = arith.andi %1206, %1191 : tensor<1x256xi1, #blocked8> - %1208 = arith.ori %1205, %1207 : tensor<1x256xi1, #blocked8> - %1209 = arith.select %1208, %1192, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1210 = arith.xori %1179, %1209 : tensor<1x256xi32, #blocked8> - %1211 = tt.reshape %1210 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<32x2x4xi32, #blocked3> - %1212 = arith.muli %1211, %111 : tensor<32x2x4xi32, #blocked3> - %1213 = "tt.reduce"(%1212) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %1214 = tt.expand_dims %1213 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %1215 = tt.broadcast %1214 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %1216 = tt.reshape %1215 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %1217 = arith.muli %1211, %19 : tensor<32x2x4xi32, #blocked3> - %1218 = "tt.reduce"(%1217) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<32x2x4xi32, #blocked3>) -> tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %1219 = tt.expand_dims %1218 {axis = 1 : i32} : tensor<32x4xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi32, #blocked3> - %1220 = tt.broadcast %1219 : tensor<32x1x4xi32, #blocked3> -> tensor<32x2x4xi32, #blocked3> - %1221 = tt.reshape %1220 {allow_reorder = false} : tensor<32x2x4xi32, #blocked3> -> tensor<1x256xi32, #blocked8> - %1222 = arith.cmpi sgt, %1216, %1221 : tensor<1x256xi32, #blocked8> - %1223 = arith.xori %1216, %1221 : tensor<1x256xi32, #blocked8> - %1224 = triton_gpu.convert_layout %arg33 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1225 = tt.reshape %1224 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<32x2x4xi8, #blocked3> - %1226 = arith.muli %1225, %109 : tensor<32x2x4xi8, #blocked3> - %1227 = "tt.reduce"(%1226) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %1228 = tt.expand_dims %1227 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %1229 = tt.broadcast %1228 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %1230 = tt.reshape %1229 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %1231 = arith.muli %1225, %20 : tensor<32x2x4xi8, #blocked3> - %1232 = "tt.reduce"(%1231) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<32x2x4xi8, #blocked3>) -> tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> - %1233 = tt.expand_dims %1232 {axis = 1 : i32} : tensor<32x4xi8, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1x4xi8, #blocked3> - %1234 = tt.broadcast %1233 : tensor<32x1x4xi8, #blocked3> -> tensor<32x2x4xi8, #blocked3> - %1235 = tt.reshape %1234 {allow_reorder = false} : tensor<32x2x4xi8, #blocked3> -> tensor<1x256xi8, #blocked8> - %1236 = arith.cmpi slt, %1230, %1235 : tensor<1x256xi8, #blocked8> - %1237 = arith.cmpi eq, %1230, %1235 : tensor<1x256xi8, #blocked8> - %1238 = arith.andi %1237, %1222 : tensor<1x256xi1, #blocked8> - %1239 = arith.ori %1236, %1238 : tensor<1x256xi1, #blocked8> - %1240 = arith.select %1239, %1223, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1241 = arith.xori %1210, %1240 : tensor<1x256xi32, #blocked8> - %1242 = tt.reshape %1241 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<64x2x2xi32, #blocked2> - %1243 = arith.muli %1242, %116 : tensor<64x2x2xi32, #blocked2> - %1244 = "tt.reduce"(%1243) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %1245 = tt.expand_dims %1244 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %1246 = tt.broadcast %1245 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %1247 = tt.reshape %1246 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %1248 = arith.muli %1242, %10 : tensor<64x2x2xi32, #blocked2> - %1249 = "tt.reduce"(%1248) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<64x2x2xi32, #blocked2>) -> tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %1250 = tt.expand_dims %1249 {axis = 1 : i32} : tensor<64x2xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi32, #blocked2> - %1251 = tt.broadcast %1250 : tensor<64x1x2xi32, #blocked2> -> tensor<64x2x2xi32, #blocked2> - %1252 = tt.reshape %1251 {allow_reorder = false} : tensor<64x2x2xi32, #blocked2> -> tensor<1x256xi32, #blocked8> - %1253 = arith.cmpi sgt, %1247, %1252 : tensor<1x256xi32, #blocked8> - %1254 = arith.xori %1247, %1252 : tensor<1x256xi32, #blocked8> - %1255 = triton_gpu.convert_layout %arg34 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1256 = tt.reshape %1255 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<64x2x2xi8, #blocked2> - %1257 = arith.muli %1256, %114 : tensor<64x2x2xi8, #blocked2> - %1258 = "tt.reduce"(%1257) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %1259 = tt.expand_dims %1258 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %1260 = tt.broadcast %1259 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %1261 = tt.reshape %1260 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %1262 = arith.muli %1256, %11 : tensor<64x2x2xi8, #blocked2> - %1263 = "tt.reduce"(%1262) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<64x2x2xi8, #blocked2>) -> tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> - %1264 = tt.expand_dims %1263 {axis = 1 : i32} : tensor<64x2xi8, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1x2xi8, #blocked2> - %1265 = tt.broadcast %1264 : tensor<64x1x2xi8, #blocked2> -> tensor<64x2x2xi8, #blocked2> - %1266 = tt.reshape %1265 {allow_reorder = false} : tensor<64x2x2xi8, #blocked2> -> tensor<1x256xi8, #blocked8> - %1267 = arith.cmpi slt, %1261, %1266 : tensor<1x256xi8, #blocked8> - %1268 = arith.cmpi eq, %1261, %1266 : tensor<1x256xi8, #blocked8> - %1269 = arith.andi %1268, %1253 : tensor<1x256xi1, #blocked8> - %1270 = arith.ori %1267, %1269 : tensor<1x256xi1, #blocked8> - %1271 = arith.select %1270, %1254, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1272 = arith.xori %1241, %1271 : tensor<1x256xi32, #blocked8> - %1273 = tt.reshape %1272 {allow_reorder = false} : tensor<1x256xi32, #blocked8> -> tensor<128x2x1xi32, #blocked9> - %1274 = arith.muli %1273, %77 : tensor<128x2x1xi32, #blocked9> - %1275 = "tt.reduce"(%1274) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %1276 = tt.expand_dims %1275 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %1277 = tt.broadcast %1276 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %1278 = tt.reshape %1277 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %1279 = arith.muli %1273, %66 : tensor<128x2x1xi32, #blocked9> - %1280 = "tt.reduce"(%1279) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i32, %arg37: i32): - %1305 = arith.addi %arg36, %arg37 : i32 - tt.reduce.return %1305 : i32 - }) : (tensor<128x2x1xi32, #blocked9>) -> tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %1281 = tt.expand_dims %1280 {axis = 1 : i32} : tensor<128x1xi32, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi32, #blocked9> - %1282 = tt.broadcast %1281 : tensor<128x1x1xi32, #blocked9> -> tensor<128x2x1xi32, #blocked9> - %1283 = tt.reshape %1282 {allow_reorder = false} : tensor<128x2x1xi32, #blocked9> -> tensor<1x256xi32, #blocked8> - %1284 = arith.cmpi sgt, %1278, %1283 : tensor<1x256xi32, #blocked8> - %1285 = arith.xori %1278, %1283 : tensor<1x256xi32, #blocked8> - %1286 = triton_gpu.convert_layout %arg35 : tensor<1x256xi8, #blocked> -> tensor<1x256xi8, #blocked8> - %1287 = tt.reshape %1286 {allow_reorder = false} : tensor<1x256xi8, #blocked8> -> tensor<128x2x1xi8, #blocked9> - %1288 = arith.muli %1287, %75 : tensor<128x2x1xi8, #blocked9> - %1289 = "tt.reduce"(%1288) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %1290 = tt.expand_dims %1289 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %1291 = tt.broadcast %1290 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %1292 = tt.reshape %1291 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %1293 = arith.muli %1287, %72 : tensor<128x2x1xi8, #blocked9> - %1294 = "tt.reduce"(%1293) <{axis = 1 : i32}> ({ - ^bb0(%arg36: i8, %arg37: i8): - %1305 = arith.addi %arg36, %arg37 : i8 - tt.reduce.return %1305 : i8 - }) : (tensor<128x2x1xi8, #blocked9>) -> tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> - %1295 = tt.expand_dims %1294 {axis = 1 : i32} : tensor<128x1xi8, #triton_gpu.slice<{dim = 1, parent = #blocked9}>> -> tensor<128x1x1xi8, #blocked9> - %1296 = tt.broadcast %1295 : tensor<128x1x1xi8, #blocked9> -> tensor<128x2x1xi8, #blocked9> - %1297 = tt.reshape %1296 {allow_reorder = false} : tensor<128x2x1xi8, #blocked9> -> tensor<1x256xi8, #blocked8> - %1298 = arith.cmpi slt, %1292, %1297 : tensor<1x256xi8, #blocked8> - %1299 = arith.cmpi eq, %1292, %1297 : tensor<1x256xi8, #blocked8> - %1300 = arith.andi %1299, %1284 : tensor<1x256xi1, #blocked8> - %1301 = arith.ori %1298, %1300 : tensor<1x256xi1, #blocked8> - %1302 = arith.select %1301, %1285, %cst_5 : tensor<1x256xi1, #blocked8>, tensor<1x256xi32, #blocked8> - %1303 = arith.xori %1272, %1302 : tensor<1x256xi32, #blocked8> - %1304 = triton_gpu.convert_layout %1303 : tensor<1x256xi32, #blocked8> -> tensor<1x256xi32, #blocked1> - tt.return %1304 : tensor<1x256xi32, #blocked1> - } -} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir b/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir new file mode 100644 index 000000000..1cca80d21 --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir @@ -0,0 +1,63 @@ +// RUN: triton-opt %s -split-input-file --triton-gpu-taskid-propagate=num-consumer-groups=1 | FileCheck %s + +// CHECK-LABEL: @async_kernel +// CHECK: %0 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 +// CHECK: %5 = tt.splat %arg2 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1024xi32> +// CHECK: %9 = tt.load %8, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> +// CHECK: %10 = tt.splat %arg1 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: tt.store %11, %9 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + +module { + tt.func public @async_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg2 : i32 -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %11, %9 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK-LABEL: @two_consumers +// CHECK: tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 +// CHECK: tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.splat %arg1 {async_task_id = dense<[1, 2]> : vector<2xi32>} +// CHECK: tt.store {{.*}} {async_task_id = dense<1> : vector<1xi32>} +// CHECK: tt.store {{.*}} {async_task_id = dense<2> : vector<1xi32>} + +module { + tt.func public @two_consumers(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.make_range {end = 2048 : i32, start = 1024 : i32} : tensor<1024xi32> + %4 = tt.splat %1 : i32 -> tensor<1024xi32> + %5 = arith.addi %4, %2 : tensor<1024xi32> + %6 = arith.addi %4, %3 : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %5 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.addptr %7, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %8 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %11 = tt.load %9 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %13 = tt.addptr %12, %5 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %14 = tt.addptr %12, %6 {async_task_id = dense<2> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %13, %10 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.store %14, %11 {async_task_id = dense<2> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir new file mode 100644 index 000000000..0aa75a88f --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir @@ -0,0 +1,857 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-code-partition=num-buffers=1 | FileCheck %s + +// CHECK-LABEL: @matmul_kernel_one_consumer +// CHECK: %[[#TASKID:]] = triton_nvidia_gpu.get_async_task_id : i32 +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[#WG0:]] = arith.cmpi eq, %[[#TASKID]], %c0_i32 : i32 +// CHECK: scf.if %[[#WG0]] +// CHECK: triton_nvidia_gpu.reg_dealloc 40 +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %[[#WG1:]] = arith.cmpi eq, %[[#TASKID]], %c1_i32 : i32 +// CHECK: scf.if %[[#WG1]] +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_gpu.local_load +// CHECK: triton_gpu.local_load +// CHECK: tt.dot +// CHECK: triton_nvidia_gpu.consumer_release + + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_one_consumer(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked1> + %cst_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked2> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked2> + %0 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %1 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %2 = arith.divsi %1, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %3 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %4 = arith.divsi %3, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %5 = arith.muli %4, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %6 = arith.divsi %0, %5 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.muli %6, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.subi %2, %7 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.minsi %8, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.remsi %0, %5 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.remsi %10, %9 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.addi %7, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.divsi %10, %9 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.muli %12, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %16 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %18 = tt.splat %14 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %20 = arith.addi %18, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %21 = arith.addi %19, %16 {async_task_id = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %23 = arith.remsi %20, %22 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %24 = arith.muli %13, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %25 = tt.splat %24 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %26 = arith.addi %25, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %27 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %28 = arith.remsi %26, %27 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %29 = tt.expand_dims %23 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %30 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked2> + %31 = arith.muli %29, %30 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked2> + %32 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = tt.expand_dims %32 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %34 = tt.broadcast %31 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %35 = tt.broadcast %33 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %36 = arith.addi %34, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256xi32, #blocked2> + %37 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked2> + %38 = tt.addptr %37, %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %39 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %40 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %41 = tt.expand_dims %39 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %42 = tt.expand_dims %40 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %43 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked1> + %44 = arith.muli %41, %43 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> + %45 = tt.expand_dims %28 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %46 = tt.broadcast %44 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %47 = tt.broadcast %45 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %48 = arith.addi %46, %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128xi32, #blocked1> + %49 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %50 = tt.addptr %49, %48 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + %51 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %52 = arith.divsi %51, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %53 = arith.muli %arg7, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %54 = tt.splat %53 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x128xi32, #blocked1> + %55:3 = scf.for %arg9 = %c0_i32 to %52 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %38, %arg12 = %50) -> (tensor<128x128xf32, #blocked>, tensor<128x256x!tt.ptr, #blocked2>, tensor<256x128x!tt.ptr, #blocked1>) : i32 { + %74 = arith.muli %arg9, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %75 = arith.subi %arg5, %74 {async_task_id = dense<0> : vector<1xi32>} : i32 + %76 = tt.splat %75 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x256xi32, #blocked2> + %77 = arith.cmpi slt, %33, %76 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked2> + %78 = tt.broadcast %77 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %79 = tt.load %arg11, %78, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2> + %80 = tt.splat %75 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked1> + %81 = arith.cmpi slt, %42, %80 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> + %82 = tt.broadcast %81 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi1, #blocked1> -> tensor<256x128xi1, #blocked1> + %83 = tt.load %arg12, %82, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1> + %84 = triton_gpu.convert_layout %79 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #blocked2> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %85 = triton_gpu.convert_layout %83 {async_task_id = dense<1> : vector<1xi32>} : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %86 = tt.dot %84, %85, %arg10, inputPrecision = tf32 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + %87 = tt.addptr %arg11, %cst_2 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %88 = tt.addptr %arg12, %54 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1]> : vector<2xi32>} %86, %87, %88 : tensor<128x128xf32, #blocked>, tensor<128x256x!tt.ptr, #blocked2>, tensor<256x128x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1]> : vector<2xi32>} + %56 = arith.truncf %55#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + %57 = tt.expand_dims %21 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %58 = tt.splat %arg8 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %59 = arith.muli %58, %57 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %60 = tt.splat %arg2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %59 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %62 = tt.expand_dims %26 {async_task_id = dense<1> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %63 = tt.broadcast %61 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x128x!tt.ptr, #blocked1> + %64 = tt.broadcast %62 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<128x128xi32, #blocked1> + %65 = tt.addptr %63, %64 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked1>, tensor<128x128xi32, #blocked1> + %66 = tt.splat %arg3 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %67 = arith.cmpi slt, %57, %66 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %68 = tt.splat %arg4 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<1x128xi32, #blocked1> + %69 = arith.cmpi slt, %62, %68 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked1> + %70 = tt.broadcast %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi1, #blocked1> -> tensor<128x128xi1, #blocked1> + %71 = tt.broadcast %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked1> -> tensor<128x128xi1, #blocked1> + %72 = arith.andi %70, %71 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked1> + %73 = triton_gpu.convert_layout %56 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked1> + tt.store %65, %73, %72 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + + +// CHECK-LABEL: @matmul_kernel_two_consumers +// CHECK: scf.if +// CHECK: triton_nvidia_gpu.reg_dealloc 40 +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: scf.if +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.warp_group_dot +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: scf.if +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.warp_group_dot +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: triton_nvidia_gpu.consumer_release + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_two_consumers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<64> : tensor<64x64xi32, #blocked> + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c8_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 8 : i32 + %cst_0 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x64xf16, #blocked> + %cst_1 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x128xf16, #blocked1> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %cst_2 = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<0.000000e+00> : tensor<64x128xf32, #mma> + %0 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.divsi %1, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.divsi %3, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = arith.muli %4, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = arith.divsi %0, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = arith.muli %6, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = arith.subi %2, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %9 = arith.minsi %8, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %10 = arith.remsi %0, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %11 = arith.remsi %10, %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %12 = arith.addi %7, %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %13 = arith.divsi %10, %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %14 = arith.muli %12, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %16 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.splat %14 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %18 = tt.splat %14 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = arith.addi %17, %15 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %20 = arith.addi %18, %16 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %21 = tt.splat %arg3 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = arith.remsi %19, %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %23 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %24 = tt.make_range {async_task_id = dense<2> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = arith.addi %17, %23 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %26 = arith.addi %18, %24 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.remsi %25, %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = arith.muli %13, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %29 = tt.make_range {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %30 = tt.splat %28 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %31 = arith.addi %30, %29 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %32 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %33 = arith.remsi %31, %32 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %34 = tt.expand_dims %22 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %35 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %36 = arith.muli %34, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %37 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %38 = tt.expand_dims %37 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %39 = tt.broadcast %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %40 = tt.broadcast %38 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked> + %41 = arith.addi %39, %40 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64xi32, #blocked> + %42 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %43 = tt.addptr %42, %41 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %44 = tt.expand_dims %27 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %45 = arith.muli %44, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %46 = tt.broadcast %45 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %47 = arith.addi %46, %40 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64xi32, #blocked> + %48 = tt.addptr %42, %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %49 = tt.expand_dims %16 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %50 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %51 = arith.muli %49, %50 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %52 = tt.expand_dims %33 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %53 = tt.broadcast %51 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %54 = tt.broadcast %52 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %55 = arith.addi %53, %54 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128xi32, #blocked1> + %56 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %58 = arith.addi %arg5, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %59 = arith.divsi %58, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %60 = tt.expand_dims %37 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %61 = tt.expand_dims %16 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %62 = arith.muli %arg7, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %63 = tt.splat %62 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x128xi32, #blocked1> + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %true_3 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false_4 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %64:5 = scf.for %arg9 = %c0_i32 to %59 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %cst_2, %arg12 = %43, %arg13 = %57, %arg14 = %48) -> (tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked1>, tensor<64x64x!tt.ptr, #blocked>) : i32 { + %93 = arith.muli %arg9, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %94 = arith.subi %arg5, %93 {async_task_id = dense<0> : vector<1xi32>} : i32 + %95 = tt.splat %94 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x64xi32, #blocked> + %96 = arith.cmpi slt, %60, %95 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %97 = tt.broadcast %96 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked> + %98 = tt.load %arg12, %97, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked> + %99 = triton_gpu.local_alloc %98 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x64xf16, #blocked>) -> !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> + %100 = tt.splat %94 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %101 = arith.cmpi slt, %61, %100 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %102 = tt.broadcast %101 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %103 = tt.load %arg13, %102, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + %104 = triton_gpu.local_alloc %103 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xf16, #blocked1>) -> !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> + %105 = tt.load %arg14, %97, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked> + %106 = triton_gpu.local_alloc %105 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x64xf16, #blocked>) -> !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> + %107 = triton_nvidia_gpu.warp_group_dot %99, %104, %arg10 {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma> + %108 = triton_nvidia_gpu.warp_group_dot %106, %104, %arg11 {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma> + %109 = tt.addptr %arg12, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %110 = tt.addptr %arg14, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %111 = tt.addptr %arg13, %63 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %107, %108, %109, %111, %110 : tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked1>, tensor<64x64x!tt.ptr, #blocked> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %65 = arith.truncf %64#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma> + %66 = arith.truncf %64#1 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma> + %67 = tt.expand_dims %20 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %68 = tt.splat %arg8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %69 = arith.muli %68, %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %70 = tt.splat %arg2 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %71 = tt.addptr %70, %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %72 = tt.expand_dims %31 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %73 = tt.broadcast %71 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x128x!tt.ptr, #blocked1> + %74 = tt.broadcast %72 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %75 = tt.addptr %73, %74 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %76 = tt.expand_dims %26 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %77 = arith.muli %68, %76 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %78 = tt.addptr %70, %77 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %79 = tt.broadcast %78 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x128x!tt.ptr, #blocked1> + %80 = tt.addptr %79, %74 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %81 = tt.splat %arg3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %82 = arith.cmpi slt, %67, %81 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %83 = tt.splat %arg4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<1x128xi32, #blocked1> + %84 = arith.cmpi slt, %72, %83 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked1> + %85 = tt.broadcast %82 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %86 = tt.broadcast %84 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %87 = arith.andi %85, %86 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xi1, #blocked1> + %88 = arith.cmpi slt, %76, %81 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %89 = tt.broadcast %88 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %90 = arith.andi %89, %86 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xi1, #blocked1> + %91 = triton_gpu.convert_layout %65 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1> + tt.store %75, %91, %87 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + %92 = triton_gpu.convert_layout %66 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1> + tt.store %80, %92, %90 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + tt.return + } +} + + +// ----- + +// CHECK-LABEL: @_matmul_layernorm_persistent_one_producer_one_consumer_one_epilog +// CHECK: %[[#TASKID:]] = triton_nvidia_gpu.get_async_task_id : i32 +// CHECK: %c0_i32_0 = arith.constant 0 : i32 +// CHECK: %[[#WG0:]] = arith.cmpi eq, %[[#TASKID]], %c0_i32_0 : i32 +// CHECK: scf.if %[[#WG0]] +// CHECK: triton_nvidia_gpu.reg_dealloc 40 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_nvidia_gpu.barrier_expect +// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %[[#WG1:]] = arith.cmpi eq, %[[#TASKID]], %c1_i32 : i32 +// CHECK: scf.if %[[#WG1]] +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.wait_barrier +// CHECK: triton_gpu.local_load +// CHECK: triton_gpu.local_load +// CHECK: triton_nvidia_gpu.warp_group_dot +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: triton_gpu.local_store +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: %c2_i32 = arith.constant 2 : i32 +// CHECK: %[[#WG2:]] = arith.cmpi eq, %[[#TASKID]], %c2_i32 : i32 +// CHECK: scf.if %[[#WG2]] +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: triton_gpu.local_load +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: tt.experimental_descriptor_store + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @_matmul_layernorm_persistent_one_producer_one_consumer_one_epilog(%arg0: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg1: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg2: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg3: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg4: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: f32) attributes {noinline = false} { + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c132_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 132 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 256 : i32 + %c255_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 255 : i32 + %cst = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} dense<0.000000e+00> : tensor<128x256xf32, #mma> + %cst_0 = arith.constant {async_task_id = dense<2> : vector<1xi32>} dense<1.000000e+00> : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %0 = arith.addi %arg7, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.divsi %0, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.addi %arg5, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.divsi %2, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.addi %arg6, %c255_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = arith.divsi %4, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = arith.muli %3, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = arith.sitofp %arg6 {async_task_id = dense<2> : vector<1xi32>} : i32 to f32 + %9 = tt.splat %8 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %10 = tt.splat %arg11 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + scf.for %arg12 = %7 to %6 step %c132_i32 : i32 { + %11 = arith.muli %arg12, %c128_i32 {async_task_id = dense<[0, 2]> : vector<2xi32>} : i32 + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %12 = scf.for %arg13 = %c0_i32 to %1 step %c1_i32 iter_args(%arg14 = %cst) -> (tensor<128x256xf32, #mma>) : i32 { + %45 = arith.muli %arg13, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %46 = tt.experimental_descriptor_load %arg0[%11, %45] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x64xf16, #blocked> + %47 = triton_gpu.local_alloc %46 {async_task_id = dense<1> : vector<1xi32>} : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %48 = tt.experimental_descriptor_load %arg1[%45, %c0_i32] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x256xf16, #blocked1> + %49 = triton_gpu.local_alloc %48 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> + %50 = triton_nvidia_gpu.warp_group_dot %47, %49, %arg14 {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %50 : tensor<128x256xf32, #mma> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %13 = "tt.reduce"(%12) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %45 = arith.addf %arg13, %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 + tt.reduce.return %45 {async_task_id = dense<2> : vector<1xi32>} : f32 + }) {async_task_id = dense<2> : vector<1xi32>} : (tensor<128x256xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %14 = arith.divf %13, %9 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %15 = tt.expand_dims %14 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %16 = tt.broadcast %15 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> + %17 = arith.subf %12, %16 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %18 = arith.mulf %17, %17 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %19 = "tt.reduce"(%18) <{axis = 1 : i32}> ({ + ^bb0(%arg13: f32, %arg14: f32): + %45 = arith.addf %arg13, %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 + tt.reduce.return %45 {async_task_id = dense<2> : vector<1xi32>} : f32 + }) {async_task_id = dense<2> : vector<1xi32>} : (tensor<128x256xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %20 = arith.divf %19, %9 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %21 = arith.addf %20, %10 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %22 = math.sqrt %21 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %23 = arith.divf %cst_0, %22 {async_task_id = dense<2> : vector<1xi32>} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %24 = tt.experimental_descriptor_load %arg3[%c0_i32] {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<256xf16, #blocked2> + %25 = tt.experimental_descriptor_load %arg4[%c0_i32] {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<256xf16, #blocked2> + %26 = tt.expand_dims %23 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %27 = tt.broadcast %26 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x1xf32, #mma> -> tensor<128x256xf32, #mma> + %28 = arith.mulf %17, %27 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %29 = triton_gpu.convert_layout %24 {async_task_id = dense<2> : vector<1xi32>} : tensor<256xf16, #blocked2> -> tensor<256xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %30 = tt.expand_dims %29 {async_task_id = dense<2> : vector<1xi32>, axis = 0 : i32} : tensor<256xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xf16, #blocked1> + %31 = triton_gpu.convert_layout %30 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked1> -> tensor<1x256xf16, #blocked3> + %32 = arith.extf %31 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked3> to tensor<1x256xf32, #blocked3> + %33 = triton_gpu.convert_layout %32 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #mma> + %34 = tt.broadcast %33 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #mma> -> tensor<128x256xf32, #mma> + %35 = arith.mulf %28, %34 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %36 = triton_gpu.convert_layout %25 {async_task_id = dense<2> : vector<1xi32>} : tensor<256xf16, #blocked2> -> tensor<256xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %37 = tt.expand_dims %36 {async_task_id = dense<2> : vector<1xi32>, axis = 0 : i32} : tensor<256xf16, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xf16, #blocked1> + %38 = triton_gpu.convert_layout %37 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked1> -> tensor<1x256xf16, #blocked3> + %39 = arith.extf %38 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf16, #blocked3> to tensor<1x256xf32, #blocked3> + %40 = triton_gpu.convert_layout %39 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #blocked3> -> tensor<1x256xf32, #mma> + %41 = tt.broadcast %40 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x256xf32, #mma> -> tensor<128x256xf32, #mma> + %42 = arith.addf %35, %41 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> + %43 = arith.truncf %42 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %44 = triton_gpu.convert_layout %43 {async_task_id = dense<2> : vector<1xi32>} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.experimental_descriptor_store %arg2[%11, %c0_i32], %44 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr, tensor<128x256xf16, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + tt.return + } +} + +// ----- + +// Verify that we can reuse buffers between two for loops +// CHECK-LABEL: @_attn_bwd_ws +// CHECK-DAG: triton_gpu.local_alloc {allocation.shareGroup = 0 : i32} : () -> !tt.memdesc<2x64x128xbf16 +// CHECK-DAG: triton_gpu.local_alloc {allocation.shareGroup = 1 : i32} : () -> !tt.memdesc<2x64x128xbf16 +// CHECK-DAG: triton_gpu.local_alloc {allocation.shareGroup = 0 : i32} : () -> !tt.memdesc<2x64x128xbf16 +// CHECK-DAG: triton_gpu.local_alloc {allocation.shareGroup = 1 : i32} : () -> !tt.memdesc<2x64x128xbf16 + +// CHECK: %[[TID:.*]] = triton_nvidia_gpu.get_async_task_id : i32 +// CHECK: %[[ZERO:.*]] = arith.constant 0 : i32 +// CHECK: %[[TWG0:.*]] = arith.cmpi eq, %[[TID]], %[[ZERO]] : i32 +// CHECK: scf.if %[[TWG0]] +// CHECK: triton_nvidia_gpu.reg_dealloc 40 +// CHECK: scf.if +// CHECK: scf.yield + +// CHECK: %[[IF_IDX:.*]] = scf.if +// CHECK: arith.divui %c0{{.*}} +// CHECK: arith.subi %c0{{.*}} +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX:.*]] = arith.addi %c0 +// CHECK: scf.yield {{.*}} %[[NEW_IDX]] +// CHECK: scf.yield {{.*}} %c0_ + +// CHECK: scf.if +// CHECK: arith.divui %[[IF_IDX]] +// CHECK: arith.subi %[[IF_IDX]] +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX2:.*]] = arith.addi %[[IF_IDX]] +// CHECK: scf.yield {{.*}} %[[NEW_IDX2]] +// CHECK: scf.yield {{.*}} %[[IF_IDX]] + +// CHECK: %[[ONE:.*]] = arith.constant 1 : i32 +// CHECK: %[[TWG1:.*]] = arith.cmpi eq, %[[TID]], %[[ONE]] : i32 +// CHECK: scf.if %[[TWG1]] +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: scf.if +// CHECK: scf.yield + +// CHECK: %[[IF_IDX_WG1:.*]] = scf.if +// CHECK: arith.divui %c0{{.*}} +// CHECK: arith.subi %c0{{.*}} +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX_WG1:.*]] = arith.addi %c0 +// CHECK: scf.yield {{.*}} %[[NEW_IDX_WG1]] +// CHECK: scf.yield {{.*}} %c0_ + +// CHECK: scf.if +// CHECK: arith.divui %[[IF_IDX_WG1]] +// CHECK: arith.subi %[[IF_IDX_WG1]] +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX2_WG1:.*]] = arith.addi %[[IF_IDX_WG1]] +// CHECK: scf.yield {{.*}} %[[NEW_IDX2_WG1]] +// CHECK: scf.yield {{.*}} %[[IF_IDX_WG1]] + +// CHECK: %[[TWO:.*]] = arith.constant 2 : i32 +// CHECK: %[[TWG2:.*]] = arith.cmpi eq, %[[TID]], %[[TWO]] : i32 +// CHECK: scf.if %[[TWG2]] +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: scf.if +// CHECK: scf.yield + +// CHECK: %[[IF_IDX_WG2:.*]] = scf.if +// CHECK: arith.divui %c0{{.*}} +// CHECK: arith.subi %c0{{.*}} +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX_WG2:.*]] = arith.addi %c0 +// CHECK: scf.yield {{.*}} %[[NEW_IDX_WG2]] +// CHECK: scf.yield {{.*}} %c0_ + +// CHECK: scf.if +// CHECK: arith.divui %[[IF_IDX_WG2]] +// CHECK: arith.subi %[[IF_IDX_WG2]] +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: arith.addi +// CHECK: %[[NEW_IDX2_WG2:.*]] = arith.addi %[[IF_IDX_WG2]] +// CHECK: scf.yield {{.*}} %[[NEW_IDX2_WG2]] +// CHECK: scf.yield {{.*}} %[[IF_IDX_WG2]] + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @_attn_bwd_ws(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg6: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg7: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg8: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg9: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg10: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg11: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg12: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg14: f32, %arg15: !tt.ptr {tt.divisibility = 16 : i32}, %arg16: !tt.ptr {tt.divisibility = 16 : i32}, %arg17: !tt.ptr {tt.divisibility = 16 : i32}, %arg18: !tt.ptr {tt.divisibility = 16 : i32}, %arg19: !tt.ptr {tt.divisibility = 16 : i32}, %arg20: !tt.ptr {tt.divisibility = 16 : i32}, %arg21: !tt.ptr {tt.divisibility = 16 : i32}, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}, %arg24: i32 {tt.divisibility = 16 : i32}, %arg25: i32 {tt.divisibility = 16 : i32}, %arg26: i32 {tt.divisibility = 16 : i32}, %arg27: i32 {tt.divisibility = 16 : i32}, %arg28: i32 {tt.divisibility = 16 : i32}, %arg29: i32, %arg30: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %false = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} false + %cst = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<0.000000e+00> : tensor<64x64xf32, #mma> + %cst_0 = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<128> : tensor<1x128xi32, #blocked> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c64_i64 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i64 + %c63_i64 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i64 + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c0_i64 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i64 + %c1_i64 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i64 + %cst_1 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} dense<0.000000e+00> : tensor<64x128xf32, #mma1> + %cst_2 = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<0.693147182> : tensor<64x128xf32, #mma1> + %0 = tt.get_program_id z {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.divsi %0, %arg29 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.remsi %0, %arg29 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = tt.addptr %arg1, %1 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %5 = tt.load %4 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %6 = tt.addptr %4, %c1_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %7 = tt.load %6 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %8 = arith.subi %7, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %9 = tt.addptr %arg3, %1 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %10 = tt.load %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %11 = tt.addptr %9, %c1_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %12 = tt.load %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %13 = arith.subi %12, %10 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %14 = arith.muli %3, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %16 = tt.make_range {async_task_id = dense<2> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %17 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked1> + %18 = tt.make_range {async_task_id = dense<2> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #blocked1> + %19 = arith.extsi %14 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %20 = arith.cmpi sle, %19, %13 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %21 = arith.cmpi sle, %19, %8 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %22 = arith.ori %20, %21 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i1 + %23:5 = scf.if %22 -> (!tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr) { + %27 = tt.addptr %arg16, %1 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i32 + %28 = tt.load %27 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr + %29 = arith.extsi %2 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %30 = arith.extsi %arg26 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %31 = arith.muli %29, %30 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %32 = arith.addi %31, %28 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %33 = arith.extsi %arg24 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %34 = arith.muli %29, %33 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %35 = arith.extsi %arg22 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %36 = arith.muli %5, %35 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %37 = arith.addi %34, %36 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %38 = arith.muli %2, %arg25 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %39 = arith.extsi %arg23 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %40 = arith.muli %10, %39 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %41 = arith.extsi %38 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %42 = arith.addi %41, %40 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %43 = tt.addptr %arg17, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + %44 = tt.addptr %arg18, %42 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + %45 = tt.addptr %arg19, %42 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + %46 = tt.addptr %arg20, %32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + %47 = tt.addptr %arg21, %32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : !tt.ptr, i64 + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %43, %44, %45, %46, %47 : !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr + } else { + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %arg17, %arg18, %arg19, %arg20, %arg21 : !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr, !tt.ptr + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %24 = arith.extsi %14 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 to i64 + %25 = arith.cmpi slt, %24, %13 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + scf.if %25 { + %27 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = tt.splat %14 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %29 = arith.addi %27, %15 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %30 = arith.addi %28, %16 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %31 = arith.extsi %14 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %32 = arith.addi %10, %31 {async_task_id = dense<0> : vector<1xi32>} : i64 + %33 = arith.trunci %32 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %34 = arith.addi %33, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %35 = arith.addi %33, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %36 = arith.muli %2, %arg25 {async_task_id = dense<0> : vector<1xi32>} : i32 + %37 = tt.experimental_descriptor_load %arg6[%33, %36] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %38 = tt.experimental_descriptor_load %arg6[%35, %36] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %39 = triton_gpu.local_alloc %37 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %40 = triton_gpu.local_alloc %38 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %41 = tt.experimental_descriptor_load %arg7[%33, %36] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %42 = tt.experimental_descriptor_load %arg7[%34, %36] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %43 = triton_gpu.local_alloc %41 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %44 = triton_gpu.local_alloc %42 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %45 = arith.addi %8, %c63_i64 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %46 = arith.divsi %45, %c64_i64 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %47 = arith.extsi %2 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %48 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %49 = arith.addi %5, %c64_i64 {async_task_id = dense<0> : vector<1xi32>} : i64 + %50 = arith.trunci %49 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %51 = arith.extsi %arg24 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %52 = arith.muli %47, %51 {async_task_id = dense<0> : vector<1xi32>} : i64 + %53 = arith.trunci %52 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %54 = tt.splat %8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i64 -> tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %55 = tt.splat %23#3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %56 = tt.splat %arg14 {async_task_id = dense<1> : vector<1xi32>} : f32 -> tensor<64x64xf32, #mma> + %57 = tt.splat %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<64x64xf32, #mma> + %58 = tt.splat %23#4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %59:3 = scf.for %arg31 = %c0_i64 to %46 step %c1_i64 iter_args(%arg32 = %c0_i32, %arg33 = %cst_1, %arg35 = %cst_1) -> (i32, tensor<64x128xf32, #mma1>, tensor<64x128xf32, #mma1>) : i64 { + %111 = tt.splat %arg32 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %112 = arith.addi %111, %48 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %113 = tt.experimental_descriptor_load %arg5[%50, %53] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %114 = triton_gpu.local_alloc %113 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %115 = tt.trans %114 {async_task_id = dense<[1, 2]> : vector<2xi32>, order = array} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> + %116 = arith.extsi %112 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> to tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %117 = arith.cmpi slt, %116, %54 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64xi64, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %118 = tt.addptr %55, %112 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %119 = tt.load %118, %117 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %120 = triton_nvidia_gpu.warp_group_dot %39, %115, %cst, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf32, #mma> + %121 = triton_nvidia_gpu.warp_group_dot %40, %115, %cst, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf32, #mma> + %122 = arith.mulf %120, %56 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %123 = arith.mulf %121, %57 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %124 = tt.experimental_descriptor_load %arg8[%50, %53] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %125 = triton_gpu.local_alloc %124 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %126 = tt.trans %125 {async_task_id = dense<[1, 2]> : vector<2xi32>, order = array} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> + %127 = triton_nvidia_gpu.warp_group_dot %43, %126, %cst, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf32, #mma> + %128 = triton_nvidia_gpu.warp_group_dot %44, %126, %cst, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf32, #mma> + %129 = tt.expand_dims %119 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<64xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xf32, #mma> + %130 = tt.broadcast %129 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> + %131 = tt.broadcast %129 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x64xf32, #mma> -> tensor<64x64xf32, #mma> + %132 = arith.subf %122, %130 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %133 = arith.subf %123, %131 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %134 = math.exp2 %132 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %135 = math.exp2 %133 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %136 = arith.truncf %134 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xbf16, #mma> + %137 = arith.truncf %135 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xbf16, #mma> + %138 = tt.addptr %58, %112 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %139 = tt.load %138, %117 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<64x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #mma}>> + %140 = triton_gpu.convert_layout %136 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xbf16, #mma> -> tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %141 = triton_gpu.convert_layout %137 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xbf16, #mma> -> tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %142 = triton_nvidia_gpu.warp_group_dot %140, %125, %arg33 {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma1> + %143 = triton_nvidia_gpu.warp_group_dot %141, %125, %arg35 {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma1> + %157 = arith.addi %arg32, %c64_i32 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 + scf.yield {async_task_id = dense<[1, 2]> : vector<2xi32>} %157, %142, %143 : i32, tensor<64x128xf32, #mma1>, tensor<64x128xf32, #mma1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, tt.num_stages = 2 : i32} + %60 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %61 = tt.expand_dims %60 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %62 = arith.cmpi slt, %61, %cst_0 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked> + %63 = tt.expand_dims %29 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %64 = tt.expand_dims %30 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %65 = arith.extsi %63 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %66 = arith.extsi %64 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %67 = tt.splat %13 {async_task_id = dense<1> : vector<1xi32>} : i64 -> tensor<64x1xi64, #blocked> + %68 = tt.splat %13 {async_task_id = dense<2> : vector<1xi32>} : i64 -> tensor<64x1xi64, #blocked> + %69 = arith.cmpi slt, %65, %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi64, #blocked> + %70 = arith.cmpi slt, %66, %68 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi64, #blocked> + %71 = tt.broadcast %62 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<64x128xi1, #blocked> + %72 = tt.broadcast %62 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<64x128xi1, #blocked> + %73 = tt.broadcast %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> + %74 = tt.broadcast %70 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> + %75 = arith.andi %71, %73 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xi1, #blocked> + %76 = arith.andi %72, %74 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xi1, #blocked> + %77 = tt.splat %arg23 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %78 = tt.splat %arg23 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %79 = arith.muli %63, %77 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %80 = arith.muli %64, %78 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %81 = tt.splat %23#2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %82 = tt.splat %23#2 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %83 = tt.addptr %81, %79 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %84 = tt.addptr %82, %80 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %85 = tt.broadcast %83 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %86 = tt.broadcast %84 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %87 = tt.broadcast %61 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> + %88 = tt.broadcast %61 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> + %89 = tt.addptr %85, %87 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %90 = tt.addptr %86, %88 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %91 = arith.truncf %59#1 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> to tensor<64x128xbf16, #mma1> + %92 = arith.truncf %59#2 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> to tensor<64x128xbf16, #mma1> + %93 = triton_gpu.convert_layout %91 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xbf16, #mma1> -> tensor<64x128xbf16, #blocked> + %94 = triton_gpu.convert_layout %92 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xbf16, #mma1> -> tensor<64x128xbf16, #blocked> + tt.store %89, %93, %75 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked> + tt.store %90, %94, %76 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %26 = arith.cmpi slt, %24, %8 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + scf.if %26 { + %27 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = tt.splat %14 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %29 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64xi32, #blocked1> + %30 = tt.splat %14 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64xi32, #blocked1> + %31 = arith.addi %27, %15 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %32 = arith.addi %28, %16 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %33 = arith.addi %29, %17 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #blocked1> + %34 = arith.addi %30, %18 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #blocked1> + %35 = arith.extsi %2 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %36 = arith.extsi %14 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %37 = arith.addi %5, %36 {async_task_id = dense<0> : vector<1xi32>} : i64 + %38 = arith.trunci %37 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %39 = arith.addi %38, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %40 = arith.addi %38, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %41 = arith.extsi %arg24 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %42 = arith.muli %35, %41 {async_task_id = dense<0> : vector<1xi32>} : i64 + %43 = arith.trunci %42 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %44 = tt.experimental_descriptor_load %arg9[%38, %43] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %45 = tt.experimental_descriptor_load %arg9[%40, %43] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %46 = triton_gpu.local_alloc %44 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %47 = triton_gpu.local_alloc %45 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %48 = arith.extsi %arg28 {async_task_id = dense<0> : vector<1xi32>} : i32 to i64 + %49 = arith.muli %35, %48 {async_task_id = dense<0> : vector<1xi32>} : i64 + %50 = arith.trunci %49 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %51 = tt.experimental_descriptor_load %arg12[%38, %50] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %52 = tt.experimental_descriptor_load %arg12[%39, %50] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %53 = triton_gpu.local_alloc %51 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %54 = triton_gpu.local_alloc %52 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %55 = arith.extsi %33 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #blocked1> to tensor<64xi64, #blocked1> + %56 = arith.extsi %34 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #blocked1> to tensor<64xi64, #blocked1> + %57 = tt.splat %8 {async_task_id = dense<1> : vector<1xi32>} : i64 -> tensor<64xi64, #blocked1> + %58 = tt.splat %8 {async_task_id = dense<2> : vector<1xi32>} : i64 -> tensor<64xi64, #blocked1> + %59 = arith.cmpi slt, %55, %57 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi64, #blocked1> + %60 = arith.cmpi slt, %56, %58 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi64, #blocked1> + %61 = tt.splat %23#3 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %62 = tt.splat %23#3 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %63 = tt.addptr %61, %33 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + %64 = tt.addptr %62, %34 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + %65 = tt.load %63, %59 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1> + %66 = tt.load %64, %60 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1> + %67 = triton_gpu.convert_layout %65 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xf32, #blocked1> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %68 = triton_gpu.convert_layout %66 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xf32, #blocked1> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %69 = tt.expand_dims %67 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xf32, #blocked3> + %70 = tt.expand_dims %68 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xf32, #blocked3> + %71 = arith.addi %13, %c63_i64 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %72 = arith.divsi %71, %c64_i64 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i64 + %73 = tt.splat %23#4 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %74 = tt.splat %23#4 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %75 = tt.addptr %73, %33 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + %76 = tt.addptr %74, %34 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + %77 = tt.load %75, %59 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1> + %78 = tt.load %76, %60 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x!tt.ptr, #blocked1> + %79 = arith.trunci %10 {async_task_id = dense<0> : vector<1xi32>} : i64 to i32 + %80 = arith.muli %2, %arg25 {async_task_id = dense<0> : vector<1xi32>} : i32 + %81 = tt.splat %arg14 {async_task_id = dense<1> : vector<1xi32>} : f32 -> tensor<64x64xf32, #mma> + %82 = tt.splat %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<64x64xf32, #mma> + %83 = tt.broadcast %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xf32, #blocked3> -> tensor<64x64xf32, #blocked3> + %84 = tt.broadcast %70 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xf32, #blocked3> -> tensor<64x64xf32, #blocked3> + %85 = triton_gpu.convert_layout %83 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + %86 = triton_gpu.convert_layout %84 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + %87 = triton_gpu.convert_layout %77 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xf32, #blocked1> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %88 = triton_gpu.convert_layout %78 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xf32, #blocked1> -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %89 = tt.expand_dims %87 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xf32, #blocked3> + %90 = tt.expand_dims %88 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xf32, #blocked3> + %91 = tt.broadcast %89 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xf32, #blocked3> -> tensor<64x64xf32, #blocked3> + %92 = tt.broadcast %90 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xf32, #blocked3> -> tensor<64x64xf32, #blocked3> + %93 = triton_gpu.convert_layout %91 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + %94 = triton_gpu.convert_layout %92 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + %95 = tt.splat %arg14 {async_task_id = dense<1> : vector<1xi32>} : f32 -> tensor<64x128xf32, #mma1> + %96 = tt.splat %arg14 {async_task_id = dense<2> : vector<1xi32>} : f32 -> tensor<64x128xf32, #mma1> + %97:2 = scf.for %arg31 = %c0_i64 to %72 step %c1_i64 iter_args(%arg32 = %cst_1, %arg33 = %cst_1) -> (tensor<64x128xf32, #mma1>, tensor<64x128xf32, #mma1>) : i64 { + %135 = tt.experimental_descriptor_load %arg10[%79, %80] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %136 = triton_gpu.local_alloc %135 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %137 = tt.trans %136 {async_task_id = dense<[1, 2]> : vector<2xi32>, order = array} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> + %138 = tt.experimental_descriptor_load %arg11[%79, %80] {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128xbf16, #blocked2> + %139 = triton_gpu.local_alloc %138 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xbf16, #blocked2>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> + %140 = tt.trans %139 {async_task_id = dense<[1, 2]> : vector<2xi32>, order = array} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> + %141 = triton_nvidia_gpu.warp_group_dot %46, %137, %cst, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf32, #mma> + %142 = triton_nvidia_gpu.warp_group_dot %47, %137, %cst, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf32, #mma> + %143 = arith.mulf %141, %81 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %144 = arith.mulf %142, %82 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %145 = triton_nvidia_gpu.warp_group_dot %53, %140, %cst, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf32, #mma> + %146 = triton_nvidia_gpu.warp_group_dot %54, %140, %cst, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x64xbf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf32, #mma> + %147 = arith.subf %143, %85 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %148 = arith.subf %144, %86 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %149 = math.exp2 %147 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %150 = math.exp2 %148 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %151 = arith.subf %145, %93 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %152 = arith.subf %146, %94 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %153 = arith.mulf %149, %151 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> + %154 = arith.mulf %150, %152 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> + %155 = arith.truncf %153 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xbf16, #mma> + %156 = arith.truncf %154 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xf32, #mma> to tensor<64x64xbf16, #mma> + %157 = triton_gpu.convert_layout %155 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x64xbf16, #mma> -> tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %158 = triton_gpu.convert_layout %156 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x64xbf16, #mma> -> tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %159 = triton_nvidia_gpu.warp_group_dot %157, %136, %cst_1, %false {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma1> + %160 = triton_nvidia_gpu.warp_group_dot %158, %136, %cst_1, %false {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma1> + %161 = arith.mulf %159, %95 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %162 = arith.mulf %160, %96 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %163 = arith.addf %arg32, %161 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %164 = arith.addf %arg33, %162 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> + scf.yield {async_task_id = dense<[1, 2]> : vector<2xi32>} %163, %164 : tensor<64x128xf32, #mma1>, tensor<64x128xf32, #mma1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, tt.num_stages = 2 : i32} + %98 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %99 = tt.expand_dims %98 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %100 = arith.cmpi slt, %99, %cst_0 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked> + %101 = tt.expand_dims %31 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %102 = tt.expand_dims %32 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %103 = arith.extsi %101 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %104 = arith.extsi %102 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked> to tensor<64x1xi64, #blocked> + %105 = tt.splat %8 {async_task_id = dense<1> : vector<1xi32>} : i64 -> tensor<64x1xi64, #blocked> + %106 = tt.splat %8 {async_task_id = dense<2> : vector<1xi32>} : i64 -> tensor<64x1xi64, #blocked> + %107 = arith.cmpi slt, %103, %105 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi64, #blocked> + %108 = arith.cmpi slt, %104, %106 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi64, #blocked> + %109 = tt.broadcast %100 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<64x128xi1, #blocked> + %110 = tt.broadcast %100 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<64x128xi1, #blocked> + %111 = tt.broadcast %107 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> + %112 = tt.broadcast %108 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi1, #blocked> -> tensor<64x128xi1, #blocked> + %113 = arith.andi %109, %111 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xi1, #blocked> + %114 = arith.andi %110, %112 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xi1, #blocked> + %115 = tt.splat %arg22 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %116 = tt.splat %arg22 {async_task_id = dense<2> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %117 = arith.muli %101, %115 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %118 = arith.muli %102, %116 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %119 = tt.splat %23#0 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %120 = tt.splat %23#0 {async_task_id = dense<2> : vector<1xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %121 = tt.addptr %119, %117 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %122 = tt.addptr %120, %118 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %123 = tt.broadcast %121 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %124 = tt.broadcast %122 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %125 = tt.broadcast %99 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> + %126 = tt.broadcast %99 {async_task_id = dense<2> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<64x128xi32, #blocked> + %127 = tt.addptr %123, %125 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %128 = tt.addptr %124, %126 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %129 = arith.mulf %97#0, %cst_2 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %130 = arith.mulf %97#1, %cst_2 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> + %131 = arith.truncf %129 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma1> to tensor<64x128xbf16, #mma1> + %132 = arith.truncf %130 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma1> to tensor<64x128xbf16, #mma1> + %133 = triton_gpu.convert_layout %131 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xbf16, #mma1> -> tensor<64x128xbf16, #blocked> + %134 = triton_gpu.convert_layout %132 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xbf16, #mma1> -> tensor<64x128xbf16, #blocked> + tt.store %127, %133, %113 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked> + tt.store %128, %134, %114 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir new file mode 100644 index 000000000..3816f5bc4 --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir @@ -0,0 +1,136 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-data-partition=num-consumer-groups=2 | FileCheck %s + +// CHECK-LABEL: @matmul_persistent_ws_cooperative_kernel +// CHECK: %[[#GA1:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr +// CHECK: %[[#GA2:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr +// CHECK: %[[#LA1:]] = triton_gpu.local_alloc %[[#GA1]] +// CHECK: %[[#LA2:]] = triton_gpu.local_alloc %[[#GA2]] +// CHECK: %[[#GB:]] = tt.load {{.*}} : tensor<64x256x!tt.ptr +// CHECK: %[[#LB:]] = triton_gpu.local_alloc %[[#GB]] +// CHECK: %[[#C1:]] = triton_nvidia_gpu.warp_group_dot %[[#LA1]], %[[#LB]], {{.*}} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf32, #mma> +// CHECK: %[[#C2:]] = triton_nvidia_gpu.warp_group_dot %[[#LA2]], %[[#LB]], {{.*}} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf32, #mma> +// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr, #blocked1> +// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr, #blocked1> + + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_persistent_ws_cooperative_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<64> : tensor<128x64xi32, #blocked> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c255_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 255 : i32 + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 256 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c8_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 8 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %cst_0 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %cst_1 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x256xf16, #blocked1> + %cst_2 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.divsi %0, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.addi %arg4, %c255_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.divsi %2, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.muli %1, %3 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = tt.get_num_programs x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = arith.muli %3, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %9 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %10 = tt.splat %arg3 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %11 = tt.make_range {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %12 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %13 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %14 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %15 = tt.expand_dims %14 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %16 = tt.broadcast %15 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> + %17 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %18 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = tt.expand_dims %18 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %20 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %21 = arith.muli %19, %20 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %22 = tt.broadcast %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %23 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked1> + %24 = arith.addi %arg5, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %25 = arith.divsi %24, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %26 = tt.expand_dims %14 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %27 = tt.expand_dims %18 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %28 = arith.muli %arg7, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %29 = tt.splat %28 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x256xi32, #blocked1> + %30 = tt.splat %arg8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %31 = tt.splat %arg2 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %32 = tt.splat %arg3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %33 = tt.splat %arg4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<1x256xi32, #blocked1> + scf.for %arg9 = %5 to %4 step %6 : i32 { + %34 = arith.divsi %arg9, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %35 = arith.muli %34, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %36 = arith.subi %1, %35 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %37 = arith.minsi %36, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %38 = arith.remsi %arg9, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %39 = arith.remsi %38, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %40 = arith.addi %35, %39 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %41 = arith.divsi %38, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %42 = arith.muli %40, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %43 = tt.splat %42 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %44 = tt.splat %42 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %45 = arith.addi %43, %8 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %46 = arith.addi %44, %9 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %47 = arith.remsi %45, %10 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %48 = arith.muli %41, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %49 = tt.splat %48 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %50 = arith.addi %49, %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %51 = arith.remsi %50, %12 {async_task_id = dense<0> : vector<1xi32>} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %52 = tt.expand_dims %47 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %53 = arith.muli %52, %13 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %54 = tt.broadcast %53 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked> + %55 = arith.addi %54, %16 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64xi32, #blocked> + %56 = tt.addptr %17, %55 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %57 = tt.expand_dims %51 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %58 = tt.broadcast %57 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %59 = arith.addi %22, %58 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256xi32, #blocked1> + %60 = tt.addptr %23, %59 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %61:3 = scf.for %arg10 = %c0_i32 to %25 step %c1_i32 iter_args(%arg11 = %cst_2, %arg12 = %56, %arg13 = %60) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { + %76 = arith.muli %arg10, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %77 = arith.subi %arg5, %76 {async_task_id = dense<0> : vector<1xi32>} : i32 + %78 = tt.splat %77 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x64xi32, #blocked> + %79 = arith.cmpi slt, %26, %78 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %80 = tt.broadcast %79 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi1, #blocked> -> tensor<128x64xi1, #blocked> + %81 = tt.load %arg12, %80, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked> + %82 = triton_gpu.local_alloc %81 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %83 = tt.splat %77 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %84 = arith.cmpi slt, %27, %83 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %85 = tt.broadcast %84 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %86 = tt.load %arg13, %85, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1> + %87 = triton_gpu.local_alloc %86 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> + %88 = triton_nvidia_gpu.warp_group_dot %82, %87, %arg11 {async_task_id = dense<[1, 2]> : vector<2xi32>, inputPrecision = 0 : i32} : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + %89 = tt.addptr %arg12, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %90 = tt.addptr %arg13, %29 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %88, %89, %90 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %62 = arith.truncf %61#0 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %63 = tt.expand_dims %46 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %64 = arith.muli %30, %63 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi32, #blocked1> + %65 = tt.addptr %31, %64 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %66 = tt.expand_dims %50 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %67 = tt.broadcast %65 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> + %68 = tt.broadcast %66 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %69 = tt.addptr %67, %68 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %70 = arith.cmpi slt, %63, %32 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi32, #blocked1> + %71 = arith.cmpi slt, %66, %33 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi32, #blocked1> + %72 = tt.broadcast %70 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %73 = tt.broadcast %71 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %74 = arith.andi %72, %73 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xi1, #blocked1> + %75 = triton_gpu.convert_layout %62 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.store %69, %75, %74 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir new file mode 100644 index 000000000..de69a59b8 --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir @@ -0,0 +1,237 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-lowering=num-consumer-groups=1 | FileCheck %s + +// CHECK: %[[#PBARRIER:]] = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64 +// CHECK: %[[#CBARRIER:]] = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64 +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#PBARRIER]][%c0_i32] +// CHECK: triton_nvidia_gpu.init_barrier %[[#]], 128 +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#CBARRIER]][%c0_i32] +// CHECK: triton_nvidia_gpu.init_barrier %[[#]], 1 +// CHECK: scf.for +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#CBARRIER]] +// CHECK: triton_nvidia_gpu.wait_barrier %[[#]] +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#PBARRIER]] +// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[#]] +// CHECK: scf.for +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#PBARRIER]] +// CHECK: triton_nvidia_gpu.wait_barrier %[[#]] +// CHECK: triton_gpu.local_load +// CHECK: triton_gpu.local_load +// CHECK: tt.dot +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#CBARRIER]] +// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[#]] + + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + %2 = triton_nvidia_gpu.create_token {num = 1 : i32} : tensor<1x!triton_nvidia_gpu.token> + %3 = triton_nvidia_gpu.get_async_task_id : i32 + %c0_i32 = arith.constant 0 : i32 + %4 = arith.cmpi eq, %3, %c0_i32 : i32 + scf.if %4 { + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked> + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked1> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_3 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked1> + %6 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %10, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.muli %12, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.subi %8, %13 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.minsi %14, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.remsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.remsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.addi %13, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.divsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.muli %18, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %23 = tt.splat %20 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = arith.addi %23, %21 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %26 = arith.remsi %24, %25 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.muli %19, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %28 = tt.splat %27 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %22 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = arith.remsi %29, %30 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %32 = tt.expand_dims %26 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %33 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %34 = arith.muli %32, %33 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %35 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %36 = tt.expand_dims %35 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %37 = tt.broadcast %34 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %38 = tt.broadcast %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %39 = arith.addi %37, %38 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256xi32, #blocked1> + %40 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked1> + %41 = tt.addptr %40, %39 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %42 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %43 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %44 = tt.expand_dims %42 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %45 = tt.expand_dims %43 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %46 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked> + %47 = arith.muli %44, %46 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> + %48 = tt.expand_dims %31 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %49 = tt.broadcast %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> -> tensor<256x128xi32, #blocked> + %50 = tt.broadcast %48 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<256x128xi32, #blocked> + %51 = arith.addi %49, %50 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128xi32, #blocked> + %52 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked> + %53 = tt.addptr %52, %51 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked>, tensor<256x128xi32, #blocked> + %54 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %55 = arith.divsi %54, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %56 = arith.muli %arg7, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %57 = tt.splat %56 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x128xi32, #blocked> + %c1_i32_4 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_5 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %false = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} false + %58:4 = scf.for %arg9 = %c0_i32_1 to %55 step %c1_i32_0 iter_args(%arg10 = %41, %arg11 = %53, %arg12 = %false, %arg13 = %c0_i32_5) -> (tensor<128x256x!tt.ptr, #blocked1>, tensor<256x128x!tt.ptr, #blocked>, i1, i32) : i32 { + %59 = arith.muli %arg9, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %60 = arith.subi %arg5, %59 {async_task_id = dense<0> : vector<1xi32>} : i32 + %61 = tt.splat %60 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x256xi32, #blocked1> + %62 = arith.cmpi slt, %36, %61 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> + %63 = tt.broadcast %62 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + triton_nvidia_gpu.producer_acquire %2, %arg13, %false {async_task_id = dense<0> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32, i1 + %c0_i32_6 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %c1_i32_7 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %64 = triton_gpu.memdesc_subview %0[%arg13, %c0_i32_6, %c0_i32_6] {async_task_id = dense<0> : vector<1xi32>} : !tt.memdesc<1x128x256xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %65 = triton_gpu.async_copy_global_to_local %arg10, %64 mask %63 other %cst_2 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1> -> <128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %66 = tt.splat %60 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked> + %67 = arith.cmpi slt, %45, %66 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> + %68 = tt.broadcast %67 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi1, #blocked> -> tensor<256x128xi1, #blocked> + %c0_i32_8 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %c1_i32_9 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %69 = triton_gpu.memdesc_subview %1[%arg13, %c0_i32_8, %c0_i32_8] {async_task_id = dense<0> : vector<1xi32>} : !tt.memdesc<1x256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + %70 = triton_gpu.async_copy_global_to_local %arg11, %69 mask %68 other %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked> -> <256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.producer_commit %2, %arg13 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32 + %71 = tt.addptr %arg10, %cst_3 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %72 = tt.addptr %arg11, %57 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked>, tensor<256x128xi32, #blocked> + %c1_i32_10 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %c0_i32_11 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_task_id = dense<0> : vector<1xi32>} true + %73 = arith.addi %arg13, %c1_i32_10 {async_task_id = dense<0> : vector<1xi32>} : i32 + %74 = arith.cmpi uge, %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %75 = arith.cmpi ult, %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %76 = arith.subi %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %77 = arith.select %74, %76, %73 {async_task_id = dense<0> : vector<1xi32>} : i32 + %78 = arith.xori %arg12, %true {async_task_id = dense<0> : vector<1xi32>} : i1 + %79 = arith.andi %74, %78 {async_task_id = dense<0> : vector<1xi32>} : i1 + %80 = arith.andi %75, %arg12 {async_task_id = dense<0> : vector<1xi32>} : i1 + %81 = arith.ori %79, %80 {async_task_id = dense<0> : vector<1xi32>} : i1 + scf.yield {async_task_id = dense<0> : vector<1xi32>} %71, %72, %81, %77 : tensor<128x256x!tt.ptr, #blocked1>, tensor<256x128x!tt.ptr, #blocked>, i1, i32 + } {async_task_id = dense<0> : vector<1xi32>} + } {async_task_id = dense<0> : vector<1xi32>} + %c1_i32 = arith.constant 1 : i32 + %5 = arith.cmpi eq, %3, %c1_i32 : i32 + scf.if %5 { + %cst = arith.constant {async_task_id = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #blocked2> + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked> + %cst_3 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked1> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_4 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked1> + %6 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %10, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.muli %12, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.subi %8, %13 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.minsi %14, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.remsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.remsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.addi %13, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.divsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.muli %18, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %23 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %20 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.splat %20 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %26 = arith.addi %24, %21 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %25, %22 {async_task_id = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %29 = arith.remsi %26, %28 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %30 = arith.muli %19, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %31 = tt.splat %30 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %32 = arith.addi %31, %23 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %33 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %34 = arith.divsi %33, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %c1_i32_5 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_6 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %false = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} false + %35:3 = scf.for %arg9 = %c0_i32_1 to %34 step %c1_i32_0 iter_args(%arg10 = %cst, %arg11 = %false, %arg12 = %c0_i32_6) -> (tensor<128x128xf32, #blocked2>, i1, i32) : i32 { + %c0_i32_7 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_8 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_9 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_10 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + triton_nvidia_gpu.consumer_wait %2, %arg12, %false {async_task_id = dense<1> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32, i1 + %54 = triton_gpu.memdesc_subview %0[%arg12, %c0_i32_7, %c0_i32_7] {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<1x128x256xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %55 = triton_gpu.local_load %54 {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<128x256xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #blocked1> + %56 = triton_gpu.convert_layout %55 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #blocked1> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> + %57 = triton_gpu.memdesc_subview %1[%arg12, %c0_i32_9, %c0_i32_9] {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<1x256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + %58 = triton_gpu.local_load %57 {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #blocked> + %59 = triton_gpu.convert_layout %58 {async_task_id = dense<1> : vector<1xi32>} : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> + %60 = tt.dot %56, %59, %arg10, inputPrecision = tf32 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2> + triton_nvidia_gpu.consumer_release %2, %arg12 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32 + %c1_i32_11 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_12 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_task_id = dense<1> : vector<1xi32>} true + %61 = arith.addi %arg12, %c1_i32_11 {async_task_id = dense<1> : vector<1xi32>} : i32 + %62 = arith.cmpi uge, %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %63 = arith.cmpi ult, %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %64 = arith.subi %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %65 = arith.select %62, %64, %61 {async_task_id = dense<1> : vector<1xi32>} : i32 + %66 = arith.xori %arg11, %true {async_task_id = dense<1> : vector<1xi32>} : i1 + %67 = arith.andi %62, %66 {async_task_id = dense<1> : vector<1xi32>} : i1 + %68 = arith.andi %63, %arg11 {async_task_id = dense<1> : vector<1xi32>} : i1 + %69 = arith.ori %67, %68 {async_task_id = dense<1> : vector<1xi32>} : i1 + scf.yield {async_task_id = dense<1> : vector<1xi32>} %60, %69, %65 : tensor<128x128xf32, #blocked2>, i1, i32 + } {async_task_id = dense<1> : vector<1xi32>} + %36 = arith.truncf %35#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf32, #blocked2> to tensor<128x128xf16, #blocked2> + %37 = tt.expand_dims %27 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %38 = tt.splat %arg8 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %39 = arith.muli %38, %37 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %40 = tt.splat %arg2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %41 = tt.addptr %40, %39 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %42 = tt.expand_dims %32 {async_task_id = dense<1> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %43 = tt.broadcast %41 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> + %44 = tt.broadcast %42 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> + %45 = tt.addptr %43, %44 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> + %46 = tt.splat %arg3 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %47 = arith.cmpi slt, %37, %46 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %48 = tt.splat %arg4 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<1x128xi32, #blocked> + %49 = arith.cmpi slt, %42, %48 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> + %50 = tt.broadcast %47 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> + %51 = tt.broadcast %49 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<128x128xi1, #blocked> + %52 = arith.andi %50, %51 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked> + %53 = triton_gpu.convert_layout %36 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf16, #blocked2> -> tensor<128x128xf16, #blocked> + tt.store %45, %53, %52 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked> + } {async_task_id = dense<1> : vector<1xi32>} + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_task_partition.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_task_partition.mlir new file mode 100644 index 000000000..75df00b1a --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_task_partition.mlir @@ -0,0 +1,64 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-task-partition=num-consumer-groups=2 | FileCheck %s + +// CHECK-LABEL: @matmul_persistent_tma_ws_cooperative_kernel +// CHECK: %[[#GA:]] = tt.experimental_descriptor_load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: %[[#LA:]] = triton_gpu.local_alloc %[[#GA]] +// CHECK: %[[#GB:]] = tt.experimental_descriptor_load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: %[[#LB:]] = triton_gpu.local_alloc %[[#GB]] +// CHECK: %[[#C:]] = triton_nvidia_gpu.warp_group_dot %[[#LA]], %[[#LB]], {{.*}} {async_task_id = dense<[1, 2]> : vector<2xi32> + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_persistent_tma_ws_cooperative_kernel(%arg0: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg1: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg2: !tt.ptr {tt.nv_tma_desc = 1 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c127_i32 = arith.constant 127 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = arith.addi %arg3, %c127_i32 : i32 + %1 = arith.divsi %0, %c128_i32 : i32 + %2 = arith.addi %arg4, %c255_i32 : i32 + %3 = arith.divsi %2, %c256_i32 : i32 + %4 = arith.muli %1, %3 : i32 + %5 = tt.get_program_id x : i32 + %6 = tt.get_num_programs x : i32 + %7 = arith.muli %3, %c8_i32 : i32 + %8 = arith.addi %arg5, %c63_i32 : i32 + %9 = arith.divsi %8, %c64_i32 : i32 + scf.for %arg6 = %5 to %4 step %6 : i32 { + %10 = arith.divsi %arg6, %7 : i32 + %11 = arith.muli %10, %c8_i32 : i32 + %12 = arith.subi %1, %11 : i32 + %13 = arith.minsi %12, %c8_i32 : i32 + %14 = arith.remsi %arg6, %7 : i32 + %15 = arith.remsi %14, %13 : i32 + %16 = arith.addi %11, %15 : i32 + %17 = arith.divsi %14, %13 : i32 + %18 = arith.muli %16, %c128_i32 : i32 + %19 = arith.muli %17, %c256_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %20:2 = scf.for %arg7 = %c0_i32 to %9 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { + %23 = tt.experimental_descriptor_load %arg0[%18, %arg9] : !tt.ptr -> tensor<128x64xf16, #blocked> + %24 = triton_gpu.local_alloc %23 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.experimental_descriptor_load %arg1[%arg9, %19] : !tt.ptr -> tensor<64x256xf16, #blocked1> + %26 = triton_gpu.local_alloc %25 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> + %27 = triton_nvidia_gpu.warp_group_dot %24, %26, %arg8 {inputPrecision = 0 : i32} : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + %28 = arith.addi %arg9, %c64_i32 : i32 + scf.yield %27, %28 : tensor<128x256xf32, #mma>, i32 + } + %21 = arith.truncf %20#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %22 = triton_gpu.convert_layout %21 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.experimental_descriptor_store %arg2[%18, %19], %22 : !tt.ptr, tensor<128x256xf16, #blocked1> + } + tt.return + } +} diff --git a/test/TritonNvidiaGPU/membar.mlir b/test/TritonNvidiaGPU/membar.mlir index 95202f12e..358f53fd7 100644 --- a/test/TritonNvidiaGPU/membar.mlir +++ b/test/TritonNvidiaGPU/membar.mlir @@ -9,8 +9,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: init_barrier tt.func @init_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -28,9 +28,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: inval_barrier tt.func @inval_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.inval_barrier %alloc : !tt.memdesc<1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.inval_barrier %alloc : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -48,9 +48,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: barrier_expect tt.func @barrier_expect(%pred : i1) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -68,9 +68,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: wait_barrier tt.func @wait_barrier(%phase : i32) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.wait_barrier %alloc, %phase : <1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %alloc, %phase : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -89,8 +89,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: gpu.barrier // CHECK-NEXT: init_barrier %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, mutable> - triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.ptr -> tensor<128x64xf16, #blocked0> tt.return %l : tensor<128x64xf16, #blocked0> } @@ -108,8 +108,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: triton_gpu.local_alloc tt.func public @tma_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, mutable> - triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.ptr, tensor<128x256xf32, #blocked0> tt.return } diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index 5e7bbb0c8..25e8e2d19 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -1,3 +1,4 @@ +#include "../third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Dialect.h" @@ -25,7 +26,8 @@ struct TestMembarPass ModuleOp moduleOp = cast(operation); // Print all ops after membar pass ModuleAllocation allocation(moduleOp); - ModuleMembarAnalysis membarPass(&allocation); + ModuleMembarAnalysis membarPass(&allocation, + mlir::triton::NVIDIA::canSkipBarSync); membarPass.run(); } }; diff --git a/test/lib/CMakeLists.txt b/test/lib/CMakeLists.txt index fc6ef10fa..d9e58999c 100644 --- a/test/lib/CMakeLists.txt +++ b/test/lib/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Analysis) +add_subdirectory(Instrumentation) diff --git a/test/lib/Instrumentation/CMakeLists.txt b/test/lib/Instrumentation/CMakeLists.txt new file mode 100644 index 000000000..90311bb86 --- /dev/null +++ b/test/lib/Instrumentation/CMakeLists.txt @@ -0,0 +1,38 @@ +set(GPU_INSTRUMENTATION_PASSES + GPUInstrumentationTestLib + ) + +set(GPUInstrumentationTestLib_SOURCES + GPUHello.cpp + ) + + +foreach( plugin ${GPU_INSTRUMENTATION_PASSES} ) + add_library( + ${plugin} + SHARED + ${${plugin}_SOURCES} + ) + + target_link_libraries( + ${plugin} + PRIVATE + LLVMCore + "$<$:-undefined dynamic_lookup>" + ) + # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python + # build. It is empty if building directly from the root + # CMakeLists.txt file. Therefore if not building from Python just + # use the default CMake shared lib path otherwise this causes a hard + # build error + if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + set_target_properties(${plugin} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY + "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation") + endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) + + # This is set to -fvisibility=hidden in the top level CMake file + # which causes the llvmGetPassPluginInfo symbol to be hidden and + # an "entry point not found" error. Reset it just for this target + target_compile_options(${plugin} PRIVATE -fvisibility=default) +endforeach() diff --git a/test/lib/Instrumentation/GPUHello.cpp b/test/lib/Instrumentation/GPUHello.cpp new file mode 100644 index 000000000..5c71857c8 --- /dev/null +++ b/test/lib/Instrumentation/GPUHello.cpp @@ -0,0 +1,76 @@ +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/PassPlugin.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +using namespace llvm; +using namespace std; + +namespace { + +struct GpuHello : public PassInfoMixin { + PreservedAnalyses run(Module &module, ModuleAnalysisManager &) { + bool modifiedCodeGen = runOnModule(module); + + return (modifiedCodeGen ? llvm::PreservedAnalyses::none() + : llvm::PreservedAnalyses::all()); + } + bool runOnModule(llvm::Module &module); + // isRequired being set to true keeps this pass from being skipped + // if it has the optnone LLVM attribute + static bool isRequired() { return true; } +}; + +} // end anonymous namespace + +bool GpuHello::runOnModule(Module &module) { + bool modifiedCodeGen = false; + + for (auto &function : module) { + if (function.isIntrinsic()) + continue; + StringRef functionName = function.getName(); + if (function.getCallingConv() == CallingConv::AMDGPU_KERNEL || + function.getCallingConv() == CallingConv::PTX_Kernel || + functionName.contains("kernel")) { + for (Function::iterator basicBlock = function.begin(); + basicBlock != function.end(); basicBlock++) { + for (BasicBlock::iterator inst = basicBlock->begin(); + inst != basicBlock->end(); inst++) { + DILocation *debugLocation = + dyn_cast(inst)->getDebugLoc(); + std::string sourceInfo = + (function.getName() + "\t" + debugLocation->getFilename() + ":" + + Twine(debugLocation->getLine()) + ":" + + Twine(debugLocation->getColumn())) + .str(); + + errs() << "Hello From First Instruction of GPU Kernel: " << sourceInfo + << "\n"; + return modifiedCodeGen; + } + } + } + } + return modifiedCodeGen; +} + +PassPluginLibraryInfo getPassPluginInfo() { + const auto callback = [](PassBuilder &pb) { + pb.registerOptimizerLastEPCallback([&](ModulePassManager &mpm, auto, auto) { + mpm.addPass(GpuHello()); + return true; + }); + }; + + return {LLVM_PLUGIN_API_VERSION, "gpu-hello", LLVM_VERSION_STRING, callback}; +}; + +extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo +llvmGetPassPluginInfo() { + return getPassPluginInfo(); +} diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 20e3ac608..a406eefa3 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -26,9 +26,9 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.triton_obj_root, 'test') - config.substitutions.append(('%PATH%', config.environment['PATH'])) -config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) +config.substitutions.append(("%shlibdir", config.llvm_shlib_dir)) +config.substitutions.append(("%shlibext", config.llvm_shlib_ext)) llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 1118ed36b..4053a8e7d 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -7,8 +7,8 @@ config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" config.llvm_lib_dir = "@LLVM_LIBS_DIR@" -config.llvm_shlib_dir = "@SHLIBDIR@" -config.llvm_shlib_ext = "@SHLIBEXT@" +config.llvm_shlib_dir = "@CMAKE_LIBRARY_OUTPUT_DIRECTORY@" +config.llvm_shlib_ext = "@CMAKE_SHARED_LIBRARY_SUFFIX@" config.llvm_exe_ext = "@EXEEXT@" config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.mlir_binary_dir = "@MLIR_BINARY_DIR@" diff --git a/third_party/amd/CMakeLists.txt b/third_party/amd/CMakeLists.txt index f75ff4b15..8228c3d39 100644 --- a/third_party/amd/CMakeLists.txt +++ b/third_party/amd/CMakeLists.txt @@ -3,5 +3,8 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) - add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms) + add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms TritonAMDGPUDialectToLLVM) +endif() +if(TRITON_BUILD_UT) + add_subdirectory(unittest) endif() diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index bdf2f863b..7a2ab1c53 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,7 +1,8 @@ -from triton.backends.compiler import BaseBackend, GPUTarget +from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor from triton._C.libtriton import ir, passes, llvm, amd from dataclasses import dataclass -from typing import Any, Tuple +from typing import Any, Dict, Tuple +from types import ModuleType import hashlib import tempfile import os @@ -11,18 +12,36 @@ from pathlib import Path +def min_dot_size(target: GPUTarget): + arch_str = target.arch + # CDNA 3.0 supports k==8 in all mfma variants except for int8 + # (where the smallest `k` supported is 16) + if "gfx94" in arch_str: + return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.is_int8() or rhsType.is_int8()) else (16, 16, 8) + # CDNA 2.0 always supports `k==8` + if "gfx9" in arch_str: + return lambda lhsType, rhsType: (16, 16, 8) + # Other architectures will only support 16,16,16 + return lambda lhsType, rhsType: (16, 16, 16) + + @dataclass(frozen=True) class HIPOptions: num_warps: int = 4 waves_per_eu: int = 1 - num_stages: int = 0 + num_stages: int = 2 num_ctas: int = 1 + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 extern_libs: dict = None cluster_dims: tuple = (1, 1, 1) debug: bool = False + sanitize_overflow: bool = True arch: str = None - allow_fp8e4nv: bool = False - allow_fp8e4b15: bool = False + supported_fp8_dtypes: Tuple[str] = ("fp8e5", ) + deprecated_fp8_dtypes: Tuple[str] = () default_dot_input_precision: str = "ieee" allowed_dot_input_precisions: Tuple[str] = ("ieee", ) enable_fp_fusion: bool = True @@ -32,11 +51,18 @@ class HIPOptions: max_num_imprecise_acc_default: int = 0 backend_name: str = 'hip' + # The following option provides hints to the AMDGPU backend regarding instruction scheduling + # for all `tt.dot` operations in a kernel. The "default" variant preserves the default + # instruction scheduling of the AMDGPU backend which aims at maximizing occupancy. + # The option is experimental and may change at any time regarding its semantics and/or may + # be gone entirely anytime. + instruction_sched_variant: str = 'default' + def __post_init__(self): default_libdir = Path(__file__).parent / 'lib' extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) # Ignore user-defined warp size for gfx9 - warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch else 64 + warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64 object.__setattr__(self, 'warp_size', warp_size) libs = ["ocml", "ockl"] for lib in libs: @@ -50,6 +76,44 @@ def hash(self): return hashlib.sha256(key.encode("utf-8")).hexdigest() +@register_descriptor +class HIPAttrsDescriptor(AttrsDescriptor): + # This property asserts if the underlying storage area of a given pointer + # can be resepresented as a 32 bit integer. When this is true, we can be + # sure that all indices into the tensor behind that pointer can use 32-bit + # indexing. That opens the door for the AMD backend to use buffer load/store + # instrinsics, which requires this property. Buffer load/store intrinsics + # gives direct out-of-bound support and simplifies index calculation for + # lower register pressure. + __slots__ = ("pointer_range_32") + + def _add_backend_properties(self, params=None, values=None): + self.property_values["tt.pointer_range"] = 32 + if params is None or values is None: + return + + self.arg_properties["tt.pointer_range"] = [ + param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg) + and not param.do_not_specialize and not param.do_not_specialize_on_alignment + ] + + @staticmethod + def is_within2gb(arg): + if hasattr(arg, "ptr_range"): + return arg.ptr_range() <= 2**31 - 1 + if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"): + # Please note that 2**31-1 is the max int32 positive limit + return arg.untyped_storage().size() <= 2**31 - 1 + return False + + @staticmethod + def get_property_key(val, align): + generic_key = AttrsDescriptor.get_property_key(val, align) + hip_key = "S" if HIPAttrsDescriptor.is_within2gb(val) else "N" + key = (generic_key + hip_key).replace("N", "") + return key if key else "N" + + class HIPBackend(BaseBackend): @staticmethod @@ -63,6 +127,15 @@ def __init__(self, target: GPUTarget) -> None: def parse_options(self, opts) -> Any: args = {'arch': self.target.arch} + + if "supported_fp8_dtypes" not in opts: + supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes) + if self.target.arch in ('gfx940', 'gfx941', 'gfx942'): + supported_fp8_dtypes.update({'fp8e4b8', 'fp8e5b16'}) + args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + + if "enable_fp_fusion" not in opts: + args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts}) return HIPOptions(**args) @@ -77,12 +150,23 @@ def pack_metadata(self, metadata): ) def get_codegen_implementation(self): - codegen_fns = dict() + codegen_fns = {"min_dot_size": min_dot_size(self.target)} return codegen_fns + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.hip import libdevice + return {"triton.language.extra.libdevice": libdevice} + def load_dialects(self, ctx): amd.load_dialects(ctx) + def get_attrs_descriptor(self, params, args): + return HIPAttrsDescriptor(params, args) + + @staticmethod + def compute_spec_key(arg, align): + return HIPAttrsDescriptor.get_property_key(arg, align) + @staticmethod def path_to_rocm_lld(): # Check env path for ld.lld @@ -101,7 +185,7 @@ def path_to_rocm_lld(): lld = Path("/usr/bin/ld.lld") if lld.is_file(): return lld - raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found") + raise Exception("ROCm linker /opt/rocm/llvm/bin/ld.lld not found. Set 'TRITON_HIP_LLD_PATH' to its path.") @staticmethod def make_ttir(mod, metadata, options): @@ -115,6 +199,7 @@ def make_ttir(mod, metadata, options): passes.common.add_cse(pm) passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) + passes.ttir.add_loop_unroll(pm) pm.run(mod) return mod @@ -134,14 +219,25 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) - if options.num_stages == 0 and amd.has_matrix_core_feature(options.arch): - amd.passes.ttgpuir.add_stream_pipeline(pm) + if amd.has_matrix_core_feature(options.arch): + assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. " + "We used to trigger software pipelining with " + "num_stages == 0. Now it will not happen anymore; " + "please update to use num_stages == 2 for " + "equivalent behavior in the past.") + amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages) passes.common.add_canonicalizer(pm) + amd.passes.ttgpuir.insert_instruction_sched_hints(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm) - if options.num_stages != 0: + if amd.has_matrix_core_feature(options.arch): amd.passes.ttgpuir.add_reorder_instructions(pm) + if os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1": + amd.passes.ttgpuir.add_canonicalize_pointers(pm) + passes.common.add_canonicalizer(pm) + amd.passes.ttgpuir.add_convert_to_buffer_ops(pm) + passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) pm.run(mod) @@ -154,6 +250,13 @@ def make_llir(src, metadata, options): pm = ir.pass_manager(mod.context) pm.enable_debug() amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch) + # custom_lds_size is an experimental parameter that defines amount of LDS available + # for one thread block. Measured in bytes. + # + # If custom_lds_size = 0, pass will consider all LDS is available for one threads block, + # LDS size is determined by provided arch name. + custom_lds_size = 0 + amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size) passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) @@ -175,20 +278,18 @@ def make_llir(src, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) + amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": passes.llvmir.add_di_scope(pm) - # This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block - # count caused by predicated loads/stores. In certain kernels, the addition of these blocks can cause the MLIR - # canonicalizer to never finish when attempting to merge blocks. The permanent solution under consideration - # involves using MUBUF instructions that have built-in out-of-bounds checks, which would eliminate the need - # for conditional branching around memory accesses. - amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm) + amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ) pm.run(mod) # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() llvm_mod = llvm.to_module(mod, context) + amd.attach_target_triple(llvm_mod) + llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, '') # Set various control constants on the LLVM module so that device # libraries can resolve references to them. @@ -208,11 +309,16 @@ def make_llir(src, metadata, options): denormal_mode = "preserve-sign" if options.allow_flush_denorm else "ieee" fns[0].add_fn_attr("denormal-fp-math-f32", denormal_mode) + # Hint the compiler that we'd like the firmware to set the kernel arguments + # to user SGPRs so that the kernel does not need to s_load its arguments + # from memory. + amd.set_all_fn_arg_inreg(fns[0]) + if options.extern_libs: paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)] llvm.link_extern_libs(llvm_mod, paths) - llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, amd.TARGET_TRIPLE) + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion) # Get some metadata metadata["shared"] = src.get_int_attr("triton_gpu.shared") diff --git a/third_party/amd/backend/driver.c b/third_party/amd/backend/driver.c index 233613a55..62eee09e7 100644 --- a/third_party/amd/backend/driver.c +++ b/third_party/amd/backend/driver.c @@ -132,12 +132,12 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { // create a struct to hold device properties return Py_BuildValue( - "{s:i, s:i, s:i, s:i, s:i, s:i, s:s, s:i}", "max_shared_mem", + "{s:i, s:i, s:i, s:i, s:i, s:i, s:s, s:i, s:i}", "max_shared_mem", props.sharedMemPerBlock, "max_num_regs", props.regsPerBlock, "multiprocessor_count", props.multiProcessorCount, "sm_clock_rate", props.clockRate, "mem_clock_rate", props.memoryClockRate, "mem_bus_width", props.memoryBusWidth, "arch", props.gcnArchName, "warpSize", - props.warpSize); + props.warpSize, "max_threads_per_sm", props.maxThreadsPerMultiProcessor); } static PyObject *loadBinary(PyObject *self, PyObject *args) { diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index c1ff6e1d6..6e1a368bf 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -484,6 +484,10 @@ def __init__(self): self.utils = HIPUtils() self.launcher_cls = HIPLauncher + def get_device_interface(self): + import torch + return torch.cuda + @staticmethod def is_active(): import torch @@ -495,3 +499,14 @@ def get_current_target(self): arch = device_properties['arch'] warp_size = device_properties['warpSize'] return GPUTarget("hip", arch.split(':')[0], warp_size) + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # It's the same as the Nvidia backend. + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') diff --git a/third_party/amd/backend/include/hsa/amd_hsa_elf.h b/third_party/amd/backend/include/hsa/amd_hsa_elf.h index 51aa389a0..0656c9d99 100644 --- a/third_party/amd/backend/include/hsa/amd_hsa_elf.h +++ b/third_party/amd/backend/include/hsa/amd_hsa_elf.h @@ -130,6 +130,7 @@ enum : unsigned { EF_AMDGPU_MACH_AMDGCN_GFX1151 = 0x04a, EF_AMDGPU_MACH_AMDGCN_GFX941 = 0x04b, EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c, + EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f, // First/last AMDGCN-based processors. EF_AMDGPU_MACH_AMDGCN_FIRST = EF_AMDGPU_MACH_AMDGCN_GFX600, diff --git a/third_party/amd/include/CMakeLists.txt b/third_party/amd/include/CMakeLists.txt index b2a802f75..08707d601 100644 --- a/third_party/amd/include/CMakeLists.txt +++ b/third_party/amd/include/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Dialect) add_subdirectory(TritonAMDGPUToLLVM) add_subdirectory(TritonAMDGPUTransforms) diff --git a/third_party/amd/include/Dialect/CMakeLists.txt b/third_party/amd/include/Dialect/CMakeLists.txt new file mode 100644 index 000000000..4f9163bdf --- /dev/null +++ b/third_party/amd/include/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonAMDGPU) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/CMakeLists.txt b/third_party/amd/include/Dialect/TritonAMDGPU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/amd/include/Dialect/TritonAMDGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..25a57075b --- /dev/null +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -0,0 +1,16 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonAMDGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=amdgpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=amdgpu) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_mlir_doc(TritonAMDGPUDialect TritonAMDGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonAMDGPUOps TritonAMDGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonAMDGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonAMDGPUAttrDefs.td) +mlir_tablegen(TritonAMDGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonAMDGPUAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(TritonAMDGPUAttrDefsIncGen) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h new file mode 100644 index 000000000..a7395f86d --- /dev/null +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_AMDGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_AMDGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Dialect/Triton/IR/Traits.h" +// clang-format off +#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc" +// clang-format on + +#define GET_ATTRDEF_CLASSES +#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "amd/include/Dialect/TritonAMDGPU/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace amdgpu {} // namespace amdgpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td new file mode 100644 index 000000000..31a43acd2 --- /dev/null +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_AMDGPU_ATTRDEFS +#define TRITON_AMDGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "TritonAMDGPUDialect.td" + +class TritonAMDGPU_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +#endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td new file mode 100644 index 000000000..d5956cf7a --- /dev/null +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_AMDGPU_DIALECT +#define TRITON_AMDGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonAMDGPU_Dialect : Dialect { + let name = "amdgpu"; + let cppNamespace = "::mlir::triton::amdgpu"; + + let description = [{ + TritonAMDGPU Dialect hosts AMD specific ops at TritonGPU abstraction level. + }]; + + let dependentDialects = []; +} + +#endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td new file mode 100644 index 000000000..538e31378 --- /dev/null +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + + +#ifndef TRITON_AMDGPU_OPS +#define TRITON_AMDGPU_OPS + +include "mlir/IR/OpBase.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "TritonAMDGPUDialect.td" +include "TritonAMDGPUAttrDefs.td" + +class TT_AMDGPU_Op traits = []> : + Op { +} + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { + let summary = "A placeholder op for instruction scheduling hints within a basic block"; + let description = [{ + A placeholder op for instruction scheduling hints applied to instructions within + a basic block where the placeholder op is located. This op is primarily intended + to be used to adjust instruction scheduling inside the resulting main loop + of a `tt.dot` operation. It's easier to identify dot ops at a high level and, thus, + to mark intended scheduling regions. The hint ops are eventually lowered + into LLVM AMDGPU instruction scheduling primitives, which are meant to control + how different kinds of instructions (valu/mfma, global/shared memory, etc.) should + interleave for better instruction level parallelism. + }]; + + let assemblyFormat = [{attr-dict}]; +} + +// +// AMD Buffer operations. +// +def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [ + SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + MemoryEffects<[MemRead]>, + TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()">, + TypesMatchWith<"result and other have the same type", "result", "other", "$_self", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, +]>{ + let summary = "Load from a scalar base pointer and a tensor offset"; + let description = [{ + AMD Buffer load operation. Buffer store is similar to + a normal store but it accesses global memory via a scalar base pointer + and a tensor of offsets instead of a tensor of pointers. The other fields + are similar to a normal load, i.e., the `mask` is a boolean vector that + determines if a given element should be read from memory, and `other` is the + element that should be returned on lane `i` when `mask[i] == 0`. + }]; + let arguments = ( + ins + TT_Ptr:$ptr, + I32Tensor:$offsets, + Optional:$mask, + Optional:$other + ); + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $ptr `[` $offsets `]` (`,` $mask^)? (`,` $other^)? + attr-dict `:` type($result) + }]; +} + +def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [ + SameLoadStoreOperandsEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">, + TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">, + TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, +]>{ + let summary = "Store into scalar base pointer and a tensor offset"; + let description = [{ + AMD Buffer store operation. Buffer store is similar to + normal store but it accesses global memory via a scalar base pointer + and a tensor of offsets instead of a tensor of pointers. The other fields + are similar to a normal store , i.e., the `mask` is a boolean vector that + determines if a given element should be written to memory, and `value` is the + tensor of elements that should be written on lane `i` when `mask[i] == 1`. + }]; + let arguments = ( + ins + TT_Tensor:$value, + TT_Ptr:$ptr, + I32Tensor:$offsets, + Optional:$mask + ); + + let assemblyFormat = [{ + $value `,` $ptr `[` $offsets `]` (`,` $mask^)? + attr-dict `:` type($value) + }]; +} + +#endif diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h b/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h index 003acbf7b..ac37aab81 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h @@ -116,7 +116,7 @@ struct GCNBuilder { Operand() = default; Operand(const Operation &) = delete; Operand(Value value, StringRef constraint) - : value(value), constraint(constraint) {} + : constraint(constraint), value(value) {} bool isList() const { return !value && constraint.empty(); } @@ -342,7 +342,7 @@ struct GCNInstrExecution { explicit GCNInstrExecution(GCNInstrCommon *instr, llvm::ArrayRef oprs, llvm::ArrayRef modifiers) - : instr(instr), argsInOrder(oprs.begin(), oprs.end()), + : argsInOrder(oprs.begin(), oprs.end()), instr(instr), mods(modifiers.begin(), modifiers.end()) {} std::string dump() const; diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index df5ad7849..bd726bd84 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -22,11 +22,23 @@ namespace AMD { std::unique_ptr> createDecomposeUnsupportedConversionsPass(StringRef targetArch); +/// @brief Creates pass that keep LDS consumption within specified limits. +/// @param arch target architecture name, for example "gfx940" +/// @param customLDSLimit defines LDS size available for one thread block +/// zero value tells pass that whole LDS is available on a device +/// @return created pass +std::unique_ptr> +createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0); } // namespace AMD std::unique_ptr> createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); -std::unique_ptr> createConvertBuiltinFuncToLLVMPass(); +std::unique_ptr> +createConvertBuiltinFuncToLLVMPass(bool ftz); +std::unique_ptr> +createInsertInstructionSchedHintsPass(); +std::unique_ptr> +createLowerInstructionSchedHintsPass(std::string variant); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index 986c6763b..9f4665aef 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -13,6 +13,18 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers ]; } +def OptimizeAMDLDSUsage : Pass<"optimize-amd-lds-usage", "mlir::ModuleOp"> { + let summary = "Minimize LDS usage"; + let constructor = "mlir::triton::AMD::createOptimizeLDSUsagePass(\"\")"; + + let options = [ + Option<"targetArch", "target-arch", "std::string", /*default*/"", + "gfx target device architecture, e.g., gfx942">, + Option<"customLDSLimit", "lds-limit", "int", /*default*/"0", + "custom limit of LDS consumption, if not provided, maximum LDS size is used">, + ]; +} + def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> { let summary = "Convert TritonGPU to LLVM"; let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)"; @@ -37,10 +49,34 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::ModuleOp"> { let summary = "Convert Builtin Func to LLVM"; - let constructor = "mlir::triton::createConvertBuiltinFuncToLLVMPass()"; + let constructor = "mlir::triton::createConvertBuiltinFuncToLLVMPass(/*ftz=*/true)"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + + let options = [ + Option<"ftz", "ftz", "bool", /*default*/"true", + "flush denorms for math functions">, + ]; +} + +def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> { + let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; + let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()"; let dependentDialects = ["mlir::LLVM::LLVMDialect"]; +} + +def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { + let summary = "Lower instruction scheduling hints to LLVM intrinsics"; + let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")"; + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + + let options = [ + Option<"variant", "variant", "std::string", /*default*/"\"default\"", + "instruction scheduling variant">, + ]; } + #endif diff --git a/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h b/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h index 4127d85dc..121bb6172 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h @@ -86,6 +86,8 @@ class MfmaInsn { unsigned getNDim(); StringRef getInsnName(); unsigned getKBase(); + Type getElementTypeA(); + Type getElementTypeB(); }; } // namespace mlir diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index e7a9753b2..d0ffdae28 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -2,11 +2,12 @@ #define TRITON_DIALECT_TRITONAMDGPU_TRANSFORMS_PASSES_H_ #include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" namespace mlir { -std::unique_ptr createTritonAMDGPUStreamPipelinePass(); +std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2); std::unique_ptr createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), @@ -21,6 +22,10 @@ std::unique_ptr createTritonAMDGPUVerifier(); std::unique_ptr createTritonAMDGPUOptimizeEpiloguePass(); +std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); + +std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass(); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "TritonAMDGPUTransforms/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index a818b1ac9..433e60be6 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -3,7 +3,7 @@ include "mlir/Pass/PassBase.td" -def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::ModuleOp"> { +def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir::ModuleOp"> { let summary = "pipeline"; let description = [{ @@ -11,9 +11,15 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod tile }]; - let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()"; + let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()"; let dependentDialects = []; + + let options = [ + Option<"numStages", "num_stages", + "int32_t", /*default*/"2", + "Number of Pipeline stages"> + ]; } def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> { @@ -53,6 +59,46 @@ def TritonAMDGPUOptimizeEpilogue : Pass<"tritonamdgpu-optimize-epilogue", "mlir: } +def TritonAMDGPUCanonicalizePointers : Pass<"tritonamdgpu-canonicalize-pointers", "mlir::ModuleOp"> { + let summary = "Canonicalize pointers: rewrite pointers passed to load/store operation as a `` pair."; + + let description = [{ + This pass pushes all the constant pointer arithmetic on a scalar basePtr, while all the vector + pointer arithmetic to a vector offset. I.e., if we consider the following IR: + ``` + %v_ptr = tt.splat %s_ptr + %c_offset = tt.splat %s_offset + %v_offset0 = tt.make_range + %v_offset1 = tt.make_range + %v_ptr0 = tt.addptr %v_ptr, %c_offset + %v_ptr1 = tt.addptr %v_ptr0, %v_offset0 + %v_ptr2 = tt.addptr %v_ptr0, %v_offset1 + %data = tt.load(%v_ptr2) + ``` + We transform this into: + ``` + %s_ptr0 = tt.addptr %s_ptr, %s_offset + %v_offset = %zero + %v_offset = arith.addi %v_offset, %v_offset0 + %v_offset = arith.addi %v_offset, %v_offset1 + %c_ptr = tt.splat %s_ptr0 + %v_ptr = tt.addptr %c_ptr, %v_offset + %data = tt.load(%v_ptr) + ``` + In the above IR: + - `v_` means "variable vector across the program" + - `c_` means "constant vector across the program" + - `s_` means "scalar" + So we transform the IR such that the constant updates become scalar updates, and the variable updates happen on the offset. Note that + when we have to load the data, we splat the scalar pointer, add the "variable" offset and then issue the load. + }]; + + let constructor = "mlir::createTritonAMDGPUCanonicalizePointersPass()"; + + let dependentDialects = []; + +} + def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", "mlir::ModuleOp"> { let summary = "Reorder instructions"; @@ -65,4 +111,14 @@ def TritonAMDGPUReorderInstructions: Pass<"tritonamdgpu-reorder-instructions", " let dependentDialects = []; } +def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "mlir::ModuleOp"> { + let summary = "Convert memory operations to buffer operations"; + + let description = "This pass converts memory operations (e.g., tt.load/tt.store) to amdgpu buffer operations, if possible"; + + let constructor = "mlir::createTritonAMDGPUConvertToBufferOpsPass()"; + + let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; +} + #endif diff --git a/third_party/amd/language/hip/__init__.py b/third_party/amd/language/hip/__init__.py new file mode 100644 index 000000000..229b57d87 --- /dev/null +++ b/third_party/amd/language/hip/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/third_party/amd/language/hip/libdevice.py b/third_party/amd/language/hip/libdevice.py new file mode 100644 index 000000000..a69d4406c --- /dev/null +++ b/third_party/amd/language/hip/libdevice.py @@ -0,0 +1,475 @@ +from triton.language import core + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")), + (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__triton_hip_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")), + (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/amd/lib/CMakeLists.txt b/third_party/amd/lib/CMakeLists.txt index b2a802f75..15c000ab8 100644 --- a/third_party/amd/lib/CMakeLists.txt +++ b/third_party/amd/lib/CMakeLists.txt @@ -1,2 +1,4 @@ +add_subdirectory(Dialect) add_subdirectory(TritonAMDGPUToLLVM) +add_subdirectory(TritonAMDGPUDialectToLLVM) add_subdirectory(TritonAMDGPUTransforms) diff --git a/third_party/amd/lib/Dialect/CMakeLists.txt b/third_party/amd/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..4f9163bdf --- /dev/null +++ b/third_party/amd/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonAMDGPU) diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/CMakeLists.txt b/third_party/amd/lib/Dialect/TritonAMDGPU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/CMakeLists.txt b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..f550b6e20 --- /dev/null +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_triton_library(TritonAMDGPUIR + Dialect.cpp + + DEPENDS + TritonAMDGPUTableGen + TritonAMDGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp new file mode 100644 index 000000000..a82a77e9f --- /dev/null +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "Dialect/TritonAMDGPU/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::amdgpu; + +void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc" diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt new file mode 100644 index 000000000..e6da8f287 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -0,0 +1,6 @@ +add_triton_library(TritonAMDGPUDialectToLLVM + TritonAMDGPUToLLVMPatterns.cpp + + DEPENDS + TritonAMDGPUIR +) diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp new file mode 100644 index 000000000..5d172fea9 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -0,0 +1,9 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +namespace mlir::triton::AMD { +void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit) { + // TODO: Insert TrtionAMDGPU dialect patterns. +} +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp new file mode 100644 index 000000000..37bdb8fe9 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -0,0 +1,175 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "BufferOpsEmitter.h" + +using mlir::triton::gpu::appendOrGetExternFuncOp; +using mlir::triton::gpu::getFunctionType; +using namespace triton::AMD; + +namespace { + +// Utility function to determine if a scalar/tensor value is zero +bool isZero(Value v) { + if (auto constantOp = v.getDefiningOp()) { + if (auto attr = dyn_cast(constantOp.getValue())) + return attr.getValue().isZero(); + if (auto attr = dyn_cast(constantOp.getValue())) + return attr.getValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + } + return false; +} +} // namespace + +namespace mlir::LLVM::AMD { +BufferEmitter::BufferEmitter(RewriterBase &rw, Location loc, TargetInfo ti) + : rewriter(rw), loc(loc), targetInfo(ti) {} + +Value BufferEmitter::createResourceDescriptor(Value basePtr) { + // 1. Create the resource descriptor + // bits 0-11: dst sel, ignored by these intrinsics + // bits 12-14: data format (ignored, must be nonzero, 7=float) + // bits 15-18: data format (ignored, must be nonzero, 4=32bit) + // bit 19: In nested heap (0 here) + // bit 20: Behavior on unmap (0 means "return 0 / ignore") + // bits 21-22: Index stride for swizzles (N/A) + // bit 23: Add thread ID (0) + // bit 24: Reserved to 1 (RDNA) or 0 (CDNA) + // bits 25-26: Reserved (0) + // bit 27: Buffer is non-volatile (CDNA only) + // bits 28-29: Out of bounds select (RDNA only) + // (0 = structured, + // 1 = check index, + // 2 = none, + // 3 = either swizzles or testing against offset field) + // bits 30-31: Type (must be 0) + uint32_t flags = (7 << 12) | (4 << 15); + if (targetInfo.getISAFamily() == ISAFamily::RDNA2 || + targetInfo.getISAFamily() == ISAFamily::RDNA3) { + flags |= (1 << 24); + uint32_t oob = 3; + flags |= (oob << 28); + } + Value stride = int_val(16, 0); + Value flagsConst = int_val(32, flags); + Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); + Value numRecordsByte = int_val(32, std::numeric_limits::max() - 1); + + Value resource = rewriter.createOrFold( + loc, rsrcType, basePtr, stride, numRecordsByte, flagsConst); + return resource; +} + +Value BufferEmitter::emitLoad(Type type, Value rsrcDesc, Value offset, + Value pred, Value falseVal) { + SmallVector args; + fillCommonArgs(type, rsrcDesc, offset, pred, args); + Type bufferType = getBufferOpType(type); + Value data = rewriter.create( + loc, bufferType, args, ArrayRef()); + data = bitcast(data, type); + if (!isZero(falseVal)) + data = select(pred, data, falseVal); + return data; +} + +void BufferEmitter::emitStore(Value rsrcDesc, Value offset, Value data, + Value pred) { + VectorType vecTy = cast(data.getType()); + Type bufferType = getBufferOpType(vecTy); + if (vecTy != bufferType) + data = bitcast(data, bufferType); + SmallVector args{data}; + fillCommonArgs(vecTy, rsrcDesc, offset, pred, args); + rewriter.create(loc, TypeRange{}, args, + ArrayRef()); +} + +Type BufferEmitter::getBufferOpType(Type type) { + int64_t vecSize = 1; + Type elementType = type; + if (auto vecType = dyn_cast(type)) { + vecSize = vecType.getNumElements(); + elementType = vecType.getElementType(); + } + + const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const size_t totalWidthBits = valueElemNBits * vecSize; + + // For bf16, always convert to i16 + Type bufferElementType = elementType; + if (elementType.isBF16()) + bufferElementType = rewriter.getI16Type(); + + // If we are dealing with a subword type (e.g., i8 or f16) but we + // still need multiple words, then pack the subwords into 32bit integers + // and update the vector length and the type + int64_t bufferVecSize = vecSize; + if (valueElemNBits < 32) { + if (totalWidthBits > 32) { + bufferElementType = rewriter.getI32Type(); + bufferVecSize = totalWidthBits / 32; + } else { + bufferElementType = rewriter.getIntegerType(totalWidthBits); + bufferVecSize = 1; + } + } + + // This is the buffer type that the buffer operation will use. It + // will be bitcast-able to the original type. So if the types + // ended up different, we simply have to emit a `bitcastOp` to convert + Type bufferType = type; + if (bufferVecSize != vecSize || bufferElementType != elementType) + bufferType = VectorType::get(bufferVecSize, bufferElementType); + if (bufferVecSize == 1) + bufferType = getElementTypeOrSelf(bufferType); + + return bufferType; +} + +void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc, + Value vOffsetElems, Value pred, + SmallVector &args) { + + // 1. Create the (masked) offset + Type elementType = getElementTypeOrSelf(type); + const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const int elementByteWidth = valueElemNBits / 8; + // Please note: the index passed is not in bytes, but in number of elements + // In order to pass the index to the buffer operation, we need to convert in + // bytes (i.e., we need to multiply by `elementByteWidth`) + Value vOffsetOutOfBunds = int_val( + 32, static_cast(std::numeric_limits::max() + int64_t(1))); + Value vOffsetBytes = mul(int_val(32, elementByteWidth), vOffsetElems); + Value maskedOffsetBytes = select(pred, vOffsetBytes, vOffsetOutOfBunds); + + // 2. Set the sgprOffset to 0 + Value sgprOffset = int_val(32, 0); + + // 3. Create the cache modifiers word + // bit 0: GLC = 0 (atomics drop value, less coherency) + // bits 1-2: SLC, DLC = 0 (similarly) + // bit 3: swizzled (0 for raw) + Value cacheModifiers = int_val(32, 0); + + // 5. Add the arguments + args.push_back(rsrcDesc); + args.push_back(maskedOffsetBytes); + args.push_back(sgprOffset); + args.push_back(cacheModifiers); +} +} // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h new file mode 100644 index 000000000..ad6d46ff7 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h @@ -0,0 +1,93 @@ +#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_BUFFER_OPS_EMITTER_H +#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_BUFFER_OPS_EMITTER_H + +#include "TargetInfo.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir::LLVM::AMD { +// Utility class to take care of buffer operation emission. We may add more +// emitters into this as needed. Buffer operations accept a memory descriptor +// and an offset. +// +// The memory descriptor is stored in s_gprs and hence needs to +// be uniform across the wave. It contains two fields (among many others): +// +// - `base_pointer`: represents the (scalar) pointer to the memory area +// - `num_records`: represents the size of the memory region. This is a +// 32 bit unsigned integer +// +// The offset can be non-uniform across the wave (and hence stored in vgprs). +// +// The high level behaviour of a buffer operation can be described as: +// ``` +// def buffer_op(mem_desc, offset): +// address = splat(mem_desc.base_pointer) +// address += offset +// return buffer_op(address) +// ``` +// This means we don't need to store the addresses in vgprs and we need less +// VALU operations to compute the final address. +// +// Also note that buffer operations support out-of-boundary memory access. +// I.e., if offset[i] > mem_desc.num_records the operation is a nop for the i-th +// thread. +// +// This can be exploited to support masked operations, like in the following +// snippet: +// ``` +// def masked_op(base_ptr, offset, pred) +// mem_desc.base_ptr = base_ptr +// mem_desc.num_records = max_int_32 +// oob_offset = max_int_32+1 +// masked_offset = (pred ? offset : oob_offset) +// buffer_op(mem_desc, masked_offset) +// ``` +// To use buffer operations three main requirements need to be met: +// +// 1. The buffer pointer needs to be a scalar, it cannot be non-uniform across +// threads of the given wave +// 2. The offset needs to be expressed in 32 bits +// 3. The offset needs to be non-negative +// +// Failure to meet 1) will result in a scalarized loop (very poor performance). +// Failure to meet 2) and 3) will result in incorrect memory access. +struct BufferEmitter { + BufferEmitter(RewriterBase &rw, Location loc, + mlir::triton::AMD::TargetInfo ti); + + // Create a resource descriptor that points to the area of memory we want to + // load from + Value createResourceDescriptor(Value basePtr); + + // Emit a predicated rocdl.raw.ptr.buffer.load + Value emitLoad(Type type, Value rsrcDesc, Value offset, Value pred, + Value falseVal); + + // Emit a predicated rocdl.raw.ptr.buffer.store + void emitStore(Value rsrcDesc, Value offset, Value data, Value pred); + +private: + // Fill common buffer operation arguments. + void fillCommonArgs(Type type, Value rsrcDesc, Value vOffsetElems, Value pred, + SmallVector &args); + + // Given a type, the buffer type can be either the same type + // or a packed version. E.g., a vector of 8xfp16 can be bitcasted to + // a vector of 4xi32. This usually makes the life of the backend easier + Type getBufferOpType(Type type); + + // Rewriter utilities + RewriterBase &rewriter; + Location loc; + mlir::triton::AMD::TargetInfo targetInfo; +}; + +} // namespace mlir::LLVM::AMD + +#endif // TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_BUFFER_OPS_EMITTER_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp index 73af042e0..409d14774 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -4,6 +4,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" namespace mlir { namespace triton { @@ -16,15 +17,14 @@ using namespace mlir; namespace { -class CallOpConversion : public mlir::RewritePattern { +class CallOpConversion : public OpRewritePattern { public: - CallOpConversion(mlir::MLIRContext *context) - : mlir::RewritePattern(LLVM::CallOp::getOperationName(), 1, context) {} + CallOpConversion(mlir::MLIRContext *context, bool ftz) + : OpRewritePattern(context, 1), ftz(ftz) {} LogicalResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(LLVM::CallOp callOp, mlir::PatternRewriter &rewriter) const override { - auto callOp = cast(op); if (isPredicatedLoad(callOp)) { return convertPredicatedLoad(callOp, rewriter); } else if (isPredicatedStore(callOp)) { @@ -38,13 +38,42 @@ class CallOpConversion : public mlir::RewritePattern { private: bool isPredicatedLoad(LLVM::CallOp callOp) const { - return callOp.getCallee().value().find(mlir::LLVM::AMD::Predicated_Load) != - llvm::StringRef::npos; + return callOp.getCallee().value().contains(mlir::LLVM::AMD::predicatedLoad); + } + + bool isPredicatedLoadCA(LLVM::CallOp callOp) const { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedLoadCA); + } + + bool isPredicatedLoadCG(LLVM::CallOp callOp) const { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedLoadCG); + } + + bool isPredicatedLoadCV(LLVM::CallOp callOp) const { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedLoadCV); } bool isPredicatedStore(LLVM::CallOp callOp) const { - return callOp.getCallee().value().find(mlir::LLVM::AMD::Predicated_Store) != - llvm::StringRef::npos; + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedStore); + } + + bool isPredicatedStoreCS(LLVM::CallOp callOp) const { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedStoreCS); + } + + bool isPredicatedStoreCG(LLVM::CallOp callOp) const { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedStoreCG); + } + + bool isPredicatedStoreWT(LLVM::CallOp callOp) const { + return callOp.getCallee().value().contains( + mlir::LLVM::AMD::predicatedStoreWT); } bool isWrappedLLVMIntrinsic(LLVM::CallOp callOp) const { @@ -72,7 +101,16 @@ class CallOpConversion : public mlir::RewritePattern { rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, pred, trueBlock, afterStore); rewriter.setInsertionPointToStart(trueBlock); - auto storeOp = rewriter.create(loc, val, ptr); + /* + | vialatile | non-tmp | gcn instr gfx94 + LLVM::StoreOp | 0 | 0 | (cg) global store + | 0 | 1 | (cs) global store nt + | 1 | 0/1 | (wt) global store sc0 sc1 + */ + bool vialatileFlag = isPredicatedStoreWT(callOp); + bool nonTmpFlag = isPredicatedStoreCS(callOp); + auto storeOp = rewriter.create( + loc, val, ptr, /*alignment=*/0, vialatileFlag, nonTmpFlag); rewriter.create(loc, afterStore); rewriter.setInsertionPointToStart(afterStore); rewriter.eraseOp(callOp); @@ -100,7 +138,16 @@ class CallOpConversion : public mlir::RewritePattern { rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, pred, trueBlock, falseBlock); rewriter.setInsertionPointToStart(trueBlock); - auto loadOp = rewriter.create(loc, elemTy, ptr); + /* + | vialatile | non-tmp | gcn instr gfx94 + LLVM::LoadOp | 0 | 0 | (ca) global load + | 0/1 | 1 | (cg) global load nt + | 1 | 0 | (cv) flat load sc0 sc1 + */ + bool vialatileFlag = isPredicatedLoadCV(callOp); + bool nonTmpFlag = isPredicatedLoadCG(callOp); + auto loadOp = rewriter.create( + loc, elemTy, ptr, /*alignment=*/0, vialatileFlag, nonTmpFlag); rewriter.create(loc, loadOp->getResult(0), afterLoad); rewriter.setInsertionPointToStart(falseBlock); rewriter.create(loc, falseVal, afterLoad); @@ -117,7 +164,7 @@ class CallOpConversion : public mlir::RewritePattern { auto operands = callOp.getOperands(); auto result = callOp.getResult(); - LLVM::LLVMFunctionType calleeType = callOp.getCalleeType().value(); + LLVM::LLVMFunctionType calleeType = callOp.getCalleeFunctionType(); Type returnType = calleeType.getReturnType(); auto loc = callOp.getLoc(); @@ -140,13 +187,25 @@ class CallOpConversion : public mlir::RewritePattern { rewriter.create(loc, returnType, op->getResult(0)); } else if (calleeName == "__triton_hip_fast_fdividef") { assert(operands.size() == 2); - auto name = StringAttr::get(callOp.getContext(), "llvm.amdgcn.rcp.f32"); - LLVM::FastmathFlagsAttr defaultFlags{}; - auto rcpOp = rewriter.create( - loc, returnType, name, operands[1], defaultFlags); + const char *intrinsic = "llvm.amdgcn.rcp.f32"; + auto rcpOp = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, + returnType, operands[1]); + LLVM::FastmathFlagsAttr defaultFlags{}; replacementOp = rewriter.create( loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags); + } else if (calleeName == "__triton_hip_fast_expf") { + assert(operands.size() == 1); + assert(operands[0].getType().getIntOrFloatBitWidth() == 32); + const double log2e = 1.4426950408889634; + LLVM::FastmathFlagsAttr defaultFlags{}; + auto mulOp = rewriter.create( + loc, rewriter.getF32Type(), operands[0], + LLVM::createConstantF32(loc, rewriter, log2e), defaultFlags); + const char *intrinsic = ftz ? "llvm.amdgcn.exp2.f32" : "llvm.exp2.f32"; + + replacementOp = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, intrinsic, returnType, mulOp->getResult(0)); } if (replacementOp) { @@ -156,23 +215,25 @@ class CallOpConversion : public mlir::RewritePattern { return mlir::failure(); } + +private: + bool ftz; }; struct ConvertBuiltinFuncToLLVM : public triton::impl::ConvertBuiltinFuncToLLVMBase< ConvertBuiltinFuncToLLVM> { + explicit ConvertBuiltinFuncToLLVM(bool ftz) { this->ftz = ftz; } + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); - // Disable block merging because of: - // https://github.com/llvm/llvm-project/issues/63230 - // TODO(giuseros): enable block merging once the above ticket is completed GreedyRewriteConfig config; - config.enableRegionSimplification = false; + config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context, this->ftz); if (mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns), config) .failed()) { @@ -186,8 +247,9 @@ struct ConvertBuiltinFuncToLLVM namespace mlir { namespace triton { -std::unique_ptr> createConvertBuiltinFuncToLLVMPass() { - return std::make_unique(); +std::unique_ptr> +createConvertBuiltinFuncToLLVMPass(bool ftz) { + return std::make_unique(ftz); } } // namespace triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index 1a8719822..b6a514f45 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonAMDGPUToLLVM + BufferOpsEmitter.cpp ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -15,11 +16,15 @@ add_triton_library(TritonAMDGPUToLLVM TargetInfo.cpp TargetUtils.cpp DecomposeUnsupportedConversions.cpp + OptimizeLDSUsage.cpp + OptimizeLDSUtility.cpp SPMDOpToLLVM.cpp + SchedInstructions.cpp DEPENDS TritonAMDGPUConversionPassIncGen LINK_LIBS PUBLIC TritonGPUToLLVM + TritonAMDGPUIR ) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 953b01dab..b7ee4efc7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -50,7 +50,11 @@ struct LocalLoadOpConversion } private: - // shared -> dot_operand if the result layout is mfma + /// Lower ttg.local_load in dot operand layout if the operand parent layout is + /// MFMA or WMMA. + /// + /// \returns value with packed loaded values or empty value if this local_load + /// is not supproted. Value lowerSharedToDotOperandMMA( triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, const LLVMTypeConverter *typeConverter, @@ -104,6 +108,8 @@ struct LocalLoadOpConversion isOuter = K == 1; Value res = lowerSharedToDotOperandMMA(op, adaptor, typeConverter, rewriter, dotOperandLayout, isOuter); + if (!res) + return failure(); rewriter.replaceOp(op, res); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp index 740e106f4..03b7c56b7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp @@ -68,9 +68,10 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, const SharedMemoryObject &smemObj) { Value base = smemObj.base; Type type = base.getType(); + Type elemType = smemObj.getBaseElemType(); for (int i = 0; i < smemObj.strides.size(); ++i) { Value offset = sub(i32_val(0), mul(smemObj.offsets[i], smemObj.strides[i])); - base = gep(ptr_ty(rewriter.getContext(), 3), type, base, offset); + base = gep(type, elemType, base, offset); } return base; } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 72e11d593..b832d985b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -36,29 +36,29 @@ namespace SharedToDotOperandMFMA { * @brief This function maps particular load of mfma dot operand to element * indexes(row, col) * - * Whole tensor is broken into "blocks" of waves along "non-K" axis. - * One block could be processed by multiple waves. - * One wave works on a piece of tensor size elemsPerInstr[0] x K. + * Whole tensor is broken into "blocks" of warps along "non-K" axis. + * One block could be processed by multiple warps. + * One warp works on a piece of tensor size elemsPerInstr[0] x K. * Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x * elemsPerInstr[1]. * * Total offset of element is a sum of following values: - * 1. Offset of wave-block in tensor - * 2. Offset of wave inside one wave-block - * 3. Offset of tile in one wave + * 1. Offset of warp-block in tensor + * 2. Offset of warp inside one warp-block + * 3. Offset of tile in one warp * 4. Offset of one lane data in a tile * 5. Offset of particular element of tensor processed by one lane * * This function computes these offsets for axies independently * Note that this function returns the offsets of elements in the first - * wave-block. The offsets of elements in later wave-blocks can be computed + * warp-block. The offsets of elements in later warp-blocks can be computed * by adding a constant stride to the xor-ed offsets of elements in the - * first wave-block. + * first warp-block. * * @param rewriter * @param loc * @param elemsPerInstr operand tile shape consumed by one MFMA instruction - * @param waveId id component of 2d wave grid along non-K axis + * @param warpId id component of 2d warp grid along non-K axis * @param laneId lane id in warp [0..63] * @param numOfElems number of elements accessed by thread per repetition * @param reps number of instructions repetition to fully cover dot operand @@ -71,7 +71,7 @@ namespace SharedToDotOperandMFMA { */ llvm::SmallVector> computeTensorElemMappingInBlock( ConversionPatternRewriter &rewriter, Location loc, - const ArrayRef &elemsPerInstr, Value waveId, Value laneId, + const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, int loadVecSize, unsigned iNonKDim, unsigned iKDim) { auto numM = reps[1]; @@ -82,7 +82,7 @@ llvm::SmallVector> computeTensorElemMappingInBlock( Value _0 = i32_val(0); Value _32 = i32_val(32); Value nonKDim = i32_val(iNonKDim); - Value waveVOffset = mul(waveId, i32_val(elemsPerInstr[0])); + Value warpVOffset = mul(warpId, i32_val(elemsPerInstr[0])); auto rank = smemOffsets.size(); @@ -95,12 +95,12 @@ llvm::SmallVector> computeTensorElemMappingInBlock( if (iNonKDim == 32) laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0); else { - // In this configuration wave contains 16 copies of same data + // In this configuration warp contains 16 copies of same data if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) { laneHOffset = i32_val(0); } else { assert(iKDim * iNonKDim / numOfElems == 64 && - "seems no all threads in wave contain unique elements"); + "seems no all threads in warp contain unique elements"); laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems)); } } @@ -110,7 +110,7 @@ llvm::SmallVector> computeTensorElemMappingInBlock( Value elemHOffset = i32_val(loadId * loadVecSize); Value sliceVOffset = - add(add(add(tileVOffset, laneVOffset), elemVOffset), waveVOffset); + add(add(add(tileVOffset, laneVOffset), elemVOffset), warpVOffset); Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset); Value row = add(sliceVOffset, smemOffsets[rank - 2]); @@ -131,7 +131,7 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) { // @param loc // @param elemsPerInstr operand tile shape [K, nonK] consumed by one MFMA // instruction -// @param waveId wave id for the "non K" axis +// @param warpId warp id for the "non K" axis // @param laneId lane id in warp [0..63] // @param warpsPerBlock number of warps per horizontal axis // @param numOfElems number of elements accessed by threads per repetition @@ -139,7 +139,7 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) { // @param cSwizzleOffset llvm::SmallVector fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, - const ArrayRef &elemsPerInstr, Value waveId, + const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int warpsPerBlock, int numOfElems, ArrayRef reps, Value cSwizzleOffset) { auto numK = reps[1]; @@ -150,7 +150,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, auto iNonKDim = elemsPerInstr[1]; int lineSize = warpsPerBlock * iNonKDim * numN; Value _nonKDim = i32_val(iNonKDim); - Value waveOffset = mul(waveId, i32_val(iNonKDim)); + Value warpOffset = mul(warpId, i32_val(iNonKDim)); Value colOffset = urem(laneId, _nonKDim); for (int block = 0; block < numN; ++block) { @@ -158,15 +158,15 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, for (int tile = 0; tile < numK; ++tile) { Value tileOffset = i32_val(tile * iKDim * lineSize); for (int elem = 0; elem < numOfElems; ++elem) { - // halfOffset is an offset related to wrapping of wave in the tile. + // halfOffset is an offset related to wrapping of warp in the tile. // for example, mfma 32 case (mapping of tensor elements to lane ids in - // wave): + // warp): // // 0 1 2 3 ... 31 // 0 1 2 3 ... 31 // 0 1 2 3 ... 31 // 0 1 2 3 ... 31 - // 32 33 34 35 ... 63 <- at this point wave is wrapping + // 32 33 34 35 ... 63 <- at this point warp is wrapping // 32 33 34 35 ... 63 // 32 33 34 35 ... 63 // 32 33 34 35 ... 63 @@ -179,7 +179,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, Value rowOffset = add(i32_val(elem * lineSize), halfOffset); Value elemOffset = add(rowOffset, colOffset); Value offset = - add(add(add(waveOffset, blockOffset), tileOffset), elemOffset); + add(add(add(warpOffset, blockOffset), tileOffset), elemOffset); offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset; } } @@ -217,12 +217,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto elemTy = aTensorTy.getElementType(); auto kWidth = encoding.getKWidth(); - auto elemsPerInstr = mfmaLayout.getMFMAInstrShapeForOperands(kWidth, opIdx); + auto elemsPerInstr = mfmaLayout.getInstrShapeForOperand(kWidth, opIdx); int64_t mfmaInstrNonK; int64_t mfmaInstrK; // TODO(Lixun): make it simpler - // getMFMAInstrShapeForOperands always returns a 2D vector + // getInstrShapeForOperand always returns a 2D vector if (rank == 3) { mfmaInstrNonK = elemsPerInstr[nonKDimIdx - 1]; mfmaInstrK = elemsPerInstr[kDimIdx - 1]; @@ -231,41 +231,47 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, mfmaInstrK = elemsPerInstr[kDimIdx]; } - auto numReps = mfmaLayout.getMFMARepForOperands(shape, kWidth, opIdx); + if (mfmaInstrNonK > shape[nonKDimIdx] || mfmaInstrK > shape[kDimIdx]) { + // This pattern does not support cases tensor shape is smaller than + // one instruction size, it will be processed by LinearLayout converter + return Value(); + } + + auto numReps = mfmaLayout.getRepForOperand(shape, kWidth, opIdx); auto numRepNonK = numReps[nonKDimIdx]; auto numRepK = numReps[kDimIdx]; auto repB = numReps[0]; // TODO(Lixun): make it simpler - // getMFMARepForOperands always returns a 3D vector + // getRepForOperand always returns a 3D vector if (rank == 2) { numRepNonK = numReps[nonKDimIdx + 1]; numRepK = numReps[kDimIdx + 1]; } - unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout); - assert(iWaveSize == 64); - Value waveSize = i32_val(iWaveSize); - Value linearWaveId = udiv(thread, waveSize); - Value lane = urem(thread, waveSize); + unsigned iWarpSize = triton::gpu::getWarpSize(mfmaLayout); + assert(iWarpSize == 64); + Value warpSize = i32_val(iWarpSize); + Value linearWarpId = udiv(thread, warpSize); + Value lane = urem(thread, warpSize); - Value spatialWaveId = AMD::getWarpIdInBlock( - rewriter, loc, linearWaveId, warpsPerCTA, mfmaInstrNonK, + Value spatialWarpId = AMD::getWarpIdInBlock( + rewriter, loc, linearWarpId, warpsPerCTA, mfmaInstrNonK, shape[nonKDimIdx], nonKDimIdx, triton::gpu::getOrder(mfmaLayout)); - // number of duplicates of elements in wave + // number of duplicates of elements in warp // In case of 64x4 x 4x4 multiplication, 4x4 B operand is duplicated 16 times int numSubBlocks = 1; if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4) numSubBlocks = 16; // numOfElemsPerThreadPerMfmaInstr - int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize; + int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWarpSize; assert(numOfElems >= 1); unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK; int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); int warpsPerBatch = rank == 3 ? std::min(shape[0], warpsPerCTA[0]) : 1; - Value waveIdInBatch = urem(linearWaveId, i32_val(warpsPerBatch)); + Value warpIdInBatch = urem(linearWarpId, i32_val(warpsPerBatch)); elemTy = typeConverter->convertType(elemTy); SmallVector loadedValues; @@ -284,7 +290,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, SmallVector elemsPerInstr{mfmaInstrK, mfmaInstrNonK}; SmallVector reps{numReps[0], numReps[2], numReps[1]}; offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, - spatialWaveId, lane, warpsPerBlockNonK, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, reps, cSwizzleOffset); } else { llvm_unreachable( @@ -296,7 +302,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, "col major operand B should be handled in the normal path"); } else { offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, - spatialWaveId, lane, warpsPerBlockNonK, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, numReps, cSwizzleOffset); } } @@ -311,13 +317,13 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, if (opIdx == 0) { offsets = AMD::computeOffsetsAType( rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, - spatialWaveId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, sharedLayout, mDim, mfmaInstrK); } else { assert(opIdx == 1); offsets = AMD::computeOffsetsBType( rewriter, loc, computeTensorElemMappingInBlock, elemsPerInstr, - spatialWaveId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, + spatialWarpId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, sharedLayout, nDim, mfmaInstrK); } smemBase = AMD::computeBasePtr(rewriter, loc, smemObj); @@ -333,10 +339,10 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, for (int b = 0; b < repB; ++b) { int operandSize = shape[rank - 1] * shape[rank - 2]; Value batchOffset = mul(i32_val(operandSize), - add(waveIdInBatch, i32_val(b * warpsPerBatch))); + add(warpIdInBatch, i32_val(b * warpsPerBatch))); for (int nonK = 0; nonK < numRepNonK; ++nonK) { int blockNonKOffset = nonK * mfmaInstrNonK * warpsPerBlockNonK; - Value waveBlockOffAdjust = i32_val(blockNonKOffset * shape[order[0]]); + Value warpBlockOffAdjust = i32_val(blockNonKOffset * shape[order[0]]); for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numOfElems); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index 950e2926a..b60c86e1a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -33,8 +33,9 @@ using ::mlir::triton::gpu::SharedEncodingAttr; namespace SharedToDotOperandWMMA { /** - * @brief This function maps particular load of wmma dot operand to element - * indexes(row, col) + * @brief Following functions maps particular load of wmma dot operand to + * element indexes(row, col). For each WMMA generation separate function is + * used. * * Whole tensor is broken into "blocks" of warps along "non-K" axis. * One block could be processed by multiple warps. @@ -64,7 +65,8 @@ namespace SharedToDotOperandWMMA { * @return vector (i-th element corresponds to i-th load instruction) of * 2-element vectors(tensor row and col). */ -llvm::SmallVector> computeTensorElemMappingInBlock( +llvm::SmallVector> +computeTensorElemMappingInBlockWmma1( ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int numOfElems, ArrayRef reps, ArrayRef smemOffsets, @@ -75,28 +77,55 @@ llvm::SmallVector> computeTensorElemMappingInBlock( const int loadsPerThread = numOfElems / loadVecSize; llvm::SmallVector> mapping(numK * loadsPerThread); - Value _0 = i32_val(0); - Value nonKDim = i32_val(iNonKDim); - Value warpVOffset = mul(warpId, i32_val(elemsPerInstr[0])); - + Value elemsPerInstrV = i32_val(elemsPerInstr[0]); + Value warpVOffset = mul(warpId, elemsPerInstrV); + Value sliceVOffset = add(urem(laneId, elemsPerInstrV), warpVOffset); auto rank = smemOffsets.size(); + Value row = add(sliceVOffset, smemOffsets[rank - 2]); for (int tile = 0; tile < numK; ++tile) { - Value tileVOffset = _0; Value tileHOffset = i32_val(tile * elemsPerInstr[1]); - Value laneVOffset = laneId; - Value laneHOffset = _0; - for (int loadId = 0; loadId < loadsPerThread; ++loadId) { - Value elemVOffset = _0; Value elemHOffset = i32_val(loadId * loadVecSize); + Value sliceHOffset = add(tileHOffset, elemHOffset); + + Value col = add(sliceHOffset, smemOffsets[rank - 1]); + mapping[loadsPerThread * tile + loadId] = {row, col}; + } + } + + return mapping; +} + +llvm::SmallVector> +computeTensorElemMappingInBlockWmma2( + ConversionPatternRewriter &rewriter, Location loc, + const ArrayRef &elemsPerInstr, Value warpId, Value laneId, + int numOfElems, ArrayRef reps, ArrayRef smemOffsets, + int loadVecSize, unsigned iNonKDim, [[maybe_unused]] unsigned iKDim) { + assert(reps.size() == 3); + assert(elemsPerInstr.size() == 2); + auto numK = reps[2]; + const int loadsPerThread = numOfElems / loadVecSize; + llvm::SmallVector> mapping(numK * loadsPerThread); - Value sliceVOffset = - add(add(add(tileVOffset, laneVOffset), elemVOffset), warpVOffset); - Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset); + Value rowsPerInstr = i32_val(elemsPerInstr[0]); + Value colsPerInstr = i32_val(elemsPerInstr[1]); + Value elemsPerThread = i32_val(elemsPerInstr[1] / 2); + Value warpVOffset = mul(warpId, rowsPerInstr); + Value sliceVOffset = add(urem(laneId, rowsPerInstr), warpVOffset); + + auto rank = smemOffsets.size(); + Value row = add(sliceVOffset, smemOffsets[rank - 2]); + Value laneHOffset = mul(udiv(laneId, colsPerInstr), elemsPerThread); + + for (int tile = 0; tile < numK; ++tile) { + Value tileHOffset = add(laneHOffset, i32_val(tile * elemsPerInstr[1])); + for (int loadId = 0; loadId < loadsPerThread; ++loadId) { + Value elemHOffset = i32_val(loadId * loadVecSize); + Value sliceHOffset = add(tileHOffset, elemHOffset); - Value row = add(sliceVOffset, smemOffsets[rank - 2]); Value col = add(sliceHOffset, smemOffsets[rank - 1]); mapping[loadsPerThread * tile + loadId] = {row, col}; @@ -116,7 +145,10 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1; auto wmmaLayout = cast(encoding.getParent()); - assert(wmmaLayout.getMNKDimPerWMMAInstr()[nonKDimIdx] == 16); + auto computeTensorElemMappingInBlock = + wmmaLayout.getVersion() == 1 ? computeTensorElemMappingInBlockWmma1 + : computeTensorElemMappingInBlockWmma2; + assert(wmmaLayout.getMNKDimPerInstr()[nonKDimIdx] == 16); auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); auto aTensorTy = cast(tensor.getType()); @@ -128,27 +160,25 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto elemTy = aTensorTy.getElementType(); int kWidth = encoding.getKWidth(); - auto elemsPerInstr = wmmaLayout.getWMMAElemsPerInstrForOperands(); + auto elemsPerInstr = wmmaLayout.getElemsPerInstrForOperands(); auto wmmaInstrK = elemsPerInstr[opIdx == 0 ? 1 : 0]; auto wmmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1]; assert(wmmaInstrNonK == 16); - auto numReps = wmmaLayout.getWMMARepForOperands(shape, elemTy, kWidth, opIdx); + auto numReps = wmmaLayout.getRepForOperand(shape, elemTy, kWidth, opIdx); auto numRepNonK = numReps[opIdx == 0 ? 1 : 2]; auto numRepK = numReps[opIdx == 0 ? 2 : 1]; auto repB = numReps[0]; unsigned iWaveSize = triton::gpu::getWarpSize(wmmaLayout); - unsigned iNumLanes = iWaveSize / 2; assert(iWaveSize == 32); Value waveSize = i32_val(iWaveSize); - Value numLanes = i32_val(iNumLanes); Value linearWaveId = udiv(thread, waveSize); - Value lane = urem(thread, numLanes); // share elem between two threads - unsigned numElemsPerThreadPerRep = wmmaInstrK; + unsigned numElemsPerThreadPerRep = + wmmaLayout.getSizePerThreadForOperand(kWidth, opIdx)[kDimIdx]; - Value warp = udiv(thread, waveSize); + Value lane = urem(thread, waveSize); unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK; int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); int warpsPerBatch = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index 68ebe9499..cece47227 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -1,3 +1,4 @@ +#include "OptimizeLDSUtility.h" #include "TargetInfo.h" #include "TritonAMDGPUToLLVM/Passes.h" #include "mlir/Pass/Pass.h" @@ -19,113 +20,6 @@ namespace triton { namespace { -constexpr int kPtrBitWidth = 64; - -static void addAttrs(Operation *op, ArrayRef attrs) { - for (const NamedAttribute attr : attrs) - op->setAttr(attr.getName(), attr.getValue()); -} - -static void promoteReduceOpResult(OpBuilder &builder, triton::ReduceOp op, - Value result, Type promotedType) { - // save original type - auto originalType = result.getType(); - auto elemType = isa(originalType) - ? cast(originalType).getElementType() - : originalType; - - // promote result type - result.setType(promotedType); - - // set insertion point after reduce op - builder.setInsertionPointAfter(op); - - // truncate result back to original type - mlir::Operation *truncResult = nullptr; - if (elemType.isInteger(16) || elemType.isInteger(8)) { - truncResult = builder.create(result.getLoc(), - originalType, result); - } else if (elemType.isF16()) { - truncResult = builder.create(result.getLoc(), - originalType, result); - } - - // replace all uses except for the truncOp above - if (truncResult != nullptr) { - result.replaceAllUsesWith(truncResult->getResult(0)); - truncResult->setOperand(0, result); - } -} - -static int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) { - unsigned inVec = 0; - unsigned outVec = 0; - auto smemShape = triton::getScratchConfigForCvtLayout(cvtOp, inVec, outVec); - unsigned elems = - std::accumulate(smemShape.begin(), smemShape.end(), 1, std::multiplies{}); - auto srcType = cvtOp.getSrc().getType(); - auto bytes = - isa(srcType.getElementType()) - ? elems * kPtrBitWidth / 8 - : elems * std::max(8, srcType.getElementTypeBitWidth()) / 8; - - return bytes; -} - -bool isPowerOfTwo(unsigned x) { return x && (x & (x - 1)) == 0; } - -static std::vector> factorizePowerOf2(int n) { - assert(isPowerOfTwo(n)); - int x = log2(n); - std::vector> pairs; - - for (int i = 0; i <= x / 2; ++i) { - int j = x - i; - pairs.push_back({pow(2, i), pow(2, j)}); - pairs.push_back({pow(2, j), pow(2, i)}); - } - - return pairs; -} - -static std::pair -createNewConvertOps(ModuleOp &mod, OpBuilder &builder, - triton::gpu::ConvertLayoutOp &cvtOp, - std::pair warpsPerCta) { - unsigned warpsPerCtaX = warpsPerCta.first; - unsigned warpsPerCtaY = warpsPerCta.second; - auto srcType = cvtOp.getSrc().getType(); - auto dstType = cvtOp.getType(); - - auto newDstType = RankedTensorType::get( - dstType.getShape(), dstType.getElementType(), dstType.getEncoding()); - RankedTensorType newSrcType; - if (auto srcMfma = - dyn_cast(srcType.getEncoding())) { - auto newMfmaEnc = triton::gpu::AMDMfmaEncodingAttr::get( - mod.getContext(), srcMfma.getVersionMajor(), srcMfma.getVersionMinor(), - {warpsPerCtaX, warpsPerCtaY}, srcMfma.getMDim(), srcMfma.getNDim(), - srcMfma.getIsTransposed(), srcMfma.getCTALayout()); - - newSrcType = RankedTensorType::get(srcType.getShape(), - srcType.getElementType(), newMfmaEnc); - } else if (auto srcWmma = dyn_cast( - srcType.getEncoding())) { - auto newWmmaEnc = triton::gpu::AMDWmmaEncodingAttr::get( - mod.getContext(), {warpsPerCtaX, warpsPerCtaY}, srcWmma.getCTALayout()); - - newSrcType = RankedTensorType::get(srcType.getShape(), - srcType.getElementType(), newWmmaEnc); - } - - auto tmpCvt = builder.create( - cvtOp.getLoc(), newSrcType, cvtOp.getSrc()); - auto newEpilogueCvt = builder.create( - cvtOp.getLoc(), newDstType, tmpCvt); - - return std::make_pair(tmpCvt, newEpilogueCvt); -} - struct DecomposeUnsupportedAMDConversions : public mlir::triton::impl::DecomposeUnsupportedAMDConversionsBase< DecomposeUnsupportedAMDConversions> { @@ -144,8 +38,8 @@ struct DecomposeUnsupportedAMDConversions triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - triton::gpu::decomposeTensorCoreToDotLayoutConversion< - triton::gpu::AMDMfmaEncodingAttr>(mod, isMfmaToDotShortcut); + triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, + isMfmaToDotShortcut); /* -------------------------------- */ // Replace `wmma -> dot_op` with `wmma -> blocked -> dot_op` @@ -205,143 +99,52 @@ struct DecomposeUnsupportedAMDConversions return; } - auto currLDSUsage = getCvtOpLDSUsage(cvtOp); + auto currLDSUsage = triton::AMD::getCvtOpLDSUsage(cvtOp); if (currLDSUsage <= sharedMemoryLimit) { return; } unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc); - triton::gpu::ConvertLayoutOp tmpCvt; - triton::gpu::ConvertLayoutOp newEpilogueCvt; - // Find all possible shapes of WarpsPerCTA by finding all possible // factorizations of numWarps. Pick shape for which both conversions in - // decomposition use LDS less than limit and for which sum of LDS usage - // is minimal. If no such shape exists, do not decompose. + // decomposition use LDS less than sharedMemoryLimit and for which sum of + // LDS usage is minimal. If no such shape exists, do not decompose. unsigned minLDSUsage = 2 * sharedMemoryLimit; int minIdx = -1; - auto factorizedNumWarps = factorizePowerOf2(numWarps); + int rank = dstBlocked.getWarpsPerCTA().size(); + auto factorizedNumWarps = + mlir::triton::AMD::factorizePowerOf2(numWarps, rank); + SmallVector tmpLayouts; for (int i = 0; i < factorizedNumWarps.size(); i++) { - auto warpsPerCTAPair = factorizedNumWarps[i]; - std::tie(tmpCvt, newEpilogueCvt) = - createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); + auto warpsPerCTA = factorizedNumWarps[i]; + tmpLayouts.push_back( + mlir::triton::AMD::createTmpLayout(srcEnc, warpsPerCTA)); + } - int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt); - int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt); - if (tmpCvtLDS <= sharedMemoryLimit && newCvtLDS <= sharedMemoryLimit) { - int LDSUsage = tmpCvtLDS + newCvtLDS; - if (LDSUsage < minLDSUsage) { - minLDSUsage = LDSUsage; - minIdx = i; - } + for (int i = 0; i < tmpLayouts.size(); i++) { + auto resources = mlir::triton::AMD::estimateResourcesForReplacement( + builder, cvtOp, tmpLayouts[i]); + if (resources.LDS <= sharedMemoryLimit && resources.LDS < minLDSUsage) { + minLDSUsage = resources.LDS; + minIdx = i; } - newEpilogueCvt.erase(); - tmpCvt.erase(); } - if (minIdx == -1) { + if (minIdx == -1 || minLDSUsage > sharedMemoryLimit) { return; } - assert(minIdx >= 0 && minIdx < factorizedNumWarps.size()); - auto warpsPerCTAPair = factorizedNumWarps[minIdx]; - std::tie(tmpCvt, newEpilogueCvt) = - createNewConvertOps(mod, builder, cvtOp, warpsPerCTAPair); + assert(minIdx >= 0 && minIdx < tmpLayouts.size()); + auto replacementCvts = mlir::triton::AMD::createNewConvertOps( + builder, cvtOp, tmpLayouts[minIdx]); - cvtOp.replaceAllUsesWith(newEpilogueCvt.getResult()); + cvtOp.replaceAllUsesWith(replacementCvts.second.getResult()); cvtOp.erase(); }); triton::gpu::decomposeBlockedToDotLayoutConversion(mod); - - // promote reduce ops - mod.walk([&](triton::ReduceOp op) -> void { - OpBuilder builder(op); - - // promote operands - SmallVector newOperands; - for (OpOperand &operand : op->getOpOperands()) { - auto val = operand.get(); - auto oldType = cast(val.getType()); - auto elemType = oldType.getElementType(); - if (elemType.isInteger(16) || elemType.isInteger(8)) { - auto newType = - oldType.cloneWith(std::nullopt, builder.getIntegerType(32)); - auto promotedVal = - builder.create(op->getLoc(), newType, val); - newOperands.push_back(promotedVal); - } else if (elemType.isF16()) { - auto newType = oldType.cloneWith(std::nullopt, builder.getF32Type()); - auto promotedVal = - builder.create(op->getLoc(), newType, val); - newOperands.push_back(promotedVal); - } else { - newOperands.push_back(val); - } - } - op->setOperands(newOperands); - - // promote results - for (Value result : op.getResults()) { - auto type = result.getType(); - if (type.isInteger(16) || type.isInteger(8)) { - promoteReduceOpResult(builder, op, result, - builder.getIntegerType(32)); - } else if (type.isF16()) { - promoteReduceOpResult(builder, op, result, builder.getF32Type()); - } else if (isa(type)) { - auto oldType = cast(type); - auto elemType = oldType.getElementType(); - if (elemType.isInteger(16) || elemType.isInteger(8)) { - promoteReduceOpResult( - builder, op, result, - oldType.cloneWith(std::nullopt, builder.getIntegerType(32))); - } else if (elemType.isF16()) { - promoteReduceOpResult( - builder, op, result, - oldType.cloneWith(std::nullopt, builder.getF32Type())); - } - } - } - - // promote combine op - for (Block &oldBlock : op.getCombineOp().getBlocks()) { - // update block args - for (auto arg : oldBlock.getArguments()) { - auto type = arg.getType(); - if (type.isInteger(16) || type.isInteger(8)) { - arg.setType(builder.getIntegerType(32)); - } else if (type.isF16()) { - arg.setType(builder.getF32Type()); - } - } - - for (Operation &oldOp : oldBlock.getOperations()) { - // update operands - for (OpOperand &operand : oldOp.getOpOperands()) { - auto val = operand.get(); - auto type = val.getType(); - if (type.isInteger(16) || type.isInteger(8)) { - val.setType(builder.getIntegerType(32)); - } else if (type.isF16()) { - val.setType(builder.getF32Type()); - } - } - - // update results - for (Value result : oldOp.getResults()) { - auto type = result.getType(); - if (type.isInteger(16) || type.isInteger(8)) { - result.setType(builder.getIntegerType(32)); - } else if (type.isF16()) { - result.setType(builder.getF32Type()); - } - } - } - } - }); } }; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index f8288024d..204d54894 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -71,7 +71,7 @@ struct DotOpMFMAConversionHelper { } int getNumSubmatrices(Type elementType, int mDim, int nDim) const { - if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) + if ((mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)) return 1; assert(mDim == nDim); switch (mDim) { @@ -98,10 +98,10 @@ struct DotOpMFMAConversionHelper { "numSubBlocks in not pow 2!"); if (numSubBlocks == 1) return acc; - constexpr int waveSize = 64; - int subBlockSize = waveSize / numSubBlocks; + constexpr int warpSize = 64; + int subBlockSize = warpSize / numSubBlocks; Value laneId = getThreadId(); - laneId = and_(laneId, i32_val(waveSize - 1)); + laneId = and_(laneId, i32_val(warpSize - 1)); auto vecTy = dyn_cast(acc.getType()); auto elemType = vecTy.getElementType(); assert(elemType.getIntOrFloatBitWidth() == 32); @@ -111,7 +111,7 @@ struct DotOpMFMAConversionHelper { accScalar[i] = extract_element(elemType, acc, i32_val(i)); if (reduceSubBlocks) { - while (subBlockSize < waveSize) { + while (subBlockSize < warpSize) { for (int i = 0; i < numScalars; ++i) { Value other_acc = shuffleXor(loc, rewriter, accScalar[i], subBlockSize); @@ -151,9 +151,9 @@ struct DotOpMFMAConversionHelper { /// @brief Zeroes out redundant values in all sub-blocks except first one /// - /// Every wave in mfma 4x4 layout holds only 4 unique values(scalar or + /// Every warp in mfma 4x4 layout holds only 4 unique values(scalar or /// vectors) in blocks of 4 consecutive threads, There are 16 copies of these - /// 4 values across all threads of the wave. Need to zero out 15 copies to use + /// 4 values across all threads of the warp. Need to zero out 15 copies to use /// accumulator between dot operations. /// @param numSubBlocks /// @param acc @@ -194,10 +194,8 @@ struct DotOpMFMAConversionHelper { int kWidth = aEncoding.getKWidth(); auto rank = aTensorTy.getShape().size(); - auto repA = - mfmaLayout.getMFMARepForOperands(aTensorTy.getShape(), kWidth, 0); - auto repB = - mfmaLayout.getMFMARepForOperands(bTensorTy.getShape(), kWidth, 1); + auto repA = mfmaLayout.getRepForOperand(aTensorTy.getShape(), kWidth, 0); + auto repB = mfmaLayout.getRepForOperand(bTensorTy.getShape(), kWidth, 1); assert(repA[2] == repB[1]); @@ -278,11 +276,18 @@ struct DotOpMFMAConversionHelper { int kpack = kWidth / kBase; SmallVector results; auto vecTy = vec_ty(type, kBase); + if (type.isBF16()) + vecTy = vec_ty(i16_ty, kBase); for (int k = 0; k < kpack; ++k) { Value vec = undef(vecTy); for (int elemId = 0; elemId < kBase; ++elemId) { auto val = extract_element(type, rawElems, i32_val(elemId + k * kBase)); - vec = insert_element(vecTy, vec, val, i32_val(elemId)); + if (type.isBF16()) { + // rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type + auto cast = bitcast(val, i16_ty); + vec = insert_element(vecTy, vec, cast, i32_val(elemId)); + } else + vec = insert_element(vecTy, vec, val, i32_val(elemId)); } if (type.getIntOrFloatBitWidth() == 8) { if (4 == kBase) @@ -329,7 +334,7 @@ struct DotOpMFMAConversionHelper { if (type.getIntOrFloatBitWidth() == 8) { vals = extractOperands(rawElems, kWidth, kBase, i8_ty); } else if (type.isBF16()) { - vals = extractOperands(rawElems, kWidth, kBase, i16_ty); + vals = extractOperands(rawElems, kWidth, kBase, bf16_ty); } else { assert(type.isF16() && "Unsupported data type"); vals = extractOperands(rawElems, kWidth, kBase, f16_ty); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 3843159fc..5a003f768 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -24,6 +24,7 @@ #include "../PatternTritonGPUOpToLLVM.h" #include "Utility.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" namespace mlir::triton::AMD { namespace { @@ -40,8 +41,8 @@ enum class WMMAInstrType : uint8_t { FP32_BF16, FP16_FP16, BF16_BF16, - INT32_IU8, - INT32_IU4, + I32_I8, + I32_I4, NOT_APPLICABLE, }; @@ -109,16 +110,16 @@ static WMMAInstrType getWMMAInstrTypeFromDot(DotOp op) { if (dElemTy.isBF16() && aElemTy.isBF16()) return WMMAInstrType::BF16_BF16; if (dElemTy.isInteger(32) && aElemTy.isInteger(8)) - return WMMAInstrType::INT32_IU8; + return WMMAInstrType::I32_I8; if (dElemTy.isInteger(32) && aElemTy.isInteger(4)) - return WMMAInstrType::INT32_IU4; + return WMMAInstrType::I32_I4; return WMMAInstrType::NOT_APPLICABLE; } -Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc, - WMMAInstrType wmmaType, Value valA, Value valB, Value valC, - Type aElType, Type bElType) { +Value generateROCDLOp(ConversionPatternRewriter &rewriter, Location loc, + WMMAInstrType wmmaType, Value valA, Value valB, + Value valC, Type aElType, Type bElType) { auto resType = valC.getType(); Value falseFlag = int_val(1, false); switch (wmmaType) { @@ -134,13 +135,13 @@ Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc, case WMMAInstrType::BF16_BF16: return rewriter.create( loc, TypeRange{resType}, ValueRange{valA, valB, valC, falseFlag}); - case WMMAInstrType::INT32_IU8: + case WMMAInstrType::I32_I8: return rewriter.create( loc, TypeRange{resType}, ValueRange{int_val(1, !aElType.isUnsignedInteger()), valA, int_val(1, !bElType.isUnsignedInteger()), valB, valC, falseFlag}); - case WMMAInstrType::INT32_IU4: + case WMMAInstrType::I32_I4: return rewriter.create( loc, TypeRange{resType}, ValueRange{int_val(1, !aElType.isUnsignedInteger()), valA, @@ -152,14 +153,100 @@ Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc, return Value(); } +std::string getTypeStr(Type ty) { + std::string scalarName; + if (ty.isF32()) { + scalarName = "f32"; + } else if (ty.isF16()) { + scalarName = "f16"; + } else if (ty.isBF16()) { + scalarName = "bf16"; + } else if (ty.isInteger(32)) { + scalarName = "i32"; + } else if (ty.isInteger(16)) { + scalarName = "i16"; + } else if (ty.isInteger(8)) { + scalarName = "iu8"; + } else if (ty.isInteger(4)) { + scalarName = "iu4"; + } else if (auto vecTy = dyn_cast(ty)) { + auto elemType = vecTy.getElementType(); + auto numElems = vecTy.getNumElements(); + scalarName = "v" + std::to_string(numElems) + getTypeStr(elemType); + } else { + llvm::report_fatal_error("WMMA data type not supported"); + } + return scalarName; +} + +StringRef getWmmaIntrinsicName(Type aElTy, Type bElTy, Type dElTy, Type valATy, + Type valCTy) { + static llvm::SmallDenseMap intrinsics; + using MapInfo = llvm::DenseMapInfo; + llvm::hash_code h = llvm::hash_combine( + MapInfo::getHashValue(aElTy), MapInfo::getHashValue(bElTy), + MapInfo::getHashValue(dElTy), MapInfo::getHashValue(valATy), + MapInfo::getHashValue(valCTy)); + if (!intrinsics.contains(h)) { + std::string name = "llvm.amdgcn.wmma."; + name += getTypeStr(dElTy); + name += ".16x16x16."; // TODO support 16x16x32 for i4 operands + name += getTypeStr(aElTy); + if (isa(aElTy) && aElTy.getIntOrFloatBitWidth() == 8) + name += '.' + getTypeStr(bElTy); + name += '.' + getTypeStr(valCTy) + "." + getTypeStr(valATy); + intrinsics[h] = name; + } + return intrinsics[h]; +} + +Value generateWMMAIntrinsic(ConversionPatternRewriter &rewriter, Location loc, + WMMAInstrType wmmaType, Value valA, Value valB, + Value valC, Type aElType, Type bElType, + Type dElType) { + auto name = getWmmaIntrinsicName(aElType, bElType, dElType, valA.getType(), + valC.getType()); + LLVM::FastmathFlagsAttr defaultFlags{}; + SmallVector operands; + if (aElType.isInteger()) + operands.push_back(int_val(1, !aElType.isUnsignedInteger())); + operands.push_back(valA); + if (bElType.isInteger()) + operands.push_back(int_val(1, !bElType.isUnsignedInteger())); + operands.push_back(valB); + operands.push_back(valC); + // Flag for using low bits in registers. Result could be already packed to + // int32. Set low bits by default for now. + if (32 / dElType.getIntOrFloatBitWidth() > 1 || dElType.isInteger(32)) { + operands.push_back(int_val(1, false)); + } + auto wmmaIntrinsic = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, name, valC.getType(), operands); + return wmmaIntrinsic.getResult(0); +} + +Value generateWMMAOp(ConversionPatternRewriter &rewriter, Location loc, + WMMAInstrType wmmaType, Value valA, Value valB, Value valC, + Type aElType, Type bElType, Type dElType, int version) { + if (version == 1) { + return generateROCDLOp(rewriter, loc, wmmaType, valA, valB, valC, aElType, + bElType); + } else { + assert(version == 2); + return generateWMMAIntrinsic(rewriter, loc, wmmaType, valA, valB, valC, + aElType, bElType, dElType); + } +} + // Conduct the Dot conversion. LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter) { auto wmmaLayout = cast( cast(op.getResult().getType()).getEncoding()); + int wmmaVer = wmmaLayout.getVersion(); auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); - auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerInstr(); auto wmmaInstrType = getWMMAInstrTypeFromDot(op); auto loc = op.getLoc(); @@ -176,9 +263,9 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, int kWidth = aEncoding.getKWidth(); auto repA = - wmmaLayout.getWMMARepForOperands(aTensorTy.getShape(), elemTy, kWidth, 0); + wmmaLayout.getRepForOperand(aTensorTy.getShape(), elemTy, kWidth, 0); auto repB = - wmmaLayout.getWMMARepForOperands(bTensorTy.getShape(), elemTy, kWidth, 1); + wmmaLayout.getRepForOperand(bTensorTy.getShape(), elemTy, kWidth, 1); assert(repA[2] == repB[1]); @@ -202,7 +289,7 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, unsigned warpSize = triton::gpu::getWarpSize(wmmaLayout); constexpr unsigned vgprElemBitWidth = 32; unsigned paddedOutputElemSize = - vgprElemBitWidth / dstElemTy.getIntOrFloatBitWidth(); + wmmaVer == 1 ? vgprElemBitWidth / dstElemTy.getIntOrFloatBitWidth() : 1; // compute number of output elements that each thread holds for one WMMA // instruction. auto elemsPerVec = mnkDim[0] * mnkDim[1] * paddedOutputElemSize / warpSize; @@ -224,7 +311,7 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, for (size_t k = 0; k < numRepK; k++) { acc = generateWMMAOp(rewriter, loc, wmmaInstrType, ha[{b, m, k}], hb[{b, n, k}], acc, aTensorTy.getElementType(), - bTensorTy.getElementType()); + bTensorTy.getElementType(), dstElemTy, wmmaVer); } for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { fc[fcThreadOffIdx + v] = extract_element( diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index dd082d25d..47d5fbb35 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -314,7 +314,7 @@ static Value convertFp32ToBf16(Location loc, auto as_int32 = bitcast(v, i32_ty); auto shifted = lshr(i32_ty, as_int32, i32_val(16)); auto truncated = trunc(i16_ty, shifted); - return bitcast(truncated, i16_ty); + return bitcast(truncated, bf16_ty); } // Otherwise it is (rounding == RoundingMode::RTNE) auto as_uint32 = bitcast(v, i32_ty); @@ -337,7 +337,7 @@ static Value convertFp32ToBf16(Location loc, auto shifted = lshr(i32_ty, res, i32_val(16)); auto truncated = trunc(i16_ty, shifted); - return truncated; + return bitcast(truncated, bf16_ty); } static Value Fp8E5M2FNUZ_to_Fp16_oneValue(Location loc, @@ -447,20 +447,20 @@ static SmallVector Fp8E5M2_to_Bf16(Location loc, out0 = or_(i32_ty, out0, sign0); out1 = or_(i32_ty, out1, sign1); - auto bf16x2VecTy = vec_ty(i16_ty, 2); + auto bf16x2VecTy = vec_ty(bf16_ty, 2); out0 = bitcast(out0, bf16x2VecTy); out1 = bitcast(out1, bf16x2VecTy); - return {extract_element(i16_ty, out0, i32_val(0)), - extract_element(i16_ty, out0, i32_val(1)), - extract_element(i16_ty, out1, i32_val(0)), - extract_element(i16_ty, out1, i32_val(1))}; + return {extract_element(bf16_ty, out0, i32_val(0)), + extract_element(bf16_ty, out0, i32_val(1)), + extract_element(bf16_ty, out1, i32_val(0)), + extract_element(bf16_ty, out1, i32_val(1))}; } static SmallVector Bf16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { - auto bf16x2VecTy = vec_ty(i16_ty, 2); + auto bf16x2VecTy = vec_ty(bf16_ty, 2); Value bf16x2Vec0 = undef(bf16x2VecTy); Value bf16x2Vec1 = undef(bf16x2VecTy); bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0)); @@ -716,22 +716,22 @@ static SmallVector Fp8E4M3_to_Bf16(Location loc, Value sign0 = and_(i32_ty, a0, i32_val(0x80008000)); Value sign1 = and_(i32_ty, a1, i32_val(0x80008000)); - auto bf16x2VecTy = vec_ty(i16_ty, 2); + auto bf16x2VecTy = vec_ty(bf16_ty, 2); Value bf16x2Vec0 = or_(i32_ty, sign0, b0); Value bf16x2Vec1 = or_(i32_ty, sign1, b1); bf16x2Vec0 = bitcast(bf16x2Vec0, bf16x2VecTy); bf16x2Vec1 = bitcast(bf16x2Vec1, bf16x2VecTy); - return {extract_element(i16_ty, bf16x2Vec0, i32_val(0)), - extract_element(i16_ty, bf16x2Vec0, i32_val(1)), - extract_element(i16_ty, bf16x2Vec1, i32_val(0)), - extract_element(i16_ty, bf16x2Vec1, i32_val(1))}; + return {extract_element(bf16_ty, bf16x2Vec0, i32_val(0)), + extract_element(bf16_ty, bf16x2Vec0, i32_val(1)), + extract_element(bf16_ty, bf16x2Vec1, i32_val(0)), + extract_element(bf16_ty, bf16x2Vec1, i32_val(1))}; } static SmallVector Bf16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { - auto bf16x2VecTy = vec_ty(i16_ty, 2); + auto bf16x2VecTy = vec_ty(bf16_ty, 2); Value bf16x2Vec0 = undef(bf16x2VecTy); Value bf16x2Vec1 = undef(bf16x2VecTy); bf16x2Vec0 = insert_element(bf16x2VecTy, bf16x2Vec0, v[0], i32_val(0)); @@ -902,7 +902,7 @@ struct FpToFpOpConversion if (srcMap.count(key) == 0) { return mlir::failure(); } - return mlir::FailureOr(srcMap.lookup(key)); + return srcMap.lookup(key); } SmallVector createDestOps(triton::FpToFpOp op, OpAdaptor adaptor, @@ -956,6 +956,19 @@ struct FpToFpOpConversion for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { inVals.push_back(operands[i][0]); } + bool isSrcFP16 = srcElementType.isF16(); + bool isSrcBF16 = srcElementType.isBF16(); + + if ((isSrcFP16 || isSrcBF16) && isDstFP32) { + SmallVector outVals; + for (Value &v : inVals) { + if (isSrcFP16) + outVals.push_back(convertFp16ToFp32(loc, rewriter, v)); + else + outVals.push_back(convertBf16ToFp32(loc, rewriter, v)); + } + return outVals; + } if (useFP16IntermediateSrc) for (Value &v : inVals) v = cvtFp32ToFp16(loc, rewriter, v, @@ -1104,7 +1117,7 @@ static SmallVector S8_to_Bf16(Location loc, f32Val = bitcast(f32Val, i32_ty); auto shifted = lshr(i32_ty, f32Val, i32_val(16)); auto truncated = trunc(i16_ty, shifted); - outValues.push_back(truncated); + outValues.push_back(bitcast(truncated, bf16_ty)); } return outValues; } @@ -1230,7 +1243,7 @@ struct ExpOpConversionApprox LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(rewriter, op, funcName, funcType); - return {rewriter.create(loc, funcOp, prod).getResult()}; + return {LLVM::createLLVMCallOp(rewriter, loc, funcOp, prod).getResult()}; } }; @@ -1263,7 +1276,7 @@ struct Exp2OpConversion appendOrGetExternFuncOp(rewriter, op, funcName, funcType); return { - rewriter.create(loc, funcOp, operands[0]).getResult()}; + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; } private: diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index d47818a2e..a45efd4a7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,6 +1,17 @@ +#include "BufferOpsEmitter.h" +#include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "PatternTritonGPUOpToLLVM.h" #include "TargetInfo.h" #include "Utility.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; @@ -86,12 +97,51 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } return mask; } + // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(const AMD::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass) : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {} + // Createa a LLVM vector of type `vecTy` containing all zeros + Value createZeroVector(OpBuilder &builder, Location loc, + VectorType vecTy) const { + mlir::Attribute zeroAttr = builder.getZeroAttr(vecTy.getElementType()); + auto denseValue = + DenseElementsAttr::get(cast(vecTy), zeroAttr); + Value zeroVal = builder.create(loc, vecTy, denseValue); + return zeroVal; + } + + // Given a vector of values `elems` and a starting point `start`, create a + // LLVM vector of length `vec` whose elements are `elems[start, ..., + // elems+vec-1]` + Value packElementRangeIntoVector(ConversionPatternRewriter &rewriter, + const LLVMTypeConverter *typeConverter, + Location loc, VectorType vecTy, + ArrayRef elems, int64_t start) const { + int64_t vec = vecTy.getNumElements(); + // If we need to mask the loaded value with other elements + Value v = undef(vecTy); + for (size_t s = 0; s < vec; ++s) { + Value otherElem = elems[start + s]; + Value indexVal = + LLVM::createIndexConstant(rewriter, loc, typeConverter, s); + v = insert_element(vecTy, v, otherElem, indexVal); + } + return v; + } + + // Return a tensor of pointers with the same type of `basePtr` and the same + // shape of `offset` + Type getPointerTypeWithShape(Value basePtr, Value offset) const { + Type basePtrType = basePtr.getType(); + auto offsetType = cast(offset.getType()); + return offsetType.cloneWith(std::nullopt, basePtrType); + } + + // Get contiguity for a tensor pointer `ptr` unsigned getContiguity(Value ptr) const { auto tensorTy = dyn_cast(ptr.getType()); if (!tensorTy) @@ -99,23 +149,74 @@ struct LoadStoreConversionBase { return axisAnalysisPass.getPtrContiguity(ptr); } + // Get contiguity for a scalar pointer `ptr` and a tensor `offset` + unsigned getContiguity(Value ptr, Value offset) const { + // Get contiguity from the offset + Type type = getPointerTypeWithShape(ptr, offset); + RankedTensorType tensorTy = cast(type); + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + auto uniqueContigPerThread = + triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + + // Get alignment from the pointer. Since this is a scalar pointer + // we should not take the pointer contiguity to consider alignment + auto *axisInfo = axisAnalysisPass.getAxisInfo(ptr); + auto maxMultipleBytes = axisInfo->getDivisibility(0); + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto align = std::max(maxMultipleBytes / elemNumBytes, 1); + + // Final contiguity is a min of the offset contiguity and pointer alignment + contiguity = std::min(align, contiguity); + return contiguity; + } + + // Determine the vector size of a tensor of pointers unsigned getVectorSize(Value ptr) const { auto tensorTy = dyn_cast(ptr.getType()); if (!tensorTy) return 1; auto contiguity = getContiguity(ptr); auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); - // The maximum vector size is 128 bits on NVIDIA GPUs. return std::min(128 / pointeeBitWidth, contiguity); } + // Given a scalar pointer and a tensor of offsets, determine the vector size + unsigned getVectorSize(Value ptr, Value offset) const { + auto contiguity = getContiguity(ptr, offset); + auto pointeeBitWidth = triton::getPointeeBitWidth(ptr.getType()); + return std::min(128 / pointeeBitWidth, contiguity); + } + + // Unpack the elements contained in a `llvmStruct` into a `SmallVector` of + // `Value`s. While you do that, check also the alignment of the mask and + // update the vector length `vec` accordingly + SmallVector + getMaskElemsAndUpdateVeclen(ConversionPatternRewriter &rewriter, Location loc, + Value llMask, Value mask, unsigned &vec) const { + SmallVector maskElems; + if (llMask) { + vec = std::min(vec, getMaskAlignment(mask)); + maskElems = unpackLLElements(loc, llMask, rewriter); + } + return maskElems; + } + unsigned getMaskAlignment(Value mask) const { return axisAnalysisPass.getMaskAlignment(mask); } + unsigned getPtrAlignment(Value ptr) const { + return axisAnalysisPass.getPtrAlignment(ptr); + } + protected: - ModuleAxisInfoAnalysis &axisAnalysisPass; const AMD::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisAnalysisPass; }; struct LoadOpConversion : public ConvertOpToLLVMPattern, @@ -153,47 +254,29 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getTotalElemsPerThread(ptr.getType()); - if (llMask) - vec = std::min(vec, getMaskAlignment(mask)); // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); // Get the LLVM values for mask - SmallVector maskElems; - if (llMask) { - maskElems = unpackLLElements(loc, llMask, rewriter); - assert(maskElems.size() == numElems); - } + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); - // Get the LLVM values for `other` - // TODO: (goostavz) handle when other is const but not splat, which - // should be rarely seen - bool otherIsSplatConstInt = false; - DenseElementsAttr constAttr; - int64_t splatVal = 0; - if (other && isa(valueElemTy) && - matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && - isa(constAttr.getElementType())) { - otherIsSplatConstInt = true; - splatVal = constAttr.getSplatValue().getSExtValue(); - } SmallVector otherElems; - if (other) { + if (other) otherElems = unpackLLElements(loc, llOther, rewriter); - } // vectorized iteration through all the pointer/mask/other elements const int valueElemNBits = std::max(8u, valueElemTy.getIntOrFloatBitWidth()); + const size_t valueElemNBytes = valueElemNBits / 8; const int numVecs = numElems / vec; + int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes; + auto cacheMod = op.getCache(); SmallVector loadedVals; for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { - // TODO: optimization when ptr is GEP with constant offset - size_t in_off = 0; - const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; const size_t width = std::min(totalWidth, maxWordWidth); @@ -206,28 +289,100 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); - mlir::Attribute zeroAttr = rewriter.getZeroAttr(valueElemTy); - auto denseValue = - DenseElementsAttr::get(cast(vecTy), zeroAttr); - Value zeroVal = rewriter.create(loc, vecTy, denseValue); - - Value falseVal = zeroVal; + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); // If we need to mask the loaded value with other elements - if (otherElems.size() != 0) { - Value v = undef(vecTy); - for (size_t s = 0; s < vec; ++s) { - Value otherElem = otherElems[vecStart + s]; - Value indexVal = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), s); - v = insert_element(vecTy, v, otherElem, indexVal); - } - falseVal = v; + if (otherElems.size() != 0) + falseVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + otherElems, vecStart); + + Value loadVal = llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, + ptrAlignmentBytes, cacheMod); + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), ii); + Value loaded = extract_element(valueElemTy, loadVal, vecIdx); + loadedVals.push_back(loaded); } + } // end vec + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct BufferLoadOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::amdgpu::BufferLoadOp>::ConvertOpToLLVMPattern; + + BufferLoadOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + Value other = op.getOther(); + + // Converted values + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); - auto loadVal = llLoad(rewriter, loc, ptr, vecTy, pred, falseVal); + // Determine the vectorization size + Type valueTy = op.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type ptrType = getPointerTypeWithShape(ptr, offset); + unsigned numElems = getTotalElemsPerThread(ptrType); + unsigned vec = getVectorSize(ptr, offset); + + // Get the offset + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + assert(offsetElems.size() == numElems); + + // Get the mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + // Get the `other` value (if any) + SmallVector otherElems; + if (llOther) + otherElems = unpackLLElements(loc, llOther, rewriter); + + // Create the resource descriptor and then emit the buffer_load intrinsic(s) + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr); + SmallVector loadedVals; + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Value pred = mask ? maskElems[vecStart] : int_val(1, 1); + Value falseVal = createZeroVector(rewriter, loc, cast(vecTy)); + if (otherElems.size() != 0) + falseVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + otherElems, vecStart); + Value loadVal = bufferEmitter.emitLoad( + vecTy, rsrcDesc, offsetElems[vecStart], pred, falseVal); for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( - rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec); + rewriter, loc, this->getTypeConverter()->getIndexType(), ii); Value loaded = extract_element(valueElemTy, loadVal, vecIdx); loadedVals.push_back(loaded); } @@ -257,6 +412,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, ConversionPatternRewriter &rewriter) const override { Value ptr = op.getPtr(); Value value = op.getValue(); + Value mask = op.getMask(); Value llPtr = adaptor.getPtr(); Value llMask = adaptor.getMask(); @@ -269,6 +425,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(valueTy)); + // Determine the vectorization size unsigned vec = getVectorSize(ptr); unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); @@ -276,26 +433,20 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, auto valueElems = unpackLLElements(loc, llValue, rewriter); assert(ptrElems.size() == valueElems.size()); - // Determine the vectorization size - SmallVector maskElems; - if (llMask) { - Value mask = op.getMask(); - maskElems = unpackLLElements(loc, llMask, rewriter); - assert(valueElems.size() == maskElems.size()); - - unsigned maskAlign = getMaskAlignment(mask); - vec = std::min(vec, maskAlign); - } + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); - Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); - const size_t dtsize = - std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); - const size_t valueElemNBits = dtsize * 8; + const size_t valueElemNBits = + std::max(8, valueElemTy.getIntOrFloatBitWidth()); + const size_t valueElemNBytes = valueElemNBits / 8; + int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes; + auto cacheMod = op.getCache(); const int numVecs = elemsPerThread / vec; + Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { - // TODO: optimization when ptr is AddPtr with constant offset - size_t in_off = 0; + Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask; + auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); const size_t maxWordWidth = std::max(32, valueElemNBits); const size_t totalWidth = valueElemNBits * vec; @@ -304,33 +455,81 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, const size_t wordNElems = width / valueElemNBits; assert(wordNElems * nWords * numVecs == elemsPerThread); - // TODO(Superjomn) Add cache policy fields to StoreOp. - // TODO(Superjomn) Deal with cache policy here. + SmallVector> asmArgs; + Value elem = valueElems[vecStart]; + Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); - Type valArgTy = IntegerType::get(ctx, width); - auto wordTy = vec_ty(valueElemTy, wordNElems); + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + valueElems, vecStart); + llStore(rewriter, loc, ptr, storeVal, pred, ptrAlignmentBytes, cacheMod); + } // end vec + rewriter.eraseOp(op); + return success(); + } +}; + +struct BufferStoreOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + using ConvertOpToLLVMPattern< + triton::amdgpu::BufferStoreOp>::ConvertOpToLLVMPattern; + + BufferStoreOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::amdgpu::BufferStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + LLVM::AMD::BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + + // original values + Value ptr = op.getPtr(); + Value offset = op.getOffsets(); + Value mask = op.getMask(); + Value data = op.getValue(); + + Value llPtr = adaptor.getPtr(); + Value llOffset = adaptor.getOffsets(); + Value llMask = adaptor.getMask(); + Value llData = adaptor.getValue(); + + // Determine the vectorization size + Type valueTy = data.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type ptrType = getPointerTypeWithShape(ptr, offset); + + unsigned numElems = getTotalElemsPerThread(ptrType); + unsigned vec = getVectorSize(ptr, offset); + + // Get the offsets and value + SmallVector offsetElems = unpackLLElements(loc, llOffset, rewriter); + SmallVector valueElems = unpackLLElements(loc, llData, rewriter); + + // Get the mask + SmallVector maskElems = + getMaskElemsAndUpdateVeclen(rewriter, loc, llMask, mask, vec); + + Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr); + Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + Type vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask; + // Create the store val + Value storeVal = packElementRangeIntoVector( + rewriter, this->getTypeConverter(), loc, cast(vecTy), + valueElems, vecStart); + bufferEmitter.emitStore(rsrcDesc, offsetElems[vecStart], storeVal, pred); + } // end vec - SmallVector> asmArgs; - for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { - // llWord is a width-len composition - Value llWord = undef(wordTy); - // Insert each value element to the composition - for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) { - const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx; - assert(elemOffset < valueElems.size()); - Value elem = valueElems[elemOffset]; - if (elem.getType().isInteger(1)) - elem = sext(i8_ty, elem); - elem = bitcast(elem, valueElemTy); - - llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx)); - } - llWord = bitcast(llWord, valArgTy); - Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask; - auto address = ptrElems[vecStart + wordIdx * wordNElems]; - llStore(rewriter, loc, address, llWord, maskVal); - } - } rewriter.eraseOp(op); return success(); } @@ -442,8 +641,6 @@ struct AtomicCASOpConversion // Fill entry block with global memory barrier and conditional branch. rewriter.setInsertionPointToEnd(curBlock); - Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); - atomPtr = bitcast(atomPtr, ptr_ty(rewriter.getContext(), 3)); auto tid = tid_val(); Value pred = icmp_eq(tid, i32_val(i)); rewriter.create(loc, pred, atomicBlock, endBlock); @@ -456,22 +653,32 @@ struct AtomicCASOpConversion auto cmpxchg = rewriter.create( loc, casPtr, casCmp, casVal, successOrdering, failureOrdering, StringRef("agent")); - // Extract the new_loaded value from the pair. - Value newLoaded = extract_val(valueElemTy, cmpxchg, 0); - store(newLoaded, atomPtr); + if (atomicNeedsSharedMemory(op.getResult())) { + // Extract the new_loaded value from the pair. + Value newLoaded = extract_val(valueElemTy, cmpxchg, 0); + Value atomPtr = + getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + store(newLoaded, atomPtr); + } rewriter.create(loc, ValueRange(), endBlock); // Build the last block: synced load from shared memory, exit. rewriter.setInsertionPointToStart(endBlock); + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); + return success(); + } + GCNBuilder BuilderMemfenceLDS; BuilderMemfenceLDS.create<>("s_waitcnt lgkmcnt(0)")->operator()(); BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx)); barrier(); + Value atomPtr = + getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); Value ret = load(valueElemTy, atomPtr); - barrier(); rewriter.replaceOp(op, {ret}); } } @@ -619,8 +826,11 @@ struct AtomicRMWOpConversion atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); } if (!tensorTy) { - Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); - store(atom, atomPtr); + if (atomicNeedsSharedMemory(op.getResult())) { + Value atomPtr = + getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); + store(atom, atomPtr); + } } rewriter.create(loc, atom, endBlock); @@ -633,10 +843,14 @@ struct AtomicRMWOpConversion : extract_element(valueElemTy, retVal, i32_val(ii)); } } else { - Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); + return success(); + } + Value atomPtr = + getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); barrier(); Value ret = load(valueElemTy, atomPtr); - barrier(); rewriter.replaceOp(op, {ret}); } } @@ -658,8 +872,9 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, axisInfoAnalysis, - benefit); + patterns + .add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp new file mode 100644 index 000000000..4a0a7fed2 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "OptimizeLDSUtility.h" +#include "TargetInfo.h" +#include "TritonAMDGPUToLLVM/Passes.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; + +namespace mlir::triton { +#define GEN_PASS_DEF_OPTIMIZEAMDLDSUSAGE +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +namespace { + +class OptimizeAMDLDSUsage + : public mlir::triton::impl::OptimizeAMDLDSUsageBase { + + int LDSLimit; + + // Try to reduce LDS usage of convert op by adding tmp layout in conversion: + // + // %1 = convert %0 (src layout -> dst layout) + // -> + // %1 = convert %0 (src layout -> tmp) + // %2 = convert %1 (tmp -> dst layout) + // + // The implicit LDS usage of convert op depends on src and dst layouts + // + // Consider mfma->blocked conversion as an example. + // + // tensor shape: [128, 128] + // mfma layout: warpsPerCTA = [1, 4], instrShape = [32, 32] + // blocked layout: sizePerThread = [1, 4], threadsPerWarp = [32, 2], + // warpsPerCTA = [4, 1] + // + // minimal mfma tile is: [1*32, 4*32] = [32, 128] + // minimal blocked tile is: [1*32*4, 4*2*1] = [128, 8] + // + // Roughtly scratch buffer shape for conversion is: + // [max(32, 128), max(128, 16)] = [128, 128]. + // + // This shape could be reduces by introducing intermediate + // layout and replacing old convert operations with two new conversions: + // + // %1 = convert %0 (mfma -> blocked) + // -> + // %1 = convert %0 (mfma -> tmp) + // %2 = convert %1 (tmp -> blocked) + // + // Let's consider tmp as blocked layout: + // sizePerThread = [1, 4], threadsPerWarp = [32, 2], warpsPerCTA = [1, 4] + // Tmp layout scratch buffer has shape: [1*32*1, 4*2*4] = [32, 32] + // + // With intermediate layout we have two scratch buffers: + // + // %1 = convert %0 (mfma -> tmp): [max(32, 32), max(128, 32)] = [32, 128] + // %2 = convert %1 (tmp -> blocked): [max(32, 128), max(32, 32)] = [128, 32] + // + // Both of these buffers are 4x times smaller than original one and their live + // times do not intersect, therefore this transformation lowers LDS + // consumption. + void tryFitCvtIntoLDS(triton::gpu::ConvertLayoutOp cvtOp, int targetLDSSize) { + OpBuilder builder(cvtOp); + + auto srcType = cvtOp.getSrc().getType(); + auto dstType = cvtOp.getType(); + + auto srcEnc = srcType.getEncoding(); + auto dstEnc = dstType.getEncoding(); + + auto ctx = srcEnc.getContext(); + auto rank = srcType.getRank(); + + unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc); + auto warpSize = triton::gpu::getWarpSize(srcEnc); + + // Find all possible shapes of WarpsPerCTA by finding all possible + // factorizations of numWarps. Pick shape for which both conversions in + // decomposition use LDS less than LDSLimit and for which sum of LDS usage + // is minimal. If no such shape exists, do not decompose. + auto factorizedNumWarps = + mlir::triton::AMD::factorizePowerOf2(numWarps, rank); + // Create a list of temporary layouts + SmallVector elemsPerThread(rank, 1); + SmallVector threadsPerWarp(rank, 1); + + // Special case for rank == 1 + if (rank == 1) { + threadsPerWarp[0] = warpSize; + } else { + assert(rank > 1); + threadsPerWarp[rank - 1] = warpSize / 8; + threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; + } + + auto layoutCTA = triton::gpu::getCTALayout(srcEnc); + auto order = triton::gpu::getOrder(srcEnc); + SmallVector dummyWarpsPerCTA(rank, 1); + + auto baseFallbackLayout = triton::gpu::BlockedEncodingAttr::get( + ctx, elemsPerThread, threadsPerWarp, dummyWarpsPerCTA, order, + layoutCTA); + SmallVector tmpLayouts; + for (int i = 0; i < factorizedNumWarps.size(); i++) { + auto warpsPerCTA = factorizedNumWarps[i]; + tmpLayouts.push_back( + mlir::triton::AMD::createTmpLayout(srcEnc, warpsPerCTA)); + tmpLayouts.push_back( + mlir::triton::AMD::createTmpLayout(dstEnc, warpsPerCTA)); + tmpLayouts.push_back( + mlir::triton::AMD::createTmpLayout(baseFallbackLayout, warpsPerCTA)); + } + + unsigned minLDSUsage = 2 * LDSLimit; + int minIdx = -1; + for (int i = 0; i < tmpLayouts.size(); i++) { + auto resources = mlir::triton::AMD::estimateResourcesForReplacement( + builder, cvtOp, tmpLayouts[i]); + // TODO analyze performance along with LDS consumption + if (resources.LDS < minLDSUsage) { + minLDSUsage = resources.LDS; + minIdx = i; + } + } + + if (minIdx == -1 || minLDSUsage > targetLDSSize) { + return; + } + + assert(minIdx >= 0 && minIdx < tmpLayouts.size()); + auto tmpLayout = tmpLayouts[minIdx]; + auto replacementCvts = + mlir::triton::AMD::createNewConvertOps(builder, cvtOp, tmpLayout); + + cvtOp.replaceAllUsesWith(replacementCvts.second.getResult()); + cvtOp.erase(); + } + + struct LDSBottleneckOperation { + triton::gpu::ConvertLayoutOp op; + int64_t LDSSizeTarget; + }; + + // Assuming that all buffer above scratch buffer in memory space can be + // shifted down in memory, gives an optimistic estimation of memory space + // available for scratch buffer. + int64_t + computeTargetScratchBufferSize(triton::gpu::ConvertLayoutOp op, + Allocation *allocation, + ArrayRef liveBuffers) { + int totalSize = 0; + auto scratchBufferId = allocation->getBufferId(op.getOperation()); + int64_t scratchBufferSize = allocation->getAllocatedSize(scratchBufferId); + size_t totalLDSConsumption = 0; + for (auto buf : liveBuffers) { + totalLDSConsumption = std::max( + totalLDSConsumption, allocation->getAllocatedInterval(buf).end()); + } + int64_t freeRequired = totalLDSConsumption - LDSLimit; + return std::max(static_cast(0), scratchBufferSize - freeRequired); + } + + SmallVector + findLDSBottleneckLayoutConvert(ModuleAllocation &allocAnalysis, + FunctionOpInterface func) { + SmallVector candidates; + auto funcAnalysis = allocAnalysis.getFuncData(func); + auto liveBuffers = funcAnalysis->getLiveBuffers(); + + func.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + auto srcTy = cvtOp.getSrc().getType(); + auto dstTy = cvtOp.getResult().getType(); + if (!cvtNeedsSharedMemory(srcTy, dstTy)) + return; + auto cvtBuffer = funcAnalysis->getBufferId(cvtOp.getOperation()); + assert(cvtBuffer != Allocation::InvalidBufferId); + + auto targetScratchBufferSize = computeTargetScratchBufferSize( + cvtOp, funcAnalysis, liveBuffers[cvtOp]); + auto currentLDSConsumption = funcAnalysis->getAllocatedSize(cvtBuffer); + if (currentLDSConsumption > targetScratchBufferSize) + candidates.push_back({cvtOp, targetScratchBufferSize}); + }); + return candidates; + } + +public: + OptimizeAMDLDSUsage(StringRef targetArch, int customLDSLimit) + : OptimizeAMDLDSUsageBase() { + this->targetArch = targetArch.str(); + this->customLDSLimit = customLDSLimit; + } + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + if ((this->LDSLimit = this->customLDSLimit) == 0) { + if (this->targetArch.empty()) { + mod->emitError("missing gfx* target for pass ") + << this->getName().str(); + return signalPassFailure(); + } + triton::AMD::TargetInfo targetInfo(this->targetArch.c_str()); + LDSLimit = targetInfo.getSharedMemorySize(); + } + + ModuleAllocation allocAnalysis(mod); + if (allocAnalysis.getSharedMemorySize() <= LDSLimit) + return; + + auto rootFunctions = allocAnalysis.getRoots(); + for (auto rootFunc : rootFunctions) { + // Find operations with peak LDS consumption + auto candidates = findLDSBottleneckLayoutConvert(allocAnalysis, rootFunc); + // Try to transform candidate operations to fit them into LDS + for (auto candidate : candidates) + tryFitCvtIntoLDS(candidate.op, candidate.LDSSizeTarget); + } + } +}; + +} // namespace + +namespace mlir::triton::AMD { + +std::unique_ptr> +createOptimizeLDSUsagePass(StringRef targetArch, int customLDSLimit) { + return std::make_unique(targetArch, customLDSLimit); +} + +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp new file mode 100644 index 000000000..fb0bfb656 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp @@ -0,0 +1,117 @@ +#include "OptimizeLDSUtility.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton::AMD { + +constexpr int kPtrBitWidth = 64; + +int getCvtOpLDSUsage(RankedTensorType srcTy, RankedTensorType dstTy) { + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + unsigned elems = getNumScratchElements(scratchConfig.paddedRepShape); + auto bytes = + isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + + return bytes; +} + +int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp op) { + return getCvtOpLDSUsage(op.getSrc().getType(), op.getType()); +} + +static void stepFactorizationPow2(std::vector> &factors, + SmallVector &curFactor, + int restTwos, int dim) { + if (dim == curFactor.size()) { + if (restTwos == 0) + factors.push_back(curFactor); + return; + } + curFactor[dim] = 1; + for (int i = 0; i <= restTwos; ++i) { + stepFactorizationPow2(factors, curFactor, restTwos - i, dim + 1); + curFactor[dim] *= 2; + } +} + +std::vector> factorizePowerOf2(int n, int rank) { + assert(llvm::isPowerOf2_32(n)); + int x = log2(n); + std::vector> factors; + SmallVector curFactor(rank, 1); + stepFactorizationPow2(factors, curFactor, x, 0); + return factors; +} + +Attribute createTmpLayout(Attribute layout, ArrayRef warpsPerCTA) { + auto ctx = layout.getContext(); + if (auto src = dyn_cast(layout)) + return triton::gpu::AMDMfmaEncodingAttr::get( + ctx, src.getVersionMajor(), src.getVersionMinor(), warpsPerCTA, + src.getMDim(), src.getNDim(), src.getIsTransposed(), + src.getCTALayout()); + if (auto src = dyn_cast(layout)) + return triton::gpu::AMDWmmaEncodingAttr::get( + ctx, /*version=*/1, warpsPerCTA, src.getCTALayout()); + if (auto src = dyn_cast(layout)) + return triton::gpu::BlockedEncodingAttr::get( + ctx, src.getSizePerThread(), src.getThreadsPerWarp(), warpsPerCTA, + src.getOrder(), src.getCTALayout()); + if (auto src = dyn_cast(layout)) { + return triton::gpu::DotOperandEncodingAttr::get( + ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA), + src.getKWidth()); + } + if (auto src = dyn_cast(layout)) { + // TODO: think of a way to construct slice layouts based on warpsPerCTA + // argument + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent()); + return triton::gpu::SliceEncodingAttr::get( + ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA)); + } + assert("Encountered unsupported layout"); + return Attribute(); +} + +std::pair +createNewConvertOps(OpBuilder &builder, triton::gpu::ConvertLayoutOp &cvtOp, + Attribute tmpLayout) { + auto srcType = cvtOp.getSrc().getType(); + auto dstType = cvtOp.getType(); + + auto newDstType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), dstType.getEncoding()); + RankedTensorType newSrcType = RankedTensorType::get( + srcType.getShape(), srcType.getElementType(), tmpLayout); + + auto tmpCvt = builder.create( + cvtOp.getLoc(), newSrcType, cvtOp.getSrc()); + auto newEpilogueCvt = builder.create( + cvtOp.getLoc(), newDstType, tmpCvt); + + return std::make_pair(tmpCvt, newEpilogueCvt); +} + +Resources +estimateResourcesForReplacement(OpBuilder builder, + mlir::triton::gpu::ConvertLayoutOp cvtOp, + Attribute tmpLayout) { + Resources res; + RankedTensorType srcTy = cvtOp.getSrc().getType(); + RankedTensorType dstTy = cvtOp.getType(); + RankedTensorType intermediateTy = RankedTensorType::get( + srcTy.getShape(), srcTy.getElementType(), tmpLayout); + + int tmpCvtLDS = mlir::triton::AMD::getCvtOpLDSUsage(srcTy, intermediateTy); + int newCvtLDS = mlir::triton::AMD::getCvtOpLDSUsage(intermediateTy, dstTy); + res.LDS = std::max(tmpCvtLDS, newCvtLDS); + return res; +} + +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h new file mode 100644 index 000000000..2bd2a977f --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h @@ -0,0 +1,50 @@ +#ifndef TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_OPTIMIZE_LDS_UTILITY_H +#define TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_OPTIMIZE_LDS_UTILITY_H + +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::AMD { + +int getCvtOpLDSUsage(RankedTensorType srcTy, RankedTensorType dstTy); + +int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp op); + +std::vector> factorizePowerOf2(int n, int rank); + +/** + * @brief Copy given layout with different warpsPerCTA parameter + * @param layout original layout + * @param warpsPerCTA new warpsPerCTA + * @return create layout + */ +Attribute createTmpLayout(Attribute layout, ArrayRef warpsPerCTA); + +/** + * Creates two chained convert layout operations + * + * %1 = cvtOp %0 (srcLayout -> dstLayout) // original operation + * -> + * %2 = cvtOp %0 (srcLayout -> tmpLayout) // .first + * %3 = cvtOp %2 (tmpLayout -> dstLayout) // .second + * + * @param builder + * @param cvtOp original operation + * @param tmpLayout + * @return pair of created operations + */ +std::pair +createNewConvertOps(OpBuilder &builder, triton::gpu::ConvertLayoutOp &cvtOp, + Attribute tmpLayout); + +struct Resources { + int LDS; +}; + +Resources +estimateResourcesForReplacement(OpBuilder builder, + mlir::triton::gpu::ConvertLayoutOp cvtOp, + Attribute tmpLayout); + +} // namespace mlir::triton::AMD + +#endif // TRITON_CONVERSION_TRITONAMDGPU_TO_LLVM_OPTIMIZE_LDS_UTILITY_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 67e5369b8..764f31a61 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -30,6 +30,9 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp new file mode 100644 index 000000000..9bed87961 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -0,0 +1,203 @@ +#include "TritonAMDGPUToLLVM/Passes.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +using namespace mlir; + +namespace { + +// The bitmask that encodes kinds of the instructions from AMD ISA. +// The bitmask is used for providing instruction scheduling hints. +enum InstructionKindMask { + NONE = 0x0000000, + ALL_ALU = 0x00000001, + VALU = 0x00000002, + SALU = 0x00000004, + MFMA = 0x00000008, + ALL_VMEM = 0x00000010, + VMEM_READ = 0x00000020, + VMEM_WRITE = 0x00000040, + ALL_DS = 0x00000080, + DS_READ = 0x00000100, + DS_WRITE = 0x00000200 +}; + +// Create an intrinsic to control how different instruction kinds should +// interleave for better ILP. +void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, + InstructionKindMask maskValue, int sizeValue, + int groupIdValue) { + MLIRContext *ctx = rewriter.getContext(); + const char *intrinsicName = "llvm.amdgcn.sched.group.barrier"; + + Value mask = + LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); + Value size = + LLVM::createConstantI32(loc, rewriter, static_cast(sizeValue)); + Value groupId = LLVM::createConstantI32(loc, rewriter, + static_cast(groupIdValue)); + + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, TypeRange{}, + ValueRange{mask, size, groupId}); +} + +// Insert intrinsic that controls the types of instructions that may be +// allowed to cross the intrinsic during instruction scheduling +Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, + int64_t maskValue) { + MLIRContext *ctx = rewriter.getContext(); + const char *intrinsicName = "llvm.amdgcn.sched.barrier"; + LLVM::FastmathFlagsAttr defaultFlags{}; + + Value mask = + LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, + TypeRange{}, ValueRange{mask}); +} + +// Insert an experimental intrinsic for instruction group level parallelism. +// The intrinsic takes a value that specifies the strategy. +Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { + MLIRContext *ctx = rewriter.getContext(); + const char *intrinsicName = "llvm.amdgcn.iglp.opt"; + LLVM::FastmathFlagsAttr defaultFlags{}; + Value iglpValue = + LLVM::createConstantI32(loc, rewriter, static_cast(value)); + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, + TypeRange{}, ValueRange{iglpValue}); +} + +struct InstructionSchedHintsRewriter + : public OpRewritePattern { + + InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) + : OpRewritePattern(ctx) { + std::transform(variant.begin(), variant.end(), variant.begin(), + [](unsigned char c) { return std::tolower(c); }); + + this->schedulingType = llvm::StringSwitch(variant) + .Case("default", SchedulingType::NONE) + .Case("iglp0", SchedulingType::IGLP0) + .Case("iglp1", SchedulingType::IGLP1) + .Default(SchedulingType::UNKNOWN); + } + + enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN }; + + LogicalResult + matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, + PatternRewriter &rewriter) const override { + + if (this->schedulingType == SchedulingType::UNKNOWN) { + llvm::dbgs() + << "[" << getDebugName() << "]: " + << "unknown instruction scheduling variant has been provided\n"; + return mlir::failure(); + } + + // The switch controls whether instructions are allowed to cross the basic + // block boundaries at the very top and at the very bottom. Note, this is + // not supposed to be used together with IGLP OPT according to the AMDGPU + // backend documentation. + const bool limitSchedulingRange = + !(schedulingType == SchedulingType::IGLP0 || + schedulingType == SchedulingType::IGLP1); + Location loc = instructionSchedHint->getLoc(); + Block *block = instructionSchedHint->getBlock(); + if (limitSchedulingRange) { + rewriter.setInsertionPointToStart(block); + createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + } + + rewriter.setInsertionPoint(block, std::prev(block->end())); + + switch (schedulingType) { + case SchedulingType::IGLP0: + [[fallthrough]]; + case SchedulingType::IGLP1: { + createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); + break; + } + case SchedulingType::NONE: + [[fallthrough]]; + default: { + break; + } + } + + if (limitSchedulingRange) + createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + + rewriter.eraseOp(instructionSchedHint); + return mlir::success(); + } + +private: + SchedulingType schedulingType; +}; + +struct LowerInstructionSchedHints + : public triton::impl::LowerInstructionSchedHintsBase< + LowerInstructionSchedHints> { + + explicit LowerInstructionSchedHints(std::string variant) { + this->variant = variant; + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + ModuleOp mod = getOperation(); + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addIllegalOp(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx, this->variant); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +struct InsertInstructionSchedHints + : public triton::impl::InsertInstructionSchedHintsBase< + InsertInstructionSchedHints> { + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + ModuleOp mod = getOperation(); + + mod->walk([ctx](triton::DotOp dot) { + if (dyn_cast(dot->getParentOp())) { + mlir::OpBuilder rewriter(ctx); + rewriter.setInsertionPointAfter(dot); + rewriter.create(dot->getLoc()); + } + }); + } +}; +} // namespace + +namespace mlir::triton { +std::unique_ptr> +createLowerInstructionSchedHintsPass(std::string variant) { + return std::make_unique(variant); +} + +std::unique_ptr> +createInsertInstructionSchedHintsPass() { + return std::make_unique(); +} +} // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 1e8f33c3f..3a40d73c2 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -1,6 +1,6 @@ #include "TargetInfo.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" #include "Utility.h" -#include "amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -10,12 +10,11 @@ namespace mlir::triton::AMD { namespace { template LLVM::LLVMFuncOp getOrInsertFunction(T &moduleOp, const Location loc, - ConversionPatternRewriter &rewriter, - StringRef name, + RewriterBase &rewriter, StringRef name, LLVM::LLVMFunctionType type) { LLVM::LLVMFuncOp ret; if (!(ret = moduleOp.template lookupSymbol(name))) { - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); ret = rewriter.create(loc, name, type, LLVM::Linkage::External); @@ -24,7 +23,7 @@ LLVM::LLVMFuncOp getOrInsertFunction(T &moduleOp, const Location loc, } // Extend all values to 64-bit per printf call requirements. -Value printfPromoteValue(ConversionPatternRewriter &rewriter, Value value) { +Value printfPromoteValue(RewriterBase &rewriter, Value value) { auto *context = rewriter.getContext(); auto loc = UnknownLoc::get(context); auto type = value.getType(); @@ -68,72 +67,74 @@ Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { return rewriter.create(loc, 0, 32); } -Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, - Type type, Value cmp) const { - auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot"); - SmallVector operands = {cmp}; - Value asmResult = - rewriter.create(loc, type, stringAttr, operands) - ->getResult(0); - return asmResult; +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { + return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.ballot", + type, cmp) + ->getResult(0); } -void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) const { +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const { + if (ctaId.has_value()) { + llvm::report_fatal_error( + "AMDGPU does not support cross-CTA shared memory transfers"); + } mlir::LLVM::AMD::llStore(rewriter, loc, ptr, val, pred); } -Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *converter, Value ptr, - Type elemTy, Value pred) const { - Value falseVal = rewriter.create( +void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, + Value ptr, Value val) const { + llvm::report_fatal_error("AMDGPU does not support stmatrix"); +} + +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const { + if (ctaId.has_value()) { + llvm::report_fatal_error( + "AMDGPU does not support cross-CTA shared memory transfers"); + } + Value falseVal = rewriter.create( loc, elemTy, rewriter.getZeroAttr(elemTy)); return mlir::LLVM::AMD::llLoad(rewriter, loc, ptr, elemTy, pred, falseVal); } -Value TargetInfo::shuffleXor(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::AMD::shuffleXor(loc, rewriter, val, i); } -Value TargetInfo::shuffleUp(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::AMD::shuffleUp(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, Value i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, int axis) const { return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); } -bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, +bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const { - return false; -} - -bool TargetInfo::processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const { + unsigned numLaneToReduce, + unsigned interleave) const { return false; } void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, - ValueRange args, - ConversionPatternRewriter &rewriter, + ValueRange args, RewriterBase &rewriter, bool useStdErr) const { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); auto *ctx = rewriter.getContext(); @@ -204,14 +205,25 @@ std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { return funcName; } -void TargetInfo::printf(ConversionPatternRewriter &rewriter, - Value formatStrStart, int formatStrByteCount, - ValueRange args) const { +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const { return printfImpl(formatStrStart, formatStrByteCount, args, rewriter, /*useStdError=*/false); } -void TargetInfo::assertFail(ConversionPatternRewriter &rewriter, Location loc, +void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, + "printfFormat_", msgNewline); + printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); +} + +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { // Compose and print an assert message. @@ -224,8 +236,19 @@ void TargetInfo::assertFail(ConversionPatternRewriter &rewriter, Location loc, printfImpl(msgValue, msgBuffer.size_in_bytes(), /*args=*/ValueRange(), rewriter, /*useStdError=*/true); + // Set block barrrier before aborting kernel, give a chance for all + // the threads in a block to check/print the assert failure. + barrier(); // Perform the trap to abort the kernel. rewriter.create(loc); } +int TargetInfo::getSharedAddressSpace() const { return 3; } + +bool TargetInfo::supportVectorizedAtomics() const { + // Note: not currently tested or used, but AMD generally supports vectorized + // atomics. + return true; +} + } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 2312c9ed6..0ce38d4d7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -18,51 +18,51 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; - Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, + Value ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const override; - void storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) const override; - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *converter, Value ptr, Type elemTy, - Value pred) const override; + void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const override; + Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const override; + void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val) const override; - Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const override; - Value programId(ConversionPatternRewriter &rewriter, Location loc, - ModuleOp moduleOp, int axis) const override; + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + int axis) const override; - bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, - SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const override; - - bool processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const override; + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; std::string getMulhiFuncName(Type resultElementTy) const override; - void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const override; - void assertFail(ConversionPatternRewriter &rewriter, Location loc, - StringRef message, StringRef file, StringRef func, - int line) const override; - bool enableLinearLayout() const override { return false; } + void printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const override; + + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; + int getSharedAddressSpace() const override; + + bool supportVectorizedAtomics() const override; private: void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args, - ConversionPatternRewriter &rewriter, bool useStdErr) const; + RewriterBase &rewriter, bool useStdErr) const; std::string arch; }; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp index 63fb972f7..7ab6fd68a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp @@ -11,6 +11,7 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { // CDNA ISA cases switch (kind) { + case llvm::AMDGPU::GK_GFX950: case llvm::AMDGPU::GK_GFX942: case llvm::AMDGPU::GK_GFX941: case llvm::AMDGPU::GK_GFX940: diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 8649911a7..94e277494 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -8,11 +8,13 @@ #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" @@ -57,6 +59,7 @@ class TritonLLVMConversionTarget : public ConversionTarget { addIllegalDialect(); addIllegalDialect(); addLegalOp(); + addLegalOp(); } }; @@ -86,14 +89,14 @@ struct ConvertTritonAMDGPUToLLVM mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); - TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); TritonLLVMConversionTarget convTarget(*context); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - // Hack: WSMaterialization may have changed the effective number of warps, + // Hack: WSLowering may have changed the effective number of warps, // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to // respect that here. if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { @@ -108,11 +111,12 @@ struct ConvertTritonAMDGPUToLLVM // Lower functions { mlir::LowerToLLVMOptions option(context); - TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); TritonLLVMFunctionConversionTarget funcTarget(*context); RewritePatternSet funcPatterns(context); - mlir::triton::populateFuncOpConversionPattern( - typeConverter, funcPatterns, numWarps, patternBenefitDefault); + mlir::triton::populateFuncOpConversionPattern(typeConverter, funcPatterns, + numWarps, targetInfo, + patternBenefitDefault); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, funcPatterns); if (failed( @@ -128,7 +132,7 @@ struct ConvertTritonAMDGPUToLLVM // Convert call and ret ops { mlir::LowerToLLVMOptions option(context); - TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); TritonLLVMFunctionConversionTarget funcTarget(*context); RewritePatternSet funcPatterns(context); if (failed( @@ -196,10 +200,14 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, - commonBenefit); + targetInfo, commonBenefit); mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); AMD::populateSPMDOpToLLVMPattern(typeConverter, patterns, AMDBenefit); + + mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter, + patterns, AMDBenefit); + // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns // to help convert scalar expression to LLVM. @@ -214,6 +222,7 @@ struct ConvertTritonAMDGPUToLLVM patterns); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); + mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { return signalPassFailure(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 111045d13..542b1ecbb 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -1,8 +1,12 @@ #include "Utility.h" #include "PatternTritonGPUOpToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::getFunctionType; @@ -35,12 +39,40 @@ std::string mangleFunc(std::string name, Type type) { } return mangled; } + +// Utility function to create a constant vector mask of length `vecSize` with +// the same `pred` value +Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc, + Value pred, int64_t vecSize) { + auto vecMaskTy = LLVM::getFixedVectorType(rewriter.getI1Type(), vecSize); + Value maskVal = undef(vecMaskTy); + for (size_t s = 0; s < vecSize; ++s) { + Value indexVal = + rewriter.create(loc, rewriter.getI64IntegerAttr(s)); + maskVal = insert_element(vecMaskTy, maskVal, pred, indexVal); + } + return maskVal; +} + +// Utility function to get the number of elements of a vector or a scalar +int64_t getNumElements(Type ty) { + if (auto vecType = dyn_cast(ty)) + return vecType.getNumElements(); + return 1; +} + +// Utility function to cast the given scalar or vector type to a vector type +Type castToVectorType(Type ty) { + if (isa(ty)) + return ty; + return LLVM::getFixedVectorType(ty, 1); +} + } // namespace namespace mlir::LLVM::AMD { -static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, - Value val, Value i, int strideInt, ShflKind mode, - Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, int strideInt, ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -126,30 +158,26 @@ static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, return Value(); } -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleIdx(loc, rewriter, val, i32_val(i)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); } -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis) { +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis) { assert(axis >= 0); assert(axis < 3); assert(moduleOp); @@ -160,29 +188,93 @@ Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, return rewriter.create(loc, i32_ty, blockId); } -Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Type elemTy, Value pred, Value falseVal) { +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal, int64_t alignmentBytes, + triton::CacheModifier cm) { + + // Try to emit llvm.intr.masked.load if we can. In theory the backend should + // be happier because we emit less branchy code to optimize. The backend will + // lower it down however it wants at some point. + if (alignmentBytes && + (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE)) { + // `llvm.intr.masked.load` only accepts vectors. If we see a scalar we need + // to bitcast to `vector<1xelemTy>` (and back) + int64_t vecSize = getNumElements(elemTy); + Type vecType = castToVectorType(elemTy); + falseVal = bitcast(falseVal, vecType); + Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize); + bool nt = (cm == triton::CacheModifier::CG); + Value vecData = rewriter.create( + loc, vecType, ptr, maskVal, falseVal, alignmentBytes, nt); + // If it is not a vector, remember to bitcast back to a scalar + vecData = bitcast(vecData, elemTy); + return vecData; + } + Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal})); auto parent = ptr.getParentRegion()->getParentOfType(); - auto funcName = mangleFunc(mlir::LLVM::AMD::Predicated_Load, funcType); + auto getLoadNameRaw = [](triton::CacheModifier cm) { + switch (cm) { + case triton::CacheModifier::CA: + return predicatedLoadCA; + case triton::CacheModifier::CG: + return predicatedLoadCG; + case triton::CacheModifier::CV: + return predicatedLoadCV; + default: + // Do not fail in compile time in the case of unsupported modifier. + // Just apply default config. + return predicatedLoad; + } + }; + + auto funcName = mangleFunc(getLoadNameRaw(cm), funcType); LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); - auto loadVal = - rewriter - .create(loc, funcOp, ValueRange({ptr, pred, falseVal})) - .getResult(); - return loadVal; + return LLVM::createLLVMCallOp(rewriter, loc, funcOp, + ValueRange({ptr, pred, falseVal})) + .getResult(); } -void llStore(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) { +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred, int64_t alignmentBytes, triton::CacheModifier cm) { + // Try to emit llvm.intr.masked.store if we can. In theory the backend should + // be happier because we emit less branchy code to optimize. The backend will + // lower it down however it wants at some point. + if (alignmentBytes && cm == triton::CacheModifier::NONE) { + // `llvm.intr.masked.store` only accepts vectors. If we see a scalar we need + // to bitcast to `vector<1xelemTy>` + Type elemTy = val.getType(); + int64_t vecSize = getNumElements(elemTy); + Type vecType = castToVectorType(elemTy); + val = bitcast(val, vecType); + Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize); + auto op = rewriter.create(loc, val, ptr, maskVal, + alignmentBytes); + return; + } + auto ctx = ptr.getContext(); Type funcType = getFunctionType(void_ty(ctx), ValueRange({ptr, val, pred})); auto parent = ptr.getParentRegion()->getParentOfType(); - auto funcName = mangleFunc(mlir::LLVM::AMD::Predicated_Store, funcType); + auto getStoreNameRaw = [](triton::CacheModifier cm) { + switch (cm) { + case triton::CacheModifier::WT: + return predicatedStoreWT; + case triton::CacheModifier::CG: + return predicatedStoreCG; + case triton::CacheModifier::CS: + return predicatedStoreCS; + default: + // Do not fail in compile time in the case of unsupported modifier. + // Just apply default config. + return predicatedStore; + } + }; + auto funcName = mangleFunc(getStoreNameRaw(cm), funcType); LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); - rewriter.create(loc, funcOp, ValueRange({ptr, val, pred})); + LLVM::createLLVMCallOp(rewriter, loc, funcOp, ValueRange({ptr, val, pred})); } } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index c60d53f4b..123234fd4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -4,35 +4,39 @@ #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" namespace mlir::LLVM::AMD { -const char Predicated_Load[] = "__predicated_load"; -const char Predicated_Store[] = "__predicated_store"; +const char predicatedLoad[] = "__predicated_load"; +const char predicatedLoadCA[] = "__predicated_load_CA"; +const char predicatedLoadCG[] = "__predicated_load_CG"; +const char predicatedLoadCV[] = "__predicated_load_CV"; +const char predicatedStore[] = "__predicated_store"; +const char predicatedStoreCG[] = "__predicated_store_CG"; +const char predicatedStoreCS[] = "__predicated_store_CS"; +const char predicatedStoreWT[] = "__predicated_store_WT"; -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis); +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis); // Loads from shared or global memory with predication. // `otherElems` is used to mask out the elements that are not loaded -Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Type elemTy, Value pred, Value falseVal); +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal, int64_t alignmentBytes = 0, + triton::CacheModifier cm = triton::CacheModifier::NONE); // Stores to shared or global memory with predication. -void llStore(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred); +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred, int64_t alignmentBytes = 0, + triton::CacheModifier cm = triton::CacheModifier::NONE); } // namespace mlir::LLVM::AMD #endif diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 86505386c..6f93bfee9 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -1,89 +1,72 @@ +#include "TritonAMDGPUToLLVM/TargetUtils.h" #include "TritonAMDGPUTransforms/MfmaGroup.h" #include "TritonAMDGPUTransforms/Passes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "triton/Tools/Sys/GetEnv.hpp" -#include "llvm/Support/Debug.h" #include using namespace mlir; namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; namespace { -using tt::DotOp; -using ttg::AMDMfmaEncodingAttr; -using ttg::AMDWmmaEncodingAttr; -using ttg::BlockedEncodingAttr; -using ttg::ConvertLayoutOp; -using ttg::DotOperandEncodingAttr; -using ttg::SliceEncodingAttr; - -enum class MatrixCoreVersion { - CDNA_MFMA1, - CDNA_MFMA2, - CDNA_MFMA3, - RDNA_WMMA, - UNKNOWN -}; +using triton::AMD::ISAFamily; -MatrixCoreVersion getMatrixCoreVersion(StringRef archGen) { - if (archGen.contains("gfx11")) - return MatrixCoreVersion::RDNA_WMMA; - if (archGen.contains("gfx908")) - return MatrixCoreVersion::CDNA_MFMA1; - if (archGen.contains("gfx90a")) - return MatrixCoreVersion::CDNA_MFMA2; - if (archGen.contains("gfx940") || archGen.contains("gfx941") || - archGen.contains("gfx942")) - return MatrixCoreVersion::CDNA_MFMA3; - return MatrixCoreVersion::UNKNOWN; +int getMfmaVersion(ISAFamily isaFamily) { + switch (isaFamily) { + case ISAFamily::CDNA1: + return 1; + case ISAFamily::CDNA2: + return 2; + case ISAFamily::CDNA3: + return 3; + default: + break; + } + return 0; } -int getMfmaVersion(MatrixCoreVersion matrixCoreVer) { - if (MatrixCoreVersion::CDNA_MFMA1 == matrixCoreVer) +int getWmmaVersion(StringRef archGen) { + if (archGen.contains("gfx11")) return 1; - if (MatrixCoreVersion::CDNA_MFMA2 == matrixCoreVer) + if (archGen.contains("gfx12")) return 2; - if (MatrixCoreVersion::CDNA_MFMA3 == matrixCoreVer) - return 3; return 0; } -SmallVector warpsPerTile(tt::DotOp dotOp, - const ArrayRef shape, - int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTile(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) return {(unsigned)numWarps, 1, 1}; - auto filter = [&dotOp](Operation *op) { + auto filter = [dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; - mlir::ForwardSliceOptions fwdOpt; + ForwardSliceOptions fwdOpt; fwdOpt.filter = filter; - mlir::BackwardSliceOptions bwdOpt; + BackwardSliceOptions bwdOpt; bwdOpt.omitBlockArguments = true; bwdOpt.filter = filter; - auto slices = mlir::getSlice(dotOp, bwdOpt, fwdOpt); + auto slices = getSlice(dotOp, bwdOpt, fwdOpt); for (Operation *op : slices) - if (isa(op) && (op != dotOp)) + if (op->hasTrait() && (op != dotOp)) return {(unsigned)numWarps, 1}; SmallVector tensorShape = {shape[0], shape[1]}; - SmallVector ret = {1, 1}; + SmallVector ret = {1, 1}; do { if (ret[0] * ret[1] >= numWarps) break; - if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >= - tensorShape[1] / shapePerWarp[1] / ret[1]) { - if (ret[0] < tensorShape[0] / shapePerWarp[0]) { + if (tensorShape[0] / (shapePerWarp.first * 2) / ret[0] >= + tensorShape[1] / shapePerWarp.second / ret[1]) { + if (ret[0] < tensorShape[0] / shapePerWarp.first) { ret[0] *= 2; } else ret[1] *= 2; @@ -92,24 +75,184 @@ SmallVector warpsPerTile(tt::DotOp dotOp, } } while (true); - if (ret[1] * shapePerWarp[1] > tensorShape[1]) { + if (ret[1] * shapePerWarp.second > tensorShape[1]) { return {ret[1], ret[0]}; } return ret; } -SmallVector -warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps, - SmallVector shapePerWarp) { +SmallVector +warpsPerTileMFMA(Operation *dotOp, ArrayRef shape, int numWarps, + std::pair shapePerWarp) { return warpsPerTile(dotOp, shape, numWarps, shapePerWarp); } -SmallVector -warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { - return warpsPerTile(dotOp, shape, numWarps, - {AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr()[0], - AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr()[1]}); +SmallVector +warpsPerTileWMMA(Operation *dotOp, ArrayRef shape, int numWarps) { + auto mnk = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr(); + return warpsPerTile(dotOp, shape, numWarps, {mnk[0], mnk[1]}); +} + +// Chooses a proper MFMA instruction that can used to compute the given dot op. +// If enforcedNonKDim is not zero, it will be used to overwrite the default +// logic to chose a MFMA with matching M/N dim. +FailureOr chooseMfmaInstruction(RankedTensorType cType, + Type aElemType, Type bElemType, + int inputKSize, int mfmaVersion, + int enforcedNonKDim) { + // number of matrix elements along k dim per one MFMA intruction + unsigned kDim = 0; + + auto resShape = cType.getShape(); + auto rank = resShape.size(); + auto M = resShape[rank - 2]; + auto N = resShape[rank - 1]; + + unsigned mDim = 0; + unsigned nDim = 0; + if (enforcedNonKDim != 0) { + mDim = nDim = enforcedNonKDim; + } else { + int minSize = std::min(M, N); + if (minSize >= 32) { + mDim = 32; + nDim = 32; + } + if (minSize >= 16 && minSize < 32) { + mDim = 16; + nDim = 16; + } + if (minSize < 16) { + if (M < 16 && N >= 64) { + mDim = 4; + nDim = 64; + } else if (M >= 64 && N < 16) { + mDim = 64; + nDim = 4; + } else { + assert(inputKSize >= 64 && + "k should be at least 64 to use this layout"); + mDim = 4; + nDim = 4; + } + } + } + assert(mDim != 0 && nDim != 0); + + auto maybeMfmaInsn = + MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType, mfmaVersion); + if (failed(maybeMfmaInsn)) + llvm::report_fatal_error("No match found in MFMA database\n"); + + kDim = maybeMfmaInsn->getKDim(); + assert(kDim != 0); + assert(M % mDim == 0 && N % nDim == 0); + assert(inputKSize % kDim == 0); + return maybeMfmaInsn; +} + +FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, + int nonKDim) { + RankedTensorType aType = dot.getA().getType(); + return chooseMfmaInstruction(dot.getC().getType(), aType.getElementType(), + dot.getB().getType().getElementType(), + aType.getShape().back(), mfmaVersion, nonKDim); +} + +using OperandTypesVector = SmallVector; +OperandTypesVector +selectMatrixCoreOperandTypes(tt::DotOp dot, + ArrayRef applicableTypes) { + SmallVector dotOperands = {dot.getA(), dot.getB(), dot.getC(), + dot.getD()}; + OperandTypesVector initElemTypes; + llvm::transform(dotOperands, std::back_inserter(initElemTypes), [](Value v) { + return cast(v.getType()).getElementType(); + }); + + // Use simple costmodel to define optimal set of the dot operands. + // Most expensive - accuracy loss conversions: + // - any larger type -> any smaller type; + // - float -> int; + // - int -> float (not supported for now); + // - signed int -> unsigned int; + // - unsigned int -> signed int with same or less size. + // They are never performed, better to use FMA. + // Supported conversion for now costs `1`, no conversion costs `0`. + // The model could be improved in the future. For example taken into account + // chain dot could be detected and result conversion score is decreased. + int maxConvertCost = + std::numeric_limits::max() / applicableTypes.front().size(); + auto calcConvertCost = [&](Type fromTy, Type toTy) -> int32_t { + if (fromTy == toTy) + return 0; + + // Skip conversion between int and float. Int16/int32 cases are lowered to + // FMA. + if (fromTy.isIntOrIndex() != toTy.isIntOrIndex()) + return maxConvertCost; + + if (fromTy.isIntOrIndex() && toTy.isIntOrIndex() && + fromTy.isUnsignedInteger() != toTy.isUnsignedInteger()) + return fromTy.isUnsignedInteger() && fromTy.getIntOrFloatBitWidth() < + toTy.getIntOrFloatBitWidth() + ? 1 + : maxConvertCost; + + return fromTy.getIntOrFloatBitWidth() <= toTy.getIntOrFloatBitWidth() + ? 1 + : maxConvertCost; + }; + auto minCost = maxConvertCost; + auto optTypes = OperandTypesVector(); + for (auto types : applicableTypes) { + assert(types.size() == initElemTypes.size()); + int accumulatedConvertCost = 0; + for (int i = 0; i < initElemTypes.size(); ++i) { + accumulatedConvertCost += calcConvertCost(initElemTypes[i], types[i]); + } + if (accumulatedConvertCost < minCost) { + minCost = accumulatedConvertCost; + optTypes = types; + } + } + return optTypes; +} + +OperandTypesVector getOperandTypesForWmmaOp(PatternRewriter &rewriter, + tt::DotOp dot, int version) { + Type f16 = rewriter.getF16Type(); + Type f32 = rewriter.getF32Type(); + Type bf16 = rewriter.getBF16Type(); + Type i8 = rewriter.getIntegerType(8); + Type i32 = rewriter.getIntegerType(32); + SmallVector applicableTypes = { + // clang-format off + {f16, f16, f32, f32}, + {f16, f16, f16, f16}, + {bf16, bf16, f32, f32}, + {bf16, bf16, bf16, bf16}, + {i8, i8, i32, i32}, + // i4, i4, i32, i32 - is supported configuration + // by WMMA instruction, but not supported by triton + // clang-format on + }; + // TODO: support fp8 configurations for WMMAv2. The code should be as + // following: + // if (version == 2) { + // Type fp8 = rewriter.getFp8Type(); + // Type bf8 = rewriter.getBF8Type(); + // applicableTypes.append({ + // // clang-format off + // {fp8, fp8, f32, f32}, + // {fp8, bf8, f32, f32}, + // {bf8, fp8, f32, f32}, + // {bf8, bf8, f32, f32}, + // // clang-format on + // }); + // } + return selectMatrixCoreOperandTypes(dot, applicableTypes); } /** @@ -129,8 +272,8 @@ warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { * @param newElemType new element type for the tensor * @return converted and optionaly casted tensor value */ -Value convertAndCastTensor(mlir::PatternRewriter &rewriter, Value value, - ::mlir::Attribute newEncoding, Type newElemType) { +Value convertAndCastTensor(PatternRewriter &rewriter, Value value, + Attribute newEncoding, Type newElemType) { assert(newElemType.isIntOrFloat()); auto loc = value.getLoc(); @@ -157,24 +300,24 @@ Value convertAndCastTensor(mlir::PatternRewriter &rewriter, Value value, unsigned oldWidth = oldElemType.getIntOrFloatBitWidth(); unsigned newWidth = newElemType.getIntOrFloatBitWidth(); if (oldWidth == newWidth) - castedTensor = rewriter.create(loc, convertedType, - convertedTensor); + castedTensor = rewriter.create(loc, convertedType, + convertedTensor); else if (oldWidth > newWidth) - castedTensor = rewriter.create(loc, castedType, - convertedTensor); + castedTensor = + rewriter.create(loc, castedType, convertedTensor); else if (oldElemType.isSignedInteger()) - castedTensor = rewriter.create(loc, castedType, - convertedTensor); + castedTensor = + rewriter.create(loc, castedType, convertedTensor); else - castedTensor = rewriter.create(loc, castedType, - convertedTensor); + castedTensor = + rewriter.create(loc, castedType, convertedTensor); } else { if (oldElemType.isF16() && newElemType.isF32()) - castedTensor = rewriter.create(loc, castedType, - convertedTensor); + castedTensor = + rewriter.create(loc, castedType, convertedTensor); else if (oldElemType.isF32() && newElemType.isF16()) - castedTensor = rewriter.create(loc, castedType, - convertedTensor); + castedTensor = + rewriter.create(loc, castedType, convertedTensor); else castedTensor = rewriter.create(loc, castedType, convertedTensor); @@ -182,43 +325,26 @@ Value convertAndCastTensor(mlir::PatternRewriter &rewriter, Value value, return castedTensor; } -class BlockedToMFMA : public mlir::RewritePattern { +class BlockedToMFMA : public OpRewritePattern { int mfmaVersion; - int enforcedNonKDim; + int nonKDim; int kPack; public: - BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion, int nonKDim, - int kPack) - : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), - mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {} - - bool isChainDot(tt::DotOp &dotOp) const { - auto filter = [&dotOp](Operation *op) { - return op->getParentRegion() == dotOp->getParentRegion(); - }; - mlir::ForwardSliceOptions fwdOpt; - fwdOpt.filter = filter; - mlir::BackwardSliceOptions bwdOpt; - bwdOpt.omitBlockArguments = true; - bwdOpt.filter = filter; - auto slices = mlir::getSlice(dotOp, bwdOpt, fwdOpt); - for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) - return true; - } - return false; - } + BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; - mlir::BackwardSliceOptions bwdOpt; + BackwardSliceOptions bwdOpt; bwdOpt.omitBlockArguments = true; bwdOpt.filter = filter; SetVector slices; - mlir::getBackwardSlice(dotOp.getResult(), &slices, bwdOpt); + getBackwardSlice(dotOp.getResult(), &slices, bwdOpt); if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != slices.end()) @@ -226,81 +352,15 @@ class BlockedToMFMA : public mlir::RewritePattern { return false; } - /// @brief Choose MFMA instruction parameters - /// @param dot target dot operation - /// @return pair {mDim, nDim, kDim, kBase} sizes of one MFMA instruction - /// arguments - std::tuple - chooseMfmaDimensions(tt::DotOp dot) const { - // number of matrix elements along k dim per one MFMA intruction - unsigned kDim = 0; - auto opType = cast(dot.getA().getType()); - auto dataTypeA = opType.getElementType(); - auto dataTypeB = - cast(dot.getB().getType()).getElementType(); - - auto resType = cast(dot.getD().getType()); - auto resShape = resType.getShape(); - auto rank = resShape.size(); - auto M = resShape[rank - 2]; - auto N = resShape[rank - 1]; - - unsigned mDim = 0; - unsigned nDim = 0; - if (enforcedNonKDim != 0) { - mDim = enforcedNonKDim; - nDim = enforcedNonKDim; - } else { - int minSize = std::min(M, N); - if (minSize >= 32) { - mDim = 32; - nDim = 32; - } - if (minSize >= 16 && minSize < 32) { - mDim = 16; - nDim = 16; - } - if (minSize < 16) { - if (M < 16 && N >= 64) { - mDim = 4; - nDim = 64; - } else if (M >= 64 && N < 16) { - mDim = 64; - nDim = 4; - } else { - assert(opType.getShape()[rank - 1] >= 64 && - "k should be at least 64 to use this layout"); - mDim = 4; - nDim = 4; - } - } - } - assert(mDim != 0 && nDim != 0); - - auto maybeMfmaInsn = - MfmaInsn::selectMfma(mDim, nDim, dataTypeA, dataTypeB, mfmaVersion); - if (failed(maybeMfmaInsn)) - llvm::report_fatal_error("No match found in MFMA database\n"); - - kDim = maybeMfmaInsn->getKDim(); - unsigned kBase = maybeMfmaInsn->getKBase(); - - assert(kDim != 0); - - assert(M % mDim == 0 && N % nDim == 0); - assert(opType.getShape()[rank - 1] % kDim == 0); - return {mDim, nDim, kDim, kBase}; - } - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto dotOp = cast(op); - + LogicalResult matchAndRewrite(tt::DotOp dotOp, + PatternRewriter &rewriter) const override { RankedTensorType oldRetType = dotOp.getType(); if (!oldRetType.getEncoding() || !isa(oldRetType.getEncoding())) return failure(); + if (!isa_and_nonnull(dotOp.getType().getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); if (!supportMFMA(dotOp)) return failure(); @@ -309,7 +369,7 @@ class BlockedToMFMA : public mlir::RewritePattern { // get MFMA encoding for the given number of warps auto retShape = oldRetType.getShape(); - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); // operands @@ -321,16 +381,21 @@ class BlockedToMFMA : public mlir::RewritePattern { ttg::AMDMfmaEncodingAttr mfmaEnc; - auto [mDim, nDim, kDim, kBase] = chooseMfmaDimensions(dotOp); + auto mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); + auto mDim = mfmaInstr.value().getMDim(); + auto nDim = mfmaInstr.value().getNDim(); + auto kDim = mfmaInstr.value().getKDim(); + auto kBase = mfmaInstr.value().getKBase(); auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); - bool isTransposed = isChainDot(dotOp); + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), /*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile, - /*instrShape*/ mDim, nDim, isTransposed, CTALayout); + /*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout); Type mfmaAccType; if (oldRetType.getElementType().isIntOrIndex()) @@ -339,7 +404,7 @@ class BlockedToMFMA : public mlir::RewritePattern { mfmaAccType = rewriter.getF32Type(); // convert accumulator - auto oldAcc = dotOp.getOperand(2); + auto oldAcc = dotOp.getC(); auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType); // Here is a brief explanation of kWidth, kBase, and kDim @@ -372,7 +437,7 @@ class BlockedToMFMA : public mlir::RewritePattern { // in mfma 4x4 case argument matrix groups in 16 groups if (mDim == 4 && nDim == 4) kWidth = kDim / 16; - if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) + if ((mDim == 4 && nDim == 64) || (mDim == 64 && nDim == 4)) kWidth = kDim; // We want to extend kWidth by kPack (kPack=1 means no extension) @@ -382,14 +447,14 @@ class BlockedToMFMA : public mlir::RewritePattern { if (!isSecondDot(dotOp)) kWidth *= kPack; - auto newAType = RankedTensorType::get( - oldAType.getShape(), oldAType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); - auto newBType = RankedTensorType::get( - oldBType.getShape(), oldBType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); - a = rewriter.create(a.getLoc(), newAType, a); - b = rewriter.create(b.getLoc(), newBType, b); + auto newAEncoding = + ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth); + auto newBEncoding = + ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth); + a = convertAndCastTensor(rewriter, a, newAEncoding, + mfmaInstr.value().getElementTypeA()); + b = convertAndCastTensor(rewriter, b, newBEncoding, + mfmaInstr.value().getElementTypeB()); auto newDot = rewriter.create( dotOp.getLoc(), newAcc.getType(), a, b, newAcc, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); @@ -398,11 +463,12 @@ class BlockedToMFMA : public mlir::RewritePattern { convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(), oldRetType.getElementType()); - rewriter.replaceOp(op, dotOutput); + rewriter.replaceOp(dotOp, dotOutput); return success(); } }; + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, Type promotedType) { Type tensorPromotedType = cast(operand.getType()) @@ -418,7 +484,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { OpBuilder builder(dotOp); Type AElType = dotOp.getA().getType().getElementType(); Type promoteType; - if (isa(D.getType().getEncoding())) { + if (isa(D.getType().getEncoding())) { Type BElType = dotOp.getB().getType().getElementType(); auto maxBitWidth = std::max(AElType.getIntOrFloatBitWidth(), @@ -435,7 +501,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { promoteType = builder.getF16Type(); else if (maxBitWidth <= 32) promoteType = builder.getF32Type(); - } else if (isa(D.getType().getEncoding())) { + } else if (isa(D.getType().getEncoding())) { Type BElType = dotOp.getB().getType().getElementType(); if (AElType == BElType) @@ -465,12 +531,10 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { Type srcElType = vTy.getElementType(); return !srcElType.isUnsignedInteger() ? builder - .create(dotOp.getLoc(), - dstType, v) + .create(dotOp.getLoc(), dstType, v) .getResult() : builder - .create(dotOp.getLoc(), - dstType, v) + .create(dotOp.getLoc(), dstType, v) .getResult(); }; auto convertTensorFPToI = [&](Type dstElType, Value v) -> Value { @@ -478,12 +542,10 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { Type dstType = vTy.cloneWith(std::nullopt, dstElType); return !dstElType.isUnsignedInteger() ? builder - .create(dotOp.getLoc(), - dstType, v) + .create(dotOp.getLoc(), dstType, v) .getResult() : builder - .create(dotOp.getLoc(), - dstType, v) + .create(dotOp.getLoc(), dstType, v) .getResult(); }; @@ -512,90 +574,93 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { }); } -class BlockedToWMMA : public mlir::RewritePattern { +class BlockedToWMMA : public OpRewritePattern { + int wmmaVersion; + public: - BlockedToWMMA(mlir::MLIRContext *context) - : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context) {} + BlockedToWMMA(MLIRContext *context, int wmmaVersion, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion) {} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto dotOp = cast(op); + LogicalResult matchAndRewrite(tt::DotOp dotOp, + PatternRewriter &rewriter) const override { + auto ctx = dotOp->getContext(); + + Value a = dotOp.getA(); + Value b = dotOp.getB(); auto oldRetType = cast(dotOp.getResult().getType()); - if (!oldRetType.getEncoding() || - !isa(oldRetType.getEncoding())) + auto oldRetEncoding = oldRetType.getEncoding(); + if (!oldRetEncoding || !isa(oldRetEncoding)) return failure(); - if (!supportWMMA(dotOp)) + auto oldAType = cast(a.getType()); + auto oldBType = cast(b.getType()); + auto retShape = oldRetType.getShape(); + auto aShape = oldAType.getShape(); + auto bShape = oldBType.getShape(); + + // check shape + auto mnkDim = ttg::AMDWmmaEncodingAttr::getMNKDimPerInstr(); + auto rank = aShape.size(); + if (aShape[rank - 2] % mnkDim[0] != 0 || // m + bShape[rank - 1] % mnkDim[1] != 0 || // n + aShape[rank - 1] % mnkDim[2] != 0) // k + return failure(); + + if (wmmaVersion == 2 && llvm::isa(oldAType) && + oldAType.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure(dotOp, "not supported yet"); + } + + // get operand types + auto operandTypes = getOperandTypesForWmmaOp(rewriter, dotOp, wmmaVersion); + if (operandTypes.empty()) return failure(); // get WMMA encoding for the given number of warps - auto retShape = oldRetType.getShape(); - auto mod = op->getParentOfType(); + auto mod = dotOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); - // operands - Value a = dotOp.getA(); - Value b = dotOp.getB(); - auto oldAType = cast(a.getType()); - auto oldBType = cast(b.getType()); - auto ctx = oldAType.getContext(); - - AMDWmmaEncodingAttr wmmaEnc; + ttg::AMDWmmaEncodingAttr wmmaEnc; - auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); auto warpsPerTile = warpsPerTileWMMA(dotOp, retShape, numWarps); - // Not supported yet - // if (retShape[0] < warpsPerTile[0] * mnkDim[0] || retShape[1] < - // warpsPerTile[1] * mnkDim[1]) - // return failure(); - auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); - wmmaEnc = AMDWmmaEncodingAttr::get(oldRetType.getContext(), warpsPerTile, - CTALayout); - - Type wmmaAccType; - auto oldRetElemType = oldRetType.getElementType(); - auto aElemType = oldAType.getElementType(); - auto bElemType = oldBType.getElementType(); - if (oldRetElemType.isIntOrIndex()) { - wmmaAccType = rewriter.getIntegerType(32); - } else if (isa(oldRetElemType) && - aElemType == oldRetElemType) { - wmmaAccType = oldRetElemType; - } else if (isa(oldRetElemType) && - aElemType.getIntOrFloatBitWidth() < 16) { - aElemType = rewriter.getF16Type(); - bElemType = rewriter.getF16Type(); - wmmaAccType = rewriter.getF16Type(); - } else { - wmmaAccType = rewriter.getF32Type(); - } - auto newRetType = RankedTensorType::get(retShape, wmmaAccType, wmmaEnc); + auto CTALayout = ttg::getCTALayout(oldRetEncoding); + wmmaEnc = ttg::AMDWmmaEncodingAttr::get(ctx, wmmaVersion, warpsPerTile, + CTALayout); + + auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc); // convert accumulator - auto oldAcc = dotOp.getOperand(2); - auto newAcc = convertAndCastTensor(rewriter, oldAcc, wmmaEnc, wmmaAccType); - - auto newAType = RankedTensorType::get( - oldAType.getShape(), aElemType, - ttg::DotOperandEncodingAttr::get(ctx, 0, wmmaEnc, mnkDim[2])); - auto newBType = RankedTensorType::get( - oldBType.getShape(), bElemType, - ttg::DotOperandEncodingAttr::get(ctx, 1, wmmaEnc, mnkDim[2])); - - Value castedA = - convertAndCastTensor(rewriter, a, newAType.getEncoding(), aElemType); - Value castedB = - convertAndCastTensor(rewriter, b, newBType.getEncoding(), bElemType); + auto oldAcc = dotOp.getC(); + auto newAcc = + convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]); + + auto newAType = + RankedTensorType::get(aShape, operandTypes[0], + ttg::DotOperandEncodingAttr::get( + ctx, 0, wmmaEnc, + wmmaEnc.getSizePerThreadForOperand( + /*kWidth=*/0, /*opIdx=*/0)[rank - 1])); + auto newBType = + RankedTensorType::get(bShape, operandTypes[1], + ttg::DotOperandEncodingAttr::get( + ctx, 1, wmmaEnc, + wmmaEnc.getSizePerThreadForOperand( + /*kWidth=*/0, /*opIdx=*/1)[rank - 2])); + + Value castedA = convertAndCastTensor(rewriter, a, newAType.getEncoding(), + operandTypes[0]); + Value castedB = convertAndCastTensor(rewriter, b, newBType.getEncoding(), + operandTypes[1]); auto newDot = rewriter.create( dotOp.getLoc(), newRetType, castedA, castedB, newAcc, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); - Value dotOutput = convertAndCastTensor( - rewriter, newDot, oldRetType.getEncoding(), oldRetElemType); - rewriter.replaceOp(op, dotOutput); + Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding, + oldRetType.getElementType()); + rewriter.replaceOp(dotOp, dotOutput); return success(); } }; @@ -616,18 +681,24 @@ class TritonAMDGPUAccelerateMatmulPass this->kPack = kPack; } void runOnOperation() override { + MLIRContext *context = &getContext(); ModuleOp m = getOperation(); - mlir::RewritePatternSet patterns(context); - auto matrixCoreVer = getMatrixCoreVersion(archGenerationName); - if (MatrixCoreVersion::CDNA_MFMA1 == matrixCoreVer || - MatrixCoreVersion::CDNA_MFMA2 == matrixCoreVer || - MatrixCoreVersion::CDNA_MFMA3 == matrixCoreVer) { - patterns.add<::BlockedToMFMA>(context, getMfmaVersion(matrixCoreVer), + RewritePatternSet patterns(context); + switch (auto isaFamily = triton::AMD::deduceISAFamily(archGenerationName)) { + case ISAFamily::CDNA1: + case ISAFamily::CDNA2: + case ISAFamily::CDNA3: + patterns.add<::BlockedToMFMA>(context, getMfmaVersion(isaFamily), matrixInstructionSize, kPack); - } else if (matrixCoreVer == MatrixCoreVersion::RDNA_WMMA) { - patterns.add<::BlockedToWMMA>(context); + break; + case ISAFamily::RDNA3: + patterns.add<::BlockedToWMMA>(context, + getWmmaVersion(archGenerationName)); + break; + default: + break; } if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index d96860c3e..7da8083cf 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -1,8 +1,10 @@ add_triton_library(TritonAMDGPUTransforms AccelerateAMDMatmul.cpp + CanonicalizePointers.cpp + ConvertToBufferOps.cpp OptimizeEpilogue.cpp ReorderInstructions.cpp - StreamPipeline.cpp + StreamPipelineV2.cpp MfmaGroup.cpp DEPENDS diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp new file mode 100644 index 000000000..a5b32abfe --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -0,0 +1,1036 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" +#include + +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/Pass/Pass.h" + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +#define DEBUG_TYPE "tritonamdgpu-canonicalize-pointers" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; + +// ----------------------------------------------------------------------------- +// Pointer canonicalizer utility class +// ----------------------------------------------------------------------------- +// This class iterates through the argument of the `funcOp`, if the argument is +// a pointer, starts a walk through its transitive uses to build a in-memory +// data structure to record the current offset to that pointer. Only when the +// pointer is really loaded/stored we materialize the base pointer with the +// offset. +// +// Let's suppose that `arg0` is a pointer. The algorithm works like that: +// +// a) At the beginning the offset is a tensor initialized to zero, and we +// associate with `%arg0` a `FatPtr{basePtr=%arg0, offset=0}`. Through the +// algorithm `FatPtr.basePtr` represents the scalar base pointer (all the +// uniform updates will go into that) and `FatPtr.offset` represents the +// tensor offset (all the non-uniform updates will go into that) +// +// +// b) Follow the pointer through the IR. When we meet: +// `%ptr = tt.addptr(%arg0, %offset)` +// +// Isolate the uniform and the non-uniform contributions of %offset = +// (%u_offset, %nu_offset) and update the scalar pointer and the tensor +// offset +// ``` +// %s_ptr = addi(%fatPoniters[ptr].basePtr, %u_offset) +// %t_offset = addi(%fatPoniters[ptr].offset, %nu_offset) +// %fatPointers[%ptr0] = FatPtr{base=%s_ptr, offset=%t_offset} +// ``` +// c) When we meet the `tt.load(%ptr)` or `tt.store(%ptr)` instructions, +// replace that instruction with: +// `%t_ptr = tt.splat(%fatPointers[%ptr].basePtr) +// `%fat_ptr = tt.addptr(%t_ptr, %fatPointers[ptr].offset)` +// `%data = tt.load(%fat_ptr)` +// +// Please note that `%offset` might be a 32bit or 64bit integer. If +// we can, we would like to use 32 bit integers. This can happen under +// certain conditions: +// +// a) We can determine that the offset cannot overflow. In this case, we can +// downcast the pointer just before emitting the load +// b) We know that the underlying memory size can be expressed as a 32 bit +// value. In this case we can simply start with a 32bit offset and downcast +// if we ever meet 64 bit operations (because we know that the offset can be +// contained in 32 bits) +// +class PointerCanonicalizer { +public: + explicit PointerCanonicalizer(ModuleOp moduleOp) + : rewriter(moduleOp.getContext()), mod(moduleOp) {} + + // Propagate fat pointers in all the functions of the module + LogicalResult run(); + +private: + // A fat pointer is represented as `basePtr + offset` internally. + struct FatPtr { + // Scalar base pointer. Needs to be `tt.splat`ed before used + Value basePtr; + // Tensor offset + Value offset; + // Flag to express if we can narrow the uses of the offset down to 32 bits + bool canNarrow = false; + // Collection of attributes that need to be applied to the pointer + SmallVector attributes; + + // Utility copy functions + FatPtr copy(Value newBasePtr, Value newOffset) { + return FatPtr{newBasePtr, newOffset, canNarrow}; + }; + FatPtr copyWithBase(Value newOffset) { + return FatPtr{basePtr, newOffset, canNarrow}; + } + FatPtr copyWithOffset(Value newBase) { + return FatPtr{newBase, offset, canNarrow}; + } + // Attribute functions + void setAttr(NamedAttribute attr) { attributes.push_back(attr); } + void setAttrs(ArrayRef attrs) { + llvm::append_range(attributes, attrs); + } + }; + + // Rewrite any operation that needs a pointer + LogicalResult materializeFatPointer(Operation *op, Location loc, Value ptr); + + // Start from an argument of a function and propagate its fat pointers + LogicalResult rewritePointer(Value argPtr); + + // Create a tensor pointer from a fat pointer `fatPtr`. The tensor pointer is + // obtained by splatting the `fatPtr.basePtr` using the `fatPtr.offset` shape + // and adding the offset to it. + Value createTensorPointer(FatPtr fatPtr, Location loc); + + // Push the attributes of the given operation `op` to the fat pointer + // corresponding to `val` + void collectFatPointerAttributes(Operation *op, Value val); + + // Rewrite a given function, canonicalizing the different pointer arguments of + // the region + LogicalResult rewriteFunction(triton::FuncOp funcOp); + + // Rewriters for different operation a pointer can walk into + LogicalResult rewriteSplatOp(triton::SplatOp splatOp, Location curLoc, + Value &nextPtr); + LogicalResult rewriteBroadcastOp(triton::BroadcastOp broadcastOp, + Location curLoc, Value &nextPtr); + LogicalResult rewriteAddPtrOp(triton::AddPtrOp addPtrOp, Location curLoc, + Value &nextPtr); + LogicalResult rewriteForOp(scf::ForOp forOp, Location curLoc, + OpOperand *operand, Value &nextPtr); + LogicalResult rewriteYieldOp(scf::YieldOp yieldOp, Location curLoc, + OpOperand *operand, Value &nextPtr); + LogicalResult rewriteWhileOp(scf::WhileOp whileOp, Location curLoc, + OpOperand *operand, Value &nextPtr); + LogicalResult rewriteConditionOp(scf::ConditionOp conditionOp, + Location curLoc, OpOperand *operand, + Value &nextPtr); + LogicalResult rewriteCondBranchOp(cf::CondBranchOp condBrOp, Location curLoc, + OpOperand *operand, Value &nextPtr); + LogicalResult rewriteSelectOp(arith::SelectOp selectOp, Location curLoc, + OpOperand *operand, Value &nextPtr); + LogicalResult rewriteBranchOp(cf::BranchOp branchOp, Location curLoc, + OpOperand *operand, Value &nextPtr); + + // Perform simplified scalar extraction. An offset can be composed by Unifrom + // (U) and non-uniform(N) components. A uniform component is basically a + // tensor constant (or a splat). A NonUniform value is a `make_range` or + // whatever we multiply with a `make_range` operation. We consider the generic + // expressions: + // offset = (N+U)*(N+U) + // + // Where the `uniformOffset=U*U` and the `nonUniformOffset=(N*U+U*N+N*N). + // + // We do not consider any expression not involving * and +. + // + // The function accepts the `rewriter`, the `location` and start recursing at + // the given `expr`. + // + // We also pass the bitness of the offset. + // + // The function returns the two components of the given offset as a + // std::pair{U, NU} + std::pair decomposeOffsetFromExpr(Location loc, Value expr, + int64_t bitness); + std::pair decomposeOffsetFromAdd(Location loc, Value expr, + int64_t bitness); + std::pair decomposeOffsetFromMul(Location loc, Value expr, + int64_t bitness); + + // Return either the operation or its rewritten op + template + OpTy resolveOp(Operation *op, + const DenseMap &rewriteOpMap) { + OpTy resolvedOp = dyn_cast(op); + if (rewriteOpMap.contains(op)) + resolvedOp = dyn_cast(rewriteOpMap.at(op)); + return resolvedOp; + } + + mlir::IRRewriter rewriter; + ModuleOp mod; + + // Symbol table: association between pointers and fatPointers + llvm::MapVector pointers; + + void clearFunctionState() { + rewriteOpMap.clear(); + queue.clear(); + opToDelete.clear(); + } + + // This structure is used to point to the right operation during the traversal + // of a function + DenseMap rewriteOpMap; + + // Queue of operations to visit in the current function + SmallVector queue; + + // List of IR to delete in the current function + SetVector opToDelete; +}; + +namespace { + +// Extend a 32bit `offset` into 64bit using a arith.extsi operation +static Value extend32bitOffsetTo64Bits(IRRewriter &rewriter, Location loc, + Value offset) { + Type elementType = getElementTypeOrSelf(offset); + + if (auto tensorType = dyn_cast(offset.getType())) { + auto shape = tensorType.getShape(); + auto newTensorType = RankedTensorType::get(shape, rewriter.getI64Type(), + tensorType.getEncoding()); + return rewriter.create(loc, newTensorType, offset); + } + return rewriter.create(loc, rewriter.getI64Type(), offset); +} + +// Narrow a 64bit `offset` into 32bit using a arith.trunci operation +static Value narrow64bitOffsetTo32bits(IRRewriter &rewriter, Location loc, + Value offset) { + Type elementType = getElementTypeOrSelf(offset); + if (elementType.isInteger(32)) + return offset; + + if (auto tensorType = dyn_cast(offset.getType())) { + auto shape = tensorType.getShape(); + auto newTensorType = RankedTensorType::get(shape, rewriter.getI32Type(), + tensorType.getEncoding()); + return rewriter.create(loc, newTensorType, offset); + } + return rewriter.create(loc, rewriter.getI32Type(), offset); +} + +// Helper function to determine if the given `op` is a constant tensor and in +// that case return the scalar value. +Value getScalarConstant(IRRewriter &rewriter, Location loc, Value expr) { + Operation *op = expr.getDefiningOp(); + + // Check for splatness + if (auto splatOp = dyn_cast_or_null(op)) + return splatOp.getSrc(); + + // Check for constant + DenseIntElementsAttr constVal; + if (auto constOp = dyn_cast_or_null(op)) { + Value val = constOp.getResult(); + if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat()) + return rewriter.create( + loc, constVal.getSplatValue()); + } + + // Check for block arguments + if (auto blockArg = dyn_cast_or_null(expr)) { + Type type = blockArg.getType(); + if (!isa(type)) + return blockArg; + } + + return Value(); +} + +// Narrowing logic +// For now we allow to narrow down to 32 bits only in the following case: +// - `baseOffset` is 32-bits and `addOffset`(64-bits) is zero +bool canNarrowOffset(Value baseOffset, Value addOffset) { + Type addOffsetType = getElementTypeOrSelf(addOffset); + auto baseSplatOp = baseOffset.getDefiningOp(); + return baseSplatOp && addOffsetType.isInteger(32); +} + +// Create a zero tensor with a given `type` +Value createTensorZero(IRRewriter &rw, Location loc, RankedTensorType type) { + mlir::Attribute zeroAttr = rw.getZeroAttr(type.getElementType()); + auto zeroDenseAttr = DenseElementsAttr::get(type, zeroAttr); + return rw.create(loc, zeroDenseAttr); +} + +} // namespace + +void PointerCanonicalizer::collectFatPointerAttributes(Operation *op, + Value val) { + auto addBlockArgumentAttr = [&](BlockArgument arg) { + // If the value is a block parameter, the operation can specify + // an attribute for the given parameter by using `tt.property_argi` + // where `argi` refers to the arg number of the given parameter. + // So we need to iterate through the property, find the right one + // and push the property onto the pointers attributes. + llvm::SmallString<8> scratchStr; + for (NamedAttribute namedAttr : op->getAttrs()) { + scratchStr.clear(); + llvm::raw_svector_ostream sstream(scratchStr); + sstream << "_arg" << arg.getArgNumber(); + StringRef attrName = namedAttr.getName().getValue(); + if (attrName.ends_with(scratchStr)) { + StringRef newAttrName = attrName.drop_back(scratchStr.size()); + namedAttr.setName(rewriter.getStringAttr(newAttrName)); + pointers[val].setAttr(namedAttr); + // Propagate the argument to the offset if it is also a block argument + if (auto offsetArg = dyn_cast(pointers[val].offset)) { + scratchStr.clear(); + sstream << newAttrName << "_arg" << offsetArg.getArgNumber(); + op->setAttr(scratchStr, namedAttr.getValue()); + } + } + } + }; + + // If it is the i-th block argument, then look if the operation defined some + // _argi attribute and add it to the fat pointer attributes + if (auto arg = dyn_cast(val)) { + addBlockArgumentAttr(arg); + return; + } + + // Otherwise add the attributes of the operation to the fat pointer + for (NamedAttribute attr : op->getAttrs()) + pointers[val].setAttr(attr); +} + +// Offset extraction logic for an addition op: +// decompose(A+B) = {U(A)+U(B), NU(A)+NU(B)} +std::pair +PointerCanonicalizer::decomposeOffsetFromAdd(Location loc, Value expr, + int64_t bitness) { + auto addOp = expr.getDefiningOp(); + auto [uniformOffsetL, nonUniformOffsetL] = + decomposeOffsetFromExpr(loc, addOp.getLhs(), bitness); + auto [uniformOffsetR, nonUniformOffsetR] = + decomposeOffsetFromExpr(loc, addOp.getRhs(), bitness); + Value uniformAdd = + rewriter.create(loc, uniformOffsetL, uniformOffsetR); + Value nonUniformAdd = + rewriter.create(loc, nonUniformOffsetL, nonUniformOffsetR); + return {uniformAdd, nonUniformAdd}; +} + +// Offset extraction logic for a multiplication op: +// decompose(A*B) = {U(A)*U(B), NU(A)*NU(B)+NU(B)*U(A)+U(A)*NU(B)} +std::pair +PointerCanonicalizer::decomposeOffsetFromMul(Location loc, Value expr, + int64_t bitness) { + auto mulOp = expr.getDefiningOp(); + auto [uniformOffsetL, nonUniformOffsetL] = + decomposeOffsetFromExpr(loc, mulOp.getLhs(), bitness); + auto [uniformOffsetR, nonUniformOffsetR] = + decomposeOffsetFromExpr(loc, mulOp.getRhs(), bitness); + Value uniformMul = + rewriter.create(loc, uniformOffsetL, uniformOffsetR); + + Value uniformOffsetLSplat = rewriter.create( + loc, nonUniformOffsetL.getType(), uniformOffsetL); + Value uniformOffsetRSplat = rewriter.create( + loc, nonUniformOffsetR.getType(), uniformOffsetR); + + Value nonUNonU = + rewriter.create(loc, nonUniformOffsetL, nonUniformOffsetR); + Value nonUU = rewriter.create(loc, uniformOffsetLSplat, + nonUniformOffsetR); + Value uNonU = rewriter.create(loc, nonUniformOffsetL, + uniformOffsetRSplat); + + Value tmp = rewriter.create(loc, nonUNonU, nonUU); + Value nonUniformMul = rewriter.create(loc, tmp, uNonU); + return {uniformMul, nonUniformMul}; +} + +std::pair +PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr, + int64_t bitness) { + + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfterValue(expr); + + // Base case 1: it is a splat. Return the scalar constant as the uniform part + if (Value scalarConst = getScalarConstant(rewriter, loc, expr)) { + auto tensorZero = + createTensorZero(rewriter, loc, cast(expr.getType())); + return {scalarConst, tensorZero}; + } + + // Base case 2: block argument. Since it is not a scalar constant, it must be + // a tensor. Note that this means we won't be able to decompose across loop + // boundaries (TODO: giuseros). + if (auto blockArg = dyn_cast(expr)) { + Value scalarZero = rewriter.create(loc, 0, bitness); + return std::make_pair(scalarZero, expr); + } + + auto offsets = + llvm::TypeSwitch>( + expr.getDefiningOp()) + .Case([&](auto broadcastOp) { + auto [uniform, nonUniform] = + decomposeOffsetFromExpr(loc, broadcastOp.getSrc(), bitness); + auto broadcastNonUniform = rewriter.create( + loc, broadcastOp.getType(), nonUniform); + return std::make_pair(uniform, broadcastNonUniform); + }) + .Case([&](auto expandOp) { + auto [uniform, nonUniform] = + decomposeOffsetFromExpr(loc, expandOp.getSrc(), bitness); + auto expandNonUniform = rewriter.create( + loc, nonUniform, expandOp.getAxis()); + return std::make_pair(uniform, expandNonUniform); + }) + .Case([&](Operation *op) { + return decomposeOffsetFromAdd(loc, expr, bitness); + }) + .Case([&](Operation *op) { + return decomposeOffsetFromMul(loc, expr, bitness); + }) + .Default([&](Operation *op) { + // Base case 3: it is not a supported operation. We assume no + // uniform part + Value scalarZero = + rewriter.create(loc, 0, bitness); + return std::make_pair(scalarZero, expr); + }); + + return offsets; +} + +Value PointerCanonicalizer::createTensorPointer(FatPtr fatPtr, Location loc) { + Value basePtr = fatPtr.basePtr; + Value offset = fatPtr.offset; + auto tensorType = dyn_cast(offset.getType()); + + // Scalar case: we only need to `tt.addptr %basePtr, %offset` + if (!tensorType) { + auto addPtrOp = rewriter.create(loc, basePtr.getType(), + basePtr, offset); + addPtrOp->setAttrs(fatPtr.attributes); + return addPtrOp.getResult(); + } + + // Tensor case: splat the scalar pointer and add the (tensor) offset: + // ``` + // %tensorBasePtr = tt.splat %basePtr + // %tensorPtr = tt.addptr %tensorBasePtr, %offset + // ``` + ArrayRef offsetShape = tensorType.getShape(); + auto tensorPtrType = RankedTensorType::get(offsetShape, basePtr.getType(), + tensorType.getEncoding()); + if (fatPtr.canNarrow) + offset = narrow64bitOffsetTo32bits(rewriter, loc, offset); + + Value tensorPtr = + rewriter.create(loc, tensorPtrType, basePtr); + + auto addPtrOp = + rewriter.create(loc, tensorPtrType, tensorPtr, offset); + + addPtrOp->setAttrs(fatPtr.attributes); + return addPtrOp.getResult(); +} + +// Rewrite a memory operation +LogicalResult PointerCanonicalizer::materializeFatPointer(Operation *op, + Location loc, + Value ptr) { + auto fatPtr = pointers[ptr]; + Value basePtr = fatPtr.basePtr; + Value offset = fatPtr.offset; + + // Create the tensor pointer (i.e., splat the base && add the offset) + Value newPtr = basePtr; + if (isa(ptr.getType())) + newPtr = createTensorPointer(fatPtr, loc); + + // Save the fat pointer in the table + pointers[newPtr] = fatPtr; + + // Map and replace the load + IRMapping mapper; + mapper.map(ptr, newPtr); + Operation *newOp = rewriter.clone(*op, mapper); + rewriter.replaceAllOpUsesWith(op, newOp); + opToDelete.insert(op); + return success(); +} + +LogicalResult PointerCanonicalizer::rewriteSplatOp(triton::SplatOp splatOp, + Location curLoc, + Value &nextPtr) { + nextPtr = splatOp.getResult(); + auto fatPtr = pointers[splatOp.getSrc()]; + auto outType = splatOp.getResult().getType(); + auto ptrShape = outType.getShape(); + auto newOffsetType = RankedTensorType::get(ptrShape, fatPtr.offset.getType(), + outType.getEncoding()); + Value offset = + rewriter.create(curLoc, newOffsetType, fatPtr.offset); + // The shape of the fat pointer is contained within the offset. We don't + // need to keep the `splat` operation here. + opToDelete.insert(splatOp); + pointers[nextPtr] = fatPtr.copy(splatOp.getSrc(), offset); + return success(); +} + +LogicalResult +PointerCanonicalizer::rewriteBroadcastOp(triton::BroadcastOp broadcastOp, + Location curLoc, Value &nextPtr) { + nextPtr = broadcastOp.getResult(); + auto fatPtr = pointers[broadcastOp.getSrc()]; + auto outType = dyn_cast(broadcastOp.getResult().getType()); + auto ptrShape = outType.getShape(); + auto offsetType = dyn_cast(fatPtr.offset.getType()); + if (!offsetType) + return failure(); + + opToDelete.insert(broadcastOp); + + auto newOffsetType = RankedTensorType::get( + ptrShape, offsetType.getElementType(), outType.getEncoding()); + Value offset = rewriter.create(curLoc, newOffsetType, + fatPtr.offset); + pointers[nextPtr] = fatPtr.copyWithBase(offset); + return success(); +} + +LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp, + Location curLoc, + Value &nextPtr) { + nextPtr = addPtrOp.getResult(); + auto fatPtr = pointers[addPtrOp.getPtr()]; + Value newPtr = fatPtr.basePtr; + // If it is a scalar pointer update, simply bump the base pointer + if (!isa(addPtrOp.getPtr().getType())) { + addPtrOp->setOperand(0, newPtr); + pointers[nextPtr] = fatPtr.copyWithOffset(nextPtr); + return success(); + } + Value offset = addPtrOp.getOffset(); + + // Early exit for the case of a constant tensor + if (Value scalarConst = getScalarConstant(rewriter, curLoc, offset)) { + newPtr = rewriter.create(curLoc, newPtr.getType(), newPtr, + scalarConst); + pointers[nextPtr] = fatPtr.copyWithOffset(newPtr); + // If we are updating the tensor pointer with a uniform value, we can + // propagate the attributes of the tensor pointer to the fat pointer. + pointers[nextPtr].setAttrs(fatPtr.attributes); + opToDelete.insert(addPtrOp); + return success(); + } + + int64_t bitness = + cast(offset.getType()).getElementTypeBitWidth(); + auto [uniformOffset, nonUniformOffset] = + decomposeOffsetFromExpr(curLoc, offset, bitness); + + // Scalar pointer update: bump the scalar pointer + newPtr = rewriter.create(curLoc, newPtr.getType(), newPtr, + uniformOffset); + + // Vector offset update (if any): bump the tensor offset + Value fatPtrOffset = fatPtr.offset; + bool canNarrow = fatPtr.canNarrow; + Value newOffset = fatPtrOffset; + bool propagateAtrs = true; + if (!isZeroConst(nonUniformOffset)) { + Type addPtrOffsetType = getElementTypeOrSelf(nonUniformOffset); + Type fatPtrOffsetType = getElementTypeOrSelf(fatPtrOffset); + canNarrow = canNarrow && canNarrowOffset(fatPtrOffset, nonUniformOffset); + + // Upcast or downcast the offset accordingly + if (addPtrOffsetType.isInteger(32) && fatPtrOffsetType.isInteger(64)) + nonUniformOffset = + extend32bitOffsetTo64Bits(rewriter, curLoc, nonUniformOffset); + else if (addPtrOffsetType.isInteger(64) && fatPtrOffsetType.isInteger(32)) + nonUniformOffset = + narrow64bitOffsetTo32bits(rewriter, curLoc, nonUniformOffset); + + newOffset = + rewriter.create(curLoc, nonUniformOffset, fatPtrOffset); + propagateAtrs = false; + } + opToDelete.insert(addPtrOp); + pointers[nextPtr] = FatPtr{newPtr, newOffset, canNarrow}; + + // If we are updating the tensor pointer with a uniform value, we can + // propagate the attributes of the tensor pointer to the fat pointer. + if (propagateAtrs) + pointers[nextPtr].setAttrs(fatPtr.attributes); + return success(); +} + +LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp, + Location curLoc, + OpOperand *curOperand, + Value &nextPtr) { + size_t operandNum = curOperand->getOperandNumber(); + FatPtr fatPtr = pointers[curOperand->get()]; + Value offset = fatPtr.offset; + Value basePtr = fatPtr.basePtr; + + // Replace the forOp with two additional argument (i.e., the curOperand's + // scalar pointer and the offset) + Value tensorPtr = createTensorPointer(fatPtr, curLoc); + auto newForOp = + replaceForOpWithNewSignature(rewriter, forOp, {basePtr, offset}); + rewriteOpMap[forOp] = newForOp; + + newForOp->setOperand(operandNum, tensorPtr); + OpOperand *forOperand = &newForOp->getOpOperand(operandNum); + // This is making sure we propagate the visit from the forOp result + nextPtr = newForOp.getTiedLoopResult(forOperand); + + // This is making sure we visit the uses within the forOp region + Value arg = newForOp.getTiedLoopRegionIterArg(forOperand); + size_t numIterArgs = newForOp.getNumRegionIterArgs(); + pointers[arg] = fatPtr.copy(newForOp.getRegionIterArg(numIterArgs - 2), + newForOp.getRegionIterArg(numIterArgs - 1)); + + // Collect attributes before continuing the visit + collectFatPointerAttributes(newForOp, arg); + + for (OpOperand &use : arg.getUses()) + queue.push_back(&use); + + // This is setting the fat pointer for the users of the loop + // and then propagate the result + size_t numResults = newForOp->getNumResults(); + pointers[nextPtr] = fatPtr.copy(newForOp->getResult(numResults - 2), + newForOp.getResult(numResults - 1)); + opToDelete.insert(forOp); + return success(); +} + +LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp, + Location curLoc, + OpOperand *curOperand, + Value &nextPtr) { + + // Rewriting the yield op is a bit more complicated, because a + // yield op can be inside of a ForOp, WhileOp(in the AfterRegion) or + // IfOp + size_t operandNum = curOperand->getOperandNumber(); + FatPtr fatPtr = pointers[curOperand->get()]; + yieldOp.getResultsMutable().append(fatPtr.basePtr); + yieldOp.getResultsMutable().append(fatPtr.offset); + + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + yieldOp->setOperand(operandNum, forOp.getRegionIterArg(operandNum)); + } else if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + // Case 1: the yieldOp is contained within an IfOp. One of the + // two branches is responsible to rewrite the operation. The other + // branch only update the yieldOp with the right parameters + Value tensorPtr = createTensorPointer(fatPtr, curLoc); + yieldOp->setOperand(operandNum, tensorPtr); + + if (yieldOp->getBlock() == &ifOp.getThenRegion().front()) { + auto newIfOp = replaceIfOpWithNewSignature( + rewriter, ifOp, {fatPtr.basePtr.getType(), fatPtr.offset.getType()}); + nextPtr = newIfOp.getResult(operandNum); + size_t numResults = newIfOp->getNumResults(); + pointers[nextPtr] = fatPtr.copy(newIfOp->getResult(numResults - 2), + newIfOp.getResult(numResults - 1)); + opToDelete.insert(ifOp); + } + + } else if (auto whileOp = resolveOp(yieldOp->getParentOp(), + rewriteOpMap)) { + // Case 2: the yieldOp is contained within the AfterRegion of a + // WhileOp. In this case, we know that the before region should have + // already been replaced (when we met the WhileOp), hence we can + // simply replace the WhileOp with a new AfterRegion (and hance a new + // set of return types) + auto newWhileOp = replaceWhileOpWithNewSignature( + rewriter, whileOp, {}, + {fatPtr.basePtr.getType(), fatPtr.offset.getType()}); + nextPtr = newWhileOp.getResult(operandNum); + size_t numResults = newWhileOp->getNumResults(); + pointers[nextPtr] = fatPtr.copy(newWhileOp->getResult(numResults - 2), + newWhileOp->getResult(numResults - 1)); + rewriteOpMap[whileOp] = newWhileOp; + opToDelete.insert(whileOp.getOperation()); + yieldOp.setOperand(operandNum, newWhileOp.getAfterArguments()[operandNum]); + } + return success(); +} + +LogicalResult PointerCanonicalizer::rewriteWhileOp(scf::WhileOp whileOp, + Location curLoc, + OpOperand *curOperand, + Value &nextPtr) { + // WhileOp rewrite happens in two phases: first rewrite the operand list + // and then rewrite the types when we meet the yieldOp + size_t operandNum = curOperand->getOperandNumber(); + FatPtr fatPtr = pointers[curOperand->get()]; + Value offset = fatPtr.offset; + Value basePtr = fatPtr.basePtr; + // Rewrite the while op with a new set of operands (but with the same + // set of return types) + Value tensorPtr = createTensorPointer(fatPtr, curLoc); + auto newWhileOp = + replaceWhileOpWithNewSignature(rewriter, whileOp, {basePtr, offset}, {}); + newWhileOp->setOperand(operandNum, tensorPtr); + Value arg = newWhileOp.getBeforeBody()->getArgument(operandNum); + // Propagate inside the BeforeRegion + size_t numArguments = newWhileOp.getBeforeBody()->getNumArguments(); + pointers[arg] = + fatPtr.copy(newWhileOp.getBeforeBody()->getArgument(numArguments - 2), + newWhileOp.getBeforeBody()->getArgument(numArguments - 1)); + nextPtr = arg; + rewriteOpMap[whileOp] = newWhileOp; + opToDelete.insert(whileOp); + return success(); +} + +// ConditionOp can only be contained within the BeforeRegion of a +// WhileOp. We already rewrote the WhileOp with the right operands, so +// we need only to add the offset the current operand to be the base +// pointer and continue the walk inside the AfterRegion +LogicalResult +PointerCanonicalizer::rewriteConditionOp(scf::ConditionOp conditionOp, + Location curLoc, OpOperand *curOperand, + Value &nextPtr) { + + size_t operandNum = curOperand->getOperandNumber(); + FatPtr fatPtr = pointers[curOperand->get()]; + Value offset = fatPtr.offset; + Value basePtr = fatPtr.basePtr; + auto whileOp = cast(conditionOp->getParentOp()); + + // Update the condition op + auto afterBlock = whileOp.getAfterBody(); + conditionOp.getArgsMutable().append({basePtr, offset}); + + // Propagate through the after region + afterBlock->addArgument(basePtr.getType(), curLoc); + afterBlock->addArgument(offset.getType(), curLoc); + nextPtr = afterBlock->getArgument(operandNum - 1); + size_t numArguments = afterBlock->getNumArguments(); + conditionOp.setOperand(operandNum, + whileOp.getRegionIterArgs()[operandNum - 1]); + pointers[nextPtr] = fatPtr.copy(afterBlock->getArgument(numArguments - 2), + afterBlock->getArgument(numArguments - 1)); + return success(); +} + +LogicalResult PointerCanonicalizer::rewriteCondBranchOp( + cf::CondBranchOp condBrOp, Location curLoc, OpOperand *curOperand, + Value &nextPtr) { + // CondBranchOp is a bit tricky to handle. Because we might be inserting + // the basePtr+offset as a TrueDestOperand(s), which is not the end of + // `condBrOp.getOperands()` + auto falseOperands = llvm::to_vector(condBrOp.getFalseDestOperands()); + auto trueOperands = llvm::to_vector(condBrOp.getTrueOperands()); + auto it = llvm::find(falseOperands, curOperand->get()); + bool isFalseOperand = (it != falseOperands.end()); + size_t operandNum = curOperand->getOperandNumber(); + + if (rewriteOpMap.contains(condBrOp)) { + // If we need to use a different condBrOp, we might also need to + // update `operandNum` + auto condBranchReplacement = + dyn_cast(rewriteOpMap[condBrOp]); + if (isFalseOperand) { + // basePtr+offset need to be added if we are on the FalseOperands + // side, but the true operands have been rewritten + bool needOffset = (condBranchReplacement.getTrueDestOperands().size() != + condBrOp.getTrueDestOperands().size()); + int maybeOffset = (needOffset ? 2 : 0); + operandNum += maybeOffset; + curOperand = &condBranchReplacement->getOpOperand(operandNum); + } + // Now we need to recompute the currentOperation and its {true,false} + // operands + falseOperands = + llvm::to_vector(condBranchReplacement.getFalseDestOperands()); + trueOperands = llvm::to_vector(condBranchReplacement.getTrueDestOperands()); + condBrOp = condBranchReplacement; + } + + // Now we can proceed almost normally + FatPtr fatPtr = pointers[curOperand->get()]; + Value offset = fatPtr.offset; + Value basePtr = fatPtr.basePtr; + + Block *falseDest = condBrOp.getFalseDest(); + Block *trueDest = condBrOp.getTrueDest(); + // Walk the destination block only if you don't have visited it yet + if (isFalseOperand) { + falseOperands.push_back(basePtr); + falseOperands.push_back(offset); + Value falseDestArg = + falseDest->getArgument(operandNum - condBrOp.getNumTrueOperands() - 1); + if (!pointers.contains(falseDestArg)) { + nextPtr = falseDestArg; + Value basePtrArg = falseDest->addArgument(basePtr.getType(), curLoc); + Value offsetArg = falseDest->addArgument(offset.getType(), curLoc); + pointers[nextPtr] = fatPtr.copy(basePtrArg, offsetArg); + } + } else { + trueOperands.push_back(basePtr); + trueOperands.push_back(offset); + Value trueDestArg = trueDest->getArgument(operandNum - 1); + if (!pointers.contains(trueDestArg)) { + nextPtr = trueDestArg; + Value basePtrArg = trueDest->addArgument(basePtr.getType(), curLoc); + Value offsetArg = trueDest->addArgument(offset.getType(), curLoc); + pointers[nextPtr] = fatPtr.copy(basePtrArg, offsetArg); + } + } + + // Create a new condBranch. We cannot simply extend the operands, + // because this would invalidate other operands pointing at the same + // cond branch + Value tensorPtr = createTensorPointer(fatPtr, curLoc); + auto newCondBranch = rewriter.create( + curLoc, condBrOp.getCondition(), trueDest, trueOperands, falseDest, + falseOperands); + + newCondBranch.setOperand(operandNum, tensorPtr); + rewriteOpMap[condBrOp] = newCondBranch; + opToDelete.insert(condBrOp); + return success(); +} + +LogicalResult PointerCanonicalizer::rewriteSelectOp(arith::SelectOp selectOp, + Location curLoc, + OpOperand *curOperand, + Value &nextPtr) { + Value trueVal = selectOp.getTrueValue(); + Value falseVal = selectOp.getFalseValue(); + Value cond = selectOp.getCondition(); + // If we didn't traverse both operands, simply materialize the pointer + if (!pointers.contains(trueVal) || !pointers.contains(falseVal)) + return materializeFatPointer(selectOp, curLoc, curOperand->get()); + + // If both have been traversed, then we can rewrite select of pointers as a + // select of base and offset + FatPtr fatPtrT = pointers[trueVal]; + FatPtr fatPtrF = pointers[falseVal]; + nextPtr = selectOp.getResult(); + + // Simple case of a scalar select: update the base pointer + if (!isa(selectOp.getType())) { + FatPtr fatPtr = pointers[trueVal]; + pointers[nextPtr] = fatPtr.copyWithOffset(nextPtr); + nextPtr = selectOp.getResult(); + return success(); + } + + // Rewrite `select` for base and offset + Value newBase = rewriter.create( + curLoc, cond, fatPtrT.basePtr, fatPtrF.basePtr); + Value newOffset = rewriter.create( + curLoc, cond, fatPtrT.offset, fatPtrF.offset); + assert(fatPtrT.canNarrow == fatPtrF.canNarrow); + + pointers[nextPtr] = fatPtrT.copy(newBase, newOffset); + opToDelete.insert(selectOp); + return success(); +} + +LogicalResult PointerCanonicalizer::rewriteBranchOp(cf::BranchOp branchOp, + Location curLoc, + OpOperand *curOperand, + Value &nextPtr) { + size_t operandNum = curOperand->getOperandNumber(); + FatPtr fatPtr = pointers[curOperand->get()]; + Value offset = fatPtr.offset; + Value basePtr = fatPtr.basePtr; + branchOp.getDestOperandsMutable().append({basePtr, fatPtr.offset}); + Value tensorPtr = createTensorPointer(fatPtr, curLoc); + branchOp->setOperand(operandNum, tensorPtr); + Block *dest = branchOp.getDest(); + + // Walk the destination block only if you don't have visited it yet + if (!pointers.contains(dest->getArgument(operandNum))) { + Value basePtrArg = dest->addArgument(basePtr.getType(), curLoc); + Value offsetArg = dest->addArgument(offset.getType(), curLoc); + nextPtr = dest->getArgument(operandNum); + pointers[nextPtr] = {basePtrArg, offsetArg, fatPtr.canNarrow}; + } + return success(); +} + +// Start from an argument of a function and propagate its +// fat pointers +LogicalResult PointerCanonicalizer::rewritePointer(Value argPtr) { + // Start the visit + for (OpOperand &use : argPtr.getUses()) + queue.push_back(&use); + + while (!queue.empty()) { + OpOperand *curOperand = queue.pop_back_val(); + Operation *curOp = curOperand->getOwner(); + Location curLoc = curOp->getLoc(); + + rewriter.setInsertionPoint(curOp); + LogicalResult res = success(); + Value nextPtr; + // We need to propagate the fat pointer throughout the IR + llvm::TypeSwitch(curOp) + .Case([&](auto splatOp) { + res = rewriteSplatOp(splatOp, curLoc, nextPtr); + }) + .Case([&](auto broadcastOp) { + res = rewriteBroadcastOp(broadcastOp, curLoc, nextPtr); + }) + .Case([&](auto addPtrOp) { + res = rewriteAddPtrOp(addPtrOp, curLoc, nextPtr); + }) + .Case([&](auto forOp) { + res = rewriteForOp(resolveOp(forOp, rewriteOpMap), curLoc, + curOperand, nextPtr); + }) + .Case([&](auto yieldOp) { + res = rewriteYieldOp(yieldOp, curLoc, curOperand, nextPtr); + }) + .Case([&](auto whileOp) { + res = rewriteWhileOp(resolveOp(whileOp, rewriteOpMap), + curLoc, curOperand, nextPtr); + }) + .Case([&](auto conditionOp) { + res = rewriteConditionOp(conditionOp, curLoc, curOperand, nextPtr); + }) + .Case([&](auto condBrOp) { + res = rewriteCondBranchOp(condBrOp, curLoc, curOperand, nextPtr); + }) + .Case([&](auto selectOp) { + res = rewriteSelectOp(selectOp, curLoc, curOperand, nextPtr); + }) + .Case([&](auto branchOp) { + res = rewriteBranchOp(branchOp, curLoc, curOperand, nextPtr); + }) + .Case([&](Operation *op) { + res = materializeFatPointer(curOp, curLoc, op->getOperand(0)); + }) + .Default([&](Operation *op) { + // If we meet an unsupported operation, materialize the fat pointer + // and continue. + LDBG("Unknown op during pointer canonicalization: " << *curOp); + res = materializeFatPointer(op, curLoc, curOperand->get()); + }); + + // Collect the attributes and Keep propagating the fat pointer down the IR + if (nextPtr) { + collectFatPointerAttributes(curOp, nextPtr); + for (OpOperand &use : nextPtr.getUses()) + if (!opToDelete.contains(use.getOwner())) + queue.push_back(&use); + } + } + return success(); +} + +LogicalResult PointerCanonicalizer::rewriteFunction(triton::FuncOp funcOp) { + Region ®ion = funcOp.getRegion(); + for (auto [idx, arg] : llvm::enumerate(region.getArguments())) { + // The pointer argument needs to be a scalar + if (!isa(arg.getType())) + continue; + int64_t bitness = 64; + if (IntegerAttr pointerRangeAttr = + funcOp.getArgAttrOfType(idx, "tt.pointer_range")) + bitness = pointerRangeAttr.getInt(); + + rewriter.setInsertionPointToStart(®ion.front()); + Value zeroOffset = + rewriter.create(region.getLoc(), 0, bitness); + + // Start the rewrite + clearFunctionState(); + pointers[arg] = FatPtr{arg, zeroOffset, true}; + if (failed(rewritePointer(arg))) + return failure(); + + // Clean-up: don't assume the operation to delete are in the correct order, + // but force dropping the reference of the ops before we delete them + for (Operation *op : opToDelete) { + op->dropAllReferences(); + op->dropAllDefinedValueUses(); + rewriter.eraseOp(op); + } + } + return success(); +} + +LogicalResult PointerCanonicalizer::run() { + llvm::SmallVector funcOps; + + // For now we don't cross function boundaries, but we should do that whenever + // is possible + mod.walk([&](triton::FuncOp funcOp) { funcOps.push_back(funcOp); }); + + for (triton::FuncOp funcOp : funcOps) { + if (failed(rewriteFunction(funcOp))) + return failure(); + } + return success(); +} +// This pass is calling the pointer canonicalization utility +// on the given MLIR module +class TritonAMDGPUCanonicalizePointersPass + : public TritonAMDGPUCanonicalizePointersBase< + TritonAMDGPUCanonicalizePointersPass> { +public: + TritonAMDGPUCanonicalizePointersPass() = default; + + void runOnOperation() override { + ModuleOp m = getOperation(); + if (failed(PointerCanonicalizer(m).run())) + signalPassFailure(); + } +}; + +std::unique_ptr mlir::createTritonAMDGPUCanonicalizePointersPass() { + return std::make_unique(); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp new file mode 100644 index 000000000..f1d922041 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -0,0 +1,260 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/TypeSwitch.h" +#include +#include + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +#define DEBUG_TYPE "tritonamdgpu-convert-buffer-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace ttg = mlir::triton::gpu; +namespace tt = mlir::triton; + +namespace { +bool verifyNonNegativeByAssumption(Value expr, + const DenseSet &assumptions) { + for (Value assume : assumptions) { + LDBG("Assumption:" << assume); + if (auto cmpOp = assume.getDefiningOp()) { + bool isGreaterThan = (cmpOp.getPredicate() == arith::CmpIPredicate::sge || + cmpOp.getPredicate() == arith::CmpIPredicate::sgt); + APInt cst; + if (isGreaterThan && (cmpOp.getLhs() == expr) && + matchPattern(cmpOp.getRhs(), m_ConstantInt(&cst))) { + return cst.isNonNegative(); + } + } + } + return false; +} + +bool verifyNonNegativeExpr(Value expr, const DenseSet &assumptions) { + + // Check if the expression is contained in any assumption + if (verifyNonNegativeByAssumption(expr, assumptions)) { + LDBG("Non negative by assumption"); + return true; + } + + // Recurse if the operation is defined + Operation *op = expr.getDefiningOp(); + if (!op) + return false; + + bool nonNegative = + llvm::TypeSwitch(expr.getDefiningOp()) + .Case([&](auto broadcastOp) { + return verifyNonNegativeExpr(broadcastOp.getSrc(), assumptions); + }) + .Case([&](auto expandOp) { + return verifyNonNegativeExpr(expandOp.getSrc(), assumptions); + }) + .Case([&](auto splatOp) { + return verifyNonNegativeExpr(splatOp.getSrc(), assumptions); + }) + .Case([&](auto makeRangeOp) { + return makeRangeOp.getStart() >= 0 && makeRangeOp.getEnd() >= 0; + }) + .Case( + [&](auto constIntOp) { return constIntOp.value() >= 0; }) + .Case([&](arith::ConstantOp constOp) { + Value val = constOp.getResult(); + DenseIntElementsAttr constVal; + if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat()) + return constVal.getSplatValue().isNonNegative(); + return false; + }) + .Case([&](auto pidOp) { return true; }) + .Case([&](auto maxOp) { + // max(a,b) >= 0 iff a>=0 || b>=0 + bool nnLhs = verifyNonNegativeExpr(maxOp.getLhs(), assumptions); + bool nnRhs = verifyNonNegativeExpr(maxOp.getRhs(), assumptions); + return nnLhs || nnRhs; + }) + .Case([&](auto remsiOp) { + // a % b >= 0 iff a>=0 + return verifyNonNegativeExpr(remsiOp.getLhs(), assumptions); + }) + .Case([&](Operation *unaryOp) { + // a = OP b >= 0 iff b >= 0 + return verifyNonNegativeExpr(unaryOp->getOperand(0), assumptions); + }) + .Case( + // Generally speaking, a OP b >= 0 iff a >= 0 && b >= 0 when + // OP != sub + [&](Operation *binOp) { + bool nnLhs = + verifyNonNegativeExpr(binOp->getOperand(0), assumptions); + bool nnRhs = + verifyNonNegativeExpr(binOp->getOperand(1), assumptions); + return nnLhs && nnRhs; + }) + .Default([&](Operation *op) { + // Conservatively assume that the expression is negative + return false; + }); + return nonNegative; +} + +// Quick analysis on the Triton IR to decide if we can safely use +// buffer operations +bool canUseBufferOps(Value ptr, const DenseSet &assumptions) { + // 1. Check if the pointer is uniform: i.e., if it comes from a uniform + // pointer(splatted) and non-uniform offset addition + + LDBG("Buffer op checks for: " << ptr); + auto addPtrOp = ptr.getDefiningOp(); + if (!addPtrOp) + return false; + + auto maybeSplatOp = addPtrOp.getPtr().getDefiningOp(); + if (!maybeSplatOp) + return false; + LDBG("Pattern matched"); + + // 2. Check if the offset is a 32-bit tensor + Value offset = addPtrOp.getOffset(); + if (cast(offset.getType()).getElementTypeBitWidth() != 32) + return false; + LDBG("32 bit offset"); + + // 3. Check if the offset is non-negative + if (!verifyNonNegativeExpr(offset, assumptions)) + return false; + + LDBG("Non-negative"); + return true; +} +} // namespace + +struct ConvertTritonLoadToBufferLoad + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertTritonLoadToBufferLoad(mlir::MLIRContext *context, + DenseSet &assumptions) + : mlir::OpRewritePattern(context), + assumptions(assumptions) {} + + mlir::LogicalResult + matchAndRewrite(triton::LoadOp op, PatternRewriter &rewriter) const override { + LDBG("Try to convert: " << op); + Value ptr = op.getPtr(); + + if (op.getCache() != triton::CacheModifier::NONE) + return failure(); + + if (canUseBufferOps(ptr, assumptions)) { + auto addPtrOp = ptr.getDefiningOp(); + Value tensorPtr = addPtrOp.getPtr(); + Value tensorOffset = addPtrOp.getOffset(); + auto splatOp = tensorPtr.getDefiningOp(); + Value basePtr = splatOp.getSrc(); + Value maybeOther{}; + if (op.getOther() && !isZeroConst(op.getOther())) + maybeOther = op.getOther(); + Value maybeMask{}; + if (op.getMask() && !isZeroConst(op.getMask())) + maybeMask = op.getMask(); + rewriter.replaceOpWithNewOp( + op, op.getType(), basePtr, tensorOffset, maybeMask, maybeOther); + return success(); + } + LDBG("Failed to convert: " << op); + return failure(); + } + +private: + // Assumptions collected through the function + DenseSet assumptions; +}; + +struct ConvertTritonStoreToBufferStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertTritonStoreToBufferStore(mlir::MLIRContext *context, + DenseSet &assumptions) + : mlir::OpRewritePattern(context), + assumptions(assumptions) {} + + mlir::LogicalResult + matchAndRewrite(triton::StoreOp op, + PatternRewriter &rewriter) const override { + LDBG("Try to convert: " << op); + Value ptr = op.getPtr(); + + if (op.getCache() != triton::CacheModifier::NONE) + return failure(); + + if (canUseBufferOps(ptr, assumptions)) { + auto addPtrOp = ptr.getDefiningOp(); + Value tensorPtr = addPtrOp.getPtr(); + Value tensorOffset = addPtrOp.getOffset(); + auto splatOp = tensorPtr.getDefiningOp(); + Value basePtr = splatOp.getSrc(); + Value maybeMask{}; + if (op.getMask() && !isZeroConst(op.getMask())) + maybeMask = op.getMask(); + rewriter.replaceOpWithNewOp( + op, op.getValue(), basePtr, tensorOffset, maybeMask); + return success(); + } + LDBG("Failed to convert: " << op); + return failure(); + } + +private: + // Assumptions collected through the function + DenseSet assumptions; +}; + +class TritonAMDGPUConvertToBufferOpsPass + : public TritonAMDGPUConvertToBufferOpsBase< + TritonAMDGPUConvertToBufferOpsPass> { + +public: + TritonAMDGPUConvertToBufferOpsPass() = default; + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + // Collect assumptions in the function + DenseSet assumptions; + m.walk([&](LLVM::AssumeOp op) { + if (op->getOperand(0).getDefiningOp()) + assumptions.insert(op->getOperand(0)); + }); + LDBG("Number of assumptions found: " << assumptions.size()); + + patterns.add(context, assumptions); + patterns.add(context, assumptions); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +std::unique_ptr mlir::createTritonAMDGPUConvertToBufferOpsPass() { + return std::make_unique(); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp index 9207d1558..d3b2b70f8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp @@ -2,7 +2,8 @@ namespace mlir { -static MfmaTypeId convertTypesToId(mlir::Type dataTypeA, mlir::Type dataTypeB) { +static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA, + mlir::Type dataTypeB) { if (dataTypeA.isF32() && dataTypeB.isF32()) { return MfmaTypeId::Fp32TyId; } @@ -27,6 +28,9 @@ static MfmaTypeId convertTypesToId(mlir::Type dataTypeA, mlir::Type dataTypeB) { if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) { return MfmaTypeId::Bf8Bf8TyId; } + if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) { + return MfmaTypeId::Fp16TyId; + } llvm_unreachable("Unsupported input argument type."); } @@ -205,16 +209,48 @@ auto getMfmaInsnGroupAttrMap = []() -> const MfmaInsnGroupMap & { return MfmaInsnMap; }; +std::pair TypesFromMfmaId(mlir::MLIRContext *ctx, + MfmaTypeId id) { + auto f8e5m2 = Float8E5M2Type::get(ctx); + auto f8e4m3fnuz = Float8E4M3FNUZType::get(ctx); + auto f8e5m2fnuz = Float8E5M2FNUZType::get(ctx); + auto f16 = Float16Type::get(ctx); + auto bf16 = BFloat16Type::get(ctx); + auto f32 = Float32Type::get(ctx); + auto i8 = IntegerType::get(ctx, 8, IntegerType::Signed); + switch (id) { + case MfmaTypeId::Fp32TyId: + return {f32, f32}; + case MfmaTypeId::Fp16TyId: + return {f16, f16}; + case MfmaTypeId::Bf16TyId: + return {bf16, bf16}; + case MfmaTypeId::I8TyId: + return {i8, i8}; + case MfmaTypeId::Fp8Fp8TyId: + return {f8e4m3fnuz, f8e4m3fnuz}; + case MfmaTypeId::Fp8Bf8TyId: + return {f8e4m3fnuz, f8e5m2fnuz}; + case MfmaTypeId::Bf8Fp8TyId: + return {f8e5m2fnuz, f8e4m3fnuz}; + case MfmaTypeId::Bf8Bf8TyId: + return {f8e5m2fnuz, f8e5m2fnuz}; + } + assert(false && "unsupported MfmaTypeId"); +} + FailureOr MfmaInsn::selectMfma(unsigned mDim, unsigned nDim, Type elementTypeA, Type elementTypeB, int mfmaVersion) { auto mfmaInsnAttrMap = getMfmaInsnGroupAttrMap(); - MfmaInsnGroupSelectKey key = { - mDim, nDim, convertTypesToId(elementTypeA, elementTypeB), mfmaVersion}; + MfmaTypeId mfmaId = chooseAppropriateMfmaId(elementTypeA, elementTypeB); + MfmaInsnGroupSelectKey key = {mDim, nDim, mfmaId, mfmaVersion}; auto it = mfmaInsnAttrMap.find(key); if (it == mfmaInsnAttrMap.end()) return failure(); - return MfmaInsn(elementTypeA, elementTypeB, (*it).second); + auto [instrElementTypeA, instrElementTypeB] = + TypesFromMfmaId(elementTypeA.getContext(), mfmaId); + return MfmaInsn(instrElementTypeA, instrElementTypeB, it->second); } MfmaInsn::MfmaInsn(Type elementTypeA, Type elementTypeB, @@ -226,4 +262,6 @@ unsigned MfmaInsn::getMDim() { return attr.m; } unsigned MfmaInsn::getNDim() { return attr.n; } StringRef MfmaInsn::getInsnName() { return attr.insn; } unsigned MfmaInsn::getKBase() { return attr.kBase; } +Type MfmaInsn::getElementTypeA() { return elementTypeA; } +Type MfmaInsn::getElementTypeB() { return elementTypeB; } } // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp index 6c6475b75..f2818297f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp @@ -100,6 +100,8 @@ class BypassEpilogueSMEM : public mlir::RewritePattern { llvm::SmallVector chainedOps; while (true) { auto chainedOp = val.getDefiningOp(); + if (!chainedOp) + return mlir::failure(); if (llvm::isa(chainedOp)) break; if (!chainedOp->hasOneUse()) @@ -110,7 +112,7 @@ class BypassEpilogueSMEM : public mlir::RewritePattern { chainedOps.push_back(chainedOp); } - auto cvtOp = dyn_cast(val.getDefiningOp()); + auto cvtOp = val.getDefiningOp(); if (!cvtOp) return mlir::failure(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index f9fac1bf5..e122f15fd 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -2,86 +2,251 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" -#include "mlir/Transforms/RegionUtils.h" -#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Passes.h" -#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h" using namespace mlir; +namespace ttg = mlir::triton::gpu; +namespace tt = mlir::triton; -static bool willIncreaseRegisterPressure(Operation *op) { - if (isa(op)) - return true; - auto cvt = dyn_cast(op); - if (!cvt) - return false; - if (isa(cvt.getType().getEncoding())) +static bool isLocalLoadOrDotLayoutConversion(Operation *op) { + if (isa(op)) return true; + if (auto cvt = dyn_cast(op)) + return isa(cvt.getType().getEncoding()); return false; } +// Search through block to find earliest insertion point for move op. This can +// be either an atomic op or last usage of source pointer. Search ends when move +// op is encountered. +static llvm::ilist::iterator +findEarlyInsertionPoint(Block *block, Operation *move) { + Value src; + if (auto ld = dyn_cast(move)) + src = ld.getPtr(); + + auto ipnt = block->end(); + for (auto bi = block->begin(); bi != block->end(); ++bi) { + auto *op = &*bi; + if (op == move) // Don't move later than current location + break; + + op->walk([&](Operation *wop) { + if (src) { + // Check for ops accessing src value. + for (auto opr : wop->getOperands()) { + if (opr == src) + ipnt = bi; + } + } + // Atomics used for global synchronization. + if (isa(wop)) + ipnt = bi; + // Break at barrier + if (isa(wop)) + ipnt = bi; + // Break at loops. + if (isa(wop)) + ipnt = bi; + }); + } + return ipnt; +} + +// Check if the operation opInsideLoop is inside any scf::ForOp and +// opOutsideLoop is not inside the same loop. +bool isCrossLoopBoundary(mlir::Operation *opInsideLoop, + mlir::Operation *opOutsideLoop) { + scf::ForOp parentForOp = opInsideLoop->getParentOfType(); + return parentForOp && !parentForOp->isAncestor(opOutsideLoop); +} + class TritonAMDGPUReorderInstructionsPass : public TritonAMDGPUReorderInstructionsBase< TritonAMDGPUReorderInstructionsPass> { public: TritonAMDGPUReorderInstructionsPass() = default; + Operation *getFirstUse(Operation *op) { + std::vector users; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + users.push_back(ancestor); + } + auto minOpIt = std::min_element(users.begin(), users.end(), + [](mlir::Operation *a, mlir::Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != users.end() ? *minOpIt : nullptr; + } + void runOnOperation() override { ModuleOp m = getOperation(); - mlir::DominanceInfo dom(m); - // Sink conversions into loops when they will increase - // register pressure + + // Sink shared memory loads and layout conversions into loops to decrease + // register pressure when possible. DenseMap opToMove; - auto moveAfter = [](Operation *lhs, Operation *rhs) { - lhs->moveAfter(rhs); - }; m.walk([&](Operation *op) { - if (!willIncreaseRegisterPressure(op)) + if (!isLocalLoadOrDotLayoutConversion(op)) return; - auto user_begin = op->user_begin(); - auto user_end = op->user_end(); - if (std::distance(user_begin, user_end) != 1) + if (!op->hasOneUse()) return; - if (user_begin->getParentOfType() == + Operation *user = *op->getUsers().begin(); + if (user->getParentOfType() == op->getParentOfType()) return; - opToMove.insert({op, *user_begin}); + opToMove.insert({op, user}); }); for (auto &kv : opToMove) kv.first->moveBefore(kv.second); - // Move LocalLoadOp and LocalAllocOp immediately after their operands. - m.walk([&](Operation *op) { - if (!isa(op)) { + opToMove.clear(); + + // Adjust the placement of LDS writes and reads to immediately follow the + // definition of their operands in case where LDS write is in the + // loop but it's operand is not. This is a heuristic for optimizing fused + // attention by hoisting Q tensor LDS read/write operations outside of the + // loop, as Q is a loop invariant and can be loaded once before entering the + // loop. + // There are two possible patterns for this adjustment depending on + // whether the write to LDS is performed using an optional `local_alloc` + // argument or a `local_store` instruction. + // + // clang-format off + // + // 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading) + // %2 = local_alloc %1 + // %3 = local_load %2 + // + // 2) %1 = some_op ... + // %2 = local_alloc + // %3 = local_store %1, %2 + // %4 = local_load %2 + // + // clang-format on + m.walk([&](ttg::LocalLoadOp localLoad) { + auto localAlloc = localLoad.getSrc().getDefiningOp(); + if (!localAlloc) + return; + + // Case when localAlloc has operands + if (localAlloc->getNumOperands() == 1) { + if (!localAlloc->hasOneUse()) + return; + + auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp(); + // Check if localAlloc is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) { + return; + } + + localAlloc->moveAfter(srcTensorOp); + localLoad->moveAfter(localAlloc); return; } - Operation *argOp = op->getOperand(0).getDefiningOp(); - if (!argOp) + + // Case when localAlloc has no operands + assert(localAlloc->getNumOperands() < 1); + auto allocVal = localAlloc->getResult(0); + + // Check if the localAlloc has exactly two uses (localStore and localLoad) + int numUses = std::distance(allocVal.use_begin(), allocVal.use_end()); + if (numUses != 2) + return; + + // localStore comes before localLoad in block. + Operation *localStore = getFirstUse(localAlloc); + if (!isa(localStore)) + return; + + auto srcTensorOp = localStore->getOperand(0).getDefiningOp(); + // Check if localStore is in the loop but it's src tensor defining op is + // outside of it. + if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) { return; - moveAfter(op, argOp); + } + + localAlloc->moveAfter(srcTensorOp); + localStore->moveAfter(localAlloc); + localLoad->moveAfter(localStore); }); - // Move transpositions just after their definition - opToMove.clear(); + + // Sink conversion after the last dealloc but before the first use ancestor + // in its block. This helps to avoid unnecessary shared memory allocation. + m.walk([&](triton::gpu::ConvertLayoutOp op) { + auto curr = mlir::Block::iterator(op); + for (; &*curr != getFirstUse(op); curr++) + if (isa(&*curr)) + op->moveAfter(&*curr); + }); + + // Move transpositions just after their definition. m.walk([&](triton::TransOp op) { - Operation *argOp = op.getSrc().getDefiningOp(); - if (!argOp) - return; - moveAfter(op, argOp); + if (Operation *argOp = op.getSrc().getDefiningOp()) + op->moveAfter(argOp); }); - return; + + SmallVector moveOps; + // Move global loads early to prefetch. This may increase register pressure + // but it enables issuing global loads early. + m.walk([&](triton::LoadOp op) { moveOps.push_back(op); }); + // Move local_stores early if dependence distance greater than + // one iteration. + // Best perf on GEMM when these precede global loads. + m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); }); + + for (auto op : llvm::reverse(moveOps)) { + // Gather use-def chain in block. + Block *block = op->getBlock(); + bool leadsToLoad = false; + SetVector backwardSet; + + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.inclusive = false; + options.filter = [&](Operation *defOp) -> bool { + Block *defBlock = defOp->getBlock(); + if (!block->findAncestorOpInBlock(*defOp)) + return false; + // Check for a `load` dependent path. + leadsToLoad |= isa(defOp); + // Only move ops residing in the same block. + return defBlock == block; + }; + mlir::getBackwardSlice(op, &backwardSet, options); + backwardSet.insert(op); + + // Don't move a local_store if its source is a load from + // the same iteration. + if (isa(op) && leadsToLoad) + continue; + + auto ipoint = findEarlyInsertionPoint(block, op); + // Remove ops that already precede the insertion point. This is done + // before moves happen to avoid `Operation::isBeforeInBlock` N^2 + // complexity. + + SmallVector dfg = backwardSet.takeVector(); + if (ipoint != block->end()) { + // Move ops to insertion point. + llvm::erase_if( + dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); }); + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveAfter(block, ipoint); + } else { + // Move ops to block begin. + for (auto *dfgop : llvm::reverse(dfg)) + dfgop->moveBefore(block, block->begin()); + } + } } }; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp deleted file mode 100644 index 6f9ed6a23..000000000 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ /dev/null @@ -1,860 +0,0 @@ -#include "TritonAMDGPUTransforms/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "llvm/ADT/MapVector.h" - -//===----------------------------------------------------------------------===// -// This file implements stream software pipelining for loops. The implementation -// here is inspired by the pipeline pass in Triton and the rocMLIR pipeliner. -// -// We divide the loop body into the following phases: -// a. Pre-load operations: for instance, index computation. -// b. Load operations: loading from global memory to shared memory. -// c. Compute operations: for instance, Triton dot. -// d. Post-load operations: for instance, index computation. -// -// To pipeline the loop, we need to: -// - Find all the dependencies of the load operations. -// - Prologue: Hoist the pipelinable load operations and shared memory store -// for the ramp up stage -// - Pipelined Loop: Assemble the loop body minus last iteration -// - Prefetch next tile from global into regs (while computing from previous) -// - Non-load loop body -// - Store next tile into shared mem -// - Epilogue: Peeled non-load loop body for last iteration -// -//===----------------------------------------------------------------------===// - -using llvm::MapVector; -using namespace mlir; -namespace ttg = triton::gpu; - -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h.inc" - -namespace { - -class LoopPipeliner { - /// Cache of ForOp and YieldOp related to this pipeliner. - scf::ForOp forOp; - scf::YieldOp yieldOp; - - bool peelLastIter = true; - - /// The new pipelined ForOp. - scf::ForOp pplForOp; - - /// Loads to be pipelined - SetVector validLoads; - /// The value that each load will be mapped to (after layout conversion) - DenseMap convertMapping; - /// load => buffer - DenseMap loadsBuffer; - /// load => buffer type (with shared layout after swizzling) - DenseMap loadsBufferType; - - /// Iterator values - Value nextLoopCond; - - /// Yield values - SmallVector yieldValues; - - /// The number of stages in the pipeline is fixed to '2' for - /// analysis since there will be a current buffer stored in - /// shared mem and a next buffer stored in regs. - int numStages = 2; - - /// Arg indicies - size_t depArgsBeginIdx; - DenseMap depArgsIdx; - - /// value (in loop) => value at stage N - DenseMap> valueMapping; - /// loop iter arg => value - DenseMap depArgsMapping; - - /// forOp value => pplForOp value - IRMapping curMapping; - /// forOp value => prefetch value - IRMapping nextMapping; - - /// Dependency ops by program order - SmallVector orderedDeps; - - SetVector currentDeps; - - /// block arguments that loads depend on - SetVector depArgs; - - /// operation => source operand defined stages - DenseMap> immediateOpStages; - - /// operations that loads depend on - SetVector depOps; - - /// Collect values that `v` depends on and are defined inside the loop - void collectValueDep(Value v, int stage, SetVector &deps, - SetVector &args); - - /// Collect all op dependencies - void collectDeps(SetVector &ops, - MapVector> &opDeps); - - void collectDepChain(Operation *op, SetVector &ops); - - /// Check if none of the for-ops has valid uses - LogicalResult checkOpUses(); - - /// Check if ops have dependencies that are not pipelinable - LogicalResult checkOpDeps(); - - void createBufferTypes(); - - void createOrderedDeps(); - - void createCurrentDeps(); - - /// Return the stage at which `v` is defined prior to `stage` - int getValueDefStage(Value v, int stage); - - /// Map `origin` to `newValue` at `stage` - void setValueMapping(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at `stage` according to the association between - /// yieldOp and forOp - void setValueMappingYield(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at the next stage according to the association - /// between yieldOp and forOp - void setValueMappingYield(Value origin, Value newValue); - - /// Return the value mapped to `origin` at `stage`, if it exists. - Value lookupOrDefault(Value origin, int stage); - - Value getLoadMask(triton::LoadOp loadOp, Value mappedMask, Value loopCond, - OpBuilder &builder); - /// Collect all args of the new loop - SmallVector collectNewLoopArgs(); - - /// Clone the forOp and return the new forOp - scf::ForOp cloneForOp(ArrayRef newLoopArgs, OpBuilder &builder); - - void updateLoadMask(triton::LoadOp loadOp, Value newMask); - /// Prefetch the next iteration for `pplForOp` - void prefetchNextBuffer(OpBuilder &builder); - void cloneCurrentBody(OpBuilder &builder); - void storeNextBuffer(OpBuilder &builder); - - bool isLoadChain(Operation *op) const; - - /// Assemble `pplForOp`'s yield op - void finalizeYield(OpBuilder &builder); - -public: - LoopPipeliner(scf::ForOp forOp) : forOp(forOp) { - yieldOp = cast(forOp.getBody()->getTerminator()); - } - - /// Collect loads to pipeline. Return success if we can pipeline this loop - LogicalResult initialize(); - - /// Emit pipelined loads (before loop body) - void emitPrologue(); - - /// emit pipelined loads (after loop body) - void emitEpilogue(DenseMap &newResults); - - /// create the new ForOp (add new args & insert prefetched ops) - scf::ForOp createNewForOp(); - - friend struct PipelinePass; -}; - -void LoopPipeliner::collectValueDep(Value v, int stage, - SetVector &deps, - SetVector &args) { - // Since we only need to peel the loop numStages-1 times, don't worry - // about depends that are too far away - if (stage < 0) - return; - - // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getRegion()) - return; - - if (Operation *op = v.getDefiningOp()) { - if (!deps.contains(op)) { - deps.insert(op); - for (Value opr : op->getOperands()) - collectValueDep(opr, stage, deps, args); - } - } else if (auto arg = dyn_cast(v)) { - if (arg.getArgNumber() > 0) { - args.insert(arg); - collectValueDep(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1, - deps, args); - } - } -} - -void LoopPipeliner::collectDeps( - SetVector &ops, - MapVector> &valueDeps) { - for (auto op : ops) { - for (Value v : op->getOperands()) { - SetVector deps; - SetVector args; - collectValueDep(v, numStages - 1, deps, args); - valueDeps[op] = deps; - } - } -} - -LogicalResult LoopPipeliner::checkOpUses() { - SetVector ops; - // We cannot use forOp.walk(...) here because we only want to visit the - // operations in the loop body block. Nested blocks are handled separately. - for (Operation &op : forOp) { - if (auto loadOp = dyn_cast(&op)) - ops.insert(&op); - } - - // Collect all ops' dependencies - MapVector> opDeps; - collectDeps(ops, opDeps); - - for (Operation *op : ops) { - auto loadOp = dyn_cast(op); - // Don't pipeline valid loads that depend on other valid loads - // (Because if a valid load depends on another valid load, this load needs - // to wait on the other load in the prologue, which is against the point - // of the pipeline pass) - bool isCandidate = true; - for (Operation *other : ops) - if (isa(other)) - if (opDeps[op].contains(other)) { - isCandidate = false; - break; - } - // We only pipeline loads that have one covert_layout (to dot_op) use - // TODO: lift this constraint in the future - if (isCandidate && loadOp.getResult().hasOneUse()) { - isCandidate = false; - Operation *use = *loadOp.getResult().getUsers().begin(); - - // Advance to the first conversion as long as the use resides in shared - // memory and it has a single use itself - while (use) { - if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) - break; - auto tensorType = - dyn_cast(use->getResult(0).getType()); - if (!tensorType || - !isa(tensorType.getEncoding())) - break; - use = *use->getResult(0).getUsers().begin(); - } - - // TODO: handle fp_to_fp conversions in between - if (auto convertLayout = llvm::dyn_cast(use)) - if (auto tensorType = - dyn_cast(convertLayout.getResult().getType())) - if (auto dotOpEnc = dyn_cast( - tensorType.getEncoding())) { - isCandidate = true; - convertMapping[loadOp] = convertLayout; - } - } else - isCandidate = false; - - if (isCandidate) - validLoads.insert(op); - } - - return validLoads.empty() ? failure() : success(); -} - -LogicalResult LoopPipeliner::checkOpDeps() { - /// arg => source operand defined stages - DenseMap> immediateArgStages; - SetVector nonImmediateDepArgs; - SetVector nonImmediateOps; - for (Operation *op : validLoads) { - for (Value v : op->getOperands()) { - SetVector deps; - SetVector args; - collectValueDep(v, numStages - 1, deps, args); - int defStage = getValueDefStage(v, numStages - 1); - if (defStage < 0) { - // assert(defStage >= 0 && - // "newLoopArgs has null args without a define op. Consider - // either " "rewrite the loop to reduce cross iteration - // dependencies or " "increase the num_stages value."); - return failure(); - } - bool immediate = args.size() > 0; - for (auto *dep : deps) { - depOps.insert(dep); - if (immediate) - immediateOpStages[dep].insert(defStage); - else - nonImmediateOps.insert(dep); - } - for (auto arg : args) { - depArgs.insert(arg); - if (immediate) - immediateArgStages[arg].insert(defStage); - else - nonImmediateDepArgs.insert(arg); - } - } - } - - // XXX: We could remove the following constraints if we can rematerialize in - // the loop. - // Check if immediateDepArgs and nonImmediateDepArgs are disjoint. - for (auto &[arg, stages] : immediateArgStages) { - assert(stages.size() == 1 && - "Triton doesn't support an argument provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateDepArgs.contains(arg) && - stages.contains(numStages - 2)) && - "Loop-carried arguments provide values for both immediate and " - "non-immediate operands of loads. Please consider removing " - "pre/post load instructions dependency on this argument."); - } - - // Check if immediateOps and nonImmediateOps are disjoint. - for (auto &[op, stages] : immediateOpStages) { - assert(stages.size() == 1 && - "Triton doesn't support an operation provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateOps.contains(op) && stages.contains(numStages - 2)) && - "Operations provide values for both immediate and " - "non-immediate operands of loads. Please consider " - "removing pre/post load instructions dependency on this " - "operation."); - } - return success(); -} - -// helpers -void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - valueMapping[origin] = SmallVector(numStages); - valueMapping[origin][stage] = newValue; -} - -void LoopPipeliner::setValueMappingYield(Value origin, Value newValue, - int stage) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto value = forOp.getRegionIterArgs()[yieldIdx]; - setValueMapping(value, newValue, stage); - } - } -} - -void LoopPipeliner::setValueMappingYield(Value origin, Value newValue) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]]; - auto originArg = forOp.getRegionIterArgs()[yieldIdx]; - nextMapping.map(originArg, newValue); - auto newArg = pplForOp.getRegionIterArgs()[depYieldIdx]; - if (!depArgsMapping.contains(newArg)) - depArgsMapping[newArg] = newValue; - } - } -} - -Value LoopPipeliner::lookupOrDefault(Value origin, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - return origin; - return valueMapping[origin][stage]; -} - -void LoopPipeliner::createBufferTypes() { - for (auto loadCvt : convertMapping) { - auto loadOp = loadCvt.first; - Value cvt = loadCvt.second; - auto dotOpEnc = cast( - cast(cvt.getType()).getEncoding()); - auto ty = cast(loadOp.getType()); - SmallVector bufferShape(ty.getShape().begin(), - ty.getShape().end()); - Type eType = ty.getElementType(); - auto blockedEnc = cast(ty.getEncoding()); - auto CTALayout = ttg::getCTALayout(ty.getEncoding()); - // unsigned bitWidth = dotOpEnc.getMMAv2kWidth() - // ? 32 / dotOpEnc.getMMAv2kWidth() - // : ty.getElementType().getIntOrFloatBitWidth(); - auto sharedEnc = ttg::SharedEncodingAttr::get( - ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, eType); - loadsBufferType[loadOp] = - triton::MemDescType::get(bufferShape, eType, sharedEnc); - } -} - -void LoopPipeliner::createOrderedDeps() { - for (Operation &op : forOp.getBody()->without_terminator()) { - if (depOps.contains(&op)) - orderedDeps.push_back(&op); - else if (op.getNumResults() > 0 && validLoads.contains(&op)) - orderedDeps.push_back(&op); - } - assert(depOps.size() + validLoads.size() == orderedDeps.size() && - "depOps contains invalid values"); -} - -void LoopPipeliner::collectDepChain(Operation *op, - SetVector &ops) { - if (op->getNumResults() == 1 && validLoads.contains(op)) - return; - if (!ops.contains(op)) { - ops.insert(op); - for (Value opr : op->getOperands()) - if (Operation *oprOp = opr.getDefiningOp()) - collectDepChain(oprOp, ops); - } -} - -void LoopPipeliner::createCurrentDeps() { - for (Operation &op : forOp.getBody()->without_terminator()) { - if (!llvm::is_contained(orderedDeps, &op)) - collectDepChain(&op, currentDeps); - } -} - -int LoopPipeliner::getValueDefStage(Value v, int stage) { - if (stage < 0) - return -1; - if (auto arg = dyn_cast(v)) { - if (arg.getArgNumber() > 0) - return getValueDefStage(yieldOp->getOperand(arg.getArgNumber() - 1), - stage - 1); - llvm_unreachable("Loop induction variable should not be a dependency"); - } else - return stage; -} - -LogicalResult LoopPipeliner::initialize() { - if (checkOpUses().failed()) - return failure(); - - if (checkOpDeps().failed()) - return failure(); - - createBufferTypes(); - - createOrderedDeps(); - - createCurrentDeps(); - - return success(); -} - -Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask, - Value loopCond, OpBuilder &builder) { - if (!peelLastIter) { - // add mask for last iteration when not peeled to epilogue - Value mask = loadOp.getMask(); - Type maskType = triton::getI1SameShape(loadOp.getType()); - Value newMask; - if (mask) { - Value cond = loopCond; - if (isa(maskType)) { - cond = - builder.create(mask.getLoc(), maskType, loopCond); - } - newMask = builder.create(mask.getLoc(), mappedMask, cond); - } else { - if (isa(maskType)) { - newMask = builder.create(loopCond.getLoc(), maskType, - loopCond); - } else { - newMask = loopCond; - } - } - return newMask; - } - // use original mask when peeling last iteration bc the loop will not do - // extra loads for the tail of the pipeline - return mappedMask; -} - -bool LoopPipeliner::isLoadChain(Operation *op) const { - if (auto cvtOp = dyn_cast(op)) { - Value loadVal = cvtOp.getSrc(); - if (auto f2fOp = dyn_cast(op)) - loadVal = f2fOp.getSrc(); - if (validLoads.contains(loadVal.getDefiningOp())) { - if (isa(cvtOp.getType().getEncoding())) - return true; - } - } - return false; -} - -void LoopPipeliner::emitPrologue() { - /// forOp block args => forOp operands - /// forOp iterator => lower bound - IRMapping prologueMap; - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { - OpOperand &operand = *forOp.getTiedLoopInit(arg); - prologueMap.map(arg, operand.get()); - } - - // Emit prologue - // Map IV to lower bound - prologueMap.map(forOp.getInductionVar(), forOp.getLowerBound()); - - // Emit Iteration 0 loads, etc - for (Operation *op : orderedDeps) { - Operation *newOp = nullptr; - if (validLoads.contains(op)) { - auto loadOp = cast(op); - // Load from global -> regs - auto newLoadOp = cloneWithInferType(builder, op, prologueMap); - Value loadVal = newLoadOp->getResult(0); - // Convert from regs to shared mem - newOp = builder.create( - loadOp.getLoc(), loadsBufferType[loadOp], loadVal); - Value cvtVal = newOp->getResult(0); - prologueMap.map(loadOp->getResult(0), cvtVal); - loadsBuffer[op] = cvtVal; - } else { - newOp = cloneWithInferType(builder, op, prologueMap); - } - // Capture loop carried results for pipelined for input - for (unsigned idx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(op->getResult(idx), newOp->getResult(idx), 1); - } // for (Operation *op : orderedDeps) -} - -void LoopPipeliner::emitEpilogue(DenseMap &newResults) { - if (!peelLastIter) - return; - OpBuilder builder(pplForOp); - builder.setInsertionPointAfter(pplForOp); - - IRMapping epilogueMap; - // Map 'for' iteration args to pipelined-for results - auto args = forOp.getRegionIterArgs(); - for (uint32_t i = 0; i < args.size(); ++i) - epilogueMap.map(args[i], pplForOp.getResult(i)); - for (auto *loadOp : validLoads) - epilogueMap.map(loadOp->getResult(0), loadsBuffer[loadOp]); - - // This is computing the upper bound of the pipelined loop as: - // pplUpperBound = lb+((ub-1-lb)/step)*step - Location loc = forOp.getLoc(); - Value ub = forOp.getUpperBound(); - Value lb = forOp.getLowerBound(); - Value step = forOp.getStep(); - Value one = builder.create(loc, 1, 32); - - // pplRange = ub-1-lb - Value pplRange = builder.create( - loc, builder.create(loc, ub, one), lb); - - // pplIters = (pplrRange/step)*step - Value pplIters = builder.create( - loc, builder.create(loc, pplRange, step), step); - - // pplUpperBound = lb+pplIters - Value pplUpperBound = builder.create(loc, lb, pplIters); - epilogueMap.map(forOp.getInductionVar(), pplUpperBound); - - const auto &yieldOprs = yieldOp.getOperands(); - // Clone the loop body after the new ForOp - // , replace original args with results of the new ForOp. - for (Operation &op : forOp.getBody()->without_terminator()) { - if (currentDeps.contains(&op)) { - Operation *newOp = nullptr; - if (isLoadChain(&op)) { - if (auto cvt = dyn_cast(&op)) { - Value mappedValue = epilogueMap.lookup(cvt.getSrc()); - if (isa(mappedValue.getType())) { - auto newCvt = builder.create( - cvt.getLoc(), cvt.getType(), mappedValue); - epilogueMap.map(cvt.getResult(), newCvt); - newOp = newCvt; - } - } - if (!newOp) - newOp = builder.clone(op, epilogueMap); - } else { - newOp = cloneWithInferType(builder, &op, epilogueMap); - } - // substitute for these results for the results of the new for loop - for (const auto &pair : llvm::zip(op.getResults(), newOp->getResults())) { - auto val = std::get<0>(pair); - auto it = llvm::find(yieldOprs, val); - if (it != yieldOprs.end()) { - uint32_t idx = std::distance(yieldOprs.begin(), it); - newResults[forOp->getResult(idx)] = std::get<1>(pair); - } - } - } - } -} - -SmallVector LoopPipeliner::collectNewLoopArgs() { - // Order of new args: - // (original args) - // (shared mem buffers for each load) - // (depArgs at stage numStages - 1) - - // We need this to update operands for yield - // original block arg => new arg's idx - SmallVector newLoopArgs; - for (auto v : forOp.getInitArgs()) { - newLoopArgs.push_back(lookupOrDefault(v, numStages - 1)); /*1*/ - } - - // Loop carried vals - depArgsBeginIdx = newLoopArgs.size(); - for (auto depArg : depArgs) { - depArgsIdx[depArg] = newLoopArgs.size(); - newLoopArgs.push_back(valueMapping[depArg][numStages - 1]); /*1*/ - } - - return newLoopArgs; -} - -scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, - OpBuilder &builder) { - auto loc = forOp.getLoc(); - // Peel off the last iteration - auto pplUpperBound = forOp.getUpperBound(); - if (peelLastIter) - pplUpperBound = - builder.create(loc, pplUpperBound, forOp.getStep()); - - // Clone the original ForOp - pplForOp = builder.create( - loc, forOp.getLowerBound(), pplUpperBound, forOp.getStep(), newLoopArgs); - - // Set mapping on body of the new ForOp - builder.setInsertionPointToStart(pplForOp.getBody()); - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - curMapping.map(arg.value(), pplForOp.getRegionIterArgs()[arg.index()]); - for (auto *loadOp : validLoads) - curMapping.map(loadOp->getResult(0), loadsBuffer[loadOp]); - curMapping.map(forOp.getInductionVar(), pplForOp.getInductionVar()); - - nextMapping = curMapping; - // Map the dep args of the next iteration to the dep args of the current - auto iterArgs = pplForOp.getRegionIterArgs(); - size_t argIdx = 0; - for (auto depArg : depArgs) { - BlockArgument nextArg = iterArgs[argIdx + depArgsBeginIdx]; - nextMapping.map(depArg, nextArg); - ++argIdx; - } - - // Compute next IV for pre-loads - Value iv = pplForOp.getInductionVar(); - curMapping.map(forOp.getInductionVar(), iv); - Value nextIV = - builder.create(iv.getLoc(), iv, pplForOp.getStep()); - nextMapping.map(forOp.getInductionVar(), nextIV); - nextLoopCond = - builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, - nextIV, pplForOp.getUpperBound()); - - return pplForOp; -} - -void LoopPipeliner::updateLoadMask(triton::LoadOp loadOp, Value newMask) { - if (newMask) { - if (loadOp->getNumOperands() > 1) - loadOp->setOperand(1, newMask); - else { - auto mask = loadOp.getMaskMutable(); - mask.assign(newMask); - } - } -} - -void LoopPipeliner::prefetchNextBuffer(OpBuilder &builder) { - // Emit prefetch loads of next buffer before compute of current buffer - for (Operation *op : orderedDeps) { - Operation *nextOp = nullptr; - if (validLoads.contains(op)) { - // Update loading mask - auto loadOp = llvm::cast(op); - auto mask = loadOp.getMask(); - // pre-load global -> regs - Value newMask = getLoadMask(loadOp, nextMapping.lookupOrDefault(mask), - nextLoopCond, builder); - if (mask) { - // If mask is defined outside the loop, don't update the map more than - // once - if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) - nextMapping.map(loadOp.getMask(), newMask); - newMask = nextMapping.lookupOrDefault(mask); - } - auto newOp = builder.clone(*op, nextMapping); - updateLoadMask(cast(newOp), newMask); - } else if (!immediateOpStages[op].contains(numStages - 2)) { - Operation *nextOp = builder.clone(*op, nextMapping); - if (auto loadOp = dyn_cast(op)) { - if (auto newMask = getLoadMask( - loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder)) { - updateLoadMask(cast(nextOp), newMask); - } - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - } - } -} - -void LoopPipeliner::cloneCurrentBody(OpBuilder &builder) { - auto loc = forOp.getLoc(); - // only add instructions that are not part of the restructuring - for (Operation &op : forOp.getBody()->without_terminator()) { - if (currentDeps.contains(&op)) { - Operation *newOp = nullptr; - if (isLoadChain(&op)) { - if (auto cvt = dyn_cast(&op)) { - Value mappedValue = curMapping.lookup(cvt.getSrc()); - if (isa(mappedValue.getType())) { - auto newCvt = builder.create( - cvt.getLoc(), cvt.getType(), mappedValue); - curMapping.map(cvt.getResult(), newCvt); - newOp = newCvt; - } - } - if (!newOp) - newOp = builder.clone(op, curMapping); - } else { - newOp = cloneWithInferType(builder, &op, curMapping); - } - } - } -} - -void LoopPipeliner::storeNextBuffer(OpBuilder &builder) { - // Store the next buffer at the end of the loop body for the next iteration - for (Operation *op : orderedDeps) { - if (!validLoads.contains(op)) { - if (immediateOpStages[op].contains(numStages - 2)) { - Operation *nextOp = builder.clone(*op, nextMapping); - if (auto loadOp = dyn_cast(op)) { - auto newMask = - getLoadMask(loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder); - updateLoadMask(cast(nextOp), newMask); - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(op->getResult(dstIdx), - nextOp->getResult(dstIdx)); - } - } - } - - // PL loads -> store next to shared - for (auto *loadOp : validLoads) { - Value loadVal = nextMapping.lookup(loadOp->getResult(0)); - // then store regs -> shared - Value storeBuf = loadsBuffer[loadOp]; - builder.create(loadOp->getLoc(), loadVal, storeBuf); - } - - // Some values have not been used by any ops in the loop body - for (BlockArgument arg : forOp.getRegionIterArgs()) - setValueMappingYield(arg, pplForOp.getRegionIterArgs()[depArgsIdx[arg]]); -} - -void LoopPipeliner::finalizeYield(OpBuilder &builder) { - SmallVector yieldValues; - for (const auto &opr : llvm::enumerate(yieldOp->getOperands())) { - if (curMapping.contains(opr.value())) - yieldValues.push_back(curMapping.lookup(opr.value())); - else - yieldValues.push_back(pplForOp.getRegionIterArgs()[opr.index()]); - } - for (size_t i = 0; i < depArgsMapping.size(); ++i) { - auto arg = pplForOp.getRegionIterArgs()[depArgsBeginIdx + i]; - assert(depArgsMapping.count(arg) && "Missing loop-carried value"); - yieldValues.push_back(depArgsMapping[arg]); - } - - builder.setInsertionPointToEnd(pplForOp.getBody()); - builder.create(yieldOp->getLoc(), yieldValues); -} - -scf::ForOp LoopPipeliner::createNewForOp() { - OpBuilder builder(forOp); - auto newLoopArgs = collectNewLoopArgs(); - cloneForOp(newLoopArgs, builder); - prefetchNextBuffer(builder); - cloneCurrentBody(builder); - storeNextBuffer(builder); - finalizeYield(builder); - return pplForOp; -} - -// Stream Pipeline -struct PipelinePass : public TritonAMDGPUStreamPipelineBase { - PipelinePass() = default; - - void runOnOperation() override { - // Pre-processing - // we make sure element-wise ops are done *after* the conversion - // to dot operands - // we can achieve this with simple recursive pattern matching - // MLIRContext *context = &getContext(); - // mlir::RewritePatternSet patterns(context); - // patterns.add(context); - // auto didPreprocess = - // applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - - // Do the pipelining - getOperation()->walk([&](scf::ForOp forOp) -> void { - LoopPipeliner pipeliner(forOp); - - if (pipeliner.initialize().failed()) - return; - - pipeliner.emitPrologue(); - scf::ForOp pplForOp = pipeliner.createNewForOp(); - DenseMap newResults; - for (unsigned i = 0; i < forOp->getNumResults(); ++i) - newResults[forOp->getResult(i)] = pplForOp->getResult(i); - pipeliner.emitEpilogue(newResults); - - // Replace the original loop - for (auto &pair : newResults) - std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair)); - forOp->erase(); - }); - } -}; -} // anonymous namespace - -std::unique_ptr mlir::createTritonAMDGPUStreamPipelinePass() { - return std::make_unique(); -} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp new file mode 100644 index 000000000..deb566a8b --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -0,0 +1,720 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create stream operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +#define DEBUG_TYPE "tritonamdgpu-stream-pipeline-v2" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +static Operation *streamPredication(RewriterBase &rewriter, Operation *op, + Value pred) { + // The epilogue peeling generates a select for the stage output. This causes + // too much register pressure with the loop result and the epilogue-dot in + // regs for the select. Conditionally executing the dot will allow the backend + // to optimize the select away as redundant. + if (auto dotOp = dyn_cast(op)) { + auto loc = dotOp->getLoc(); + auto ifOp = rewriter.create(loc, dotOp.getResult().getType(), + pred, /*withElseRegion=*/true); + auto thenB = ifOp.getThenBodyBuilder(); + auto yield = thenB.create(loc, dotOp.getResult()); + dotOp->moveBefore(yield); + ifOp.getElseBodyBuilder().create(loc, dotOp.getC()); + return ifOp; + } else if (isa(op)) { + return op; + } + return tt::predicateOp(rewriter, op, pred); +} + +namespace { + +// Encapsulate stream pipelining +// For each `scf.for` create a StreamPipeliner manager. +class StreamPipeliner { +public: + StreamPipeliner(scf::ForOp _forOp, int _numStages) + : forOp(_forOp), schedule(_numStages), numStages(_numStages), + axisInfoAnalysis(forOp->getParentOfType()) { + options.supportDynamicLoops = true; + options.peelEpilogue = true; + options.predicateFn = streamPredication; + } + + void computeLoadOpsToIndirectionLevelAndUse(); + void assignMemoryLayouts(); + void scheduleLoads(DenseSet &rootUsers); + void scheduleDependencies(); + void scheduleDistanceOneDependencies(); + void scheduleRemainingToLastStage(tt::CoarseSchedule::Cluster afterPrologue); + + bool preprocessLoopAndBuildSchedule(); + bool pipelineLoop(); + + Value createAlloc(Operation *loadOp, ttg::SharedEncodingAttr sharedEnc, + unsigned numBuffers); + void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx, + tt::CoarseSchedule::Cluster prefetchCluster); + void createStreamOps(); + +private: + scf::ForOp forOp; + tt::CoarseSchedule schedule; + int numStages; + + // Mapping and indirection level for each `tt.load` to its use. + llvm::SmallVector> + loadOpToIndLevelAndUse; + + struct LoadInfo { + // Shared layout is used for loads feeding into dot ops. + ttg::SharedEncodingAttr sharedEncoding = nullptr; + // The distance of this load's stage to its use' stage. + int distToUse = 0; + bool usedByDot = false; + }; + + // Mapping for each pipelined load to scheduling details. + llvm::MapVector loadToInfo; + + // Lookup alignment/contiguity mappings for the current module. + tt::ModuleAxisInfoAnalysis axisInfoAnalysis; + + // Capture list of new shared memory buffers. + SmallVector sharedMemAllocs; + + // Pipelining options for the PipelineExpander + tt::PipeliningOption options; +}; + +} // namespace + +void StreamPipeliner::createStreamCopy( + tt::LoadOp loadOp, Value alloc, Value extractIdx, + tt::CoarseSchedule::Cluster prefetchCluster) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + + tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + Operation *copy = builder.clone(*loadOp); + + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + auto subviewTy = tt::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); + // Clean up old local caches. + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + triton::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) + alloc.erase(); + + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad); + auto result = sharedLoad->getResults(); + + // Create a select for non-zero other values. + Value other = loadOp.getOther(); + if (other && !isZeroConst(other)) { + auto select = builder.create( + loc, loadOp.getType(), mask, sharedLoad.getResult(), other); + result = select->getResults(); + } + + loadOp->replaceAllUsesWith(result); + + // Prefetch load ahead of the dot stage if is used by the dot. + if (loadToInfo[loadOp].usedByDot) { + assert(numStages >= 2 && "requires num_stages=2 at least"); + schedule.insert(storeOp, numStages - 2, prefetchCluster); + schedule.insert(viewLoad, numStages - 2, prefetchCluster); + } + loadOp.erase(); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreDotEnc(Value val) { + ttg::SharedEncodingAttr attr; + for (Operation *user : val.getUsers()) { + ttg::SharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()).getEncoding()); + if (!dotOpEnc) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + SmallVector sharedOrder; + int rank = order.size(); + // TODO rework this when shared -> dotOp conversions support arbitrary + // shared memory ordering + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be + // the slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (order[i] != 0) + sharedOrder.emplace_back(order[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = order; + } + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, CTALayout, + bitWidth, /*needTrans=*/false); + } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; + } + return attr; +} + +// Create a map from load ops to their indirection levels and the final uses +// of the load op (another load op, or a dot op). +// +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +void StreamPipeliner::computeLoadOpsToIndirectionLevelAndUse() { + DenseSet seen; + + // Recursively visit the given op and its operands to discover all load ops + // and collect their indirection levels and uses. + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + // Skip previously visited load ops. + if (!seen.insert(op).second) + return; + + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.emplace_back(op, distance, use); + use = op; + ++distance; + } + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasTrait()) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } +} + +// Goes through all load ops to identify those that can be pipelined and assign +// layout to them. +void StreamPipeliner::assignMemoryLayouts() { + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(op)) + // TODO: We'd need to verify that the distance is the same. + continue; + + LoadInfo loadInfo; + auto loadOp = cast(op); + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) { + LDBG("Skip non-tensor load " << *loadOp); + continue; + } + + auto pointeeTy = + cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * pointeeTy.getIntOrFloatBitWidth(); + + // Limit shared memory sharing to width >= 32 elements. + LDBG("Load " << *loadOp << " has width " << width); + if (width < 32) { + LDBG("Skip width<32 load " << *loadOp); + continue; + } + + if (use->hasTrait()) { + // Only use shared memory when feeding into a dot op. + loadInfo.usedByDot = true; + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + } else if (auto useOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadToInfo already, it means that the use is not valid for pipelining + // for some reason. We should skip this loadOp, too. + // + // Note that we have an assumption that the use of this loadOp has already + // be processed in a previous loop iteration. This assumption is held by + // how loadOpsToIndirectionLevelAndUse recursively collects + // loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(useOp) == 0) { + continue; + } + } + + loadToInfo[op] = loadInfo; + } +} + +void StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + computeLoadOpsToIndirectionLevelAndUse(); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return; + + // Check which loads are good for pipelining, and assign them memory layouts. + assignMemoryLayouts(); + if (loadToInfo.empty()) + return; + + // Filter out load ops that cannot be pipelined. + int resize = 0; + for (int i = 0, e = loadOpToIndLevelAndUse.size(); i < e; ++i) { + auto [loadOp, distance, use] = loadOpToIndLevelAndUse[i]; + if (loadToInfo.count(loadOp) != 0) + loadOpToIndLevelAndUse[resize++] = loadOpToIndLevelAndUse[i]; + } + loadOpToIndLevelAndUse.resize(resize); + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + + // The stage gap between chained loads--this allows us to "spread" loads + // with a non-one step in case the number of stages given by the user is + // large. + assert(numStages >= 2 && "requires num_stages=2 at least"); + unsigned stagesBetweenLoads = + llvm::divideCeil(numStages - 2, maxIndirectionLevel + 1); + LDBG("stagesBetweenLoads = " << stagesBetweenLoads); + + // Put the root uses of the loads in the last stage. + tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). + if (!isa(use)) { + schedule.insert(use, numStages - 1, rootUsersCluster); + rootUsers.insert(use); + } + } + + // Create a cluster for load ops at each indirection level. + SmallVector loadsClusters; + for (int i = 0; i <= maxIndirectionLevel; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); + } + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadsClusters[indLevel]); + } + + // Calculate distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + LLVM_DEBUG({ + LDBG("Chosen loads to pipeline:"); + for (const auto &[load, info] : loadToInfo) { + LDBG(" - load: " << *load); + LDBG(" distToUse: " << info.distToUse); + LDBG(" usedByDot: " << info.usedByDot); + } + }); +} + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +void StreamPipeliner::scheduleDependencies() { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; ++stage) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); + } + } +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +void StreamPipeliner::scheduleDistanceOneDependencies() { + auto getNestedOperands = [](Operation *op) { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap + dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + auto arg = dyn_cast(operand); + if (!arg || arg.getArgNumber() == 0 || arg.getOwner() != op.getBlock()) + continue; + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (!defOp || schedule.count(defOp) != 0) + continue; + if (isa(defOp)) { + // Exception: schedule loads with a distance of 1 together with the + // current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], true); + } + } + } +} + +void StreamPipeliner::scheduleRemainingToLastStage( + tt::CoarseSchedule::Cluster afterPrologue) { + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; + tt::CoarseSchedule::Cluster opCluster = schedule[op].second; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); + } +} + +// Create an allocation that can hold distance number of loadOp shapes. +Value StreamPipeliner::createAlloc(Operation *loadOp, + ttg::SharedEncodingAttr sharedEnc, + unsigned numBuffers) { + OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), numBuffers); + Type memdescType = tt::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + return builder.create(loadOp->getLoc(), memdescType, + Value()); +} + +// Convert load ops into shared memory allocation loads and apply +// multi-buffering based on the required number of buffers. +void StreamPipeliner::createStreamOps() { + // Calculate the number of buffers needed for each load. + // TODO: Use the precise number of buffers needed by the particular load. + int numBuffers = -1; + for (auto &[_, info] : loadToInfo) + numBuffers = std::max(numBuffers, info.distToUse); + LDBG("deduced shared memory buffer number = " << numBuffers); + + SmallVector> loadToAllocs; + for (auto &[loadOp, info] : loadToInfo) { + if (!info.sharedEncoding) + continue; + + Value alloc = createAlloc(loadOp, info.sharedEncoding, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + sharedMemAllocs.push_back(alloc); + loadToAllocs.emplace_back(loadOp, alloc); + } + + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); + + Location loc = forOp.getLoc(); + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, {extractIdx}); + forOp.erase(); + forOp = newForOp; + + // Create one counter for the extract indices to avoid creating long + // live range. + extractIdx = newForOp.getBody()->getArgument(newOperandIndex); + + builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + // Create a cluster for prefetching global reads for the dot. + tt::CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + + for (auto &[op, alloc] : loadToAllocs) { + if (auto loadOp = dyn_cast(op)) + createStreamCopy(loadOp, alloc, extractIdx, prefetchCluster); + } + // Patch the yield with the updated counters. + appendToForOpYield(forOp, {extractIdx}); +} + +bool StreamPipeliner::preprocessLoopAndBuildSchedule() { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + scheduleLoads(rootUsers); + if (loadToInfo.empty()) + return false; + + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + schedule.dump(); + }); + + // Convert the loads into shared memory allocations and loads from them. + createStreamOps(); + + LLVM_DEBUG({ + LDBG("Coarse schedule with stream loads:"); + schedule.dump(); + }); + + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + scheduleDependencies(); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + schedule.dump(); + }); + + scheduleDistanceOneDependencies(); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + schedule.dump(); + }); + + scheduleRemainingToLastStage(afterPrologue); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + schedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> coarseSchedule = + schedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [coarseSchedule](scf::ForOp, + std::vector> &s) { + s = std::move(coarseSchedule); + }; + + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + // Explicitly deallocate created allocations. + for (auto alloc : sharedMemAllocs) + builder.create(forOp.getLoc(), alloc); + return true; +} + +// Return true if the preconditions for pipelining the loop are met. +static bool checkPrecondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { return !operand.getDefiningOp(); })) + return false; + + // Don't pipeline outer loops. + auto hasNestedLoopInside = [forOp](Operation *op) { + if (op != forOp && isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }; + return !forOp->walk(hasNestedLoopInside).wasInterrupted(); +} + +bool StreamPipeliner::pipelineLoop() { + if (!checkPrecondition(forOp)) + return false; + + if (!preprocessLoopAndBuildSchedule()) + return false; + LDBG("Loop before sending to expander:\n" << *forOp); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + return succeeded(tt::pipelineForLoop(rewriter, forOp, options)); +} + +namespace { +struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base { + PipelinePass() = default; + PipelinePass(int32_t numStages) { this->numStages = numStages; } + + void runOnOperation() override { + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + for (scf::ForOp forOp : loops) { + StreamPipeliner sp(forOp, getNumStagesOrDefault(forOp)); + sp.pipelineLoop(); + } + } + +private: + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists, otherwise use the + // global control. + if (auto attr = forOp->getAttrOfType(tt::kNumStagesAttrName)) + return attr.getInt(); + return numStages; + } +}; +} // anonymous namespace + +std::unique_ptr +mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages) { + return std::make_unique(numStages); +} diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index ddc1feb2a..f97676aaf 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -1,3 +1,4 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TritonAMDGPUToLLVM/Passes.h" #include "TritonAMDGPUToLLVM/TargetUtils.h" #include "TritonAMDGPUTransforms/Passes.h" @@ -32,29 +33,45 @@ namespace py = pybind11; namespace { +const char *const amdTargetTriple = "amdgcn-amd-amdhsa"; + void init_triton_amd_passes_ttgpuir(py::module &&m) { using namespace mlir::triton; m.def("add_to_llvmir", [](mlir::PassManager &pm, const std::string &arch, bool ftz) { pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz)); }); - m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(createConvertBuiltinFuncToLLVMPass()); + m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm, bool ftz) { + pm.addPass(createConvertBuiltinFuncToLLVMPass(ftz)); + }); + m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) { + pm.addPass(createInsertInstructionSchedHintsPass()); }); + m.def("lower_instruction_sched_hints", + [](mlir::PassManager &pm, std::string variant) { + pm.addPass(createLowerInstructionSchedHintsPass(variant)); + }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) { pm.addPass( mlir::triton::AMD::createDecomposeUnsupportedConversionsPass(arch)); }); + ADD_PASS_WRAPPER_2("add_optimize_lds_usage", + mlir::triton::AMD::createOptimizeLDSUsagePass, + const std::string &, int32_t); ADD_PASS_WRAPPER_3("add_accelerate_matmul", mlir::createTritonAMDGPUAccelerateMatmulPass, const std::string, int, int); ADD_PASS_WRAPPER_0("add_optimize_epilogue", mlir::createTritonAMDGPUOptimizeEpiloguePass); + ADD_PASS_WRAPPER_0("add_canonicalize_pointers", + mlir::createTritonAMDGPUCanonicalizePointersPass); + ADD_PASS_WRAPPER_0("add_convert_to_buffer_ops", + mlir::createTritonAMDGPUConvertToBufferOpsPass); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); - ADD_PASS_WRAPPER_0("add_stream_pipeline", - mlir::createTritonAMDGPUStreamPipelinePass); + ADD_PASS_WRAPPER_1("add_stream_pipelinev2", + mlir::createTritonAMDGPUStreamPipelineV2Pass, int); } void addControlConstant(llvm::Module *module, const char *name, @@ -81,18 +98,22 @@ void init_triton_amd(py::module &&m) { auto passes = m.def_submodule("passes"); init_triton_amd_passes_ttgpuir(passes.def_submodule("ttgpuir")); - m.attr("TARGET_TRIPLE") = "amdgcn-amd-amdhsa"; + m.attr("TARGET_TRIPLE") = amdTargetTriple; m.attr("CALLING_CONV_AMDGPU_KERNEL") = (unsigned)llvm::CallingConv::AMDGPU_KERNEL; m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; + registry.insert(); // registry.insert(); mlir::registerROCDLDialectTranslation(registry); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); + m.def("attach_target_triple", + [](llvm::Module *module) { module->setTargetTriple(amdTargetTriple); }); + // Set target architecture ISA version m.def("set_isa_version", [](llvm::Module *module, const std::string &arch) { llvm::AMDGPU::IsaVersion version = llvm::AMDGPU::getIsaVersion(arch); @@ -144,8 +165,7 @@ void init_triton_amd(py::module &&m) { const std::string &features) { std::string error; - const char *targetTriple = "amdgcn-amd-amdhsa"; - llvm::Triple triple(targetTriple); + llvm::Triple triple(amdTargetTriple); const llvm::Target *target = llvm::TargetRegistry::lookupTarget(triple.normalize(), error); if (!target) @@ -157,11 +177,11 @@ void init_triton_amd(py::module &&m) { const llvm::MCTargetOptions mcOptions; std::unique_ptr mri( - target->createMCRegInfo(targetTriple)); + target->createMCRegInfo(amdTargetTriple)); std::unique_ptr mai( - target->createMCAsmInfo(*mri, targetTriple, mcOptions)); + target->createMCAsmInfo(*mri, amdTargetTriple, mcOptions)); std::unique_ptr sti( - target->createMCSubtargetInfo(targetTriple, arch, features)); + target->createMCSubtargetInfo(amdTargetTriple, arch, features)); llvm::MCContext ctx(triple, mai.get(), mri.get(), sti.get(), &srcMgr, &mcOptions); @@ -184,11 +204,9 @@ void init_triton_amd(py::module &&m) { target->createMCCodeEmitter(*mcii, ctx)); std::unique_ptr mab( target->createMCAsmBackend(*sti, *mri, mcOptions)); + std::unique_ptr ow(mab->createObjectWriter(svos)); mcStreamer.reset(target->createMCObjectStreamer( - triple, ctx, std::move(mab), mab->createObjectWriter(svos), - std::move(ce), *sti, mcOptions.MCRelaxAll, - mcOptions.MCIncrementalLinkerCompatible, - /*DWARFMustBeAtTheEnd=*/false)); + triple, ctx, std::move(mab), std::move(ow), std::move(ce), *sti)); std::unique_ptr parser( createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai)); @@ -239,4 +257,13 @@ void init_triton_amd(py::module &&m) { return false; } }); + + m.def("set_all_fn_arg_inreg", [](llvm::Function *fn) { + for (llvm::Argument &arg : fn->args()) { + // Check for incompatible attributes. + if (arg.hasByRefAttr() || arg.hasNestAttr()) + continue; + arg.addAttr(llvm::Attribute::InReg); + } + }); } diff --git a/third_party/amd/unittest/CMakeLists.txt b/third_party/amd/unittest/CMakeLists.txt new file mode 100644 index 000000000..bd3c0c6c0 --- /dev/null +++ b/third_party/amd/unittest/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Conversion) diff --git a/third_party/amd/unittest/Conversion/CMakeLists.txt b/third_party/amd/unittest/Conversion/CMakeLists.txt new file mode 100644 index 000000000..6d7a6b293 --- /dev/null +++ b/third_party/amd/unittest/Conversion/CMakeLists.txt @@ -0,0 +1,6 @@ +add_triton_ut(NAME TestOptimizeLDS +SRCS OptimizeLDSTest.cpp +LIBS + TritonAnalysis + TritonIR + TritonGPUIR) diff --git a/third_party/amd/unittest/Conversion/OptimizeLDSTest.cpp b/third_party/amd/unittest/Conversion/OptimizeLDSTest.cpp new file mode 100644 index 000000000..a9f112239 --- /dev/null +++ b/third_party/amd/unittest/Conversion/OptimizeLDSTest.cpp @@ -0,0 +1,42 @@ +//===- OptimizeLDSTest.cpp - Tests for OptimizeLDSUtility -----------------===// + +#include "third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h" +#include +#include + +namespace mlir { + +template bool checkProdEq(ArrayRef a) { + unsigned prod = + std::reduce(a.begin(), a.end(), 1u, std::multiplies()); + return prod == P; +} + +TEST(OptimizeLDSUtility, factorizePowerOf2) { + int numwarps; + int rank; + // check rank=1 generation + numwarps = 4; + rank = 1; + auto output1 = triton::AMD::factorizePowerOf2(numwarps, rank); + ASSERT_EQ(output1.size(), 1); + ASSERT_EQ(output1[0][0], numwarps); + // check rank=2 generation + numwarps = 8; + rank = 2; + auto output2 = triton::AMD::factorizePowerOf2(numwarps, rank); + ASSERT_EQ(output2.size(), 4); + ASSERT_TRUE(std::all_of(output2.begin(), output2.end(), checkProdEq<8>)); + ASSERT_TRUE(std::all_of(output2.begin(), output2.end(), + [](auto a) { return a.size() == 2; })); + // check rank=3 generation + numwarps = 8; + rank = 3; + auto output3 = triton::AMD::factorizePowerOf2(numwarps, rank); + ASSERT_EQ(output3.size(), 10); + ASSERT_TRUE(std::all_of(output3.begin(), output3.end(), checkProdEq<8>)); + ASSERT_TRUE(std::all_of(output3.begin(), output3.end(), + [](auto a) { return a.size() == 3; })); +} + +} // namespace mlir diff --git a/third_party/f2reduce/f2reduce.cpp b/third_party/f2reduce/f2reduce.cpp index e3aa7dfe9..0f87531b9 100644 --- a/third_party/f2reduce/f2reduce.cpp +++ b/third_party/f2reduce/f2reduce.cpp @@ -1,10 +1,19 @@ +#include "f2reduce.h" + #include #include -#include "f2reduce.h" -namespace { +#if defined(_MSC_VER) +#define RESTRICT __restrict +#define NO_INLINE __declspec(noinline) +#elif defined(__GNUC__) +#define RESTRICT __restrict__ +#define NO_INLINE __attribute__ ((noinline)) +#endif + +namespace f2reduce { -void swap_rows(uint64_t* __restrict__ x, uint64_t* __restrict__ y, uint64_t n) { +void swap_rows(uint64_t* RESTRICT x, uint64_t* RESTRICT y, uint64_t n) { for (uint64_t i = 0; i < n; i++) { uint64_t z = x[i]; x[i] = y[i]; y[i] = z; } @@ -12,40 +21,40 @@ void swap_rows(uint64_t* __restrict__ x, uint64_t* __restrict__ y, uint64_t n) { // the noinline attribute is necessary for gcc to properly vectorise this: template -__attribute__ ((noinline)) void memxor_lop7(uint64_t* __restrict__ dst, - const uint64_t* __restrict__ src1, - const uint64_t* __restrict__ src2, - const uint64_t* __restrict__ src3, - const uint64_t* __restrict__ src4, - const uint64_t* __restrict__ src5, - const uint64_t* __restrict__ src6) { +NO_INLINE void memxor_lop7(uint64_t* RESTRICT dst, + const uint64_t* RESTRICT src1, + const uint64_t* RESTRICT src2, + const uint64_t* RESTRICT src3, + const uint64_t* RESTRICT src4, + const uint64_t* RESTRICT src5, + const uint64_t* RESTRICT src6) { for (uint64_t i = 0; i < N; i++) { dst[i] ^= src1[i] ^ src2[i] ^ src3[i] ^ src4[i] ^ src5[i] ^ src6[i]; } } template -__attribute__ ((noinline)) void memxor_lop5(uint64_t* __restrict__ dst, - const uint64_t* __restrict__ src1, - const uint64_t* __restrict__ src2, - const uint64_t* __restrict__ src3, - const uint64_t* __restrict__ src4) { +NO_INLINE void memxor_lop5(uint64_t* RESTRICT dst, + const uint64_t* RESTRICT src1, + const uint64_t* RESTRICT src2, + const uint64_t* RESTRICT src3, + const uint64_t* RESTRICT src4) { for (uint64_t i = 0; i < N; i++) { dst[i] ^= src1[i] ^ src2[i] ^ src3[i] ^ src4[i]; } } template -__attribute__ ((noinline)) void memxor_lop3(uint64_t* __restrict__ dst, - const uint64_t* __restrict__ src1, - const uint64_t* __restrict__ src2) { +NO_INLINE void memxor_lop3(uint64_t* RESTRICT dst, + const uint64_t* RESTRICT src1, + const uint64_t* RESTRICT src2) { for (uint64_t i = 0; i < N; i++) { dst[i] ^= src1[i] ^ src2[i]; } } template -void memxor_inplace(uint64_t* __restrict__ dst, const uint64_t* __restrict__ src1, const uint64_t* __restrict__ src2) { +void memxor_inplace(uint64_t* RESTRICT dst, const uint64_t* RESTRICT src1, const uint64_t* RESTRICT src2) { for (uint64_t i = 0; i < N; i++) { dst[i] = src1[i] ^ src2[i]; } @@ -77,7 +86,7 @@ void split_k(int k, int* subkays) { * AVX512-friendly. */ template -void kronrod(uint64_t* __restrict__ matrix, uint64_t rows, uint64_t stride, const uint64_t* __restrict__ workspace, uint64_t* __restrict__ cache, const uint64_t* __restrict__ pivots, int k) { +void kronrod(uint64_t* RESTRICT matrix, uint64_t rows, uint64_t stride, const uint64_t* RESTRICT workspace, uint64_t* RESTRICT cache, const uint64_t* RESTRICT pivots, int k) { constexpr int logwidth = 5; static_assert(N <= (1ull << logwidth), "kronrod assumes that N <= 32"); @@ -121,10 +130,12 @@ void kronrod(uint64_t* __restrict__ matrix, uint64_t rows, uint64_t stride, cons if (N >= 32) { // prefetch 256 bytes, 15 rows later: uint64_t* ppp = matrix + (r + 15) * stride; +#if defined(__GNUC__) __builtin_prefetch(ppp); __builtin_prefetch(ppp + 8); __builtin_prefetch(ppp + 16); __builtin_prefetch(ppp + 24); +#endif } uint64_t w = workspace[r]; @@ -154,7 +165,7 @@ void kronrod(uint64_t* __restrict__ matrix, uint64_t rows, uint64_t stride, cons } -bool find_pivots(uint64_t* __restrict__ pivots, uint64_t* __restrict__ this_strip, uint64_t rows, uint64_t &starting_row, uint64_t *workspace, uint64_t &next_b, uint64_t final_b, int K, int& k) { +bool find_pivots(uint64_t* RESTRICT pivots, uint64_t* RESTRICT this_strip, uint64_t rows, uint64_t &starting_row, uint64_t *workspace, uint64_t &next_b, uint64_t final_b, int K, int& k) { // sorted copy, so that we can skip existing pivots: uint64_t spivots[64] = {(uint64_t) -1}; @@ -237,7 +248,7 @@ bool find_pivots(uint64_t* __restrict__ pivots, uint64_t* __restrict__ this_stri * The long switch statements are because we generate bespoke code for each * value of the chunk width N, which outperforms having a variable-length loop. */ -void chunked_kronrod(const uint64_t* __restrict__ pivots, uint64_t* __restrict__ matrix, uint64_t rows, uint64_t strips, uint64_t stride, const uint64_t* workspace, uint64_t* __restrict__ cache, int k) { +void chunked_kronrod(const uint64_t* RESTRICT pivots, uint64_t* RESTRICT matrix, uint64_t rows, uint64_t strips, uint64_t stride, const uint64_t* workspace, uint64_t* RESTRICT cache, int k) { uint64_t re = strips - 1; @@ -307,7 +318,7 @@ void chunked_kronrod(const uint64_t* __restrict__ pivots, uint64_t* __restrict__ * Find up to K pivot rows in this strip of 64 columns, remove them from all * other rows, and permute them into the correct places. */ -bool perform_K_steps(uint64_t* __restrict__ matrix, uint64_t* __restrict__ stripspace, uint64_t rows, uint64_t strips, uint64_t stride, uint64_t &starting_row, uint64_t *workspace, uint64_t* __restrict__ cache, uint64_t &next_b, int K, uint64_t final_b) { +bool perform_K_steps(uint64_t* RESTRICT matrix, uint64_t* RESTRICT stripspace, uint64_t rows, uint64_t strips, uint64_t stride, uint64_t &starting_row, uint64_t *workspace, uint64_t* RESTRICT cache, uint64_t &next_b, int K, uint64_t final_b) { memset(workspace, 0, 8 * rows); @@ -354,7 +365,7 @@ bool perform_K_steps(uint64_t* __restrict__ matrix, uint64_t* __restrict__ strip } -void inplace_rref_strided_K(uint64_t* __restrict__ matrix, uint64_t* __restrict__ stripspace, uint64_t rows, uint64_t cols, uint64_t stride, uint64_t *workspace, uint64_t *cache, int K) { +void inplace_rref_strided_K(uint64_t* RESTRICT matrix, uint64_t* RESTRICT stripspace, uint64_t rows, uint64_t cols, uint64_t stride, uint64_t *workspace, uint64_t *cache, int K) { uint64_t strips = (cols + 63) >> 6; diff --git a/third_party/iluvatar/CMakeLists.txt b/third_party/iluvatar/CMakeLists.txt new file mode 100644 index 000000000..189faac8e --- /dev/null +++ b/third_party/iluvatar/CMakeLists.txt @@ -0,0 +1,24 @@ +add_subdirectory(include) +add_subdirectory(lib) + +if(TRITON_BUILD_PYTHON_MODULE) + if(FLAGTREE_PLUGIN) + add_subdirectory(plugin) + add_triton_plugin(TritonILUVATAR + SHARED_LIB iluvatarTritonPlugin + ) + else() + find_library(iluvatarTritonPluginLib + NAMES + iluvatarTritonPlugin.so + PATHS + ${CMAKE_CURRENT_SOURCE_DIR} + REQUIRED + ) + add_triton_plugin(TritonILUVATAR + SHARED_LIB ${iluvatarTritonPluginLib} + ) + endif() +endif() + +add_subdirectory(bin) diff --git a/third_party/iluvatar/backend/__init__.py b/third_party/iluvatar/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/iluvatar/backend/compiler.py b/third_party/iluvatar/backend/compiler.py new file mode 100644 index 000000000..8778b7892 --- /dev/null +++ b/third_party/iluvatar/backend/compiler.py @@ -0,0 +1,232 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes, llvm, iluvatar + +from dataclasses import dataclass +import functools +from typing import Any, Tuple, Optional +import hashlib +import re +import tempfile +import signal +import os +import subprocess +from pathlib import Path +from triton.backends.iluvatar.driver import cuda_home_dirs + + +@functools.lru_cache(None) +def file_hash(path): + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +@dataclass(frozen=False) +class CUDAOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + # maxnreg corresponds to the ptx parameter .maxnreg, which controls the + # maximum number of 32-bit registers used by one thread. + maxnreg: Optional[int] = None + cluster_dims: tuple = (1, 1, 1) + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + allow_fp8e4b15: bool = False + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + debug: bool = False + backend_name: str = 'cuda' + use_sme: int = 0 + enable_sme: bool = True + num_vgpr: int = 0 + + def __post_init__(self): + default_libdir = cuda_home_dirs() + "/nvvm/libdevice/" + extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get('libdevice', None): + extern_libs['libdevice'] = os.getenv("TRITON_LIBDEVICE_PATH", + str(default_libdir + 'libdevice.compute_bi.10.bc')) + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + + def hash(self): + hash_dict = dict(self.__dict__) + hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class CUDABackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'cuda' + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self.capability = target.arch + assert isinstance(self.capability, int) + self.binary_ext = "cubin" + + def parse_options(self, opts) -> Any: + args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts} + # args["allow_fp8e4nv"] = self.capability >= 89 + # args["allow_fp8e4b15"] = self.capability < 90 + args["allow_fp8e4nv"] = False + args["allow_fp8e4b15"] = False + args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 + return CUDAOptions(**args) + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + def get_codegen_implementation(self): + codegen_fns = dict() + return codegen_fns + + def load_dialects(self, ctx): + iluvatar.load_dialects(ctx) + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, capability): + # TTIR -> TTGIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 64, opt.num_ctas) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + iluvatar.passes.ttgpuir.add_accelerate_matmul(pm, capability, opt.use_sme) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, True) + passes.common.add_cse(pm) + iluvatar.passes.ttgpuir.add_matmul_load(pm, capability) # only MR(71) support sme + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, True) + passes.common.add_cse(pm) + passes.ttgpuir.add_pipeline(pm, opt.num_stages) + passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, True) + passes.ttgpuir.add_remove_layout_conversions(pm) + iluvatar.passes.ttgpuir.add_matmul_mmastore(pm, capability) + passes.ttgpuir.add_remove_layout_conversions(pm) + iluvatar.passes.ttgpuir.add_mmareduce(pm, capability) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + return mod + + @staticmethod + def make_llir(src, metadata, options, capability): + mod = src + # TritonGPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + iluvatar.passes.ttgpuir.add_decompose_unsupported_conversions(pm) + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + passes.ttgpuir.add_allocate_shared_memory(pm) + iluvatar.passes.ttgpuir.add_to_llvmir(pm, capability) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + + passes.convert.add_scf_to_cf(pm) + passes.convert.add_cf_to_llvmir(pm) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + pm.run(mod) + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + iluvatar.set_nvvm_reflect_ftz(llvm_mod) + + # Set maxnreg on all kernels, if it was provided. + if options.maxnreg is not None: + for k in llvm_mod.get_functions(): + if not k.is_declaration() and k.is_external_linkage(): + k.set_nvvm_maxnreg(options.maxnreg) + + # Set kernel attributes first given this may affect later optimizations. + fns = [fn for fn in llvm_mod.get_functions() if not fn.is_declaration()] + # The public kernel should be kernel 0. + fns[0].set_calling_conv(iluvatar.CALLING_CONV_ILUVATAR_KERNEL) + if (options.num_vgpr > 0): + fns[0].add_fn_attr("iluvatar-num-vgpr", f"{options.num_vgpr}") + + if options.extern_libs: + paths = [path for (name, path) in options.extern_libs] + llvm.link_extern_libs(llvm_mod, paths) + + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, iluvatar.TARGET_TRIPLE) + + # Get some metadata + metadata["shared"] = src.get_int_attr("triton_gpu.shared") + ret = str(llvm_mod) + del llvm_mod + del context + return ret + + @staticmethod + def make_cubin(src, metadata, options, capability): + names = re.findall(r"define iluvatar_kernel void @([a-zA-Z_][a-zA-Z0-9_]*)", src) + assert len(names) == 1 + metadata["name"] = names[0] + + triple = "bi-iluvatar-ilurt" + proc = "ivcore11" + if capability == 70: + proc = "ivcore10" + elif capability == 71: + proc = "ivcore11" + elif capability == 80: + proc = "ivcore20" + else: + print("iluvatar not support current compute capability", capability) + cubin = iluvatar.translate_llvmir_to_cubin(src, triple, proc, '', [], options.enable_fp_fusion, False) + return cubin + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) + stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability) + + @functools.lru_cache() + def hash(self): + version = '' + return f'{version}-{self.capability}' diff --git a/third_party/iluvatar/backend/driver.c b/third_party/iluvatar/backend/driver.c new file mode 100644 index 000000000..0e9e1359a --- /dev/null +++ b/third_party/iluvatar/backend/driver.c @@ -0,0 +1,291 @@ +#include "cuda.h" +#include +#include +#define PY_SSIZE_T_CLEAN +#include +#include + +// Raises a Python exception and returns false if code is not CUDA_SUCCESS. +static bool gpuAssert(CUresult code, const char *file, int line) { + if (code == CUDA_SUCCESS) + return true; + + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + return false; +} + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +// Used to check if functions exist in old CUDA driver versions. +#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ + do { \ + if ((funcPointer) == NULL) { \ + (funcPointer) = (initializerFunction)(); \ + if ((funcPointer) == NULL) { \ + return NULL; \ + } \ + } \ + } while (0) + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + CUdevice device; + cuDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem; + int max_num_regs; + int multiprocessor_count; + int warp_size; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + CUDA_CHECK_AND_RETURN_NULL( + cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + CUfunction fun; + CUmodule mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + // create driver handles + CUcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx)); + if (!pctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx)); + } + + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + if (shared > 49152 && shared_optin > 49152) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int shared_total, shared_static; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( + &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static)); + } + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +typedef CUresult (*cuOccupancyMaxActiveClusters_t)( + int *numClusters, CUfunction func, const CUDA_LAUNCH_PARAMS *config); + +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \ + if (!libHandle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \ + return NULL; \ + } \ + /* Clear any existing error */ \ + dlerror(); \ + symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from libcuda.so.1"); \ + dlclose(libHandle); \ + return NULL; \ + } \ + return funcHandle; \ + } + +defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, + cuOccupancyMaxActiveClusters); + +static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { + int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, + maxActiveClusters = -1; + int shared = 0; + CUfunction func; + + if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, + &clusterDimY, &clusterDimZ)) { + return NULL; + } + + // Let each SM have one block + int maxActiveBlocks = 1; + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared)); + Py_END_ALLOW_THREADS; + + CUDA_LAUNCH_PARAMS config; + config.gridDimX = clusterDimX; + config.gridDimY = maxActiveBlocks * clusterDimY; + config.gridDimZ = clusterDimZ; + config.blockDimX = 128; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared; + config.hStream = 0; + + static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters, + getCuOccupancyMaxActiveClustersHandle); + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config)); + Py_END_ALLOW_THREADS; + return PyLong_FromLong(maxActiveClusters); +} + +static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { + long size; + if (!PyArg_ParseTuple(args, "l", &size)) { + return NULL; + } + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS; + + // Ensure we have an active context. + CUcontext ctx = NULL; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx)); + if (!ctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&ctx, /*device=*/0)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx)); + } + + // We can't set the fifo size after running a kernel that calls printf. This + // is true even if the set() call is a nop and the new size is the same as the + // old size. + // + // This is unfriendly, so check if the old size matches the new size, and skip + // the set() call if so. + size_t oldSize = 0; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE)); + if (oldSize != size) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size)); + } + + Py_END_ALLOW_THREADS; + return Py_None; +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided cubin into CUDA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, + "Python interface for cuOccupancyMaxActiveClusters function"}, + {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, + "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " + "controls how many bytes can be streamed from kernels before data starts " + "being dropped. This inherits all the limitations of this call; in " + "particular it's an error to change this value after launching any kernel " + "that calls printf()."}, + + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_cuda_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/iluvatar/backend/driver.py b/third_party/iluvatar/backend/driver.py new file mode 100644 index 000000000..9f62f8a63 --- /dev/null +++ b/third_party/iluvatar/backend/driver.py @@ -0,0 +1,411 @@ +import functools +import os +import hashlib +import subprocess +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.compiler import GPUTarget +from triton.backends.driver import GPUDriver + +dirname = os.path.dirname(os.path.realpath(__file__)) +include_dir = [os.path.join(dirname, "include")] +libdevice_dir = os.path.join(dirname, "lib") +libraries = ['cuda'] + + +def is_corex(): + import torch + return hasattr(torch, "corex") and torch.corex == True + + +@functools.lru_cache() +def cuda_home_dirs(): + loc = subprocess.check_output(["whereis", "ixsmi"]).decode().strip().split() + default_dir = '/usr/local/corex' + if (len(loc) > 1): + default_dir = os.path.dirname(os.path.dirname(loc[1])) + return os.getenv("CUDA_HOME", default=default_dir) + + +@functools.lru_cache() +def libcuda_dirs(): + env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH") + if env_libcuda_path: + return [env_libcuda_path] + + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() + # each line looks like the following: + # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 + locs = [line.split()[-1] for line in libs.splitlines() if "libcuda.so.1" in line] + dirs = [os.path.dirname(loc) for loc in locs] + env_ld_library_path = os.getenv("LD_LIBRARY_PATH") + if env_ld_library_path and not dirs: + dirs = [dir for dir in env_ld_library_path.split(":") if os.path.exists(os.path.join(dir, "libcuda.so.1"))] + msg = 'libcuda.so cannot found!\n' + if locs: + msg += 'Possible files are located at %s.' % str(locs) + msg += 'Please create a symlink of libcuda.so to any of the files.' + else: + msg += 'Please make sure GPU is set up and then run "/sbin/ldconfig"' + msg += ' (requires sudo) to refresh the linker cache.' + assert any(os.path.exists(os.path.join(path, 'libcuda.so.1')) for path in dirs), msg + return dirs + + +@functools.lru_cache() +def library_dirs(): + if is_corex: + cuda_path = cuda_home_dirs() + cuda_lib_dirs = os.path.join(cuda_path, "lib64") + return [libdevice_dir, cuda_lib_dirs] + else: + return [libdevice_dir, *libcuda_dirs()] + + +@functools.lru_cache() +def include_dirs(): + if is_corex(): + cuda_path = cuda_home_dirs() + cu_include_dir = os.path.join(cuda_path, "include") + return include_dir + [cu_include_dir] + else: + return include_dir + + +def compile_module_from_src(src, name): + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dirs(), include_dirs(), libraries) + cache.put(src, "main.c", binary=False) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod, cache_path + + +# ------------------------ +# Utils +# ------------------------ + + +class CudaUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CudaUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + mod, self.cache_path = compile_module_from_src( + Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters + self.set_printf_fifo_size = mod.set_printf_fifo_size + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "CUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def make_launcher(constants, signature, ids): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiKKOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + # generate glue code + params = [i for i in signature.keys() if i not in constants] + warp_size = 64 if is_corex() else 32 + src = f""" +#include \"cuda.h\" +#include +#include +#include + +static inline void gpuAssert(CUresult code, const char *file, int line) +{{ + if (code != CUDA_SUCCESS) + {{ + const char* prefix = "Triton Error [CUDA]: "; + const char* str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + }} +}} + +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +typedef CUresult (*cuLaunchKernelEx_t)(const CUDA_LAUNCH_PARAMS* config, CUfunction f, void** kernelParams, void** extra); + +static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ + // Open the shared library + void* handle = dlopen("libcuda.so.1", RTLD_LAZY); + if (!handle) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); + return NULL; + }} + // Clear any existing error + dlerror(); + cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); + // Check for errors + const char *dlsym_error = dlerror(); + if (dlsym_error) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1"); + return NULL; + }} + return cuLaunchKernelExHandle; +}} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; + if (gridX*gridY*gridZ > 0) {{ + if (num_ctas == 1) {{ + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} else {{ + CUDA_LAUNCH_PARAMS config; + config.gridDimX = gridX * clusterDimX; + config.gridDimY = gridY * clusterDimY; + config.gridDimZ = gridZ * clusterDimZ; + config.blockDimX = {warp_size} * num_warps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared_memory; + config.hStream = stream; + static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; + if (cuLaunchKernelExHandle == NULL) {{ + cuLaunchKernelExHandle = getLaunchKernelExHandle(); + }} + CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); + }} + }} +}} + +typedef struct _DevicePtrInfo {{ + CUdeviceptr dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + /* + uint64_t dev_ptr; + int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == CUDA_ERROR_INVALID_VALUE) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} + ptr_info.dev_ptr = dev_ptr; + */ + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; + }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; + }} + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +class CudaLauncher(object): + + def __init__(self, src, metadata): + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + src = make_launcher(constants, signature, ids) + mod, _ = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class CudaDriver(GPUDriver): + + def __init__(self): + self.utils = CudaUtils() # TODO: make static + self.launcher_cls = CudaLauncher + super().__init__() + + def get_current_target(self): + device = self.get_current_device() + capability = self.get_device_capability(device) + capability = capability[0] * 10 + capability[1] + warp_size = 64 + return GPUTarget("cuda", capability, warp_size) + + def get_cache_path(self): + return self.utils.cache_path + + @staticmethod + def is_active(): + import torch + return torch.cuda.is_available() and (torch.version.hip is None) diff --git a/third_party/iluvatar/bin/CMakeLists.txt b/third_party/iluvatar/bin/CMakeLists.txt new file mode 100644 index 000000000..6a645a52a --- /dev/null +++ b/third_party/iluvatar/bin/CMakeLists.txt @@ -0,0 +1,95 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + +add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED) + +# TODO: what's this? +llvm_update_compile_flags(triton-opt) +target_link_libraries(triton-opt PRIVATE + TritonLLVMIR + TritonAnalysis + TritonTransforms + TritonGPUTransforms + MLIRGPUToROCDLTransforms + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # MLIR core + MLIROptLib + MLIRPass + MLIRTransforms +) +set_target_properties(triton-opt PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) + +mlir_check_all_link_libraries(triton-opt) + +add_llvm_executable(triton-reduce triton-reduce.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(triton-reduce) + +llvm_update_compile_flags(triton-reduce) +target_link_libraries(triton-reduce PRIVATE + TritonLLVMIR + TritonAnalysis + TritonTransforms + TritonGPUTransforms + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # MLIR core + MLIRReduceLib + MLIRPass + MLIRTransforms +) +set_target_properties(triton-reduce PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) + +mlir_check_all_link_libraries(triton-reduce) + +add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(triton-lsp) + +llvm_update_compile_flags(triton-lsp) +target_link_libraries(triton-lsp PRIVATE + TritonAnalysis + TritonTransforms + TritonGPUTransforms + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # MLIR core + MLIRLspServerLib + MLIRPass + MLIRTransforms +) +set_target_properties(triton-lsp PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) + +mlir_check_all_link_libraries(triton-lsp) + + +add_llvm_executable(triton-llvm-opt + triton-llvm-opt.cpp + + PARTIAL_SOURCES_INTENDED + DEPENDS + intrinsics_gen + SUPPORT_PLUGINS + ) +target_link_libraries(triton-llvm-opt PRIVATE + TritonLLVMIR + + LLVMAnalysis + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(triton-llvm-opt) +set_target_properties(triton-llvm-opt PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) diff --git a/third_party/iluvatar/bin/RegisterTritonDialects.h b/third_party/iluvatar/bin/RegisterTritonDialects.h new file mode 100644 index 000000000..66693b9e2 --- /dev/null +++ b/third_party/iluvatar/bin/RegisterTritonDialects.h @@ -0,0 +1,109 @@ +#pragma once + +#ifdef __NVIDIA__ +#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#endif +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#ifdef __NVIDIA__ +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#endif + +// Below headers will allow registration to ROCm passes +#ifdef __AMD__ +#include "TritonAMDGPUToLLVM/Passes.h" +#include "TritonAMDGPUTransforms/Passes.h" +#include "TritonAMDGPUTransforms/TritonGPUConversion.h" +#endif + +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#ifdef __NVIDIA__ +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#endif + +#ifdef __NVIDIA__ +#include "nvidia/include/NVGPUToLLVM/Passes.h" +#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" +#endif +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" + +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/InitAllPasses.h" + +#include "python/src/plugin.h" + +#include +#include +#include + +#ifndef __ILUVATAR__ +namespace mlir { +namespace test { +void registerTestAliasPass(); +void registerTestAlignmentPass(); +void registerTestAllocationPass(); +void registerTestMembarPass(); +} // namespace test +} // namespace mlir +#endif + +using registerConvertTritonGPUToLLVMPassFunc = void (*)(); +DEFINE_LOAD_FUNC(registerConvertTritonGPUToLLVMPass) + +inline void registerTritonDialects(mlir::DialectRegistry ®istry) { + mlir::registerAllPasses(); + mlir::registerTritonPasses(); + mlir::triton::gpu::registerTritonGPUPasses(); +#ifdef __NVIDIA__ + mlir::registerTritonNvidiaGPUPasses(); + mlir::test::registerTestAliasPass(); + mlir::test::registerTestAlignmentPass(); + mlir::test::registerTestAllocationPass(); + mlir::test::registerTestMembarPass(); +#endif + mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::triton::registerAllocateSharedMemoryPass(); +#ifdef __ILUVATAR__ + DEFINE_CALL_LOAD_FUNC(iluvatar, registerConvertTritonGPUToLLVMPass) + func(); +#endif +#ifdef __NVIDIA__ + mlir::triton::registerConvertTritonGPUToLLVMPass(); + mlir::triton::registerConvertNVGPUToLLVMPass(); + mlir::triton::registerDecomposeUnsupportedNVIDIAConversions(); +#endif + mlir::registerLLVMDIScope(); + +#ifdef __AMD__ + // TritonAMDGPUToLLVM passes + mlir::triton::registerConvertTritonAMDGPUToLLVM(); + mlir::triton::registerConvertBuiltinFuncToLLVM(); + mlir::triton::registerDecomposeUnsupportedAMDConversions(); + + // TritonAMDGPUTransforms passes + mlir::registerTritonAMDGPUAccelerateMatmul(); + mlir::registerTritonAMDGPUOptimizeEpilogue(); + mlir::registerTritonAMDGPUReorderInstructions(); + mlir::registerTritonAMDGPUStreamPipeline(); +#endif + + // TODO: register Triton & TritonGPU passes + registry.insert(); +} diff --git a/third_party/iluvatar/bin/triton-llvm-opt.cpp b/third_party/iluvatar/bin/triton-llvm-opt.cpp new file mode 100644 index 000000000..1ec804cb5 --- /dev/null +++ b/third_party/iluvatar/bin/triton-llvm-opt.cpp @@ -0,0 +1,121 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt OutputFilename("o", + cl::desc("Override output filename"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple::normalize(TargetTriple)); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + // Default to standard output. + if (OutputFilename.empty()) + OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + Out->keep(); + return 0; +} diff --git a/third_party/iluvatar/bin/triton-lsp.cpp b/third_party/iluvatar/bin/triton-lsp.cpp new file mode 100644 index 000000000..b185b0374 --- /dev/null +++ b/third_party/iluvatar/bin/triton-lsp.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/third_party/iluvatar/bin/triton-opt.cpp b/third_party/iluvatar/bin/triton-opt.cpp new file mode 100644 index 000000000..2d2570771 --- /dev/null +++ b/third_party/iluvatar/bin/triton-opt.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "Triton (GPU) optimizer driver\n", registry)); +} diff --git a/third_party/iluvatar/bin/triton-reduce.cpp b/third_party/iluvatar/bin/triton-reduce.cpp new file mode 100644 index 000000000..8235f8fc8 --- /dev/null +++ b/third_party/iluvatar/bin/triton-reduce.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-reduce/MlirReduceMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::mlirReduceMain(argc, argv, context)); +} diff --git a/third_party/iluvatar/include/CMakeLists.txt b/third_party/iluvatar/include/CMakeLists.txt new file mode 100644 index 000000000..109c292fe --- /dev/null +++ b/third_party/iluvatar/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton) diff --git a/third_party/iluvatar/include/triton/Analysis/Alias.h b/third_party/iluvatar/include/triton/Analysis/Alias.h new file mode 100644 index 000000000..a06df5ae2 --- /dev/null +++ b/third_party/iluvatar/include/triton/Analysis/Alias.h @@ -0,0 +1,96 @@ +#ifndef TRITON_ANALYSIS_ALIAS_H +#define TRITON_ANALYSIS_ALIAS_H + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { + +class AliasInfo { +public: + AliasInfo() = default; + AliasInfo(Value value) { insert(value); } + + void insert(Value value) { allocs.insert(value); } + + const DenseSet &getAllocs() const { return allocs; } + + bool operator==(const AliasInfo &other) const { + return allocs == other.allocs; + } + + /// The pessimistic value state of a value without alias + static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) { + return AliasInfo(); + } + static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); } + + /// The union of both arguments + static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs); + + void print(raw_ostream &os) const { + llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); }); + } + +private: + /// The set of allocated values that are aliased by this lattice. + /// For now, we only consider aliased value produced by the following + /// situations: + /// 1. values returned by scf.yield + /// 2. block arguments in scf.for + /// Example: + /// alloc v1 alloc v2 + /// | | + /// |--------------| |------------| + /// scf.for v3 scf.for v4 scf.for v5 + /// | + /// scf.yield v6 + /// + /// v1's alloc [v1] + /// v2's alloc [v2] + /// v3's alloc [v1] + /// v4's alloc [v1, v2] + /// v5's alloc [v2] + /// v6's alloc [v1] + /// + /// Therefore, v1's liveness range is the union of v3, v4, and v6 + /// v2's liveness range is the union of v4 and v5. + DenseSet allocs; +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Alias Analysis +//===----------------------------------------------------------------------===// +class SharedMemoryAliasAnalysis + : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +public: + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::SparseForwardDataFlowAnalysis; + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + + /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. + /// Given two values, returns their aliasing behavior. + AliasResult alias(Value lhs, Value rhs); + + /// Returns the modify-reference behavior of `op` on `location`. + ModRefResult getModRef(Operation *op, Value location); + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, lattice->join( + AliasInfo::getPessimisticValueState(lattice->getPoint()))); + } + + /// Computes if the alloc set of the results are changed. + void + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALIAS_H diff --git a/third_party/iluvatar/include/triton/Analysis/Allocation.h b/third_party/iluvatar/include/triton/Analysis/Allocation.h new file mode 100644 index 000000000..92f63eb48 --- /dev/null +++ b/third_party/iluvatar/include/triton/Analysis/Allocation.h @@ -0,0 +1,257 @@ +#ifndef TRITON_ANALYSIS_ALLOCATION_H +#define TRITON_ANALYSIS_ALLOCATION_H + +#include "triton/Analysis/Utility.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include + +namespace mlir { + +namespace triton { +class AllocationAnalysis; + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec); +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op); + +} // namespace triton + +/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h +/// A class that represents an interval, specified using a start and an end +/// values: [Start, End). +template class Interval { +public: + Interval() {} + Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); } + T start() const { return Start; } + T end() const { return End; } + T size() const { return End - Start; } + bool contains(T Addr) const { return Start <= Addr && Addr < End; } + bool intersects(const Interval &R) const { + return Start < R.End && R.Start < End; + } + bool operator==(const Interval &R) const { + return Start == R.Start && End == R.End; + } + bool operator!=(const Interval &R) const { return !(*this == R); } + bool operator<(const Interval &R) const { + return std::make_pair(Start, End) < std::make_pair(R.Start, R.End); + } + +private: + T Start = std::numeric_limits::min(); + T End = std::numeric_limits::max(); +}; + +template Interval(T, T) -> Interval; + +class Allocation { +public: + /// A unique identifier for shared memory buffers + using BufferId = size_t; + using BufferIdSetT = DenseSet; + using FuncAllocMapT = CallGraph::FuncDataMapT; + + static constexpr BufferId InvalidBufferId = + std::numeric_limits::max(); + + Allocation() = default; + /// Creates a new Allocation analysis that computes the shared memory + /// information for all associated shared memory values. + explicit Allocation(Operation *operation) : operation(operation) {} + + /// Runs allocation analysis on the given top-level operation. + void run(FuncAllocMapT &funcAllocMap); + + /// Returns the operation this analysis was constructed from. + Operation *getOperation() const { return operation; } + + /// Returns the offset of the given buffer in the shared memory. + size_t getOffset(BufferId bufferId) const { + return bufferSet.at(bufferId).offset; + } + + /// Returns the size of the given buffer in the shared memory. + size_t getAllocatedSize(BufferId bufferId) const { + return bufferSet.at(bufferId).size; + } + + /// Returns the allocated interval of the given buffer. + Interval getAllocatedInterval(BufferId bufferId) const { + auto &buffer = bufferSet.at(bufferId); + return Interval(buffer.offset, buffer.offset + buffer.size); + } + + /// Returns the buffer id of the given value. + /// This interface only returns the allocated buffer id. + /// If you want to get all the buffer ids that are associated with the given + /// value, including alias buffers, use getBufferIds. + BufferId getBufferId(Value value) const { + if (valueBuffer.count(value)) { + return valueBuffer.lookup(value)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns all the buffer ids of the given value, including alias buffers. + BufferIdSetT getBufferIds(Value value) const { + BufferIdSetT bufferIds; + auto allocBufferId = getBufferId(value); + if (allocBufferId != InvalidBufferId) + bufferIds.insert(allocBufferId); + for (auto *buffer : aliasBuffer.lookup(value)) { + if (buffer->id != InvalidBufferId) + bufferIds.insert(buffer->id); + } + return bufferIds; + } + + /// Returns the scratch buffer id of the given value. + BufferId getBufferId(Operation *operation) const { + if (opScratch.count(operation)) { + return opScratch.lookup(operation)->id; + } else if (opVirtual.count(operation)) { + return opVirtual.lookup(operation)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns if the given buffer is a virtual buffer. + bool isVirtualBuffer(BufferId bufferId) const { + return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual; + } + + /// Returns the size of total shared memory allocated + size_t getSharedMemorySize() const { return sharedMemorySize; } + +private: + /// A class that represents a shared memory buffer + struct BufferT { + /// Explicit: triton_gpu.local_alloc + /// Scratch: triton_gpu.convert_layout + /// Virtual: triton.call + enum class BufferKind { Explicit, Scratch, Virtual }; + + /// MT: thread-safe + inline static std::atomic nextId = 0; + + BufferKind kind; + BufferId id; + size_t size; + size_t alignment; + size_t offset; + + bool operator==(const BufferT &other) const { return id == other.id; } + bool operator<(const BufferT &other) const { return id < other.id; } + + BufferT() : BufferT(BufferKind::Explicit, 0) {} + BufferT(BufferKind kind, size_t size, size_t alignment = 4, + size_t offset = 0) + : kind(kind), id(nextId++), size(size), alignment(alignment), + offset(offset) {} + + size_t setOffsetAligned(size_t newOffset) { + return offset = llvm::alignTo(newOffset, alignment); + } + }; + + /// Op -> Scratch Buffer + using OpScratchMapT = DenseMap; + /// Value -> Explicit Buffer + using ValueBufferMapT = llvm::MapVector; + /// Value -> Alias Buffer + using AliasBufferMapT = llvm::MapVector>; + /// BufferId -> Buffer + using BufferSetT = std::map; + +private: + template + void addBuffer(KeyType &key, Args &&...args) { + auto buffer = BufferT(Kind, std::forward(args)...); + bufferSet[buffer.id] = std::move(buffer); + if constexpr (Kind == BufferT::BufferKind::Explicit) { + valueBuffer[key] = &bufferSet[buffer.id]; + } else if constexpr (Kind == BufferT::BufferKind::Virtual) { + opVirtual[key] = &bufferSet[buffer.id]; + } else { + opScratch[key] = &bufferSet[buffer.id]; + } + } + + void addAlias(Value value, Value alloc) { + aliasBuffer[value].insert(valueBuffer[alloc]); + } + +private: + Operation *operation = nullptr; + OpScratchMapT opScratch; + OpScratchMapT opVirtual; + ValueBufferMapT valueBuffer; + AliasBufferMapT aliasBuffer; + BufferSetT bufferSet; + size_t sharedMemorySize = 0; + + friend class triton::AllocationAnalysis; +}; + +/// Static analysis that computes the allocation of shared memory buffers +/// of the entire call graph. +/// The allocation is performed in a post-order walk of the call graph. +/// Each call op is treated like convert_layout that allocates a scratch buffer. +/// At each call, we compute the start offset of the scratch buffer and pass it +/// as an argument to the callee. +class ModuleAllocation : public CallGraph { +public: + using FuncOffsetMapT = DenseMap; + + explicit ModuleAllocation(ModuleOp moduleOp) + : CallGraph(moduleOp) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); + if (inserted) + iter->second.run(funcMap); + }); + } + + size_t getSharedMemorySize() { + size_t size = 0; + for (auto funcOp : getRoots()) { + auto *alloc = getFuncData(funcOp); + size = std::max(size, alloc->getSharedMemorySize()); + } + return size; + } + + size_t getSharedMemorySize(FunctionOpInterface funcOp) { + return getFuncData(funcOp)->getSharedMemorySize(); + } + + void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) { + sharedMemoryValue[funcOp] = value; + } + + Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) { + return sharedMemoryValue[funcOp]; + } + +private: + FuncOffsetMapT sharedMemoryValue; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALLOCATION_H diff --git a/third_party/iluvatar/include/triton/Analysis/AxisInfo.h b/third_party/iluvatar/include/triton/Analysis/AxisInfo.h new file mode 100644 index 000000000..22a7ed554 --- /dev/null +++ b/third_party/iluvatar/include/triton/Analysis/AxisInfo.h @@ -0,0 +1,215 @@ +#ifndef TRITON_ANALYSIS_AXISINFO_H +#define TRITON_ANALYSIS_AXISINFO_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include +#include + +namespace mlir::triton { + +//===----------------------------------------------------------------------===// +// AxisInfo +//===----------------------------------------------------------------------===// + +/// This lattice value represents known information on the axes of a lattice. +class AxisInfo { +public: + typedef SmallVector DimVectorT; + +public: + AxisInfo() : AxisInfo({}, {}, {}) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy) + : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy, + std::optional constantValue) + : contiguity(contiguity), divisibility(divisibility), + constancy(constancy), constantValue(constantValue) { + assert(divisibility.size() == contiguity.size()); + assert(constancy.size() == contiguity.size()); + } + + // contiguity[d] is the length of the shortest sequence of contiguous integers + // along dimension d. + // + // If we have an array of N elements with a contiguity value C, then the array + // can be divided into a list of N/C sequences of C contiguous elements. + // Since we have N = 2^k, C must be a power of two. + // + // For example, the 2D array + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has contiguity [1, 4], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27], + // [18, 22, 26, 30], + // [19, 23, 27, 31]] + // + // has contiguity [2, 1]. + int64_t getContiguity(size_t dim) const { return contiguity[dim]; } + const DimVectorT &getContiguity() const { return contiguity; } + + // divisibility[d] is the largest power of two that divides the first element + // of all groups of length contiguity[d] along dimension d. + // + // For example, + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has divisibility [1, 2], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27]] + // + // has divisibility [4, 1]. + // + // On the other hand, + // + // [0, 1, 2, 0, 4, 5, 6, 7] + // + // has divisibility 1 because its contiguity is 1. + int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } + const DimVectorT &getDivisibility() const { return divisibility; } + + // constancy[d] is the length of the shortest sequence of repeating integers + // along dimension d. + // + // This is particularly useful to infer the contiguity of operations (e.g. + // add) involving a constant. + // + // If we have an array of N elements, with a constancy value C, then the array + // can be divided into a list of N/C sequences of C elements with the same + // value. Since we have N = 2^k, C must be a power of two. + // + // For example + // + // [[8, 8, 8, 8, 12, 12, 12, 12], + // [16, 16, 16, 16, 20, 20, 20, 20]] + // + // has constancy [1, 4]. + int64_t getConstancy(size_t dim) const { return constancy[dim]; } + const DimVectorT &getConstancy() const { return constancy; } + + int getRank() const { return contiguity.size(); } + + std::optional getConstantValue() const { return constantValue; } + + template + static void + initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, + DimVectorT *divisibility, DimVectorT *constancy); + + bool operator==(const AxisInfo &other) const { + return contiguity == other.contiguity && + divisibility == other.divisibility && constancy == other.constancy && + constantValue == other.constantValue; + } + + static AxisInfo getPessimisticValueState(Value value); + + // The gcd of both arguments for each dimension + static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); + + void print(raw_ostream &os) const { + auto print = [&](StringRef name, DimVectorT vec) { + os << name << " = ["; + llvm::interleaveComma(vec, os); + os << "]"; + }; + print("contiguity", contiguity); + print(", divisibility", divisibility); + print(", constancy", constancy); + os << ", constant_value = "; + if (constantValue) + os << *constantValue; + else + os << ""; + } + +private: + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + + // The constant value of the lattice if we can infer it. + std::optional constantValue; +}; + +// Module level axis info analysis based on the call graph, assuming that we do +// not have recursive functions. +// +// Since each function will be called multiple times, we need to calculate the +// axis info based on the axis info of all the callers. In the future, we can +// perform optimization using function cloning so that each call site will have +// unique axis info. +using AxisInfoMapT = DenseMap; +class ModuleAxisInfoAnalysis : public CallGraph { +public: + explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, AxisInfoMapT{}); + }); + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + SymbolTableCollection symbolTable; + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = + dyn_cast(callOp.resolveCallable(&symbolTable)); + update(callOp, callee); + }); + } + } + + AxisInfo *getAxisInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } + + unsigned getPtrContiguity(Value ptr); + unsigned getPtrAlignment(Value ptr); + unsigned getMaskAlignment(Value mask); + +private: + void initialize(FunctionOpInterface funcOp); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; + +} // namespace mlir::triton + +#endif diff --git a/third_party/iluvatar/include/triton/Analysis/Membar.h b/third_party/iluvatar/include/triton/Analysis/Membar.h new file mode 100644 index 000000000..43bd5d15b --- /dev/null +++ b/third_party/iluvatar/include/triton/Analysis/Membar.h @@ -0,0 +1,154 @@ +#ifndef TRITON_ANALYSIS_MEMBAR_H +#define TRITON_ANALYSIS_MEMBAR_H + +#include "Allocation.h" +#include "llvm/ADT/SmallPtrSet.h" + +#include + +namespace mlir { + +class OpBuilder; + +struct BlockInfo { + using BufferIdSetT = Allocation::BufferIdSetT; + using IntervalSetT = std::set>; + + IntervalSetT syncReadIntervals; + IntervalSetT syncWriteIntervals; + + BlockInfo() = default; + + /// Unions two BlockInfo objects. + BlockInfo &join(const BlockInfo &other) { + syncReadIntervals.insert(other.syncReadIntervals.begin(), + other.syncReadIntervals.end()); + syncWriteIntervals.insert(other.syncWriteIntervals.begin(), + other.syncWriteIntervals.end()); + return *this; + } + + /// Returns true if intervals in two BlockInfo objects are intersected. + bool isIntersected(const BlockInfo &other) const { + return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) || + /*WAR*/ + isIntersected(syncReadIntervals, other.syncWriteIntervals) || + /*WAW*/ + isIntersected(syncWriteIntervals, other.syncWriteIntervals); + } + + /// Clears the intervals because a barrier is inserted. + void sync() { + syncReadIntervals.clear(); + syncWriteIntervals.clear(); + } + + /// Compares two BlockInfo objects. + bool operator==(const BlockInfo &other) const { + return syncReadIntervals == other.syncReadIntervals && + syncWriteIntervals == other.syncWriteIntervals; + } + + bool operator!=(const BlockInfo &other) const { return !(*this == other); } + +private: + bool isIntersected(const IntervalSetT &lhsIntervalSet, + const IntervalSetT &rhsIntervalSet) const { + for (auto &lhs : lhsIntervalSet) + for (auto &rhs : rhsIntervalSet) + if (lhs.intersects(rhs)) + return true; + return false; + } +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Barrier Analysis +//===----------------------------------------------------------------------===// +class MembarAnalysis { +public: + using FuncBlockInfoMapT = CallGraph::FuncDataMapT; + /// Creates a new Membar analysis that generates the shared memory barrier + /// in the following circumstances: + /// - RAW: If a shared memory write is followed by a shared memory read, and + /// their addresses are intersected, a barrier is inserted. + /// - WAR: If a shared memory read is followed by a shared memory write, and + /// their addresses are intersected, a barrier is inserted. + /// The following circumstances do not require a barrier: + /// - WAW: not possible because overlapped memory allocation is not allowed. + /// - RAR: no write is performed. + /// Temporary storage of operations such as Reduce are considered as both + /// a shared memory read. If the temporary storage is written but not read, + /// it is considered as the problem of the operation itself but not the membar + /// analysis. + MembarAnalysis() = default; + explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {} + + /// Runs the membar analysis to the given operation, inserts a barrier if + /// necessary. + void run(FuncBlockInfoMapT &funcBlockInfoMap); + +private: + /// Applies the barrier analysis based on the SCF dialect, in which each + /// region has a single basic block only. + /// Example: + /// region1 + /// op1 + /// op2 (scf.if) + /// region2 + /// op3 + /// op4 + /// region3 + /// op5 + /// op6 + /// op7 + /// TODO: Explain why we don't use ForwardAnalysis: + void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder); + + /// Updates the BlockInfo operation based on the operation. + void update(Operation *operation, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder); + + /// Collects the successors of the terminator + void visitTerminator(Operation *operation, SmallVector &successors); + + void insertBarrier(Operation *operation, OpBuilder *builder); + +private: + Allocation *allocation = nullptr; +}; + +/// Postorder traversal on the callgraph to insert membar instructions +/// of each function. +/// Each function maintains a BlockInfo map that includes all potential buffers +/// after returning. This way users do not have to explicitly insert membars +/// before and after function calls, but might be a bit conservative. +class ModuleMembarAnalysis : public CallGraph { +public: + ModuleMembarAnalysis(ModuleAllocation *moduleAllocation) + : CallGraph(moduleAllocation->getModuleOp()), + moduleAllocation(moduleAllocation) {} + + void run() { + walk( + // Pre-order walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order walk callback + [&](FunctionOpInterface funcOp) { + auto *allocation = moduleAllocation->getFuncData(funcOp); + auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo()); + if (inserted) { + MembarAnalysis analysis(allocation); + analysis.run(funcMap); + } + }); + } + +private: + ModuleAllocation *moduleAllocation; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_MEMBAR_H diff --git a/third_party/iluvatar/include/triton/Analysis/Utility.h b/third_party/iluvatar/include/triton/Analysis/Utility.h new file mode 100644 index 000000000..528d43575 --- /dev/null +++ b/third_party/iluvatar/include/triton/Analysis/Utility.h @@ -0,0 +1,377 @@ +#ifndef TRITON_ANALYSIS_UTILITY_H +#define TRITON_ANALYSIS_UTILITY_H + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + +inline bool isZeroConst(Value v) { + auto constantOp = v.getDefiningOp(); + if (!constantOp) + return false; + if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + return false; +} + +class ReduceOpHelper { +public: + explicit ReduceOpHelper(triton::ReduceOp op) + : op(op.getOperation()), axis(op.getAxis()) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = op.getElementTypes(); + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + + ArrayRef getSrcShape() { return srcShape; } + + Attribute getSrcLayout() { return srcEncoding; } + + triton::ReduceOp getOperation() { return op; } + + bool isReductionOnLayoutFastAxis(); + + unsigned getThreadOffsetOnReductionAxis(); + + bool isWarpSynchronous(); + + unsigned getInterWarpSize(); + + unsigned getIntraWarpSize(); + + unsigned getInterWarpSizeWithUniqueData(); + + unsigned getIntraWarpSizeWithUniqueData(); + + unsigned getThreadsReductionAxis(); + + SmallVector getScratchConfig(); + + SmallVector getOrderWithAxisAtBeginning(); + + unsigned getScratchSizeInBytes(); + + bool isSupportedLayout(); + + bool isReduceWithinCTA(); + + unsigned getAxis() { return axis; } + +private: + triton::ReduceOp op; + ArrayRef srcShape; + Attribute srcEncoding; + SmallVector srcElementTypes; + int axis; +}; + +class ScanLoweringHelper { +public: + explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = op.getElementTypes(); + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + // Return true if the lowering of the scan op is supported. + bool isSupported(); + // Return the number of elements per thread along axis dim. + unsigned getAxisNumElementsPerThread(); + // Return the number of elements per thread along non-axis dims. + unsigned getNonAxisNumElementsPerThread(); + // Return the number of threads per warp along non-axis dims. + unsigned getNonAxisNumThreadsPerWarp(); + // Return the flat numbers of threads computing independent scan results. + unsigned getNonAxisNumThreadsPerCTA(); + // Return the number of warps per CTA along axis dim. + unsigned getAxisNumWarps(); + // Return the number of warps per CTA along axis dim with unique data. + unsigned getAxisNumWarpsWithUniqueData(); + // Return the number of threads per warp along axis dim. + unsigned getAxisNumThreadsPerWarp(); + // Return the number of threads per warp along axis dim with unique data. + unsigned getAxisNumThreadsPerWarpWithUniqueData(); + // Return the number of blocks along axis dim. + unsigned getAxisNumBlocks(); + // Return the number of blocks along non axis dim. + unsigned getNonAxisNumBlocks(); + // Return the size of the scratch space needed for scan lowering. + unsigned getScratchSizeInBytes(); + // Return the number of elements of the scratch space needed for scan + // lowering. + unsigned getScratchSizeInElems(); + + // Stride between contiguous element along axis dim. + unsigned getAxisElementStride(); + // Stride between contiguous threads along axis dim. + unsigned getAxisThreadStride(); + // Stride between contiguous blocks along axis dim. + unsigned getAxisBlockStride(); + + Location getLoc() { return scanOp.getLoc(); } + unsigned getAxis() { return scanOp.getAxis(); } + bool getReverse() { return scanOp.getReverse(); } + triton::gpu::BlockedEncodingAttr getEncoding(); + llvm::ArrayRef getShape() { return srcShape; } + unsigned getNumOperands() { return scanOp.getNumOperands(); } + SmallVector getElementTypes() { return srcElementTypes; } + Attribute getSrcLayout() { return srcEncoding; } + Region &getCombineOp(); + +private: + triton::ScanOp scanOp; + Attribute srcEncoding; + llvm::ArrayRef srcShape; + SmallVector srcElementTypes; +}; + +// Decomposes a reshape into simpler pieces. +// +// As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2]. +// You might explain what this does as follows. +// +// - Split the first input dimension into [2,2]. +// - Take the remaining two input dimensions, merge them into a single [16] +// dim, and then split that into [8,2]. +// +// In general, a reshape can be described a sequence of smushing one or more +// input dimensions together and then breaking them apart into one or more +// output dimensions. So we could represent the example above as follows. +// +// [ +// ([0], [0, 1]), # input dim [0] -> output dims [0, 1] +// ([1, 2], [2, 3]), # input dims [1, 2] -> output dims [2, 3] +// ] +// +// Notice that the input dims (first tuple elems) appear in sequential order if +// you read left-to-right-top-to-bottom, and so do the output dims. +// +// This function returns the above decomposition. +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, ArrayRef dstShape); + +bool maybeSharedAllocationOp(Operation *op); + +bool supportMFMA(triton::DotOp op); + +bool supportWMMA(triton::DotOp op); + +bool supportMMA(triton::DotOp op, int version); + +bool supportMMA(Value value, int version); + +bool isSingleValue(Value value); + +bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); + +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); + +bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); + +bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy); + +// Return true if the src and dst layout match. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy); + +// TODO: Move utility functions that belong to ConvertLayoutOp to class +// ConvertLayoutOpHelper in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); + +/// Multi-root DAG topological sort. +/// Performs a topological sort of the Operation in the `toSort` SetVector. +/// Returns a topologically sorted SetVector. +/// It is faster than mlir::topologicalSort because it prunes nodes that have +/// been visited before. +SetVector +multiRootTopologicalSort(const SetVector &toSort); + +void getBackwardSliceCorex(Operation *op, SetVector *backwardSlice, + TransitiveFilter filter, + bool omitBlockArguments = false); + +void getBackwardSliceImplCorex(Operation *op, + SetVector *backwardSlice, + TransitiveFilter filter, + bool omitBlockArguments = false); + +/// This uses the toplogicalSort above +SetVector +multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, + TransitiveFilter forwardFilter = nullptr); + +/// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +/// This class represents a call graph for a given ModuleOp and holds +/// data of type T associated with each FunctionOpInterface. +template class CallGraph { +public: + using FuncDataMapT = DenseMap; + + /// Constructor that builds the call graph for the given moduleOp. + explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); } + + /// Walks the call graph and applies the provided update functions + /// to the edges and nodes. + template + void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) { + DenseSet visited; + for (auto root : roots) { + doWalk(root, visited, updateEdgeFn, + updateNodeFn); + } + } + + /// Retrieves the data associated with a function + T *getFuncData(FunctionOpInterface funcOp) { + if (funcMap.count(funcOp)) { + return &funcMap[funcOp]; + } + return nullptr; + } + + /// Getters + ModuleOp getModuleOp() const { return moduleOp; } + SmallVector getRoots() const { return roots; } + size_t getNumFunctions() const { return funcMap.size(); } + + /// Returns true if the given function is a root. + bool isRoot(FunctionOpInterface funcOp) const { + return llvm::is_contained(roots, funcOp); + } + + /// Maps the data and the graph nodes associated with a funcOp to a + /// targetFuncOp. + template + void mapFuncOp(FROM funcOp, TO targetFuncOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.second == funcOp) { + edge.second = targetFuncOp; + } + } + } + graph[targetFuncOp] = graph[funcOp]; + // Replace in roots + for (auto it = roots.begin(); it != roots.end(); ++it) { + if (*it == funcOp) { + *it = targetFuncOp; + break; + } + } + // Replace in funcMap + funcMap[targetFuncOp] = funcMap[funcOp]; + } + + /// Maps the graph edges associated with a callOp to a targetCallOp. + template + void mapCallOp(FROM callOp, TO targetCallOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.first == callOp) { + edge.first = targetCallOp; + } + } + } + } + +private: + void build() { + SymbolTableCollection symbolTable; + DenseSet visited; + // Build graph + moduleOp.walk([&](Operation *op) { + auto caller = op->getParentOfType(); + if (auto callOp = dyn_cast(op)) { + auto *callee = callOp.resolveCallable(&symbolTable); + auto funcOp = dyn_cast_or_null(callee); + if (funcOp) { + graph[caller].emplace_back( + std::pair(callOp, funcOp)); + visited.insert(funcOp); + } + } + }); + // Find roots + moduleOp.walk([&](FunctionOpInterface funcOp) { + if (!visited.count(funcOp)) { + roots.push_back(funcOp); + } + }); + } + + template + void doWalk(FunctionOpInterface funcOp, + DenseSet &visited, UpdateEdgeFn updateEdgeFn, + UpdateNodeFn updateNodeFn) { + if (visited.count(funcOp)) { + llvm::report_fatal_error("Cycle detected in call graph"); + } + if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) { + updateNodeFn(funcOp); + } + for (auto [callOp, callee] : graph[funcOp]) { + if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) { + updateEdgeFn(callOp, callee); + } + doWalk(callee, visited, updateEdgeFn, + updateNodeFn); + if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) { + updateEdgeFn(callOp, callee); + } + } + if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) { + updateNodeFn(funcOp); + } + visited.erase(funcOp); + } + +protected: + ModuleOp moduleOp; + DenseMap>> + graph; + FuncDataMapT funcMap; + SmallVector roots; +}; +// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v); + +} // namespace mlir + +#endif // TRITON_ANALYSIS_UTILITY_H diff --git a/third_party/iluvatar/include/triton/CMakeLists.txt b/third_party/iluvatar/include/triton/CMakeLists.txt new file mode 100644 index 000000000..27c703b3c --- /dev/null +++ b/third_party/iluvatar/include/triton/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) diff --git a/third_party/iluvatar/include/triton/Conversion/CMakeLists.txt b/third_party/iluvatar/include/triton/Conversion/CMakeLists.txt new file mode 100644 index 000000000..730f5cadd --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonGPUToLLVM) +add_subdirectory(TritonToTritonGPU) diff --git a/third_party/iluvatar/include/triton/Conversion/MLIRTypes.h b/third_party/iluvatar/include/triton/Conversion/MLIRTypes.h new file mode 100644 index 000000000..fadba413f --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/MLIRTypes.h @@ -0,0 +1,42 @@ +#ifndef TRITON_CONVERSION_MLIR_TYPES_H +#define TRITON_CONVERSION_MLIR_TYPES_H + +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// This file redefines some common MLIR types for easy usage. +namespace mlir { +namespace triton { +namespace type { + +// Integer types +inline Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); } +inline Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); } +inline Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); } +inline Type u32Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 32, IntegerType::Unsigned); +} +inline Type u1Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 1, IntegerType::Unsigned); +} + +// Float types +inline Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); } +inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); } +inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); } +inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); } + +inline bool isFloat(Type type) { + return type.isF32() || type.isF64() || type.isF16() || type.isF128() || + type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || + type.isFloat8E5M2FNUZ(); +} + +inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } + +} // namespace type +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_MLIR_TYPES_H diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h new file mode 100644 index 000000000..00ec88089 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h @@ -0,0 +1,27 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +class ConversionPatternRewriter; +class Location; + +namespace triton { +using llvm::StringRef; + +inline std::string strJoin(llvm::ArrayRef strs, + llvm::StringRef delimiter) { + return llvm::join(strs.begin(), strs.end(), delimiter); +} + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..93f8374e5 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonGPUToLLVM) +add_public_tablegen_target(TritonGPUConversionPassIncGen) diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h new file mode 100644 index 000000000..1286b4e56 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -0,0 +1,243 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H +#define TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton { + +namespace gpu { + +SmallVector reorderValues(const SmallVector &values, Type inType, + Type ouType); + +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + +Type getElementType(Value value); + +class MultipleOperandsRange + : public iterator_range>::iterator> { + using ContainerT = SmallVector>; + +public: + using iterator_range::iterator_range; + ContainerT::reference operator[](ContainerT::size_type idx) { + return begin()[idx]; + } + ContainerT::const_reference operator[](ContainerT::size_type idx) const { + return begin()[idx]; + } + ContainerT::size_type size() const { return end() - begin(); } +}; + +// Base pattern for elementwise conversion using ConcreteT. Unpacks individual +// elements from a `!llvm.struct` via `llvm.extactvalue`, calls +// ConcreteT::createDestOps on each element, and packs them back into an +// `!llvm.struct` using `llvm.insertvalue`. +// +// Also supports processing the inputs in a vectorized form by consuming and +// producing multiple operand sets in ConcreteT::createDestOps. +template +class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit ElementwiseOpConversionBase( + LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit), + axisAnalysisPass(axisAnalysisPass) {} + + // Try to deduplicate the resultVals based on the + // constancy properties of the result discovered by + // the axis analysis pass. If possible, redundant + // computation is eliminated. + SmallVector maybeDeduplicate(SourceOp op, + SmallVector resultVals) const { + if (!isMemoryEffectFree(op)) + // the op has side effects: can't dedup + return resultVals; + SmallVector results = op->getResults(); + if (results.size() == 0 || results.size() > 1) + // there must be exactly 1 result + return resultVals; + Value result = results[0]; + Type type = result.getType(); + if (!type) + return resultVals; + RankedTensorType rtType = dyn_cast(type); + if (!rtType) + // the result must be a tensor + return resultVals; + Attribute encoding = rtType.getEncoding(); + if (!encoding) + // encoding not available + return resultVals; + Attribute baseEncoding = encoding; + if (isa(baseEncoding)) + // TODO: this logic seems incorrect for mfma layout. Skip for now. + // We saw mismatches for some flash-attention tests on AMD backend. + // Note that this logic works for sliced layout whose parent is + // mfma layout. Therefore, this is not combined with the following check. + return resultVals; + while (auto sliced = dyn_cast(baseEncoding)) + baseEncoding = sliced.getParent(); + if (isa(baseEncoding)) { + // TODO: this logic seems incorrect for mma layout. Skip for now. + // The following test crashes and some other miscompile: + // test_core::test_fp8_dot_acc + return resultVals; + } + if (isa(baseEncoding)) { + // TODO: this logic seems incorrect for mma layout. Skip for now. + // The following test crashes and some other miscompile: + // test_core::test_fp8_dot_acc + return resultVals; + } + + SmallVector elemsPerThread = getElemsPerThread(rtType); + int rank = elemsPerThread.size(); + if (product(elemsPerThread) != resultVals.size()) + return resultVals; + AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result); + if (!axisInfo) + // axis info (e.g., constancy) not available + return resultVals; + SmallVector contigPerThread = getContigPerThread(encoding); + if (rank != contigPerThread.size()) + return resultVals; + + SmallVector constancy = axisInfo->getConstancy(); + if (rank != constancy.size()) + return resultVals; + bool hasConstancy = false; + for (int i = 0; i < rank; ++i) { + if (constancy[i] > contigPerThread[i]) { + if (constancy[i] % contigPerThread[i] != 0) + // constancy is not evenly covered by contigPerThread + return resultVals; + // can't move the values across different + // "contigPerThread"-sized blocks + constancy[i] = contigPerThread[i]; + } + if (elemsPerThread[i] < 1 || constancy[i] < 1) + return resultVals; + if (!(elemsPerThread[i] % constancy[i] == 0 || + constancy[i] % elemsPerThread[i] == 0)) + // either the constancy along each dimension must fit + // into the elemsPerThread or the other way around + return resultVals; + if (constancy[i] > 1) + hasConstancy = true; + } + if (!hasConstancy) + // nothing to deduplicate + return resultVals; + + if (rank > 1) { + // reorder the shape and constancy vectors by the axis order: + // from the fastest-changing to the smallest-changing axis + SmallVector order = getOrder(encoding); + if (rank != order.size()) + return resultVals; + elemsPerThread = applyPermutation(elemsPerThread, order); + constancy = applyPermutation(constancy, order); + } + + SmallVector strides(rank, 1); + for (int i = 1; i < rank; ++i) { + strides[i] = strides[i - 1] * elemsPerThread[i - 1]; + } + SmallVector dedupResultVals; + dedupResultVals.reserve(resultVals.size()); + for (int i = 0; i < resultVals.size(); ++i) { + // each coordinate of the orig_idx is "coarsened" using the + // constancy along this dimension: the resulting dedup_idx + // points to the reused value in the original resultsVal + int orig_idx = i; + int dedup_idx = 0; + for (int j = 0; j < rank; ++j) { + int coord_j = orig_idx % elemsPerThread[j]; + dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j]; + orig_idx /= elemsPerThread[j]; + } + dedupResultVals.push_back(resultVals[dedup_idx]); + } + + return dedupResultVals; + } + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTy = op.getType(); + Location loc = op->getLoc(); + // element type + auto resultElementTy = getElementTypeOrSelf(resultTy); + Type elemTy = this->getTypeConverter()->convertType(resultElementTy); +#ifdef __ILUVATAR__ + auto srcType = this->getTypeConverter()->convertType(resultTy); + if (auto structTy = dyn_cast(srcType)) + elemTy = structTy.getBody()[0]; +#endif + SmallVector> allOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + subOperands = unpackI32(subOperands, argTy, rewriter, loc, + this->getTypeConverter()); + allOperands.resize(subOperands.size()); + for (auto v : llvm::enumerate(subOperands)) + allOperands[v.index()].push_back(v.value()); + } + if (allOperands.size() == 0) + allOperands.push_back({}); + + SmallVector resultVals; + for (auto it = allOperands.begin(), end = allOperands.end(); it != end;) { + auto curr = static_cast(this)->createDestOps( + op, adaptor, rewriter, elemTy, MultipleOperandsRange(it, end), loc); + if (curr.size() == 0) + return failure(); + for (auto v : curr) { + if (!static_cast(v)) + return failure(); + resultVals.push_back(v); + } + it += curr.size(); + } + if (op->getNumOperands() > 0) { + auto argTy = op->getOperand(0).getType(); + resultVals = reorderValues(resultVals, argTy, resultTy); + } + resultVals = maybeDeduplicate(op, resultVals); + resultVals = + packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); + Value view = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + rewriter.replaceOp(op, view); + + return success(); + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + +} // namespace gpu + +} // namespace mlir::triton +#endif diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Passes.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Passes.h new file mode 100644 index 000000000..b013f2628 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Passes.h @@ -0,0 +1,32 @@ +#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H +#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +#define GEN_PASS_DECL +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +namespace gpu { +std::unique_ptr> createAllocateSharedMemoryPass(); + +} // namespace gpu + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Passes.td new file mode 100644 index 000000000..700dcd6b4 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -0,0 +1,11 @@ +#ifndef TRITONCOMMONGPU_CONVERSION_PASSES +#define TRITONCOMMONGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> { + let summary = "Add metadata for shared memory allocation"; + let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()"; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h new file mode 100644 index 000000000..d1494fd7e --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -0,0 +1,104 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H + +#include "TargetInfoBase.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::BlockedEncodingAttr; + +namespace SharedToDotOperandFMA { +Value convertLayout(int opIdx, Value val, Value llVal, + BlockedEncodingAttr dLayout, Value thread, Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); +} +LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); +namespace mlir { +namespace triton { + +constexpr int patternBenefitDefault = 1; +constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; +constexpr int patternBenefitClampOptimizedPattern = 20; +constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; + +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateViewOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateMinMaxFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool hwNanPropagationSupported, + PatternBenefit benefit); +void populateClampFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit); + +void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Patterns.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Patterns.h new file mode 100644 index 000000000..934501ad3 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Patterns.h @@ -0,0 +1,32 @@ +#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PATTERNS_H +#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PATTERNS_H + +#include + +namespace mlir { +class ModuleOp; +class RankedTensorType; + +namespace triton::gpu { + +/// Replaces `blocked -> dot_op` with `blocked -> shared -> dot_op` in the given +/// |module| op because the codegen doesn't handle `blocked -> dot_op` directly. +void decomposeBlockedToDotLayoutConversion(ModuleOp module); + +/// Replaces `splat -> shared` with `splat -> blocked -> shared` in the given +/// |module| op. +void decomposeSplatOpToSharedLayoutConversion(ModuleOp module); + +/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the +/// given |module| op, but bypass the decomposition if |shortcutFn| returns +/// true. +using ShortcutFn = std::function; +template +void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, + ShortcutFn shortcutFn); + +} // namespace triton::gpu + +} // namespace mlir + +#endif diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h new file mode 100644 index 000000000..380c8cc1d --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -0,0 +1,65 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H + +#include "triton/Conversion/MLIRTypes.h" + +namespace mlir::triton { +class TargetInfoBase { +public: + virtual bool supportMaximumMinimum() const = 0; + + virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0; + + virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc, + Type type, Value cmp) const = 0; + + virtual Value storeShared(ConversionPatternRewriter &rewriter, Location loc, + Value ptr, Value val, Value pred) const = 0; + virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + Value ptr, Type elemTy, Value pred) const = 0; + + virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const = 0; + virtual Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const = 0; + virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const = 0; + virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, Value i) const = 0; + + virtual Value programId(ConversionPatternRewriter &rewriter, Location loc, + ModuleOp moduleOp, int axis) const = 0; + + virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce) const = 0; + + virtual bool processReplicaUsingStMatrix( + ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + SmallVector &vals, RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, ArrayRef origRepShape, + ArrayRef outOrd, unsigned accumNumReplicates, + int swizzleByteWidth = 0) const = 0; + + virtual std::string getMulhiFuncName(Type resultElementTy) const = 0; + // Emits LLVM code with |rewriter| to print a message following the given + // format from the device. |formatStrStart| is the pointer to the start of + // the format string global variable; |args| are the arguments to fill + // placeholders in the format string. + virtual void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const = 0; + // Emits LLVM code with |rewriter| to perform assertion failure with the given + // |message| from the given |func| in |file|. + virtual void assertFail(ConversionPatternRewriter &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const = 0; + + // Whether to enable linear layout. This is a per-backend temporary escape + // hatch to disable linear layout while figuring out issues. Eventually we + // want to enable linear layout everywhere and delete this control. + virtual bool enableLinearLayout() const { return true; } + + virtual ~TargetInfoBase() {} +}; +} // namespace mlir::triton +#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h new file mode 100644 index 000000000..ab9d0ebf8 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h @@ -0,0 +1,26 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); + + Type getElementTypeForStruct(TensorOrMemDesc type); + Type convertTritonPointerType(triton::PointerType type); + Type convertTritonTensorType(RankedTensorType type); + Type convertMemDescType(MemDescType type); + Type convertAsyncToken(triton::gpu::AsyncTokenType type); +}; + +#endif diff --git a/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Utility.h new file mode 100644 index 000000000..46d02d5bd --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -0,0 +1,1645 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H + +#include + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "python/src/plugin.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "ttgpu_to_llvm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::triton; + +using emitOffsetForTCULayoutFunc = SmallVector> (*)( + const triton::gpu::IluvatarMmaEncodingAttr &, RankedTensorType); +DEFINE_LOAD_FUNC(emitOffsetForTCULayout) + +using emitBaseIndexForTCULayoutFunc = SmallVector (*)( + Location, RewriterBase &, const triton::gpu::IluvatarMmaEncodingAttr &, + RankedTensorType); +DEFINE_LOAD_FUNC(emitBaseIndexForTCULayout) + +using remapOffsetFunc = Value (*)(Value, Value, RankedTensorType, bool, + Location, RewriterBase &, int, bool); +DEFINE_LOAD_FUNC(remapOffset) + +// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive +// Operators +#define inttofloat(...) rewriter.create(loc, __VA_ARGS__) +#define inttoptr(...) rewriter.create(loc, __VA_ARGS__) +#define ptrtoint(...) rewriter.create(loc, __VA_ARGS__) +#define zext(...) rewriter.create(loc, __VA_ARGS__) +#define sext(...) rewriter.create(loc, __VA_ARGS__) +#define fpext(...) rewriter.create(loc, __VA_ARGS__) +#define trunc(...) rewriter.create(loc, __VA_ARGS__) +#define udiv(...) rewriter.create(loc, __VA_ARGS__) +#define urem(...) rewriter.create(loc, __VA_ARGS__) +#define add(...) rewriter.create(loc, __VA_ARGS__) +#define sub(...) rewriter.create(loc, __VA_ARGS__) +#define fadd(...) rewriter.create(loc, __VA_ARGS__) +#define mul(...) rewriter.create(loc, __VA_ARGS__) +#define fmul(...) rewriter.create(loc, __VA_ARGS__) +#define smax(...) rewriter.create(loc, __VA_ARGS__) +#define umax(...) rewriter.create(loc, __VA_ARGS__) +#define fmax(...) rewriter.create(loc, __VA_ARGS__) +#define smin(...) rewriter.create(loc, __VA_ARGS__) +#define umin(...) rewriter.create(loc, __VA_ARGS__) +#define fmin(...) rewriter.create(loc, __VA_ARGS__) +#define shl(...) rewriter.create(loc, __VA_ARGS__) +#define lshr(...) rewriter.create(loc, __VA_ARGS__) +#define and_(...) rewriter.create(loc, __VA_ARGS__) +#define xor_(...) rewriter.create(loc, __VA_ARGS__) +#define or_(...) rewriter.create(loc, __VA_ARGS__) +#define bitcast(val__, type__) \ + rewriter.create(loc, type__, val__) +#define addrspacecast(...) \ + rewriter.create(loc, __VA_ARGS__) +#define gep(...) rewriter.create(loc, __VA_ARGS__) +#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) +#define insert_val(...) rewriter.create(loc, __VA_ARGS__) +#define extract_val(...) rewriter.create(loc, __VA_ARGS__) +#define insert_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define extract_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define load(...) rewriter.create(loc, __VA_ARGS__) +#define store(...) rewriter.create(loc, __VA_ARGS__) +#define fcmp_ogt(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::ogt, lhs, rhs) +#define fcmp_olt(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::olt, lhs, rhs) +#define fcmp_eq(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::oeq, lhs, rhs) +#define icmp_eq(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) +#define icmp_ne(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__) +#define icmp_slt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) +#define icmp_sle(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__) +#define icmp_sgt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__) +#define icmp_sge(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__) +#define icmp_ult(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__) +#define icmp_ule(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__) +#define icmp_ugt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__) +#define icmp_uge(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__) +#define select(...) rewriter.create(loc, __VA_ARGS__) +#define address_of(...) rewriter.create(loc, __VA_ARGS__) +#define barrier() rewriter.create(loc) +#define undef(...) rewriter.create(loc, __VA_ARGS__) +#define null(...) rewriter.create(loc, __VA_ARGS__) +#define call(...) rewriter.create(loc, __VA_ARGS__) + +// Types +#define int_ty(width) rewriter.getIntegerType(width) +#define i64_ty rewriter.getIntegerType(64) +#define i32_ty rewriter.getIntegerType(32) +#define i16_ty rewriter.getIntegerType(16) +#define i32_ty rewriter.getIntegerType(32) +#define i64_ty rewriter.getIntegerType(64) +#define ui32_ty rewriter.getIntegerType(32, false) +#define ui64_ty rewriter.getIntegerType(64, false) +#define f16_ty rewriter.getF16Type() +#define bf16_ty rewriter.getBF16Type() +#define i8_ty rewriter.getIntegerType(8) +#define i1_ty rewriter.getI1Type() +#define f32_ty rewriter.getF32Type() +#define f64_ty rewriter.getF64Type() +#define vec_ty(type, num) VectorType::get(num, type) +#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) +#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) +#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) + +// Constants +#define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__) +#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) +#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) +#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) +#define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__) +#define int_val(width, val) \ + LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) +#define tid_val() getThreadId(rewriter, loc) + +// Attributes +#define i32_arr_attr(...) rewriter.getI32ArrayAttr({__VA_ARGS__}) +#define i64_arr_attr(...) rewriter.getI64ArrayAttr({__VA_ARGS__}) +#define str_attr(str) ::mlir::StringAttr::get(ctx, (str)) + +namespace mlir { +namespace triton { + +// Delinearize supposing order is [0, 1, .. , n] +template +llvm::SmallVector getMultiDimIndexImpl(T linearIndex, + llvm::ArrayRef shape) { + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearRemain = linearIndex; + llvm::SmallVector multiDimIndex(rank); + for (int i = rank - 1; i >= 0; --i) { + multiDimIndex[i] = linearRemain / accMul; + linearRemain = linearRemain % accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return multiDimIndex; +} + +template +llvm::SmallVector getMultiDimIndex(T linearIndex, llvm::ArrayRef shape, + llvm::ArrayRef order) { + size_t rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + auto reorderedMultiDim = getMultiDimIndexImpl(linearIndex, reordered); + llvm::SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +// Linearize supposing order is [0, 1, .. , n] +template +T getLinearIndexImpl(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape) { + assert(multiDimIndex.size() == shape.size()); + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearIndex = 0; + for (int i = rank - 1; i >= 0; --i) { + linearIndex += multiDimIndex[i] * accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return linearIndex; +} + +template +T getLinearIndex(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape, + llvm::ArrayRef order) { + assert(shape.size() == order.size()); + return getLinearIndexImpl(applyPermutation(multiDimIndex, order), + applyPermutation(shape, order)); +} + +namespace gpu { +Type getFunctionType(Type resultType, ValueRange operands); + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, + Operation *op, StringRef funcName, + Type funcType, StringRef libname = "", + StringRef libpath = ""); +} // namespace gpu + +} // namespace triton + +namespace LLVM { +using namespace mlir::triton; + +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v); + +/// Create a 64-bit integer constant. +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v); + +/// Create a 16-bit float constant. +Value createConstantF16(Location loc, OpBuilder &rewriter, float v); + +/// Create a 32-bit float constant. +Value createConstantF32(Location loc, OpBuilder &rewriter, float v); + +/// Create a 64-bit float constant. +Value createConstantF64(Location loc, OpBuilder &rewriter, double v); + +/// Create NaN constant of specified type. +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type); + +/// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + TypeConverter *converter, int64_t value); + +/// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value); + +/// Helper function to get strides from a given shape and its order +SmallVector getStridesFromShapeAndOrder(ArrayRef shape, + ArrayRef order, + Location loc, + RewriterBase &rewriter); +struct SharedMemoryObject { + Value base; // i32 ptr. The start address of the shared memory object after + // the initial allocation or the last slicing operation. + Type baseElemType; + // We need to store strides as Values, not integers, because the + // extract_slice instruction can take a slice at arbitrary offsets. + // Take $a[16:32, 16:32] as an example; though we know the stride of $a[0] is + // 32, we need to let the instruction that uses $a be aware of that. + // Otherwise, when we use $a, we only know that the shape of $a is 16x16. If + // we store strides into an attribute array of integers, the information + // cannot pass through block argument assignment because attributes are + // associated with operations, not Values. + // TODO(Keren): We may need to figure out a way to store strides as integers + // if we want to support more optimizations. + SmallVector + strides; // i32 int. The strides of the shared memory object. + SmallVector offsets; // i32 int. + // Offsets are applied at the last slicing operation. + // We can use offsets to recover the previous base. + // The offsets are zero at the initial allocation. + + SharedMemoryObject(Value base, Type baseElemType, ArrayRef strides, + ArrayRef offsets) + : base(base), baseElemType(baseElemType), + strides(strides.begin(), strides.end()), + offsets(offsets.begin(), offsets.end()) {} + + SharedMemoryObject(Value base, Type baseElemType, ArrayRef shape, + ArrayRef order, Location loc, + RewriterBase &rewriter) + : base(base), baseElemType(baseElemType) { + strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); + offsets.append(order.size(), i32_val(0)); + } + + SmallVector getStrides() const { return strides; } + SmallVector getOffsets() const { return offsets; } + Value getBase() const { return base; } + Type getBaseElemType() const { return baseElemType; } + + SmallVector getElems() const { + SmallVector elems; + elems.push_back(base); + elems.append(strides.begin(), strides.end()); + elems.append(offsets.begin(), offsets.end()); + return elems; + } + + SmallVector getTypes() const { + SmallVector types; + types.push_back(base.getType()); + types.append(strides.size(), IntegerType::get(base.getContext(), 32)); + types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); + return types; + } + + Value getCSwizzleOffset(int order) const { + assert(order >= 0 && order < strides.size()); + return offsets[order]; + } + + Value getBaseBeforeSlice(int order, Location loc, + ConversionPatternRewriter &rewriter) const { + Value cSwizzleOffset = getCSwizzleOffset(order); + Value offset = sub(i32_val(0), cSwizzleOffset); + Type type = base.getType(); + return gep(type, baseElemType, base, offset); + } +}; + +SharedMemoryObject +getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, + ConversionPatternRewriter &rewriter); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape); + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape, + ArrayRef order); + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape); + +Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, + StringRef key, StringRef content); + +// Given an elemId which represents the index of an element from the list of +// elements that are in the thread's registers (i.e. total of +// numel(sizePerThread)), it calculates the multi dim offset of the element in +// the smem buffer. Recall that the smem buffer will only store a single replica +// when converting distributed to distributed layout. Also, a replica is the +// smallest CTA tile that is common between input and output layouts. +SmallVector getMultiDimOffset( + Attribute layout, Location loc, ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type, + ArrayRef multiDimCTAInRepId, ArrayRef shapePerCTATile, + bool isTrans = false, bool stNotRd = false); + +// Given a multiDimOffset, this function wraps around each dimension to be +// within shape. +SmallVector getWrappedMultiDimOffset( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDimOffset, ArrayRef shape, + SmallVector shapePerCTATile, SmallVector shapePerCTA); + +inline bool isKernel(FunctionOpInterface funcOp) { + return funcOp.getVisibility() == SymbolTable::Visibility::Public; +} + +inline Value getStackPointer(PatternRewriter &rewriter, + FunctionOpInterface funcOp) { + auto mod = funcOp->getParentOfType(); + LLVM::GlobalOp globalBase = nullptr; + mod.walk([&](LLVM::GlobalOp op) { + if (op.getSymName() == "global_smem") + globalBase = op; + }); + assert(globalBase); + if (isKernel(funcOp)) + return rewriter.create(funcOp.getLoc(), globalBase); + else + return funcOp.getArgument(funcOp.getNumArguments() - 1); +} + +inline Value getSharedMemoryBase(Location loc, + ConversionPatternRewriter &rewriter, + Operation *op) { + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + FunctionOpInterface func = + op->template getParentOfType(); + assert(op->hasAttr("allocation.offset")); + size_t offset = cast(op->getAttr("allocation.offset")) + .getValue() + .getZExtValue(); + Value offVal = i32_val(offset); + Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); + return base; +} +} // namespace LLVM + +/* ------------------------------------ */ +// Returns CTA level thread idx +inline Value getThreadIdInCTA(RewriterBase &rewriter, Location loc) { + Value tid = + rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); + return rewriter.create(loc, i32_ty, tid); +} + +// Returns CTA level thread idx. +inline Value getThreadId(RewriterBase &rewriter, Location loc) { + Value tid = getThreadIdInCTA(rewriter, loc); + auto mod = rewriter.getBlock()->getParent()->getParentOfType(); + return tid; +} + +// ----------------------------------------------------------------------- +// Shared memory utilities +// ----------------------------------------------------------------------- +using LLVM::getMultiDimIndex; +using LLVM::SharedMemoryObject; +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::SharedMemoryObject; +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::CTALayoutAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::IluvatarMmaEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, + ArrayRef strides) { + assert(offsets.size() == strides.size()); + Value ret = i32_val(0); + for (auto [offset, stride] : llvm::zip(offsets, strides)) { + ret = add(ret, mul(offset, stride)); + } + return ret; +} + +// ----------------------------------------------------------------------- +// Blocked layout indices +// ----------------------------------------------------------------------- + +// "Applies" the given layout by computing layout(indices) and returning the +// resulting Values. +// +// In other words, this generates LLVM-dialect MLIR code to "run" the layout +// function. +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices); + +inline SmallVector +emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, + const BlockedEncodingAttr &blockedLayout, + RankedTensorType type) { + MLIRContext *ctx = rewriter.getContext(); + auto shape = type.getShape(); + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout)); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + auto sizePerThread = blockedLayout.getSizePerThread(); + auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + auto order = blockedLayout.getOrder(); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); + unsigned rank = shape.size(); + + // delinearize threadId to get the base index + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, order); + SmallVector multiDimThreadId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + + SmallVector multiDimBase(rank); + for (unsigned k = 0; k < rank; ++k) { + // Wrap around multiDimWarpId/multiDimThreadId in case + // shapePerCTATile[k] > shapePerCTA[k] + auto maxWarps = + ceil(shapePerCTA[k], sizePerThread[k] * threadsPerWarp[k]); + auto maxThreads = ceil(shapePerCTA[k], sizePerThread[k]); + multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps)); + multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads)); + // multiDimBase[k] = (multiDimThreadId[k] + + // multiDimWarpId[k] * threadsPerWarp[k]) * + // sizePerThread[k]; + Value threadsPerWarpK = i32_val(threadsPerWarp[k]); + Value sizePerThreadK = i32_val(sizePerThread[k]); + multiDimBase[k] = + mul(sizePerThreadK, + add(multiDimThreadId[k], mul(multiDimWarpId[k], threadsPerWarpK))); + } + + return multiDimBase; +} + +inline SmallVector> +emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout, + RankedTensorType type) { + auto ctx = type.getContext(); + auto shape = type.getShape(); + auto sizePerThread = blockedLayout.getSizePerThread(); + auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + auto order = blockedLayout.getOrder(); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); + + unsigned rank = shape.size(); + SmallVector tilesPerDim(rank); + for (unsigned k = 0; k < rank; ++k) + tilesPerDim[k] = ceil(shapePerCTA[k], shapePerCTATile[k]); + + unsigned elemsPerThread = triton::gpu::getTotalElemsPerThread(type); + unsigned totalSizePerThread = product(sizePerThread); + SmallVector> reorderedOffset(elemsPerThread); + for (unsigned n = 0; n < elemsPerThread; ++n) { + unsigned linearNanoTileId = n / totalSizePerThread; + unsigned linearNanoTileElemId = n % totalSizePerThread; + SmallVector multiDimNanoTileId = + getMultiDimIndex(linearNanoTileId, tilesPerDim, order); + SmallVector multiDimNanoTileElemId = + getMultiDimIndex(linearNanoTileElemId, sizePerThread, order); + for (unsigned k = 0; k < rank; ++k) { + unsigned reorderedMultiDimId = + (multiDimNanoTileId[k] * + (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + + multiDimNanoTileElemId[k]) % + shapePerCTA[k]; + + reorderedOffset[n].push_back(reorderedMultiDimId); + } + } + + return reorderedOffset; +} + +// ----------------------------------------------------------------------- +// Mma layout indices +// ----------------------------------------------------------------------- + +inline SmallVector +emitBaseIndexWithinCTAForMmaLayoutV1(Location loc, RewriterBase &rewriter, + const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto wpt = mmaLayout.getWarpsPerCTA(); + static constexpr std::array fpw{{2, 2, 1}}; + auto [isARow, isBRow, isAVec4, isBVec4, _] = + mmaLayout.decodeVoltaLayoutStates(); + + Value thread = getThreadId(rewriter, loc); + auto *ctx = thread.getContext(); + Value _1 = i32_val(1); + Value _2 = i32_val(2); + Value _4 = i32_val(4); + Value _16 = i32_val(16); + Value _32 = i32_val(32); + Value _fpw0 = i32_val(fpw[0]); + Value _fpw1 = i32_val(fpw[1]); + + // A info + auto aRep = mmaLayout.getMMAv1Rep(0); + auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); + // B info + auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); + auto bRep = mmaLayout.getMMAv1Rep(1); + + SmallVector rep({aRep[0], bRep[1]}); + SmallVector spw({aSpw[0], bSpw[1]}); + SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); + + Value lane = urem(thread, _32); + Value warp = udiv(thread, _32); + + Value warp0 = urem(warp, i32_val(wpt[0])); + Value warp12 = udiv(warp, i32_val(wpt[0])); + Value warp1 = urem(warp12, i32_val(wpt[1])); + + // warp offset + Value offWarpM = mul(warp0, i32_val(spw[0])); + Value offWarpN = mul(warp1, i32_val(spw[1])); + // quad offset + Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0); + Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1); + // pair offset + Value offPairM = udiv(urem(lane, _16), _4); + offPairM = urem(offPairM, _fpw0); + offPairM = mul(offPairM, _4); + Value offPairN = udiv(urem(lane, _16), _4); + offPairN = udiv(offPairN, _fpw0); + offPairN = urem(offPairN, _fpw1); + offPairN = mul(offPairN, _4); + offPairM = mul(offPairM, i32_val(rep[0] / 2)); + offQuadM = mul(offQuadM, i32_val(rep[0] / 2)); + offPairN = mul(offPairN, i32_val(rep[1] / 2)); + offQuadN = mul(offQuadN, i32_val(rep[1] / 2)); + // quad pair offset + Value offLaneM = add(offPairM, offQuadM); + Value offLaneN = add(offPairN, offQuadN); + // a, b offset + Value offsetAM = add(offWarpM, offLaneM); + Value offsetBN = add(offWarpN, offLaneN); + // m indices + Value offsetCM = add(and_(lane, _1), offsetAM); + // n indices + Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN))); + return {offsetCM, offsetCN}; +} + +inline SmallVector> +emitOffsetForMmaLayoutV1(const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + + auto [isARow, isBRow, isAVec4, isBVec4, _] = + mmaLayout.decodeVoltaLayoutStates(); + + // TODO: seems like the pattern below to get `rep`/`spw` appears quite often + // A info + auto aRep = mmaLayout.getMMAv1Rep(0); + auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); + // B info + auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); + auto bRep = mmaLayout.getMMAv1Rep(1); + + auto wpt = mmaLayout.getWarpsPerCTA(); + static constexpr std::array fpw{{2, 2, 1}}; + SmallVector rep({aRep[0], bRep[1]}); + SmallVector spw({aSpw[0], bSpw[1]}); + SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); + + SmallVector idxM; + for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0]) + for (unsigned mm = 0; mm < rep[0]; ++mm) + idxM.push_back(m + mm * 2); + + SmallVector idxN; + for (int n = 0; n < shape[1]; n += shapePerCTA[1]) { + for (int nn = 0; nn < rep[1]; ++nn) { + idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]); + idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1); + } + } + + SmallVector> ret; + for (unsigned x1 : idxN) { // N + for (unsigned x0 : idxM) { // M + SmallVector idx(2); + idx[0] = x0; // M + idx[1] = x1; // N + ret.push_back(std::move(idx)); + } + } + return ret; +} + +inline SmallVector> +emitOffsetForMmaLayoutV2(const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + SmallVector> ret; + + auto rank = shape.size(); + for (unsigned i = 0; i < shapePerCTA[rank - 2]; + i += getShapePerCTATile(mmaLayout)[rank - 2]) { + for (unsigned j = 0; j < shapePerCTA[rank - 1]; + j += getShapePerCTATile(mmaLayout)[rank - 1]) { + if (rank == 3) { + ret.push_back({0, i, j}); + ret.push_back({0, i, j + 1}); + ret.push_back({0, i + 8, j}); + ret.push_back({0, i + 8, j + 1}); + } else { + ret.push_back({i, j}); + ret.push_back({i, j + 1}); + ret.push_back({i + 8, j}); + ret.push_back({i + 8, j + 1}); + } + } + } + return ret; +} + +// Note that this may return a null Value for one or more dimensions. This is +// valid only if you're going to slice off the relevant dimension. +inline SmallVector +emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter, + const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto _warpsPerCTA = mmaLayout.getWarpsPerCTA(); + auto rank = shape.size(); + assert(rank == 2 || rank == 3); + auto warpOrder = triton::gpu::getWarpOrder(mmaLayout); + ArrayRef instrShape = mmaLayout.getInstrShape(); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + + uint32_t repM = + (_warpsPerCTA[rank - 2] * instrShape[rank - 2]) / shapePerCTA[rank - 2]; + uint32_t repN = + (_warpsPerCTA[rank - 1] * instrShape[rank - 1]) / shapePerCTA[rank - 1]; + + uint32_t warpsM; + if (repM > 1) + warpsM = _warpsPerCTA[rank - 2] / repM; + else + warpsM = shape[rank - 2] / instrShape[rank - 2]; + + uint32_t warpsN; + if (repN > 1) + warpsN = _warpsPerCTA[rank - 1] / repN; + else + warpsN = shape[rank - 1] / instrShape[rank - 1]; + + SmallVector multiDimWarpId(rank); + multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, warpOrder); + Value warpIdM = urem(multiDimWarpId[rank - 2], i32_val(warpsM)); + Value warpIdN = urem(multiDimWarpId[rank - 1], i32_val(warpsN)); + + Value offWarpM = mul(warpIdM, i32_val(instrShape[rank - 2])); + Value offWarpN = mul(warpIdN, i32_val(instrShape[rank - 1])); + + SmallVector multiDimBase(rank); + if (rank == 3) + multiDimBase[0] = multiDimWarpId[0]; + + // warpsM/N may be 0, in which case warpIDM/N is poison (division by 0), which + // will cause LLVM to eliminate all ops that depend on the poison value. This + // *can* be okay, if the bad dimension is filtered out by a slice layout. So + // we rely on the caller to check. Worst case we crash, which is better than + // silently producing bad code. + if (warpsM != 0) + multiDimBase[rank - 2] = add(udiv(laneId, i32_val(4)), offWarpM); + if (warpsN != 0) + multiDimBase[rank - 1] = + add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarpN); + + return multiDimBase; +} + +inline SmallVector> +emitOffsetForMmaLayoutV3(const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + SmallVector> ret; + ArrayRef instrShape = mmaLayout.getInstrShape(); + + for (unsigned i = 0; i < shapePerCTA[0]; + i += getShapePerCTATile(mmaLayout)[0]) { + for (unsigned j = 0; j < shapePerCTA[1]; + j += getShapePerCTATile(mmaLayout)[1]) { + for (unsigned k = 0; k < instrShape[1]; k += 8) { + ret.push_back({i, j + k}); + ret.push_back({i, j + k + 1}); + ret.push_back({i + 8, j + k}); + ret.push_back({i + 8, j + k + 1}); + } + } + } + return ret; +} + +inline SmallVector +emitBaseIndexForMfmaLayout(Location loc, RewriterBase &rewriter, + const AMDMfmaEncodingAttr &mfmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto rank = shape.size(); + assert(rank == 2 || rank == 3); + auto _warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + unsigned mDim = mfmaLayout.getMDim(); + unsigned nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout)); + Value effectiveWarpSize = warpSize; + if (mDim == 4 && nDim == 4) { + const int uniqueValuesPerWarp = 4; + effectiveWarpSize = i32_val(uniqueValuesPerWarp); + } + Value laneId = urem(threadId, effectiveWarpSize); + Value warpId = udiv(threadId, warpSize); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, _warpsPerCTA, + triton::gpu::getWarpOrder(mfmaLayout)); + if (shape[rank - 2] >= mDim) { + assert(shape[rank - 2] % mDim == 0); + multiDimWarpId[rank - 2] = + urem(multiDimWarpId[rank - 2], + i32_val(ceil(shape[rank - 2], mDim))); + } + if (shape[rank - 1] >= nDim) { + assert(shape[rank - 1] % nDim == 0); + multiDimWarpId[rank - 1] = + urem(multiDimWarpId[rank - 1], + i32_val(ceil(shape[rank - 1], nDim))); + } + Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mDim)); + Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(nDim)); + + SmallVector multiDimBase(rank); + if (mfmaLayout.getIsTransposed()) { + multiDimBase[rank - 1] = + add(mul(i32_val(4), udiv(laneId, i32_val(mDim))), offWarp1); + multiDimBase[rank - 2] = add(urem(laneId, i32_val(mDim)), offWarp0); + } else { + multiDimBase[rank - 2] = + add(mul(i32_val(4), udiv(laneId, i32_val(nDim))), offWarp0); + multiDimBase[rank - 1] = add(urem(laneId, i32_val(nDim)), offWarp1); + } + // TODO(Lixun): It is assumed when rank = 3, warpsPerCTA is set to + // {numWarps, 1, 1}. We need to generalize the offset computation. + if (rank == 3) { + assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); + multiDimBase[0] = urem(warpId, i32_val(shape[0])); + } + return multiDimBase; +} + +inline void emitMfmaOffsetForCTA(const AMDMfmaEncodingAttr &mfmaLayout, + SmallVector> &offsets, + unsigned bOff, unsigned ctaOffsetX, + unsigned ctaOffsetY) { + auto mDim = mfmaLayout.getMDim(); + auto nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + // MFMA output tile consists of repeated "dot operand B" layout groups along + // row axis. This variable defines number of these groups. + DenseMap groups{{4, 1}, {16, 1}, {32, 4}}; + unsigned numGroups = groups.at(std::min(mDim, nDim)); + const unsigned elemsPerThreadPerGroup = 4; + auto warpSize = getWarpSize(mfmaLayout); + assert(warpSize == 64); + auto shapePerCta = getShapePerCTATile(mfmaLayout); + auto rank = shapePerCta.size(); + SmallVector elemOff(rank, 0); + for (unsigned block = 0; block < numGroups; block++) { + unsigned rowOrColOffset = + block * elemsPerThreadPerGroup * warpSize / std::min(mDim, nDim); + for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { + if (mfmaLayout.getIsTransposed()) { + elemOff[rank - 2] = ctaOffsetX * shapePerCta[rank - 2]; + elemOff[rank - 1] = + ctaOffsetY * shapePerCta[rank - 1] + elem + rowOrColOffset; + } else { + elemOff[rank - 2] = + ctaOffsetX * shapePerCta[rank - 2] + elem + rowOrColOffset; + elemOff[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; + } + if (rank == 3) + elemOff[0] = bOff; + offsets.push_back(elemOff); + } + } +} + +inline SmallVector> +emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout, + RankedTensorType type) { + auto tensorShape = type.getShape(); + SmallVector> offsets; + auto shapePerCTA = getShapePerCTA(mfmaLayout, tensorShape); + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + auto rank = type.getRank(); + SmallVector numReps(rank); + unsigned mDim = mfmaLayout.getMDim(); + unsigned nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 2] = mDim; + shapePerWarp[rank - 1] = nDim; + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); + unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); + numReps[d] = ceil(inPerWarp, shapePerWarp[d]); + } + + unsigned repBatch = rank == 3 ? numReps[0] : 1; + auto warpsPerBatch = + rank == 3 ? std::min(tensorShape[0], warpsPerCTA[0]) : 1; + + for (unsigned b = 0; b < repBatch; ++b) { + for (unsigned i = 0; i < numReps[rank - 2]; ++i) { + for (unsigned j = 0; j < numReps[rank - 1]; ++j) { + emitMfmaOffsetForCTA(mfmaLayout, offsets, b * warpsPerBatch, i, j); + } + } + } + return offsets; +} + +inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout, + SmallVector> &offsets, + unsigned ctaBatchOffset, unsigned ctaOffsetX, + unsigned ctaOffsetY) { + const unsigned elemsPerThreadPerGroup = 8; + auto warpSize = getWarpSize(wmmaLayout); + assert(warpSize == 32); + auto shapePerCta = getShapePerCTATile(wmmaLayout); + auto rank = shapePerCta.size(); + assert(rank == 2 || rank == 3); + SmallVector elemOffset(rank, 0); + if (rank == 3) + elemOffset[0] = ctaBatchOffset; + for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { + elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; + elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; + offsets.push_back(elemOffset); + } +} + +inline SmallVector +emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, + const AMDWmmaEncodingAttr &wmmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto _warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + auto rank = _warpsPerCTA.size(); + assert(rank == 2 || rank == 3); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(wmmaLayout)); + Value laneId = + urem(threadId, i32_val(triton::gpu::getWarpSize(wmmaLayout) / 2)); + Value threadIdPerWarp = urem(threadId, warpSize); + + Value warpId = udiv(threadId, warpSize); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, _warpsPerCTA, + triton::gpu::getWarpOrder(wmmaLayout)); + if (shape[rank - 2] >= mnkDim[0]) { + assert(shape[rank - 2] % mnkDim[0] == 0); + multiDimWarpId[rank - 2] = + urem(multiDimWarpId[rank - 2], + i32_val(ceil(shape[rank - 2], mnkDim[0]))); + } + if (shape[rank - 1] >= mnkDim[1]) { + assert(shape[rank - 1] % mnkDim[1] == 0); + multiDimWarpId[rank - 1] = + urem(multiDimWarpId[rank - 1], + i32_val(ceil(shape[rank - 1], mnkDim[1]))); + } + Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mnkDim[0])); + Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(mnkDim[1])); + + SmallVector multiDimBase(rank); + + multiDimBase[rank - 2] = + add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + multiDimBase[rank - 1] = add(laneId, offWarp1); + + // TODO: It is assumed when rank = 3, warpsPerCTA is set to + // {numWarps, 1, 1}. We need to generalize the offset computation. + if (rank == 3) { + assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); + multiDimBase[0] = urem(warpId, i32_val(shape[0])); + } + return multiDimBase; +} + +inline SmallVector> +emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout, + RankedTensorType type) { + auto tensorShape = type.getShape(); + SmallVector> offsets; + auto shapePerCTA = getShapePerCTA(wmmaLayout, tensorShape); + auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + + auto rank = tensorShape.size(); + assert(rank == 2 || rank == 3); + + SmallVector numWarpsPerDim(rank, 1); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 2] = mnkDim[0]; + shapePerWarp[rank - 1] = mnkDim[1]; + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); + unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); + numWarpsPerDim[d] = ceil(inPerWarp, shapePerWarp[d]); + } + + unsigned repBatch = rank == 3 ? numWarpsPerDim[0] : 1; + unsigned repM = numWarpsPerDim[rank - 2]; + unsigned repN = numWarpsPerDim[rank - 1]; + auto warpsPerBatch = + rank == 3 ? std::min(tensorShape[0], warpsPerCTA[0]) : 1; + + for (unsigned b = 0; b < repBatch; ++b) { + for (unsigned i = 0; i < repM; ++i) { + for (unsigned j = 0; j < repN; ++j) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, b * warpsPerBatch, i, j); + } + } + } + return offsets; +} + +inline SmallVector> +emitOffsetForLayout(Attribute layout, RankedTensorType type); + +inline SmallVector> +emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout, + RankedTensorType type) { + auto parentEncoding = sliceLayout.getParent(); + unsigned dim = sliceLayout.getDim(); + auto parentShape = sliceLayout.paddedShape(type.getShape()); + RankedTensorType parentTy = + RankedTensorType::get(parentShape, type.getElementType(), parentEncoding); + auto parentOffsets = emitOffsetForLayout(parentEncoding, parentTy); + if (parentOffsets.empty()) + return {}; + + SmallVector> resultOffsets; + std::set> uniqueOffsets; + + for (unsigned i = 0; i < parentOffsets.size(); ++i) { + SmallVector offsets(parentOffsets[i].begin(), + parentOffsets[i].end()); + offsets.erase(offsets.begin() + dim); + if (auto [it, inserted] = uniqueOffsets.insert(offsets); inserted) { + resultOffsets.push_back(offsets); + } + } + + // It can happen that after deduplicating elements above, resultOffsets has + // fewer than getTotalElementsPerThread() elements. In that case repeat the + // sequence. + int elemsPerThread = triton::gpu::getTotalElemsPerThread(type); + assert(resultOffsets.size() > 0); + assert(elemsPerThread % resultOffsets.size() == 0); + int numRepeats = elemsPerThread / resultOffsets.size(); + SmallVector> ret; + for (int i = 0; i < numRepeats; ++i) { + for (unsigned j = 0; j < resultOffsets.size(); ++j) { + ret.push_back(SmallVector(resultOffsets[j])); + } + } + return ret; +} + +// ----------------------------------------------------------------------- +// Get offsets / indices for any layout +// ----------------------------------------------------------------------- + +inline SmallVector emitCTAOffsetForLayout(Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + Attribute layout, + ArrayRef shape) { + unsigned rank = shape.size(); + SmallVector CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); + SmallVector CTASplitNum = triton::gpu::getCTASplitNum(layout); + SmallVector CTAOrder = triton::gpu::getCTAOrder(layout); + SmallVector shapePerCTA = + triton::gpu::getShapePerCTA(CTASplitNum, shape); + + // Delinearize clusterCTAId + Value clusterCTAId = target.getClusterCTAId(rewriter, loc); + SmallVector multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + // CTA Wrapping + for (unsigned i = 0; i < rank; ++i) { + // This wrapping rule must be consistent with getShapePerCTA + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + multiDimClusterCTAId[i] = urem(multiDimClusterCTAId[i], i32_val(splitNum)); + } + + SmallVector CTAOffset(rank); + for (unsigned i = 0; i < rank; ++i) + CTAOffset[i] = mul(multiDimClusterCTAId[i], i32_val(shapePerCTA[i])); + + return CTAOffset; +} + +inline SmallVector +emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { + auto shape = type.getShape(); + + SmallVector baseIndex; + RewriterBase::InsertionGuard guard(rewriter); + SmallVector result; + if (auto blockedLayout = mlir::dyn_cast(layout)) { + result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter, + blockedLayout, type); + } else if (auto mmaLayout = mlir::dyn_cast(layout)) { + if (mmaLayout.isVolta()) + result = + emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter, mmaLayout, type); + if (mmaLayout.isAmpere() || mmaLayout.isHopper()) + result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, mmaLayout, + type); + } else if (auto mmaLayout = mlir::dyn_cast(layout)) { + if (mmaLayout.isVolta()) { + DEFINE_CALL_LOAD_FUNC(iluvatar, emitBaseIndexForTCULayout) + result = func(loc, rewriter, mmaLayout, type); + } + } else if (auto mfmaLayout = mlir::dyn_cast(layout)) { + result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type); + } else if (auto wmmaLayout = mlir::dyn_cast(layout)) { + result = emitBaseIndexForWmmaLayout(loc, rewriter, wmmaLayout, type); + } else if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(type.getShape()); + RankedTensorType parentTy = + RankedTensorType::get(parentShape, type.getElementType(), parentLayout); + result = emitBaseIndexForLayoutImpl(loc, rewriter, target, parentLayout, + parentTy, withCTAOffset); + result.erase(result.begin() + sliceLayout.getDim()); + // CTAOffset has been added in emitBaseIndexForLayout of parentLayout + return result; + } else { + llvm_unreachable("unsupported emitBaseIndexForLayout"); + } + if (withCTAOffset) { + auto CTAOffset = + emitCTAOffsetForLayout(loc, rewriter, target, layout, shape); + assert(CTAOffset.size() == result.size() && "Rank mismatch"); + for (unsigned k = 0; k < result.size(); ++k) { + // Individual elements of `result` may be null. In the caller + // (emitBaseIndexForLayout), we assert that all such dimensions are sliced + // off. + if (!result[k]) + continue; + result[k] = add(result[k], CTAOffset[k]); + } + } + return result; +} + +inline SmallVector +emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { + SmallVector idx = emitBaseIndexForLayoutImpl( + loc, rewriter, target, layout, type, withCTAOffset); + + // Check that any null values were sliced out. + for (Value v : idx) { + if (!v) { + llvm::errs() << "Failed to generate indexing code, possibly due to bad " + "#mma layout. Please rerun your program with " + "MLIR_ENABLE_DUMP=1 and file a bug." + << "\nloc: " << loc << "\nlayout: " << layout + << "\ntype: " << type << "\nwithCTAOffset: " << withCTAOffset + << "\n"; + llvm::report_fatal_error("Failed to generate indexing code"); + } + } + + return idx; +} + +inline SmallVector> +emitOffsetForLayout(Attribute layout, RankedTensorType type) { + if (auto blockedLayout = dyn_cast(layout)) + return emitOffsetForBlockedLayout(blockedLayout, type); + if (auto mmaLayout = dyn_cast(layout)) { + if (mmaLayout.isVolta()) + return emitOffsetForMmaLayoutV1(mmaLayout, type); + if (mmaLayout.isAmpere()) + return emitOffsetForMmaLayoutV2(mmaLayout, type); + if (mmaLayout.isHopper()) + return emitOffsetForMmaLayoutV3(mmaLayout, type); + } + if (auto mmaLayout = dyn_cast(layout)) { + if (mmaLayout.isVolta()) { + DEFINE_CALL_LOAD_FUNC(iluvatar, emitOffsetForTCULayout) + return func(mmaLayout, type); + } + } + if (auto mfmaLayout = mlir::dyn_cast(layout)) { + return emitOffsetForMfmaLayout(mfmaLayout, type); + } + if (auto wmmaLayout = mlir::dyn_cast(layout)) { + return emitOffsetForWmmaLayout(wmmaLayout, type); + } + if (auto sliceLayout = mlir::dyn_cast(layout)) + return emitOffsetForSliceLayout(sliceLayout, type); + llvm_unreachable("unsupported emitOffsetForLayout"); +} + +// Eventually this will become the only emitIndices function. +std::optional>> +emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset); + +// Emit indices calculation within each ConversionPattern, and returns a +// [elemsPerThread X rank] index matrix. +inline SmallVector> +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset, + bool allowLL = true) { + // Eventually the LinearLayout path will be the only one. For now we allow + // both paths so we can test that they produce the same results. + if (allowLL && target.enableLinearLayout()) { + std::optional>> llOffsets = + emitIndicesUsingLinearLayouts(loc, rewriter, target, layout, type, + withCTAOffset); + if (llOffsets.has_value()) + return *llOffsets; + } + + // step 1, delinearize threadId to get the base index + auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, target, layout, + type, withCTAOffset); + // step 2, get offset of each element + auto offset = emitOffsetForLayout(layout, type); + // step 3, add offset to base, and reorder the sequence + // of indices to guarantee that elems in the same + // sizePerThread are adjacent in order + auto shape = type.getShape(); + unsigned rank = shape.size(); + unsigned elemsPerThread = offset.size(); + SmallVector> multiDimIdx(elemsPerThread, + SmallVector(rank)); + for (unsigned n = 0; n < elemsPerThread; ++n) + for (unsigned k = 0; k < rank; ++k) + multiDimIdx[n][k] = add(multiDimBase[k], i32_val(offset[n][k])); + + return multiDimIdx; +} + +/* ---------------- */ +/* ---------------- */ + +inline DenseMap getSwizzledSharedPtrs( + Location loc, const TargetInfoBase &target, unsigned inVec, + RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout, + Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter, + SmallVectorImpl &offsetVals, SmallVectorImpl &srcStrides) { + // This utility computes the pointers for accessing the provided swizzled + // shared memory layout `resSharedLayout`. More specifically, it computes, + // for all indices (row, col) of `srcEncoding` such that idx % inVec = 0, + // the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] + + // colOff) where : + // phase = (row // perPhase) % maxPhase + // rowOff = row + // colOff = colOffSwizzled + colOffOrdered + // colOffSwizzled = ((col // outVec) ^ phase) * outVec + // colOffOrdered = (col % outVec) // minVec * minVec + // + // Note 1: + // ------- + // Because swizzling happens at a granularity of outVec, we need to + // decompose the offset into a swizzled factor and a non-swizzled + // (ordered) factor + // + // Note 2: + // ------- + // If we have x, y, z of the form: + // x = 0b00000xxxx + // y = 0byyyyy0000 + // z = 0b00000zzzz + // then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y + // This means that we can use some immediate offsets for shared memory + // operations. + auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); + auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); + Value dstPtrBase = gep(dstPtrTy, resElemTy, smemObj.base, dstOffset); + + auto srcEncoding = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto srcShapePerCTA = triton::gpu::getShapePerCTA(srcTy); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + // swizzling params as described in TritonGPUAttrDefs.td + unsigned outVec = resSharedLayout.getVec(); + unsigned perPhase = resSharedLayout.getPerPhase(); + unsigned maxPhase = resSharedLayout.getMaxPhase(); + // Order + auto inOrder = triton::gpu::getOrder(srcEncoding); + auto outOrder = triton::gpu::getOrder(resSharedLayout); + assert(maxPhase == 1 || + outVec * maxPhase <= srcShape[outOrder[0]] && + "Swizzling would generate out of bounds memory accesses"); + // Tensor indices held by the current thread, as LLVM values + auto srcIndices = emitIndices(loc, rewriter, target, srcEncoding, srcTy, + /*withCTAOffset=*/false); + // Swizzling with leading offsets (e.g. Hopper GMMA) + unsigned swizzlingByteWidth = 0; + if (resSharedLayout.getHasLeadingOffset()) { + if (perPhase == 4 && maxPhase == 2) + swizzlingByteWidth = 32; + else if (perPhase == 2 && maxPhase == 4) + swizzlingByteWidth = 64; + else if (perPhase == 1 && maxPhase == 8) + swizzlingByteWidth = 128; + else + llvm::report_fatal_error("Unsupported shared layout."); + } + unsigned numElemsPerSwizzlingRow = + swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth(); + Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow); + unsigned leadingDimOffset; + if (outOrder.size() >= 2) { + leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]]; + } else { + leadingDimOffset = numElemsPerSwizzlingRow; + } + + Value leadingDimOffsetVal = i32_val(leadingDimOffset); + // Return values + DenseMap ret; + // cache for non-immediate offsets + DenseMap cacheCol, cacheRow; + unsigned minVec = std::min(outVec, inVec); + Value strideRow = outOrder.size() >= 2 ? srcStrides[outOrder[1]] : i32_val(0); + Value strideCol = srcStrides[outOrder[0]]; + LDBG("getSwizzledSharedPtrs: perPhase = " + << perPhase << " maxPhase = " << maxPhase << " minVec = " << minVec + << " inVec = " << inVec << " outVec = " << outVec << " strideRow " + << strideRow << " strideCol " << strideCol); + for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { + Value offset = i32_val(0); + // Extract multi dimensional index for current element + auto idx = srcIndices[elemIdx]; + Value idxCol = idx[outOrder[0]]; // contiguous dimension + Value idxRow; + if (outOrder.size() >= 2) { + idxRow = idx[outOrder[1]]; // discontiguous dimension + } else { + idxRow = i32_val(0); + } + // compute phase = (row // perPhase) % maxPhase + Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase)); +#if defined(__ILUVATAR__) + // corex swizzle + bool isRow = outOrder[0] == 1; + Value off = NULL; + auto capability = getNVIDIAComputeCapability( + smemObj.base.getDefiningOp()->getParentOfType()); + if (resSharedLayout.getUseTcu() && idx.size() == 2) { + DEFINE_CALL_LOAD_FUNC(iluvatar, remapOffset) + off = func(idx[0], idx[1], srcTy, isRow, loc, rewriter, capability, + !perPhase); + } else { + off = add(mul(idxCol, strideCol), mul(idxRow, strideRow)); + } + ret[elemIdx] = gep(dstPtrTy, resElemTy, dstPtrBase, off); +#else + // extract dynamic/static offset for immediate offsetting + unsigned immedateOffCol = 0; + unsigned immedateOffRow = 0; + if (leadingDimOffset) { + // hopper + offset = + mul(udiv(idxCol, numElemsPerSwizzlingRowVal), leadingDimOffsetVal); + // Shrink by swizzling blocks + idxCol = urem(idxCol, numElemsPerSwizzlingRowVal); + strideRow = numElemsPerSwizzlingRowVal; + } + if (auto add = dyn_cast_or_null(idxCol.getDefiningOp())) { + if (auto _cst = dyn_cast_or_null( + add.getRhs().getDefiningOp())) { + unsigned cst = + cast(_cst.getValue()).getValue().getSExtValue(); + unsigned key = cst % (outVec * maxPhase); + cacheCol.insert({key, idxCol}); + idxCol = cacheCol[key]; + immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase); + } + } + if (auto add = dyn_cast_or_null(idxRow.getDefiningOp())) { + if (auto _cst = dyn_cast_or_null( + add.getRhs().getDefiningOp())) { + unsigned cst = + mlir::cast(_cst.getValue()).getValue().getSExtValue(); + unsigned key = cst % (perPhase * maxPhase); + cacheRow.insert({key, idxRow}); + idxRow = cacheRow[key]; + immedateOffRow = cst / (perPhase * maxPhase) * (perPhase * maxPhase); + } + } + // row offset is simply row index + Value rowOff = mul(idxRow, strideRow); + // because swizzling happens at a granularity of outVec, we need to + // decompose the offset into a swizzled factor and a non-swizzled + // (ordered) factor: colOffSwizzled = ((col // outVec) ^ phase) * outVec + // colOffOrdered = (col % outVec) // minVec * minVec + Value colOffSwizzled = xor_(udiv(idxCol, i32_val(outVec)), phase); + colOffSwizzled = mul(colOffSwizzled, i32_val(outVec)); + Value colOffOrdered = urem(idxCol, i32_val(outVec)); + colOffOrdered = udiv(colOffOrdered, i32_val(minVec)); + colOffOrdered = mul(colOffOrdered, i32_val(minVec)); + Value colOff = add(colOffSwizzled, colOffOrdered); + // compute non-immediate offset + if (outOrder.size() == 3) + offset = add(offset, mul(idx[outOrder[2]], srcStrides[outOrder[2]])); + offset = add(offset, add(rowOff, mul(colOff, strideCol))); + Value currPtr = gep(dstPtrTy, resElemTy, dstPtrBase, offset); + // compute immediate offset + Value immediateOff; + if (outOrder.size() >= 2) { + immediateOff = + add(mul(i32_val(immedateOffRow), strideRow), i32_val(immedateOffCol)); + } else { + immediateOff = i32_val(immedateOffCol); + } + + ret[elemIdx] = gep(dstPtrTy, resElemTy, currPtr, immediateOff); +#endif + } + return ret; +} + +inline SmallVector loadSharedToDistributed( + Value dst, Value src, SharedMemoryObject smemObj, Type elemTy, Location loc, + ConversionPatternRewriter &rewriter, const TargetInfoBase &target) { + auto dstTy = cast(dst.getType()); + auto dstShape = dstTy.getShape(); + assert(dstShape.size() <= 2 && "Unexpected rank of loadSharedToDistributed"); + auto srcTy = cast(src.getType()); + auto dstDistributedLayout = dstTy.getEncoding(); + if (auto mmaLayout = dyn_cast(dstDistributedLayout)) { + assert((!mmaLayout.isVolta()) && + "ConvertLayout Shared->MMAv1 is not supported yet"); + } + auto srcSharedLayout = + cast(srcTy.getEncoding()); + auto srcElemTy = srcTy.getElementType(); + auto dstElemTy = dstTy.getElementType(); + LDBG("loadSharedToDistributed elemTy " << elemTy << " srcElemTy " << srcElemTy + << " dstElemTy " << dstElemTy); + auto inOrd = triton::gpu::getOrder(srcSharedLayout); + auto outOrd = triton::gpu::getOrder(dstDistributedLayout); + unsigned outVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + dstDistributedLayout, dstShape)[outOrd[0]] + : 1; + + // If the shmem layout is not swizzled, we can trivially vectorize loads + // across the whole width of the most-minor dimension of the shape, because + // Triton requires all the dims are powers of 2. + unsigned inVec = srcSharedLayout.getMaxPhase() == 1 + ? srcTy.getShape()[inOrd[0]] + : srcSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy); + SmallVector offsetVals = {smemObj.strides.size(), i32_val(0)}; + + DenseMap sharedPtrs = + getSwizzledSharedPtrs(loc, target, outVec, dstTy, srcSharedLayout, elemTy, + smemObj, rewriter, offsetVals, smemObj.strides); + assert(outElems % minVec == 0 && "Unexpected number of elements"); + unsigned numVecs = outElems / minVec; + auto wordTy = vec_ty(elemTy, minVec); + SmallVector outVals(outElems); + for (unsigned i = 0; i < numVecs; ++i) { + Value smemAddr = sharedPtrs[i * minVec]; + smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); + auto valVec = load(wordTy, smemAddr); + valVec.setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8); + for (unsigned v = 0; v < minVec; ++v) { + Value currVal = extract_element(elemTy, valVec, i32_val(v)); + outVals[i * minVec + v] = currVal; + } + } + return outVals; +} + +inline void storeDistributedToShared(Value src, ArrayRef inVals, + ArrayRef dstStrides, Value dst, + Value smemBase, Type elemTy, Location loc, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &target) { + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + auto rank = srcShape.size(); + assert(rank <= 3 && "Unexpected rank of storeDistributedToShared"); + auto dstTy = cast(dst.getType()); + auto srcDistributedLayout = srcTy.getEncoding(); + if (auto mmaLayout = dyn_cast(srcDistributedLayout)) { + assert((!mmaLayout.isVolta()) && + "ConvertLayout MMAv1->Shared is not supported yet"); + } + auto dstSharedLayout = + cast(dstTy.getEncoding()); + auto dstElemTy = dstTy.getElementType(); + auto inOrd = triton::gpu::getOrder(srcDistributedLayout); + auto outOrd = dstSharedLayout.getOrder(); + unsigned inVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + srcDistributedLayout, srcShape)[inOrd[0]] + : 1; + // If the shmem layout is not swizzled, we can trivially vectorize stores + // across the whole width of the most-minor dimension of the shape, because + // Triton requires all the dims are powers of 2. +#ifdef __ILUVATAR__ + unsigned outVec = dstSharedLayout.getVec(); +#else + unsigned outVec = dstSharedLayout.getMaxPhase() == 1 + ? dstTy.getShape()[inOrd[0]] + : dstSharedLayout.getVec(); +#endif + unsigned minVec = std::min(outVec, inVec); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + auto wordTy = vec_ty(elemTy, minVec); + Value word; + + SmallVector srcStrides(dstStrides); + SmallVector offsetVals(rank, i32_val(0)); + SharedMemoryObject smemObj(smemBase, elemTy, srcStrides, offsetVals); + + DenseMap sharedPtrs = + getSwizzledSharedPtrs(loc, target, inVec, srcTy, dstSharedLayout, elemTy, + smemObj, rewriter, offsetVals, srcStrides); + LDBG("storeDistributedToShared: numElems = " << numElems << " minVec = " + << minVec << " " << wordTy); + for (unsigned i = 0; i < numElems; ++i) { + if (i % minVec == 0) + word = undef(wordTy); + word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec)); + if (i % minVec == minVec - 1) { + Value smemAddr = sharedPtrs[i / minVec * minVec]; + smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); + store(word, smemAddr) + .setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8); + } + } +} + +inline Value +getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, + ConversionPatternRewriter &rewriter) { + auto elems = smemObj.getElems(); + auto types = smemObj.getTypes(); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + // pack into struct + Value llvmStruct = rewriter.create(loc, structTy); + for (const auto &v : llvm::enumerate(elems)) { + assert(v.value() && "can not insert null values"); + llvmStruct = insert_val(structTy, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +inline SmallVector +unpackLLElements(Location loc, Value llvmStruct, + ConversionPatternRewriter &rewriter) { + assert(bool(llvmStruct) && "can not unpack null values"); + if (llvmStruct.getType().isIntOrIndexOrFloat() || + isa(llvmStruct.getType()) || + isa(llvmStruct.getType())) + return {llvmStruct}; + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector results(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + results[i] = extract_val(type, llvmStruct, i); + } + return results; +} + +inline Value packLLElements(Location loc, + const LLVMTypeConverter *typeConverter, + ValueRange resultVals, + ConversionPatternRewriter &rewriter, Type type) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (!structType) { + assert(resultVals.size() == 1); + return *resultVals.begin(); + } + + auto elementTypes = structType.getBody(); + if (elementTypes.size() != resultVals.size()) { + emitError(loc) << " size mismatch when packing elements for LLVM struct" + << " expected " << elementTypes.size() << " but got " + << resultVals.size(); + } + Value llvmStruct = rewriter.create(loc, structType); + for (const auto &v : llvm::enumerate(resultVals)) { + if (!v.value()) { + emitError(loc) + << "cannot insert null values into struct, but tried to insert" + << v.value(); + } + if (v.value().getType() != elementTypes[v.index()]) { + LDBG("type " << type << " structType " << structType); + LDBG("value " << v.value()); + emitError(loc) << "invalid element type in packLLEElements. Expected " + << elementTypes[v.index()] << " but got " + << v.value().getType(); + } + llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +inline bool isLayoutMmaV1(Attribute layout) { + bool isMmaV1 = false; + if (auto mmaLayout = dyn_cast(layout)) { + isMmaV1 = mmaLayout.isVolta(); + } + if (auto sliceLayout = dyn_cast(layout)) { + isMmaV1 = isa(sliceLayout.getParent()) && + cast(sliceLayout.getParent()).isVolta(); + } + return isMmaV1; +} + +} // namespace mlir + +#endif diff --git a/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..99d90c4d7 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPU) +add_public_tablegen_target(TritonConversionPassIncGen) diff --git a/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/Passes.h b/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/Passes.h new file mode 100644 index 000000000..e159406b3 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_CONVERSION_PASSES_H +#define TRITON_CONVERSION_PASSES_H + +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/Passes.td b/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/Passes.td new file mode 100644 index 000000000..84150fe67 --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/Passes.td @@ -0,0 +1,37 @@ +#ifndef TRITON_CONVERSION_PASSES +#define TRITON_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> { + let summary = "Convert Triton to TritonGPU"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + // TODO: Does this pass depend on SCF? + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; + + let options = [ + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps">, + + Option<"threadsPerWarp", "threads-per-warp", + "int32_t", /*default*/"32", + "number of threads per warp">, + Option<"numCTAs", "num-ctas", + "int32_t", /*default*/"1", + "number of ctas in a cga">, + Option<"target", "target", + "std::string", /*default*/"\"\"", + "the GPU target, e.g., cuda:80, hip:gfx942"> + ]; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h new file mode 100644 index 000000000..d3da1394e --- /dev/null +++ b/third_party/iluvatar/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -0,0 +1,31 @@ +#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H +#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H + +#include +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps"; +constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas"; +constexpr static char AttrTargetName[] = "triton_gpu.target"; + +constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp"; + +// Create the pass with numWarps passed from cl::opt. +std::unique_ptr> createConvertTritonToTritonGPUPass(); + +// Create the pass with numWarps set explicitly. +std::unique_ptr> +createConvertTritonToTritonGPUPass(const std::string &target, int numWarps, + int threadsPerWarp = 32, int numCTAs = 1); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/CMakeLists.txt b/third_party/iluvatar/include/triton/Dialect/CMakeLists.txt new file mode 100644 index 000000000..27cb65ce5 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/CMakeLists.txt b/third_party/iluvatar/include/triton/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/CMakeLists.txt b/third_party/iluvatar/include/triton/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..f682f54a1 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,27 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS TritonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + +set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td) +mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) + +add_public_tablegen_target(TritonTableGen) diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/Dialect.h b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Dialect.h new file mode 100644 index 000000000..7a98204dd --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Dialect.h @@ -0,0 +1,83 @@ +#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITON_IR_DIALECT_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h.inc" +#include "triton/Dialect/Triton/IR/OpsEnums.h.inc" +#include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { + +struct GlobalMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +class DialectInferLayoutInterface + : public DialectInterface::Base { +public: + DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + inferTransOpEncoding(Attribute operandEncoding, ArrayRef order, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + bool noWarpReduce, Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const = 0; + + // Note: This function only verifies the operand encoding. It doesn't infer + // the result encoding. + virtual LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const = 0; + + // Tries to compute the encoding for the result of a reshape operation that + // makes the reshape a "nop", i.e. the same GPU threads contain the same + // elements as before the reshape. Note that this is not always possible (in + // which case you'd need to choose a different layout for the input to the + // reshape). + virtual LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + // Verify that the encoding are compatible to be used together in a dot + // operation + virtual LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const = 0; +}; + +} // namespace triton +} // namespace mlir + +#endif // TRITON_IR_DIALECT_H_ diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/Interfaces.h b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Interfaces.h new file mode 100644 index 000000000..f8f3a6f74 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Interfaces.h @@ -0,0 +1,9 @@ +#ifndef TRITON_IR_INTERFACES_H_ +#define TRITON_IR_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/Traits.h b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Traits.h new file mode 100644 index 000000000..f34a0fd59 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Traits.h @@ -0,0 +1,120 @@ +#ifndef TRITON_IR_TRAITS_H_ +#define TRITON_IR_TRAITS_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LogicalResult.h" + +#include + +namespace mlir { +namespace OpTrait { + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +// The rationale for this trait is to prevent users from creating programs +// that would have catastrophic register pressure and cause the compiler to +// hang. +// Since H100 has 256KB registers, we should allow users to create tensors +// of size up to 256K elements. It will spill for datatypes wider than 1B, +// but we probably should limit number of elements (rather than bytes) to +// keep specs simple +int constexpr maxTensorNumElements = 1048576; + +LogicalResult verifyTensorSize(Operation *op); +LogicalResult verifyTensorLayouts(Operation *op); + +LogicalResult verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType = false); + +LogicalResult +verifySameOperandsAndResultEncoding(Operation *op, + bool allowTensorPointerType = false); + +LogicalResult verifySameLoadStoreOperandsShape(Operation *op); + +LogicalResult verifySameLoadStoreOperandsAndResultShape(Operation *op); + +} // namespace impl + +template +class TensorSizeTrait : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorSize(op); + } +}; + +// Trait applied to all Triton MLIR ops. Checks that the layouts of tensors are +// valid. +template +class VerifyTensorLayoutsTrait + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorLayouts(op); + } +}; + +template +class SameOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding(op); + } +}; + +template +class SameOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op); + } +}; + +template +class SameLoadStoreOperandsShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsShape(op); + } +}; + +template +class SameLoadStoreOperandsAndResultShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsAndResultShape(op); + } +}; + +template +class SameLoadStoreOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op, + /*allowTensorPointerType=*/true); + } +}; + +template +class SameLoadStoreOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding( + op, /*allowTensorPointerType=*/true); + } +}; + +} // namespace OpTrait +} // namespace mlir + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonAttrDefs.td new file mode 100644 index 000000000..adfeaff6f --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -0,0 +1,121 @@ +#ifndef TRITON_ATTR_DEFS +#define TRITON_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Attributes for LoadOp and StoreOp +def TT_CacheModifierAttr : I32EnumAttr< + "CacheModifier", "", + [ + I32EnumAttrCase<"NONE", 1, "none">, + I32EnumAttrCase<"CA", 2, "ca">, + I32EnumAttrCase<"CG", 3, "cg">, + I32EnumAttrCase<"WB", 4, "wb">, + I32EnumAttrCase<"CS", 5, "cs">, + I32EnumAttrCase<"WT", 6, "wt">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_EvictionPolicyAttr : I32EnumAttr< + "EvictionPolicy", "", + [ + I32EnumAttrCase<"NORMAL", 1, "evict_normal">, + I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, + I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_PaddingOptionAttr : I32EnumAttr< + "PaddingOption", "", + [ + I32EnumAttrCase<"PAD_ZERO", 1, "zero">, + // We can not set the string value to "NAN" because it is a keyword in C++ + I32EnumAttrCase<"PAD_NAN", 2, "nan"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +// atomic +def TT_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"FADD", 5, "fadd">, + I32EnumAttrCase<"MAX", 6, "max">, + I32EnumAttrCase<"MIN", 7, "min">, + I32EnumAttrCase<"UMAX", 8, "umax">, + I32EnumAttrCase<"UMIN", 9, "umin">, + I32EnumAttrCase<"XCHG", 10, "exch"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Program ID dimensions. +def TT_ProgramDim : I32EnumAttr< + "ProgramIDDim", "", + [ + I32EnumAttrCase<"X", 0, "x">, + I32EnumAttrCase<"Y", 1, "y">, + I32EnumAttrCase<"Z", 2, "z">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Rounding mode. +def TT_RoundingModeAttr : I32EnumAttr< + "RoundingMode", "", + [ + I32EnumAttrCase<"RTZ", 0, "rtz">, + I32EnumAttrCase<"RTNE", 1, "rtne">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// PropagateNan. +def TT_PropagateNanAttr : I32EnumAttr< + "PropagateNan", "", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"ALL", 0xFFFF, "all">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// InputPrecision +def TT_InputPrecisionAttr : I32EnumAttr< + "InputPrecision", "", + [ + I32EnumAttrCase<"TF32", 0, "tf32">, + I32EnumAttrCase<"TF32x3", 1, "tf32x3">, + I32EnumAttrCase<"IEEE", 2, "ieee"> + ]>{ + let cppNamespace = "::mlir::triton"; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonDialect.td b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonDialect.td new file mode 100644 index 000000000..c917538c7 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -0,0 +1,46 @@ +#ifndef TRITON_DIALECT +#define TRITON_DIALECT + +include "mlir/IR/OpBase.td" + +def Triton_Dialect : Dialect { + let name = "tt"; + + let cppNamespace = "::mlir::triton"; + + let summary = "The Triton IR in MLIR"; + + let description = [{ + Triton Dialect. + + Dependent Dialects: + * Arith: + * addf, addi, andi, cmpf, cmpi, divf, fptosi, ... + * Math: + * exp, sin, cos, log, ... + * StructuredControlFlow: + * for, if, while, yield, condition + * ControlFlow: + * br, cond_br + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "math::MathDialect", + "scf::SCFDialect", + "cf::ControlFlowDialect" + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +include "triton/Dialect/Triton/IR/TritonTypes.td" + + +#endif // TRITON_DIALECT diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonInterfaces.td new file mode 100644 index 000000000..cfc7d0032 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -0,0 +1,15 @@ +#ifndef TRITON_INTERFACES +#define TRITON_INTERFACES + +include "mlir/IR/OpBase.td" + +def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; +def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">; +def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">; +def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">; +def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">; +def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAndResultShape">; +def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">; +def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">; + +#endif // TRITON_INTERFACES diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonOps.td new file mode 100644 index 000000000..7bbde2c72 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonOps.td @@ -0,0 +1,1154 @@ +#ifndef TRITON_OPS +#define TRITON_OPS + +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" + + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +// +// Op Base +// +class TT_Op traits = []> : + Op { +} + +// +// Cast Ops +// +// Use cast ops in arith: +// bitcast +// fptoui, fptosi, uitofp, sitofp, +// extf, tructf, +// extui, extsi, tructi +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast int64 to pointer"; + + let arguments = (ins TT_I64Like:$src); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast pointer to int64"; + + let arguments = (ins TT_PtrLike:$src); + + let results = (outs TT_I64Like:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +// arith.bitcast doesn't support pointers +def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast between types of the same bitwidth"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + // TODO: Add verifier +} + +def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Floating point casting for custom types"; + + let description = [{ + Floating point casting for custom types (F8), and non-default rounding modes. + + F8 <-> FP16, BF16, FP32, FP64 + }]; + + let arguments = ( + ins TT_FloatTensor:$src, + OptionalAttr:$rounding + ); + + let results = (outs TT_FloatTensor:$result); + + let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)"; + + let hasVerifier = 1; +} + +// +// Arithmetic Ops +// + +def TT_ClampFOp : TT_Op<"clampf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Clamp operation for floating point types"; + + let description = [{ + Clamp operation for floating point types. + + The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. + }]; + + let arguments = ( + ins + TT_FloatLike:$x, + TT_FloatLike:$min, + TT_FloatLike:$max, + TT_PropagateNanAttr:$propagateNan + ); + + let results = (outs TT_FloatLike:$result); + + // List $propagateNan explicitly rather than relying on attr-dict to pick it + // up, because if it's inside attr-dict, its value will be printed as a + // number rather than as a meaningful string. + let assemblyFormat = "$x `,` $min `,` $max `,` `propagateNan` `=` $propagateNan attr-dict `:` type($result)"; +} + +// +// Math Ops +// + +def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise sqrt for floating point types"; + + let description = [{ + Precise sqrt for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x attr-dict `:` type($x)"; +} + +def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise div for floating point types"; + + let description = [{ + Precise div for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$y); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Most significant N bits of the 2N-bit product of two integers"; + + let description = [{ + Most significant N bits of the 2N-bit product of two integers. + }]; + + let arguments = (ins TT_IntLike:$x, TT_IntLike:$y); + + let results = (outs TT_IntLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +// +// Pointer Arith Ops +// +def TT_AddPtrOp : TT_Op<"addptr", + [Pure, + Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; +} + +def TT_AdvanceOp : TT_Op<"advance", + [Pure, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let summary = "Advance a tensor pointer by offsets"; + + let arguments = (ins TT_TensorPtr:$ptr, Variadic:$offsets); + + let results = (outs TT_TensorPtr:$result); + + let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)"; +} + +// +// Load/Store Ops +// +def TT_LoadOp : TT_Op<"load", [ + // SameLoadStoreOperandsAndResultShape, + // SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", + // "($_op.getOperands().size() <= 1) || std::equal_to<>()">, + "($_op.getOperands().size() != 2) || std::equal_to<>()">, + TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)", + // "($_op.getOperands().size() <= 2) || std::equal_to<>()"> + "($_op.getOperands().size() != 3) || std::equal_to<>()"> +]> { + let summary = "Load from a tensor of pointers or from a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + Optional:$mask, + Optional:$other, + + DefaultValuedAttr{}">:$boundaryCheck, + OptionalAttr:$padding, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile, + Optional:$inputStride, + Optional:$placeHolder0, Optional:$placeHolder1 + ); + + let results = (outs TT_Type:$result); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor pointer with boundary check and padding + OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask and other + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A utility function to build the operation with all attributes + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + ]; + + // // Specify `cacheModifier` and `evictionPolicy` explicitly in the + // // assemblyFormat instead of as part of attr-dict so that they get printed + // // as strings rather than opaque integers. + + // // Note there's no comma between `other` and `cacheModifier` and between + // // `cacheModifier` and `evictionPolicy`. This is due to an apparent + // // limitation in the MLIR custom-format parser. In oilist, the initial + // // keywords of each clause have to be unique, so they can't be `,`. + + // // Even if we gave up on order-independence and used vanilla optional + // // clauses, the format (`,` `foo` `=` $foo^)? (`,` `bar` `=` $bar^)? will + // // not match the string ", bar = 0" because after the initial comma (first + // // token of the first optional clause) we expect to see "foo". + // let assemblyFormat = [{ + // $ptr (`,` $mask^)? (`,` $other^)? + // oilist( + // `cacheModifier` `=` $cache | + // `evictionPolicy` `=` $evict + // ) + // attr-dict `:` type($ptr) + // }]; + let hasCustomAssemblyFormat = 1; + + let hasCanonicalizer = 1; +} + +def TT_StoreOp : TT_Op<"store", [ + SameLoadStoreOperandsShape, + SameLoadStoreOperandsEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"value type matches ptr type", "ptr", "value", + "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", + "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Store by a tensor of pointers or by a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + TT_Type:$value, + Optional:$mask, + DefaultValuedAttr{}">:$boundaryCheck, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)>, + // A tensor pointer with boundary check + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$boundaryCheck, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)> + ]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between mask, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $ptr `,` $value (`,` $mask^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +// +// Atomic Ops +// +def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"mask type matches value type", + "val", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "atomic rmw"; + + let description = [{ + load data at $ptr, do $rmw_op with $val, and store result to $ptr. + + return old value at $ptr + }]; + + let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr, + TT_Type:$val, Optional:$mask, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $atomic_rmw_op, $sem, and $scope rather than relying on + // attr-dict so they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` + functional-type(operands, $result) + }]; +} + +def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { + let summary = "atomic cas"; + + let description = [{ + compare $cmp with data $old at location $ptr, + + if $old == $cmp, store $val to $ptr, + + else store $old to $ptr, + + return $old + }]; + + let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $sem and $scope rather than relying on attr-dict so + // they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` + functional-type(operands, $result) + }]; +} + +// +// Shape Manipulation Ops +// +def TT_SplatOp : TT_Op<"splat", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "splat"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; +} + +def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "expand_dims"; + + let arguments = (ins TT_Tensor:$src, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +def TT_ReshapeOp : TT_Op<"reshape", [Pure, + SameOperandsAndResultElementType]> { + let summary = "reinterpret a tensor to a different shape. It may change elements order if the attribute is set."; + let description = [{ + reinterpret a tensor to a different shape. + + If allow_reorder is set the compiler is free to change the order of + elements to generate more efficient code. + + If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. + The compiler is still free to change it for better performance. + }]; + let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr:$efficient_layout); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + let hasCanonicalizeMethod = 1; + let hasFolder = 1; + let hasVerifier = 1; + let builders = [ + OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder), + [{ + build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr()); + }]>]; +} + +def TT_BroadcastOp : TT_Op<"broadcast", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "broadcast a tensor"; + + let description = [{ + For a given tensor, broadcast changes one or more dimensions with size 1 + to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot + change the size of a non-1 dimension. + }]; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +// cat is not `pure` because it may reorder elements +def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, + SameTypeOperands, + SameOperandsAndResultElementType]> { + let summary = "concatenate 2 tensors"; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_JoinOp : TT_Op<"join", [ + NoMemoryEffect, SameTypeOperands, + DeclareOpInterfaceMethods, +]> { + let summary = "join two tensors along a new, minor dimension"; + let description = [{ + For example, if the two input tensors are 4x8xf32, returns a tensor of + shape 4x8x2xf32. + + Because Triton tensors always have a power-of-two number of elements, + the two input tensors must have the same shape. + }]; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_SplitOp : TT_Op<"split", [ + NoMemoryEffect, + DeclareOpInterfaceMethods, + TypesMatchWith<"outLHS and outRHS types match", + "outLHS", "outRHS", "$_self">, +]> { + let summary = "splits a tensor into two, along its last dimension"; + let description = [{ + The input must be a tensor whose last dimension has size 2. Returns two + tensors, src[..., 0] and src[..., 1]. + + For example, if the input shape is 4x8x2xf32, returns two tensors of + shape 4x8xf32. + }]; + + let arguments = (ins TT_Tensor:$src); + let results = (outs TT_Tensor:$outLHS, TT_Tensor:$outRHS); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($outLHS)"; +} + +def TT_TransOp : TT_Op<"trans", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + + let summary = "rearrange the dimensions of a tensor"; + let description = [{ + For example, given a tensor x with shape [1,2,4], transpose(x) with + order=[2,0,1] rearranges the tensor to have shape [4,1,2]. + + Although this op is called "trans", it implements both tl.trans() and + tl.permute(). ("permute" might be a better name, but it's called "trans" + because originally it only supported 2D tensors.) + + ## Implementation note on encodings: + + In the TritonGPU dialect (and probably others), an encoding is chosen for + this op's output so it's a nop from the perspective of code generation. + + For example, suppose tensor x has an encoding such that GPU thread [i,j,k] + has a register containing element [i,j,k] of the tensor. Now we transpose + x with order [2,1,0], i.e. we reverse the order of its dimensions. In + TritonGPU, we will choose a layout for the output of the transpose so that + GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the + same element it had before! All we've done is "rename" the element that + thread [i,j,k] has. + + The "real" transpose -- i.e. moving data between GPU threads -- occurs in + convertLayout ops that appear before and/or after the operation. + + We do this so that you can chain multiple data-movement ops (e.g. + transpose+reshape+concat) without going to shared memory after each one. + }]; + + let arguments = ( + ins TT_TensorOrMemDesc:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorOrMemDesc:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// SPMD Ops +// +def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +// +// Dot Op +// +def TT_DotOp : TT_Op<"dot", [Pure, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC + when the inputs are f32. It can be one of: tf32, tf32x3, ieee. + tf32: use TC with tf32 ops. + tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + ieee: don't use TC, implement dot in software. + If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. + }]; + + let arguments = ( + ins + TT_TensorOrMemDesc:$a, + TT_TensorOrMemDesc:$b, + TT_FpIntTensor:$c, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs TT_FpIntTensor:$d); + + // attr-dict prints enums as integers. To get inputPrecision printed as a + // string, we need to specify it explicitly. + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; + let hasVerifier = 1; +} + +// +// Reduce Op +// +def TT_ReduceOp: TT_Op<"reduce", + [Pure, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis, BoolAttr:$noWarpReduce); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "bool":$noWarpReduce)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ReduceReturnOp: TT_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +// +// Scan Op +// +def TT_ScanOp: TT_Op<"scan", + [Pure, + SameOperandsAndResultEncoding, + SameOperandsAndResultShape, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Associative scan using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis, BoolAttr:$reverse); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "bool":$reverse)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ScanReturnOp: TT_Op<"scan.return", + [HasParent<"ScanOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for scan operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + + +// +// External Elementwise op +// +def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; +} + +// +// Make Range Op +// +def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { + let summary = "make range"; + + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + + // WARNING: MLIR generates getStart()/getEnd() functions which return + // uint32_t, even though these arguments are to be interpreted as *signed* + // int32 values. If this matters, use get{Start,End}Attr().getInt(), which + // return int64_t. + let arguments = (ins I32Attr:$start, I32Attr:$end); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// ElementwiseInlineAsm Op +// +def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ + Elementwise, + SameOperandsAndResultEncoding, + DeclareOpInterfaceMethods +]> { + let summary = "inline assembly applying an elementwise operation to a group of packed elements."; + let description = [{ + Runs an inline asm block to generate one or more tensors. + + The asm block is given `packed_element` elements at a time. Exactly which + elems it receives is unspecified. + }]; + + let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic>:$args); + let results = (outs Variadic:$result); + + let assemblyFormat = [{ + $asm_string attr-dict ($args^ `:` type($args))? `->` type($result) + }]; + + let hasVerifier = 1; +} + +// +// Histogram Op +// +def TT_HistogramOp : TT_Op<"histogram", [Pure]> { + let summary = "return a histgram of the inputs."; + let description = [{ + Return the histogram of the input tensor. The number of bins is equal to + the dimension of the output tensor. Each bins has a width of 1 and bins + start at 0. + }]; + + let arguments = (ins TT_IntTensor:$src); + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = [{ + $src attr-dict `:` type($src) `->` type($result) + }]; +} + +// +// Print Op +// +def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>, + Arguments<(ins StrAttr:$prefix, BoolAttr:$hex, Variadic>:$args)> { + let summary = "Device-side print, as in CUDA for debugging"; + let description = [{ + `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. + format are generated automatically from the arguments. + }]; + let assemblyFormat = [{ + $prefix attr-dict (`:` $args^ `:` type($args))? + }]; +} + +// +// Assert Op +// +def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "Device-side assert, as in CUDA for correctness checking"; + let description = [{ + `tt.assert` takes a condition tensor, a message string, a file string, a function string, and a line number. + If the condition is false, the message is printed, and the program is aborted. + }]; + let arguments = (ins TT_Tensor:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line); + let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)"; +} + +// +// Make Tensor Pointer Op +// +def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", + [Pure, + SameVariadicOperandSize, + TypesMatchWith<"infer pointer type from the result type", + "result", "base", + "getPointerType(getElementTypeOfTensorPointerType($_self))">]> { + let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified"; + + let description = [{ + `tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a + pointer to the block tensor, e.g. returns a type of `tt.ptr>`. + }]; + + // TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints. + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorPtr:$result); + + // TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly + // Add additional `[]` to increase readability and split variadic lists + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins + "Value":$base, + "ValueRange":$shape, + "ValueRange":$strides, + "ValueRange":$offsets, + "ArrayRef":$tensorShape, + "ArrayRef":$order + )> + ]; +} + +// The following ops, including `call`, `func`, and `return` are copied and modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +// We could revert it back once MLIR has a better inliner interface. +// +// Function Ops +// +def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `tt.call` operation represents a direct call to a function that is + within the same symbol scope as the call. The operands and result types of + the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); + } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + + // Required by CallOpInterface. + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). An external function declaration (used when + referring to a function declared in some other module) has no body. While + the MLIR textual form provides a nice inline syntax for function arguments, + they are internally represented as “block arguments” to the first block in + the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // External function definitions. + tt.func @abort() + tt.func @scribble(i32, i64, memref) -> f64 + + // A function that returns its argument twice: + tt.func @count(%x: i64) -> (i64, i64) + attributes {fruit: "banana"} { + return %x, %x: i64, i64 + } + + // A function with an argument attribute + tt.func @example_fn_arg(%x: i32 {swift.self = unit}) + + // A function with a result attribute + tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + + // A function with an attribute + tt.func @example_fn_attr() attributes {dialectName.attrName = false} + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getFunctionType().getResults(); } + + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; +} + +def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `tt.return` operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + The operand number and types must match the signature of the function + that contains the operation. + + Example: + + ```mlir + tt.func @foo() : (i32, f8) { + ... + tt.return %0, %1 : i32, f8 + } + ``` + }]; + + let arguments = (ins Variadic:$srcs); + + let builders = [OpBuilder<(ins), [{ + build($_builder, $_state, std::nullopt); + }]>]; + + let assemblyFormat = "attr-dict ($srcs^ `:` type($srcs))?"; + let hasVerifier = 1; +} + + +def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ + MemoryEffects<[MemRead]>]> { + let summary = "Load from descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA load operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The destination tensor type and shape must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + Variadic:$indices, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` qualified(type($desc_ptr)) `->` type($result) + }]; +} + +def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ + MemoryEffects<[MemWrite]>]> { + let summary = "store value based on descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The shape and types of `src` must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_Tensor:$src, + Variadic:$indices + ); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc_ptr)) `,` type($src) + }]; +} + +#endif // Triton_OPS diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td new file mode 100644 index 000000000..e3aed2262 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td @@ -0,0 +1,24 @@ +#ifndef TRITON_TYPE_INTERFACES +#define TRITON_TYPE_INTERFACES + +include "mlir/IR/OpBase.td" + +// Interface dynamically attached to RankedTensorType and MemDescType. +def TT_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"Returns the encoding of the tensor or memory descriptor", + "mlir::Attribute", "getEncoding", (ins)>, + InterfaceMethod<"Returns element type", + "mlir::Type", "getElementType", (ins)>, + InterfaceMethod<"Returns the type shape", + "llvm::ArrayRef", "getShape", (ins)>, + InterfaceMethod<"Returns the tensor or buffer rank", + "int64_t", "getRank", (ins)>, + InterfaceMethod<"Returns the element type bit width", + "int64_t", "getElementTypeBitWidth", (ins)>, + + ]; +} + +#endif // TRITON_TYPE_INTERFACES diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonTypes.td b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonTypes.td new file mode 100644 index 000000000..fd5af9cc8 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -0,0 +1,140 @@ +#ifndef TRITON_TYPES +#define TRITON_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" + +// +// Types +// +class TritonTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def TT_FloatTensor : RankedTensorOf<[TT_Float]>; +def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def TT_BoolTensor : RankedTensorOf<[I1]>; +def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; + +// Integer Type +def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def TT_IntTensor : RankedTensorOf<[TT_Int]>; +def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>; + +// I32 Type +// TT_I32 -> I32 +// TT_I32Tensor -> I32Tensor +def TT_I32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// TT_I64 -> I64 +// TT_I64Tensor -> I64Tensor +def TT_I64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class TT_PtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `TT_PtrOf`) +def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def TT_Ptr : TT_PtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def TT_PtrTensor : RankedTensorOf<[TT_Ptr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>; + +// Tensor Type +def TT_FpIntTensor : RankedTensorOf<[TT_Float, TT_Int]>; +def TT_Tensor : RankedTensorOf<[TT_Float, TT_Int, TT_Ptr]>; + +// Pointer Type to Tensor Type: `ptr>` +def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>; + +// Any Type in Triton IR +def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>; + +// Memory descriptor type. +def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { + let summary = "memory descriptor type (`::mlir::triton::MemDescType`) in Triton IR type system"; + + let description = [{ + Memory descriptor contains a base pointer (scalar) and a descriptor of the memory. + If mutable memory is false that means the memory is constant and can only be allocated and stored once. + A constant memory allocation is different than a tensor as it can have multiple views and the descriptor + can be changed without changing the underlying memory. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType, + "Attribute":$encoding, + "bool":$mutable_memory + ); + let extraClassDeclaration = [{ + MemDescType cloneWith(std::optional> shape, + Type elementType) const { + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding()); + } + + bool hasRank() const { return true; } + }]; + let builders = [ + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, /*mutableMemory=*/false); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "bool":$mutableMemory + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, mutableMemory); + }]> + ]; + let hasCustomAssemblyFormat = 1; +} + + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/Types.h b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Types.h new file mode 100644 index 000000000..bf1967f1b --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Types.h @@ -0,0 +1,39 @@ +#ifndef TRITON_IR_TYPES_H_ +#define TRITON_IR_TYPES_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.h.inc" + +#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc" + +namespace mlir { + +namespace triton { + +bool isTensorPointerType(Type type); + +bool isTensorOrTensorPointerType(Type type); + +unsigned getPointeeBitWidth(Type type); + +Type getPointeeType(Type type); + +Type getPointerType(Type type); + +Type getElementTypeOfTensorPointerType(Type type); + +Type getI1SameShape(Type type); + +Type getI32SameShape(Type type); + +Type getPointerTypeSameShape(Type type); + +} // namespace triton + +} // namespace mlir + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/IR/Utility.h b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Utility.h new file mode 100644 index 000000000..0ef597147 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/IR/Utility.h @@ -0,0 +1,190 @@ +#ifndef TRITON_IR_UTILITY_H_ +#define TRITON_IR_UTILITY_H_ + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include +#include + +namespace mlir { + +template SmallVector convertType(ArrayRef in) { + SmallVector out; + for (const auto &i : in) + out.push_back(T(i)); + return out; +} + +template +SmallVector convertType(const VecU &in) { + return convertType(ArrayRef(in)); +} + +template Int product(llvm::ArrayRef arr) { + return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{}); +} +template auto product(const VecT &vec) { + return product(llvm::ArrayRef(vec)); +} + +// TODO(jlebar): Rename to ceilOfRatio. +template Int ceil(Int m, Int n) { return (m + n - 1) / n; } + +/// Get the highest power of 2 divisor of an integer. +template T highestPowOf2Divisor(T n) { + if (n == 0) { + return (static_cast(1) << (sizeof(T) * 8 - 2)); + } + return (n & (~(n - 1))); +} + +/// Get the next power of 2 for an integer (or the integer itself if it is a +/// power of 2). +template T nextPowOf2(T n) { + if (n == 0) { + return 1; + } + n--; + for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) { + n |= n >> i; + } + return n + 1; +} + +namespace triton { + +// Many functions here have two overloads, fn(ArrayRef) and fn(const VecT&). +// This is helpful because C++ won't both convert a vector to ArrayRef *and* +// infer the proper type T in one step. So without the second overload, we +// would have to explicitly convert most arguments to ArrayRef at the callsite. + +template +SmallVector applyPermutation(ArrayRef vec, ArrayRef permutation) { + static_assert(std::is_integral_v); + assert(vec.size() == permutation.size()); + + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (U i = 0; i < static_cast(sortedPerm.size()); i++) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret; + ret.reserve(vec.size()); + for (const U &i : permutation) { + ret.push_back(vec[i]); + } + return ret; +} + +template +auto applyPermutation(const VecT &vec, const PermT &permutation) { + return applyPermutation(ArrayRef(vec), ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector inversePermutation(ArrayRef permutation) { + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (int i = 0; i < sortedPerm.size(); ++i) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret(permutation.size()); + for (int i = 0; i < permutation.size(); ++i) { + ret[permutation[i]] = i; + } + return ret; +} + +template +[[nodiscard]] auto inversePermutation(const VecT &permutation) { + return inversePermutation(ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector gather(ArrayRef elems, ArrayRef indices) { + SmallVector ret; + ret.reserve(indices.size()); + for (const U &i : indices) { + ret.push_back(elems[i]); + } + return ret; +} + +template +[[nodiscard]] auto gather(const VecT &elems, const IdxT &indices) { + return gather(ArrayRef(elems), ArrayRef(indices)); +} + +// Is `vec` [0, 1, ..., n]? Returns true on empty list. +template bool isIota(ArrayRef vec) { + static_assert(std::is_integral_v); + for (T i = 0; i < vec.size(); ++i) { + if (vec[i] != i) { + return false; + } + } + return true; +} + +template bool isIota(const VecT &vec) { + return isIota(ArrayRef(vec)); +} + +// Is `vals` some permutation of the numbers 0..(vals.size()-1)? +template bool isPermutationOfIota(ArrayRef vals) { + SmallVector sorted(vals); + llvm::sort(sorted); + return isIota(sorted); +} + +template bool IsPermutationOfIota(const VecT &vec) { + return isPermutationOfIota(ArrayRef(vec)); +} + +// Is `vec` [i, i+1, ..., i+n]? Returns true on empty list. +template bool isConsecutive(ArrayRef vec) { + static_assert(std::is_integral_v); + for (int i = 1; i < vec.size(); i++) { + if (vec[i] != vec[i - 1] + 1) { + return false; + } + } + return true; +} + +template bool isConsecutive(const VecT &vec) { + return isConsecutive(ArrayRef(vec)); +} + +// LLVM's STLExtras.h provides a bunch of functions that work over ranges, but +// it's missing min/max_element until +// https://github.com/llvm/llvm-project/commit/fab2bb8b makes it into Triton. +// TODO(jlebar): Remove this once we have the LLVM helpers. +template auto min_element(R &&Range) { + return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range)); +} +template +auto min_element(R &&Range, Compare &&C) { + return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range), + std::forward(C)); +} +template auto max_element(R &&Range) { + return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range)); +} +template +auto max_element(R &&Range, Compare &&C) { + return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range), + std::forward(C)); +} + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/iluvatar/include/triton/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..372a9ec11 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton) +add_public_tablegen_target(TritonTransformsIncGen) diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/Transforms/Passes.h b/third_party/iluvatar/include/triton/Dialect/Triton/Transforms/Passes.h new file mode 100644 index 000000000..fde54fe17 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/Transforms/Passes.h @@ -0,0 +1,21 @@ +#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr createCombineOpsPass(); + +std::unique_ptr createReorderBroadcastPass(); +std::unique_ptr createRewriteTensorPointerPass(); + +} // namespace triton + +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +} // namespace mlir + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/Triton/Transforms/Passes.td b/third_party/iluvatar/include/triton/Dialect/Triton/Transforms/Passes.td new file mode 100644 index 000000000..4ebff63fa --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/Triton/Transforms/Passes.td @@ -0,0 +1,44 @@ +#ifndef TRITON_PASSES +#define TRITON_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonCombineOps : Pass { + let summary = "combine ops"; + let description = [{ + dot(a, b, 0) + c => dot(a, b, c) + + addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1)) + + select(cond, load(ptrs, broadcast(cond), ???), other) => + load(ptrs, broadcast(cond), other) + }]; + + let constructor = "mlir::triton::createCombineOpsPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect"]; +} + +def TritonReorderBroadcast : Pass { + let summary = "Moves broadcast and splat after elementwise operations"; + let description = [{ + elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + }]; + let constructor = "mlir::triton::createReorderBroadcastPass()"; + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonRewriteTensorPointer : Pass { + let summary = "Rewrite load/stores with tensor pointers into legacy load/stores"; + let description = [{ + This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy + semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute + the pointer/mask/other for each load/store. + }]; + + let constructor = "mlir::triton::createRewriteTensorPointerPass()"; + + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/CMakeLists.txt b/third_party/iluvatar/include/triton/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/Attributes.h b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/Attributes.h new file mode 100644 index 000000000..a99ddfc17 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/Attributes.h @@ -0,0 +1,10 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ + +#include "mlir/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" + +#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..73c9401c1 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,21 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu) +add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) +mlir_tablegen(TritonGPUAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(TritonGPUAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonGPUAttrDefsIncGen) diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/Dialect.h b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/Dialect.h new file mode 100644 index 000000000..fb8b93a9f --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -0,0 +1,132 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonGPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace gpu { + +struct SharedMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +unsigned getTotalElemsPerThread(Type type); + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, + Type eltTy); + +SmallVector getElemsPerThread(Type type); + +// Returns the number of threads per warp that may have access to replicated +// elements. If you want non-replicated threads, use +// getThreadsPerWarpWithUniqueData. +SmallVector getThreadsPerWarp(Attribute layout); + +unsigned getWarpSize(Attribute layout); + +// Returns the number of warps per CTA that may have access to replicated +// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData. +SmallVector getWarpsPerCTA(Attribute layout); + +SmallVector getSizePerThread(Attribute layout); + +// Returns the number of contiguous elements that each thread +// has access to, on each dimension of the tensor. E.g. +// for a blocked layout with sizePerThread = [1, 4], returns [1, 4], +// regardless of the shape of the tensor. +SmallVector getContigPerThread(Attribute layout); + +// Returns the number of non-replicated contiguous elements that each thread +// has access to, on each dimension of the tensor. For a blocked layout +// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements +// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1, +// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be +// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4]. +SmallVector getUniqueContigPerThread(Attribute layout, + ArrayRef tensorShape); + +// Returns the number of threads per warp that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17 +// have access to the full tensor, whereas the other threads have access to +// replicated elements, so this function returns [2, 2]. +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape); + +// Returns the number of warps per CTA that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2], +// returns [1, 1], since the first warp has access to the full tensor, whereas +// the other warps have access to replicated elements. +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); + +SmallVector getWarpOrder(Attribute layout); + +SmallVector getOrder(Attribute layout); + +CTALayoutAttr getCTALayout(Attribute layout); + +SmallVector getCTAsPerCGA(Attribute layout); + +SmallVector getCTASplitNum(Attribute layout); + +SmallVector getCTAOrder(Attribute layout); + +/* The difference between ShapePerCTATile and ShapePerCTA: + * (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp * + * WarpsPerCTA in each dimension and is independent from the tensor shape. + * (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension. + * (3) In the implementation of emitIndices, ShapePerCTATile will + * be replicated or wrapped to fit ShapePerCTA. + */ +SmallVector +getShapePerCTATile(Attribute layout, + ArrayRef tensorShape = ArrayRef()); + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape); +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape); +SmallVector getShapePerCTA(Type type); + +unsigned getNumWarpsPerCTA(Attribute layout); + +unsigned getNumCTAs(Attribute layout); + +bool isaDistributedLayout(Attribute layout); + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding); + +// Return true if a view between the two types cannot be implemented as a no-op. +bool isExpensiveView(Type srcType, Type dstType); + +bool isMmaConvertLayout(Operation *op); + +bool isSliceMmaConvertLayout(Operation *op, bool srcNoWarpReduce, + bool dstNoWarpReduce); + +// Return a blocked encoding where the shape is distributed contiguously amongst +// the threads, warps, CTAs with 1 element per threads. +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs); + +} // namespace gpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h new file mode 100644 index 000000000..d4f274742 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -0,0 +1,37 @@ +// Conversions from TritonGPU layouts (e.g. BlockedEncodingAttr) to +// LinearLayout. + +#include + +#include "triton/Tools/LinearLayout.h" + +namespace mlir::triton::gpu { + +// - BlockedEncodingAttrs have the following input dimensions. +// +// "register": elements in one thread +// "lane": threads in a warp +// "warp": warps in a block/CTA +// "block": blocks in a cluster +// +// - An n-dimensional SharedEncodingAttr has the following input dimensions. +// +// "offset": the n'th element in the allocation, within a particular block +// "block": blocks in a cluster +// +// All layouts have the following output dimensions. +// +// "dimi" for i in 0..n-1: the location in the n'th logical dimension of the +// output tensor. These also are not reordered according to the layout's +// `order`. +// +// You can flatten the input or output dimensions into a single dimension using +// LinearLayout::flattenIns/Outs(). +// +// Returns std::nullopt if the given layout can't be converted to an LL. +// TODO(jlebar): Remove the std::optional once all layouts are supported. +// +std::optional toLinearLayout(ArrayRef shape, + Attribute layout); + +} // namespace mlir::triton::gpu diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td new file mode 100644 index 000000000..d224d053d --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -0,0 +1,1410 @@ +#ifndef TRITONGPU_ATTRDEFS +#define TRITONGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +//===----------------------------------------------------------------------===// +// TritonGPU Attribute Definitions +//===----------------------------------------------------------------------===// +def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let methods = [ + InterfaceMethod<"Return total element size per thread.", + "unsigned", + "getTotalElemsPerThread", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy)>, + + InterfaceMethod<"Return element size per thread in each dimension.", + "SmallVector", + "getElemsPerThread", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy)>, + ]; +} + +class TritonGPU_Attr traits = [], + Dialect dialect = TritonGPU_Dialect, + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + + let description = [{ +TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines +how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function +\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding +to the indices of the CUDA threads allowed to access some data at index $i$. + +For example, let us consider the layout function: +\mathcal{L}(0, 0) = {0, 4} +\mathcal{L}(0, 1) = {1, 5} +\mathcal{L}(1, 0) = {2, 6} +\mathcal{L}(1, 1) = {3, 7} + +Then, attaching $\mathcal{L} to a tensor $T$ would mean that: +- T[0,0] is owned by both cuda thread 0 and 4 +- T[0,1] is owned by both cuda thread 1 and 5 +- T[1,0] is owned by both cuda thread 2 and 6 +- T[1,1] is owned by both cuda thread 3 and 7 + +Right now, Triton implements two main classes of layouts: shared, and distributed. + }]; + let attrName = "triton.gpu." # attrMnemonic; + + code extraBaseClassDeclaration = [{ + unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const; + SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const; + ::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const; + }]; +} + +//===----------------------------------------------------------------------===// +// CTA Layout +//===----------------------------------------------------------------------===// + +def CTALayoutAttr : TritonGPU_Attr<"CTALayout", "cta_layout"> { + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$CTAsPerCGA, + ArrayRefParameter<"unsigned">:$CTASplitNum, + ArrayRefParameter<"unsigned">:$CTAOrder + ); + + let description = [{ +Describes how blocks are distributed among the cooperate thread arrays (aka +CTAs, aka thread blocks) in a cooperate thread group (aka CTG, aka thread group +cluster). CGAs were introduced in Hopper (sm90). + +The tensor is divided up into CTASplitNum pieces, which are distributed among +the CTAsPerCGA thread blocks. Each CTA processes a subtensor of shape +`tensor_shape / CTASplitNum`. + +Example 0: The tensor shape is [64, 128] and, there are two CTAs, each +processing half the tensor [64, 64]. Then CTAsPerCGA = [1, 2] and +CTASplitNum = [1, 2]. + +Example 1: The tensor shape is [64, 128] and, there are two CTAs, both +processing the complete tensor [64, 128]. This happens when multicast is +enabled. In this case, CTAsPerCTA = [1, 2] but CTASplitNum = [1, 1]. + +Example 2: Consider a matmul AxB=C, where A=[M,K], B=[K,N], C=[M,N]. The +CTAsPerCGA for A, B, C are the same, [SplitM, SplitN], but the CTASplitNum are +different. CTASplitNum_A = [SplitM, 1], which means multicast on dim1, +CTASplitNum_B = [1, SplitN], which means multicast on dim0, CTASplitNum_C = +[SplitM, SplitN] which means no multicast. + +Currently programs with multiple CTAs per CGA are an experimental feature in +Triton, not enabled by default. + +You can leave off the CTALayout properties in the textual IR and Triton will +fill in the "default" CTALayout of CTAsPerCGA = CTASplitNum = [1...1]. In +addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to +[n-1,...,0] (it doesn't matter in this case). + }]; + + // CTALayout::get canonicalizes CTAOrder to [n,n-1,...,0] if CTAsPerCGA is + // [1...1]. The CTAOrder doesn't matter in this case. + // + // This is a little weird because if you write textual IR with a one order and + // then print it back out, you might get a different order. But it seems this + // is the best way to canonicalize an attribute in MLIR. + let builders = [ + AttrBuilder<(ins "ArrayRef":$CTAsPerCGA, + "ArrayRef":$CTASplitNum, + "ArrayRef":$CTAOrder), [{ + if (llvm::all_of(CTAsPerCGA, [](unsigned x) { return x == 1; })) { + SmallVector order; + for (int i = CTAsPerCGA.size() - 1; i >= 0; --i) + order.push_back(i); + return $_get(context, CTAsPerCGA, CTASplitNum, order); + } + return $_get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + }]>, + ]; + + let extraClassDeclaration = [{ + SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const { + llvm::report_fatal_error( + "Unsupported getElemsPerThread in CTALayoutAttr."); + } + unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { + llvm::report_fatal_error( + "Unsupported getTotalElemsPerThread in CTALayoutAttr."); + } + + static CTALayoutAttr getDefault(MLIRContext *context, int rank) { + SmallVector CTAsPerCGA(rank, 1); + SmallVector CTASplitNum(rank, 1); + SmallVector CTAOrder; + for (int i = rank - 1; i >= 0; --i) + CTAOrder.push_back(i); + return get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + } + }]; + + let genVerifyDecl = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// Shared Layout Encoding +//===----------------------------------------------------------------------===// + +def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> { + let mnemonic = "shared"; + + let description = [{ +An encoding for tensors whose elements may be simultaneously accessed by +different cuda threads in the programs, via shared memory. In other words, +for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. + +In order to avoid shared memory bank conflicts, elements may be swizzled. +Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1]. + +1. Basic swizzling + + #shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // xor with 0 + [ 5, 4, 7, 6], // xor with 1 + [10, 11, 8, 9], // xor with 2 + [15, 14, 13, 12] // xor with 3 + +Here elements of row r are xor'ed with r (or more properly, in[r][c] -> +out[r][c^r]). + +2. Multiple rows per phase + + #shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14] + +Elements of row r are xor'ed with r/2. In other words, perPhase=2 +means that pairs of 2 rows get the same swizzling. + +3. Max-phase applied + + $shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 5, 4, 7, 6], // phase 1 (xor with 1) + [ 8, 9, 10, 11], // phase 0 + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // ... + [21, 20, 23, 22], + [24, 25, 26, 27], + [29, 28, 31, 30] + +Elements of row r are xor'ed with (r/2) % 2. In other words, maxPhase=m has the +effect of limiting the maximum value of the xor to m-1. + +4. Max-phase and per-phase + + #shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], // phase 0 + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // phase 0 + [20, 21, 22, 23], // phase 0 + [25, 24, 27, 26], // phase 1 + [29, 28, 31, 30]] // phase 1 + +Here the xor value (the "phase", I guess?) changes every perPhase rows, up to a +maximum value of maxPhase-1. In other words, elements of row r are xor'ed with +(r/2) % 2. + +5. Adding vec + + #shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3, 4, 5, 6, 7], + [10, 11, 8, 9, 14, 15, 12, 13], + [20, 21, 22, 23, 16, 17, 18, 19], + [30, 31, 28, 29, 26, 27, 24, 25] + +When vec=2, elements are swizzled in pairs of 2. In other words, the element at +(r,c) has value + + ((c / 2) ^ r) * 2 + (c % 2). + +For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case, +when the matrix is stored in shared memory, there will be an offset not +only in the stride dimension, but also in the leading dimension. For example, +a matrix of size 16x128 and data type I8 is stored in the shared memory with +64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64, +compared to 1*64 when the hasLeadingOffset is false. + }]; + + // swizzle info: vec, perPhase, maxPhase + // order: the fastest-changing axis first + let parameters = ( + ins + "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + ArrayRefParameter<"unsigned">:$order, + "CTALayoutAttr":$CTALayout, + "bool":$hasLeadingOffset, + "bool":$useTcu + ); + + let builders = [ + AttrBuilder<(ins "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout), [{ + bool hasLeadingOffset = false; // default value + bool useTcu; +#if defined(__ILUVATAR__) + useTcu = true; +#else + useTcu = false; +#endif + return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset, useTcu); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit), [{ + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); + }]>, + + // TODO(jlebar): This should not be an overload of + // SharedEncodingAttr::get(). It's misleading, because it does a bunch of + // nontrivial work based on the given dotOpEnc. + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ + + // ---- begin GFX908/GFX90A ---- + if (auto mfmaEnc = mlir::dyn_cast(dotOpEnc.getParent())) { + int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + if (needTrans) + kDimNum = 1 - kDimNum; + bool isKDimInner = (order[0] == kDimNum); + if (isKDimInner) { + const int numBanks = 32; + const int bankBitWidth = 32; + const int SIMDWidth = 16; + + // number of inner dimension rows per one pattern repeat + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + // vecSize is set to kWidth of the dotop layout + int vecSize = dotOpEnc.getKWidth(); + int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); + + // TODO (zhanglx): figure out better parameters for mfma4 + if (mfmaEnc.getMDim() == 4) + maxPhase = 4; + + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return get(context, 1, 1, 1, order, CTALayout); + } + } + + // ---- begin GFX11 ---- + if (mlir::isa(dotOpEnc.getParent())) { + if (dotOpEnc.getOpIdx() == 0) { + const int numBanks = 32; + const int bankBitWidth = 32; + + // number of inner dimension rows per one pattern repeat + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit; + int maxPhase = 16 / perPhase; + + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return get(context, 1, 1, 1, order, CTALayout); + } + } + + int opIdx = dotOpEnc.getOpIdx(); + // ---- begin Iluvatar ---- + if (auto mmaEnc = mlir::dyn_cast(dotOpEnc.getParent())) { + if (mmaEnc.isVolta()) { + // iluvatar not use swizzle, so use perPhase to store opIdx + return get(context, 1, opIdx, 1, order, CTALayout); + } + } + + // ---- begin Nvidia ---- + auto mmaEnc = mlir::dyn_cast(dotOpEnc.getParent()); + + if(!mmaEnc) + return get(context, 1, 1, 1, order, CTALayout); + + // int opIdx = dotOpEnc.getOpIdx(); + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + // number of rows per phase + + // index of the inner dimension in `order` + unsigned inner = (opIdx == 0) ? 0 : 1; + + // ---- begin Volta ---- + if (mmaEnc.isVolta()) { + int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8)); + perPhase = std::max(perPhase, 1); + bool is_row = order[0] != 0; + bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) : + is_row && (shapePerCTA[order[0]] <= 16); + int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) : + ((is_row && !is_vec4) ? 2 : 1); + int rep = 2 * pack_size; + int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; + int vec = 2 * rep; + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + // ---- begin Ampere ---- + if (mmaEnc.isAmpere()) { + int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); + perPhase = std::max(perPhase, 1); + std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; + int vecWidth = 32 / typeWidthInBit; + if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) { + perPhase = std::max(perPhase, 2 * vecWidth); + } + int rank = order.size(); + // --- handle A operand --- + if (opIdx == 0) { // compute swizzling for A operand + int m = (needTrans) ? matShape[2] : matShape[0]; + int k = (needTrans) ? matShape[0] : matShape[2]; + int vec = (order[0] == rank-1) ? k : m; + int mmaStride = (order[0] == rank-1) ? m : k; + int maxPhase = mmaStride / perPhase; + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + // --- handle B operand --- + if (opIdx == 1) { + // we compute vec and maxPhase m, n and k size of the mma + // instruction. when matmul operands is transposed, we should + // consider that to get m, n and k. + int n = needTrans ? matShape[2] : matShape[1]; + int k = needTrans ? matShape[1] : matShape[2]; + int vec = (order[0] == rank-1) ? n : k; + int mmaStride = (order[0] == rank-1) ? k : n; + int maxPhase = mmaStride / perPhase; + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + llvm_unreachable("invalid operand index"); + } + + // ---- begin version 3 ---- + if (mmaEnc.isHopper()) { + llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr" + " is Hopper has not been implemented yet"); + return $_get(context, 1, 1, 1, order, CTALayout, true, false); + } + + // ---- not implemented ---- + llvm_unreachable("unsupported swizzling for provided MMA version"); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool":$needTrans), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy), [{ + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth(); + int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1; + + // get proper shared memory swizzling mode from the contiguous dimension + // size of the origin blocked layout. + auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8; + if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) { + perPhase = 1; + maxPhase = 8; + } else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) { + perPhase = 2; + maxPhase = 4; + } else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) { + perPhase = 4; + maxPhase = 2; + } else { + llvm_unreachable("unsupported shared memory layout for MMAv3"); + } + + bool useTcu; +#if defined(__ILUVATAR__) + useTcu = true; +#else + useTcu = false; +#endif + return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true, useTcu); + }]> + ]; + + let extraClassDeclaration = extraBaseClassDeclaration; + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// Distributed Layout Encoding +//===----------------------------------------------------------------------===// +def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ +The Distributed encoding describes the layout L with the 4-level compute hierarchy on GPU. +It is abstracted from the top to the bottom as CTAs Per CGA->Warps Per CTA->Threads Per Warp->Values Per Thread. + +For CTAs Per CGA and Warps Per CTA level, the linear id is distributed contiguously with the shape and order. +For example, for a shape/order pair defines a distribution layout +shape = [4, 4] +order = [0, 1] // The fastest-changing axis first +-> +layout = [0 4 8 12] + [1 5 9 13] + [2 6 10 14] + [3 7 11 15] + +For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + }]; + + let methods = [ + // Interface for the meta information about the multiple thread hierarchy. + InterfaceMethod<"Get the shape of the CTAs per CGA.", + "SmallVector", + "getCTAsPerCGA">, + + InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", + "SmallVector", + "getCTAOrder">, + + InterfaceMethod<"Get the shape of the warps per CTA.", + "SmallVector", + "getWarpsPerCTA">, + + InterfaceMethod<"Get the order of the warps per CTA. The fastest-changing axis first", + "SmallVector", + "getWarpOrder">, + + InterfaceMethod<"Get the shape of the threads per warp", + "SmallVector", + "getThreadsPerWarp">, + + InterfaceMethod<"Get the order of the threads per warp. The fastest-changing axis first", + "SmallVector", + "getThreadOrder">, + + InterfaceMethod<"Get the shape of the values per thread.", + "SmallVector", + "getSizePerThread">, + + InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", + "SmallVector", + "getCTASplitNum">, + + InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA", + "SmallVector", + "getShapePerCTATile", + (ins "ArrayRef":$tensorShape)>, + + InterfaceMethod<"Gets the number of contiguous elements per thread.", + "SmallVector", + "getContigPerThread">, + ]; +} + +class DistributedEncoding traits = [], + Dialect dialect = TritonGPU_Dialect> + : TritonGPU_Attr { + + let description = [{ +Distributed encodings have a layout function L that is entirely characterized +by a d-dimensional tensor T. Note that L doesn't need to have the same shape +(or even the same rank) as the tensor it is encoding. + +The layout function \mathcal{L} of this layout is then defined, for an +index `i` \in Z^d, as follows: + +\mathcal{L}(T)[i_d] = L[(i_d + k_d*T.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*T.shape[d] < L.shape[d] + +Intuitively, when the tensor dim size T.shape[d] is larger than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "wrapped around" manner, with +each thread owning multiple values. + +OTOH, when the tensor dim size T.shape[d] is smaller than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "broadcasted" manner, with +each value owned by multiple threads. + +For example, for a tensor/layout pair +T = [x x x x x x x x] + [x x x x x x x x] +L = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] + +Then the data of T would be distributed as follow between the 16 CUDA threads: +L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, + {4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ] + }]; + + code extraDistributedDeclaration = extraBaseClassDeclaration # [{ + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + SmallVector getWarpsPerCTA() const; + SmallVector getWarpOrder() const; + SmallVector getThreadsPerWarp() const; + SmallVector getThreadOrder() const; + + SmallVector getSizePerThread() const; + SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; + }]; +} + +//===----------------------------------------------------------------------===// +// Blocked Layout Encoding +//===----------------------------------------------------------------------===// + +def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding", "blocked_encoding"> { + let mnemonic = "blocked"; + + let description = [{ +An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout +used to promote memory coalescing in LoadInst and StoreInst. +It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which +specify the amount of elements owned by each CUDA thread, warp and CTA respectively. + +Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and +4 CTAs (taking 2x2 for example) as follows: + +CTA [0,0] CTA [0,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +CTA [1,0] CTA [1,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {2, 2} + CTASplitNum = {2, 2} +}> +}]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$sizePerThread__, + ArrayRefParameter<"unsigned">:$threadsPerWarp__, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first + + // CTALayout is optional in the textual IR. If omitted, we infer it to be a + // single CTA (so CTAsPerCGA = [1,...,1], CTASplitNum = [1,...,1], + // CTAOrder=[n,n-1,...,0]). + "CTALayoutAttr":$CTALayout, + "unsigned":$loadType, + ArrayRefParameter<"unsigned">:$smeWarpsPerCTA + ); + let genVerifyDecl = 1; + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "CTALayoutAttr":$CTALayout), [{ + unsigned rank = sizePerThread.size(); + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + SmallVector smeWpt(rank); + SmallVector shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + unsigned remainingLanes = numThreadsPerWarp; + unsigned remainingThreads = numWarps * numThreadsPerWarp; + unsigned remainingWarps = numWarps; + unsigned prevLanes = 1; + unsigned prevWarps = 1; + + // starting from the contiguous dimension + for (unsigned d = 0; d < rank - 1; ++d) { + unsigned i = order[d]; + unsigned threadsPerCTA = std::clamp(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]); + threadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); + warpsPerCTA[i] = std::clamp(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps); + remainingWarps /= warpsPerCTA[i]; + remainingLanes /= threadsPerWarp[i]; + remainingThreads /= threadsPerCTA; + prevLanes *= threadsPerWarp[i]; + prevWarps *= warpsPerCTA[i]; + } + + // Expand the last dimension to fill the remaining lanes and warps + threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes; + warpsPerCTA[order[rank - 1]] = numWarps / prevWarps; + + return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout, 0, smeWpt); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "unsigned":$numCTAs), [{ + unsigned rank = sizePerThread.size(); + SmallVector CTAsPerCGA(rank); + SmallVector CTASplitNum(rank); + ArrayRef CTAOrder = order; + + unsigned remainingCTAs = numCTAs; + + // starting from the most strided dimension + for (int d = rank - 1; d >= 0; --d) { + unsigned i = order[d]; + CTAsPerCGA[i] = std::clamp(remainingCTAs, 1, shape[i] / sizePerThread[i]); + CTASplitNum[i] = CTAsPerCGA[i]; + remainingCTAs /= CTAsPerCGA[i]; + } + + CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level + + CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout); + }]>, + + #define GET_ILUVATAR_BLOKED_LAYOUT_BUILDER + #ifdef FLAGTREE_PLUGIN + include "TritonILUVATARGPUToLLVM/TritonGPUAttrDefs.td" + #else + AttrBuilder<(ins "unsigned":$loadType, + "unsigned":$numWarps, + "Type": $eltTy, + "ArrayRef":$shape, + "ArrayRef":$order, + "ArrayRef":$sizePerThread, + "ArrayRef":$threadsPerWarp, + "ArrayRef":$warpsPerCTA, + "unsigned":$numCTAs), [{ + SmallVector wpt({1, 1}); + auto impl = load_AttrBuilder_func("iluvatar", "blockedLayoutBuilder"); + CTALayoutAttr CTALayout = impl(loadType, numWarps, eltTy, + shape, order, sizePerThread, threadsPerWarp, warpsPerCTA, numCTAs, + wpt, context); + return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout, loadType, wpt); + }]> + #endif + + ]; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SliceEncodingAttr squeeze(int axis); + + SmallVector getContigPerThread() { + // Block encoding is dense stride layout. The elements per thread are contiguous. + return getSizePerThread(); + }; + }]; + + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// MMA Layout Encoding +//===----------------------------------------------------------------------===// +// TODO: MMAv1 and MMAv2 should be two instances of the same class +def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let methods = [ + + InterfaceMethod<"Return whether the layout support reduction op.", + "bool", + "supportReduction">, + + InterfaceMethod<"Return shape per CTA.", + "SmallVector", + "getShapePerCTATileForDotOperands", + (ins "ArrayRef":$tensorShape, + "unsigned":$opIdx)>, + + InterfaceMethod<"Return total element size per thread for dot operands.", + "unsigned", + "getTotalElemsPerThreadForOperands", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy, + "unsigned":$kWidth, + "unsigned":$opIdx)>, + + InterfaceMethod<"Return size per thread for dot operands.", + "SmallVector", + "getSizePerThreadForOperands", + (ins "unsigned":$opIdx)>, + ]; +} + +def AMDMfmaEncodingAttr : DistributedEncoding<"AMDMfmaEncoding", "amd_mfma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_mfma"; + + let description = [{ +An encoding for tensors that have been produced by MFMA matrix core instructions, +available on AMD Instinct GPUs of CDNA architectures. + +It is characterized by the following parameters: +- `versionMajor` and `versionMinor` indicates the GPU architecture: + - 1.0: gfx908, i.e. MI100 + - 2.0: gfx90a: i.e. MI200, MI210, MI250 + - 3.0: gfx940, gfx941, gfx942: MI300 +- `warpsPerCTA` indicates the wave layout in the workgroup. +- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction. +- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout +without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel). + +Example 1: +Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32. +The data will be distributed between threads as follows: + + wave 0 wave 1 +-----------------/\-------------- -----------------/\-------------- +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] + +Example 2: +Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and MDim=NDim=16. +The data will be distributed between threads as follows: + + wave 0 wave 1 +-----------------/\------------- ------------------/\--------------- +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] + +Example 3: +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4. +The data will be distributed between threads as follows(note that each element is duploicated in 16 threads): +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4. +The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): + +M N -> wave 0 wave 2 +| --------------------------/\-------------------------- ------------------------------/\------------------------------ +V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + wave 1 wave 3 + --------------------------/\-------------------------- ------------------------------/\------------------------------ + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] +}]; + + let parameters = ( + ins + "unsigned": $versionMajor, + "unsigned": $versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "unsigned":$MDim, + "unsigned":$NDim, + "bool":$isTransposed, + "CTALayoutAttr":$CTALayout + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool supportReduction() const { + return true; + } + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getMFMAInstrShapeForOperands(int kWidth, int opIdx) const; + SmallVector getMFMARepForOperands(ArrayRef operandShape, int kWidth, int opIdx) const; + + SmallVector getContigPerThread() { + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + if (getIsTransposed()) + contigPerThread[rank - 1] = 4; + else + contigPerThread[rank - 2] = 4; + return contigPerThread; + }; + + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + +def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_wmma"; + + let description = [{ +An important limitation of WMMA for layout is a shape for tiles proccessed +by a single wave. It is [16, 16]. +This encoding assumes specific access to matrix elements by threads. + +Example: +Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3]. + + wave 0 [16, 16] wave 1 [16, 16] wave 2 [16, 16] +-----------/\---------- -----------/\---------- -----------/\---------- +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +... ... ... +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + + wave 3 [16, 16] wave 4 [16, 16] wave 5 [16, 16] +-----------/\---------- -----------/\---------- -----------/\---------- +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +... ... ... +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + }]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "CTALayoutAttr":$CTALayout + ); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool supportReduction() const { + return true; + } + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getWMMAElemsPerInstrForOperands() const; + SmallVector getWMMARepForOperands(ArrayRef operandShape, + Type elemType, int kWidth, int opIdx) const; + static SmallVector getMNKDimPerWMMAInstr(); + + SmallVector getContigPerThread() { + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + return contigPerThread; + }; + }]; +} + +def NvidiaMmaEncodingAttr : DistributedEncoding<"NvidiaMmaEncoding", "nvidia_mma_encoding", [MmaEncodingTrait]> { + let mnemonic = "nvidia_mma"; + + let description = [{ +An encoding for tensors that have been produced by tensor cores. + +It is characterized by two parameters: +- A 'versionMajor' which specifies the generation the tensor cores + whose output is being partitioned: + - 1 for first-gen tensor cores (Volta), and + - 2 for second-gen tensor cores (Turing/Ampere). +- A 'versionMinor' which indicates the specific layout of a tensor core + generation, e.g. for Volta, there might be multiple kinds of layouts + annotated by 0,1,2 and so on. +- A `blockTileSize` to indicate how data should be partitioned between warps. + +// -------------------------------- version = 1 --------------------------- // + +For first-gen tensor cores, the implicit warpTileSize is [16, 16]. +Note: the layout is different from the recommended in PTX ISA +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.884 section, FP32 accumulator). + +For example, when versionMinor=1, the matrix L corresponding to +blockTileSize=[32,16] is: + + warp 0 +--------------------------------/\------------------------------- +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] + + warp 1 = warp0 + 32 +--------------------------------/\------------------------------- +[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ] +[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ] +[ ............................................................... ] + + +// -------------------------------- version = 2 --------------------------- // + +For second-gen tensor cores, the implicit warpTileSize is [16, 8]. +Information about this layout can be found in the official PTX documentation +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.16816 section, FP32 accumulator). + +For example, the matrix L corresponding to blockTileSize=[32,16] is: + warp 0 warp 2 +-----------------/\------------- ----------------/\------------- +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 + + warp 1 warp 3 +----------------/\------------- ----------------/\------------- +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 + +}]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "CTALayoutAttr":$CTALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + let builders = [ + // Specially for MMAV1(Volta) + AttrBuilder<(ins "int":$versionMajor, + "int":$numWarps, + "CTALayoutAttr":$CTALayout, + "ArrayRef":$instrShape, + "ArrayRef":$shapeC, + "bool":$isARow, + "bool":$isBRow, + "bool":$isAVec4, + "bool":$isBVec4, + "int":$id), [{ + assert(versionMajor == 1 && "This builder is specially for versionMajor==1"); + // 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4] + int versionMinor = (isARow * (1<<0)) |\ + (isBRow * (1<<1)) |\ + (isAVec4 * (1<<2)) |\ + (isBVec4 * (1<<3)); + + // TODO: Share code with + // DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the + // rep,spw and fpw. + SmallVector wpt({1, 1}); + SmallVector wpt_nm1; + + SmallVector rep(2), spw(2); + std::array fpw{{2, 2, 1}}; + int packSize0 = (isARow || isAVec4) ? 1 : 2; + rep[0] = 2 * packSize0; + spw[0] = fpw[0] * 4 * rep[0]; + + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; + rep[1] = 2 * packSize1; + spw[1] = fpw[1] * 4 * rep[1]; + + do { + wpt_nm1 = wpt; + if (wpt[0] * wpt[1] < numWarps) + wpt[0] = std::clamp(wpt[0] * 2, 1, shapeC[0] / spw[0]); + if (wpt[0] * wpt[1] < numWarps) + wpt[1] = std::clamp(wpt[1] * 2, 1, shapeC[1] / spw[1]); + } while (wpt_nm1 != wpt); + + return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape); + }]>, + + + AttrBuilder<(ins "int":$versionMajor, + "int":$numWarps, + "CTALayoutAttr":$CTALayout, + "ArrayRef":$instrShape, + "ArrayRef":$shapeA, + "ArrayRef":$shapeB, + "ArrayRef":$shapeC, + "bool":$isARow, + "bool":$isBRow, + "int":$id), [{ + assert(versionMajor == 1 && "This builder is specially for versionMajor==1"); + bool isAVec4 = !isARow && (shapeA[isARow] <= 16); + bool isBVec4 = isBRow && (shapeB[isBRow] <= 16); + return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id); + }]> + ]; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool isVolta() const; + bool isTuring() const; + bool isAmpere() const; + bool isHopper() const; + + unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef shape) const; + + // Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor + std::tuple decodeVoltaLayoutStates() const; + + // Number of bits in versionMinor to hold the ID of the MMA encoding instance. + // Here 5 bits can hold 32 IDs in a single module. + static constexpr int numBitsToHoldMmaV1ID{5}; + + // For MMA v1, method `getMMAv1IsRow` returns whether e.g. the a operand is used + // in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation + // section 9.7.13.4.1 for more details. + bool getMMAv1IsRow(int opIdx) const; + bool getMMAv1IsVec4(int opIdx) const; + int getMMAv1NumOuter(ArrayRef shape, int opIdx) const; + SmallVector getMMAv1Rep(int opIdx) const; + SmallVector getMMAv1ShapePerWarp(int opIdx) const; + int getMMAv1Vec(int opIdx) const; + SmallVector getMMAv2Rep(ArrayRef shape, + int bitwidth, int opIdx) const; + + bool supportReduction() const { + if (isAmpere() || isHopper()) { + return true; + } + return false; + }; + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + + SmallVector getContigPerThread() { + assert(isVolta() || isAmpere() || isHopper()); + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + contigPerThread[rank - 1] = 2; + return contigPerThread; + }; + + }]; + + let hasCustomAssemblyFormat = 1; +} + +def IluvatarMmaEncodingAttr : DistributedEncoding<"IluvatarMmaEncoding", "iluvatar_mma_encoding", [MmaEncodingTrait]> { + let mnemonic = "iluvatar_mma"; + + let description = [{ +An encoding for tensors that have been produced by tensor cores. + +}]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "CTALayoutAttr":$CTALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + // let builders = []; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool isVolta() const; + unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef shape) const; + + bool supportReduction() const { + return true; + }; + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + + SmallVector getTCUShapePerWarp(int bitwidth) const; + SmallVector getTCUShapePerCTA(int bitwidth) const; + SmallVector getTCURep(ArrayRef shape, int bitwidth, int opIdx) const; + + SmallVector getContigPerThread() { + assert(isVolta()); + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + if (isVolta() && getVersionMinor() == 1) { + contigPerThread[rank - 1] = 2; + } + return contigPerThread; + }; + }]; + + let hasCustomAssemblyFormat = 1; +} + +def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> { + let mnemonic = "slice"; + + let description = [{ + Given a `parent` layout and a `dim`, squeezes the given `dim` in the `parent` + layout and distributes values in a tensor T according to the new layout. + + For example, given + + T = [x x x x x x x x] + L_parent = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] (with 16 CUDA threads) + + With dim = 0, squeezing out dim 0, we have + L = [{0,4,8,12}, {1,5,9,13}, {2,6,10,14}, {3,7,11,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L(T) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ] + + With dim = 1, squeezing out dim 1, we have + L = [ {0,1,2,3}, {4,5,6,7}, {8,9,10,11}, {12,13,14,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L = [ {0,1,2,3}, {4,5,6,7}, ..., {12,13,14,15}, {0,1,2,3}, ..., {12,13,14,15} ] + + This is useful for constructing the inverse layout of an expand_dims operation + during some optimization passes. + }]; + + let parameters = ( + ins + "unsigned":$dim, + // TODO: constraint here to only take distributed encodings + "Attribute":$parent, + DefaultValuedParameter<"bool", "false">:$noWarpReduce + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + template + SmallVector paddedShape(ArrayRef shape) const; + + SmallVector getContigPerThread() { + auto parentLayout = mlir::cast(getParent()); + auto parentContigPerThread = parentLayout.getContigPerThread(); + parentContigPerThread.erase(parentContigPerThread.begin() + getDim()); + return parentContigPerThread; + }; + }]; + + let hasCustomAssemblyFormat = 1; +} + +def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> { + let mnemonic = "dot_op"; + + let description = [{ +In the TritonGPU dialect, given `d = tt.dot a, b, c` tt.dot's operands a and b +must be of DotOperandEncodingAttr layout, if the dot is MMA v1 or v2 (i.e. +pre-Hopper). For MMA v3, the operands are *almost always* in a regular shared +encoding, but sometimes the LHS is also a dot-operand encoding. + +a's opIdx is 0, b's opIdx is 1. + +The parent field is the layout of d. + +kWidth defines number of consecutive elements stored by one thread along k dimension. +Some layouts do not use this parameter, either because they have a fixed number of +elements along the K dim, or they use all elements of the tensor along the K dim. + }]; + + let parameters = ( + ins + "unsigned":$opIdx, + "Attribute":$parent, + DefaultValuedParameter<"unsigned", "0">:$kWidth, + "unsigned":$useSme + ); + + let builders = [ + // Specially for MMAV1(Volta) + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "Type":$eltTy), [{ + NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); + if (!parentAttr || !parentAttr.isAmpere()) + return $_get(context, opIdx, parent, 0, 0); + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + unsigned MMAv2kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, MMAv2kWidth, 0); + }]>, + + // Specially for MR/BI150 + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "Type":$eltTy, + "unsigned":$useSme), [{ + IluvatarMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + unsigned kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, kWidth, useSme); + }]> + ]; + + let assemblyFormat = "`<` `{` struct(params) `}` `>`"; + let genVerifyDecl = 1; + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getContigPerThread() { + return getSizePerThread(); + }; + }]; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefsPlugin.h b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefsPlugin.h new file mode 100644 index 000000000..3714de635 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefsPlugin.h @@ -0,0 +1,15 @@ +#ifndef GET_ILUVATAR_BLOKED_LAYOUT_BUILDER_PLUGIN_H +#define GET_ILUVATAR_BLOKED_LAYOUT_BUILDER_PLUGIN_H + +#include "mlir/Support/LLVM.h" +#include "python/src/plugin.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" + +using AttrBuilderFunc = mlir::triton::gpu::CTALayoutAttr (*)( + unsigned, unsigned, mlir::Type, llvm::ArrayRef, + llvm::ArrayRef, llvm::ArrayRef, + llvm::ArrayRef, llvm::ArrayRef, unsigned, + llvm::SmallVector &, mlir::MLIRContext *); +DEFINE_LOAD_FUNC(AttrBuilder) + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td new file mode 100644 index 000000000..10f2c8c68 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -0,0 +1,54 @@ +#ifndef TRITONGPU_DIALECT +#define TRITONGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonGPU_Dialect : Dialect { + let name = "triton_gpu"; + + let cppNamespace = "::mlir::triton::gpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton GPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "mlir::gpu::GPUDialect", + "tensor::TensorDialect", + ]; + + let extraClassDeclaration = [{ + static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static int getNumWarps(ModuleOp mod) { + if (!mod->hasAttr("triton_gpu.num-warps")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-warps attribute"); + return cast(mod->getAttr("triton_gpu.num-warps")).getInt(); + } + static int getNumCTAs(ModuleOp mod) { + if (!mod->hasAttr("triton_gpu.num-ctas")) + return 1; + return cast(mod->getAttr("triton_gpu.num-ctas")).getInt(); + } + void registerTypes(); + + static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; } + + static int getThreadsPerWarp(ModuleOp mod) { + Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp"); + if(!threadsPerWarp) { + return 32; + } + return cast(threadsPerWarp).getInt(); + } + }]; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h new file mode 100644 index 000000000..0ee2cfeca --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h @@ -0,0 +1,6 @@ +#ifndef TRITON_GPU_DIALECT_INTERFACES_H +#define TRITON_GPU_DIALECT_INTERFACES_H + +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc" + +#endif // TRITON_GPU_DIALECT_INTERFACES_H diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td new file mode 100644 index 000000000..cd83c70af --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -0,0 +1,239 @@ +#ifndef TRITONGPU_OPS +#define TRITONGPU_OPS + +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +class TTG_Op traits = []> : + Op { +} + +def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", + [SameOperandsAndResultShape, + SameOperandsAndResultElementType, + Pure]> { + let summary = "convert layout"; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { + let summary = "async wait"; + + let arguments = (ins Variadic:$asyncToken, I32Attr:$num); + + let results = (outs TTG_AsyncToken:$retToken); + + let assemblyFormat = "$asyncToken attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { + let summary = "async commit group"; + + let results = (outs TTG_AsyncToken:$asyncToken); + let arguments = (ins Variadic:$inputTokens); + + let assemblyFormat = [{ + $inputTokens attr-dict + }]; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + TypesMatchWith<"infer mask type from src type", + "src", "mask", "getI1SameShape($_self)", + // "($_op.getOperands().size() <= 3) || std::equal_to<>()">, + "($_op.getOperands().size() == 6) || ($_op.getOperands().size() <= 3) || std::equal_to<>()">, + TypesMatchWith<"infer other type from src type", + "src", "other", "getPointeeType($_self)", + // "($_op.getOperands().size() <= 4) || std::equal_to<>()"> + "($_op.getOperands().size() != 4) || std::equal_to<>()"> +]> { + let summary = "copy data from global memory to local memory asynchronously"; + + let description = [{ + This operation copies data from global memory to local memory asynchronously. + This is analogue to tt.load except the data are copied to local memory pointed + by by the memory descriptor instread of a distributed tensor. The rest of the + operands are the same as tt.load. + }]; + + let arguments = ( + ins TT_Tensor:$src, + TT_MemDescType:$result, + Optional:$mask, + Optional:$other, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile, + Optional:$inputStride, Optional:$placeHolder0, Optional:$placeHolder1 + ); + + let builders = [ + OpBuilder<(ins "Value":$src, "Value":$result, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + ]; + + let results = (outs TTG_AsyncToken:$token); + + let extraClassDeclaration = [{ + static DenseSet getEligibleLoadByteWidth(int computeCapability) { + DenseSet validLoadBytes; + if (computeCapability >= 80) { + validLoadBytes = {4, 8, 16}; + } + return validLoadBytes; + } + }]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between other, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + // let assemblyFormat = [{ + // $src `,` $result (`mask` $mask^)? (`other` $other^)? + // oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + // attr-dict `:` type($src) `->` type($result) + // }]; + let hasCustomAssemblyFormat = 1; +} + + +// Allocate shared memory +def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods]> { + let summary = "allocate tensor"; + let description = [{ + This operation allocates buffer in shared memory and return a descriptor + containing the address and a view of the buffer. + + Explicitly deallocating a buffer is optional; see local_dealloc. + }]; + let arguments = (ins Optional:$src); + + let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}]; + + let results = (outs TT_MemDescType:$result); +} + +// Deallocate shared memory +def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree]>]> { + let summary = "dealloc buffer"; + + let description = [{ + This operation deallocates a buffer explicitly. Using the buffer after this + operation is undefined. + + This operation is optional. If you don't explicitly dealloc a buffer, the + compiler assumes it's deallocated at the first point that post-dominates all + uses of the alloc. + + Because we assume a memdesc is dead at the first point that post-dominates + its uses, ops that wait for an async operation on a memdesc to complete + (such as triton_nvidia_gpu.dot_wait) should also take the memdesc as an + operand. + }]; + + let arguments = (ins TT_MemDescType:$src); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}]; +} + +def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> { + let summary = "take a subview of the descriptor."; + + let description = [{ + This operation returns a new descriptor representing a subview of the buffer. + It doesn't affect the underlying memory. The subview can be rank-reduced. + + For example, suppose that + - the input shape is 2x4x16xf16, + - the output shape is 4x4xf16, and + - offsets = [1, 0, 4]. + + Then in Python syntax, the subview covers input[1][0:4][4:8]. + }]; + let arguments = ( + ins TT_MemDescType:$src, Variadic:$offsets); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}]; + + let results = (outs TT_MemDescType:$result); + + let hasVerifier = 1; +} + +def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods]> { + let summary = "Load a buffer from local memory into a distributed tensor"; + + let description = [{ + Load a tensor from the local memory descriptor into a distributed tensor. + }]; + let arguments = (ins TT_MemDescType:$src, Optional :$token); + + let builders = [ + OpBuilder<(ins "Type":$retType, "Value":$src), + [{ + build($_builder, $_state, retType, src, /*token=*/static_cast(nullptr)); + }]>]; + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; + + let results = (outs TT_Tensor:$result); +} + +def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods]> { + let summary = "Store a distributed tensor into a buffer in local memory"; + + let description = [{ + Store a distributed tensor into a buffer in local memory. + }]; + let arguments = (ins TT_Tensor:$src, TT_MemDescType:$dst); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{ + $src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst)) + }]; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td new file mode 100644 index 000000000..6765ac40c --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td @@ -0,0 +1,36 @@ +#ifndef TRITONGPU_TYPES +#define TRITONGPU_TYPES + +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class TTG_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTG_TokenType : TTG_TypeDef<"Token", "token"> { + let parameters = (ins "int32_t":$type); + + let builders = [ + TypeBuilder<(ins "unsigned":$type), [{ + return $_get($_ctxt, type); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +def TTG_AsyncToken : TTG_TypeDef<"AsyncToken", + "async.token", []> { + let summary = "async token type"; + let description = [{ + `ttg.async.token` is a type returned by an asynchronous operation. + It is used to establish an SSA-based link between async operations + and operations that group or synchronize the async operations. + }]; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/Types.h b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/Types.h new file mode 100644 index 000000000..edf37fef6 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/IR/Types.h @@ -0,0 +1,10 @@ +#ifndef TRITONGPU_IR_TYPES_H_ +#define TRITONGPU_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..6be94d1a8 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU) +add_public_tablegen_target(TritonGPUTransformsIncGen) diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/Passes.h new file mode 100644 index 000000000..e5605a791 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -0,0 +1,22 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace gpu { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +} // namespace gpu +} // namespace triton +} // namespace mlir +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/Passes.td new file mode 100644 index 000000000..e70c99ee4 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -0,0 +1,145 @@ +#ifndef TRITONGPU_PASSES +#define TRITONGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { + let summary = "pipeline"; + + let description = [{ + Applies software pipelining to loops in the module based on number of stages. + This may convert some load into asynchronous loads, and multi-buffer the data. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} + +def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { + let summary = "3xTF32 trick"; + + let description = [{ + Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s + to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385 + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + ]; +} + +def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { + let summary = "prefetch"; + + let description = [{ + Decompose `DotOp` instructions in loops into several finer-grained `DotOp` + that may have their operands constructed at the end of the previous iteration + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + +def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> { + let summary = "accelerate matmul"; + + let description = [{ + Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators + (e.g., Nvidia tensor cores) + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> { + let summary = "fuse transpositions"; + + let description = [{ + Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of + hardware-accelerated transpositions. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + Option<"hoistLayoutConversion", "hoist-layout-conversion", + "bool", /*default*/"true", + "whether to move conver to dot operand earlier pass elementwise ops"> + ]; +} + +def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { + let summary = "coalesce"; + + let description = [{ + TODO + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + + +def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> { + let summary = "remove superfluous layout conversions"; + + let description = [{ + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + +} + +def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality", "mlir::ModuleOp"> { + let summary = "Reduce the cost of synchronization between threads in an SM"; + + let description = [{ + Today, this optimizes reduction yielded by loop to be thread-local until after the loop completes. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> { + let summary = "Reorder instructions"; + + let description = "This pass reorder instructions so as to (1) decrease register pressure (e.g., by moving " + "conversions from shared memory before their first use) and (2) promote LLVM instruction " + "order more friendly to `ptxas`."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReduceDataDuplication: Pass<"tritongpu-reduce-data-duplication", "mlir::ModuleOp"> { + let summary = "Reduce data duplication in register by decomposing convert[distributed -> dotOperand] " + "into convert[distributed -> shared -> dotOperand]"; + + let description = "Decomposing conversions this way makes it possible to use CSE and reuse #shared tensors"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and-if", "mlir::ModuleOp"> { + let summary = "Combine tensor select and if"; + + let description = "For select instruction that uses the same condidtion as the if instruction in the same block " + "this pass combines the select into the if instruction, making the select operands returned by the " + "then/else yields."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h new file mode 100644 index 000000000..fbfa235fc --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while converting to the TritonGPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class TritonGPUTypeConverter : public TypeConverter { +public: + TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp, + int numCTAs); + int getNumWarps() const { return numWarps; } + int getThreadsPerWarp() const { return threadsPerWarp; } + int getNumCTAs() const { return numCTAs; } + +private: + MLIRContext *context; + int numWarps; + int threadsPerWarp; + int numCTAs; +}; + +class TritonGPUConversionTarget : public ConversionTarget { + +public: + explicit TritonGPUConversionTarget(MLIRContext &ctx, + TritonGPUTypeConverter &typeConverter); +}; + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ diff --git a/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/Utility.h new file mode 100644 index 000000000..114c18142 --- /dev/null +++ b/third_party/iluvatar/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -0,0 +1,177 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include + +namespace mlir { + +namespace triton { +class ModuleAxisInfoAnalysis; +class LoadOp; +class StoreOp; +class FuncOp; +namespace gpu { +class SharedEncodingAttr; +} +} // namespace triton + +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + TensorOrMemDesc type, + int numWarps); + +/// Returns true if the Load uses block pointer. +bool isLoadFromTensorPtr(triton::LoadOp op); + +// Return an array of indices enumerating the elements of 'arr' in descending +// order (so that result[i] is the index of the i-th largest element of 'arr') +SmallVector argSort(const SmallVector &arr); + +// Return the operand used to access the memory in the operation +Value getMemAccessPtr(Operation *op); + +// Return bitwidth of tensor element +unsigned getElementBitWidth(RankedTensorType type); + +// Calculate the optimal number of elements per thread for a given operation +// along an axis with greatest continuity. +unsigned +getNumElementsPerThread(Operation *op, SmallVector order, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis); + +/* Dump Triton IR in graphviz dot format. + * + * You can override `onValue` and `onOperation` in a subclass to mark + * specific Values and Operations. The below subclass + * GraphLayoutMarker is an example. + * + * Default NodeInfo for Value nodes: + * {{"shape": "box"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", shapeStr}} + * + * Default NodeInfo for Operation nodes: + * {{"shape": "ellipse"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", operationName}} + * + * If the key "label" is not set by `onValue` or `onOperation`, default labels + * will be generated. For Value node, the default label is the shape string and + * for Operation node, it is the operation name. + * + * Reference: + * https://graphviz.org/doc/info/shapes.html + * https://graphviz.org/doc/info/colors.html + * + * Usage: + * C++: GraphDumper().dumpToFile(func, "func.dot"); + * Shell: dot -Tjpg func.dot -o func.jpg + */ +class GraphDumper { +public: + using NodeInfo = std::map; + + // Override this function to mark specific Values + virtual NodeInfo onValue(Value value) const; + // Override this function to mark specific Operations + virtual NodeInfo onOperation(Operation *op) const; + + std::string dump(triton::FuncOp func) const; + void dumpToFile(triton::FuncOp func, const std::string &filename) const; + +protected: + std::string getShapeStr(const Type &type) const; + + std::string getUniqueId(Value value) const; + std::string getUniqueId(Operation *op) const; + + std::string emitNode(const std::string &id, const NodeInfo style) const; + std::string emitEdge(const std::string &srcId, + const std::string &destId) const; + + std::string emitValueNode(Value value) const; + std::string emitOperationNode(Operation *op) const; +}; + +/* A subclass of GraphDumper that marks different layout kinds in different + * colors.*/ +class GraphLayoutMarker : public GraphDumper { +public: + NodeInfo onValue(Value value) const override; + +protected: + std::string getColor(const Type &type) const; +}; + +// Infers the encoding of the result of op given the source encoding. +std::optional inferDstEncoding(Operation *op, Attribute encoding); + +// Infers the encoding of the source of op given the result encoding. +std::optional inferSrcEncoding(Operation *op, Attribute encoding); + +bool isExpensiveLoadOrStore(Operation *op); + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); + +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separately for the loop to be correct. +scf::ForOp replaceForOpWithNewSignature( + RewriterBase &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements); +scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, + ValueRange newIterOperands); + +// Replace IfOp with a new IfOp with extra results operands. The YieldOp is not +// updated and needs to be updated separately for the bodies to be correct. +scf::IfOp replaceIfOpWithNewSignature( + RewriterBase &rewriter, scf::IfOp loop, TypeRange newResultTypes, + SmallVectorImpl> &replacements); + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping); + +// Get backward slice of tensor values starting from the root node along with +// encoding propagation. +LogicalResult getConvertBackwardSlice( + Value root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation = nullptr); + +// Populate pattern to remove dead cycles in ForOp. +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(OpBuilder &b, Location loc, unsigned linear, + ArrayRef shape); + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape); +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape); + +// Return true if the op is a pure elementwise_inline_asm op with a single +// operand and single result. +bool isPureUnaryInlineAsm(Operation *op); + +// read the compute capability from the module attributes +int getNVIDIAComputeCapability(Operation *module); + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/third_party/iluvatar/include/triton/Target/CMakeLists.txt b/third_party/iluvatar/include/triton/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/iluvatar/include/triton/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/iluvatar/include/triton/Target/LLVMIR/CMakeLists.txt b/third_party/iluvatar/include/triton/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..1f6c1b351 --- /dev/null +++ b/third_party/iluvatar/include/triton/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVMIR) +add_public_tablegen_target(LLVMIRIncGen) diff --git a/third_party/iluvatar/include/triton/Target/LLVMIR/Passes.h b/third_party/iluvatar/include/triton/Target/LLVMIR/Passes.h new file mode 100644 index 000000000..27ecb5c3d --- /dev/null +++ b/third_party/iluvatar/include/triton/Target/LLVMIR/Passes.h @@ -0,0 +1,17 @@ +#ifndef TRITON_TARGET_LLVM_IR_PASSES_H +#define TRITON_TARGET_LLVM_IR_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Create a pass to add DIScope +std::unique_ptr createLLVMDIScopePass(); + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "triton/Target/LLVMIR/Passes.h.inc" + +} // namespace mlir + +#endif // TRITON_TARGET_LLVM_IR_PASSES_H diff --git a/third_party/iluvatar/include/triton/Target/LLVMIR/Passes.td b/third_party/iluvatar/include/triton/Target/LLVMIR/Passes.td new file mode 100644 index 000000000..999b0b889 --- /dev/null +++ b/third_party/iluvatar/include/triton/Target/LLVMIR/Passes.td @@ -0,0 +1,15 @@ +#ifndef TRITON_TARGET_LLVMIR_PASSES +#define TRITON_TARGET_LLVMIR_PASSES + +include "mlir/Pass/PassBase.td" + +def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> { + let summary = "Materialize LLVM line info"; + let description = [{ + This pass materializes line mapping information for LLVM IR dialect operations. + }]; + + let constructor = "mlir::createLLVMDIScopePass()"; +} + +#endif diff --git a/third_party/iluvatar/include/triton/Tools/LinearLayout.h b/third_party/iluvatar/include/triton/Tools/LinearLayout.h new file mode 100644 index 000000000..fb2680241 --- /dev/null +++ b/third_party/iluvatar/include/triton/Tools/LinearLayout.h @@ -0,0 +1,532 @@ +#ifndef TRITON_TOOLS_LINEARLAYOUT_H +#define TRITON_TOOLS_LINEARLAYOUT_H + +#include +#include +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::triton { + +// # High-level overview of linear layouts +// +// The idea for linear layouts is due to Adam P. Goucher. +// +// In Triton, a linear layout (LL) is a function that maps from a "hardware +// location" to a "logical tensor index". +// +// For example, suppose we have a 2D tensor T stored in GPU registers. T's +// layout is the function that, given a "hardware location" tuple of (thread-id, +// warp-id), returns an index (x,y) into T. In other words, if L(t,w) = (x,y) +// is our linear layout func, then a register in thread t in warp w contains the +// value T[x,y]. +// +// The key fact about LLs is, the mapping from (t,w) to (x,y) is not arbitrary. +// We only need to specify the value of L(t,w) at certain special points +// (namely, the values L(t,0) and L(0,w) where t and w are powers of 2), and +// from those we can compute all the other values of L. +// +// Here's an example LL where we have 4 warps and 4 threads per warp, and the +// tensor T has shape 4x4. We define the function L by choosing the values of +// L(0,1), L(0,2), L(1,0), and L(2,0). Our choices are shown below. +// +// t/w 0 1 2 3 +// 0 ? (0,1) (0,2) ? +// L(t,w) = 1 (1,1) ? ? ? +// 2 (2,2) ? ? ? +// 3 ? ? ? ? +// +// You only need to specify these four values to define the whole linear layout. +// These special values are called the "basis vectors" or "bases" of the layout. +// We complete the table by xor'ing together the bases, according to the +// following rule. (I write "⊕" for xor.) +// +// L(t1 ⊕ t2, w1 ⊕ w2) = L(t1, w1) ⊕ L(t2, w2) (linearity rule). +// +// The linearity rule plus our four choices allows us to fill in the whole +// table. Here's how we might compute some of the values. +// +// L(0,0) = L(1 ⊕ 1, 0 ⊕ 0) = L(1,0) ⊕ L(1,0) = (1,1) ⊕ (1,1) = (0,0) +// L(0,3) = L(0 ⊕ 0, 2 ⊕ 1) = L(0,2) ⊕ L(0,1) = (0,2) ⊕ (0,1) = (0,3) +// L(3,0) = L(2 ⊕ 1, 0 ⊕ 0) = L(2,0) ⊕ L(1,0) = (2,2) ⊕ (1,1) = (3,3) +// L(3,3) = L(3 ⊕ 0, 0 ⊕ 3) = L(3,0) ⊕ L(0,3) = (3,3) ⊕ (0,3) = (3,0). +// +// (Notice it's a consequence of the linearity rule that L(0,0) = (0,0), no +// matter what values we chose for the table.) +// +// The whole table looks like this. +// +// t/w 0 1 2 3 +// 0 (0,0) (0,1) (0,2) (0,3) +// L(t,w) = 1 (1,1) (1,0) (1,3) (1,2) +// 2 (2,2) (2,3) (2,0) (2,1) +// 3 (3,3) (3,2) (3,1) (3,0). +// +// Careful readers will recognize this as a classic "swizzled" layout where +// (t, w) -> (t, w ⊕ t). To go from this formula to an LL, you only need to +// compute the results at input points (0,1), (0,2), (1,0), and (2,0). + +// Indeed the whole point of LLs is that they allow us to specify transposed and +// swizzled layouts as a "general case". Instead of a layout class for +// registers in a thread, and another layout for registers in a thread but in +// MMAv2 order, and so on, all of these can be represented by different LLs. +// This gets rid of special cases and lets us write more general code. +// +// In this example, L was a 2D -> 2D function, but LLs are general MD -> ND +// functions. In practice, a GPU register layout usually has input dims (reg, +// thread-id, warp-id, block-id), where reg represents the fact that one thread +// may store values for the tensor in multiple registers. +// +// To summarize, a linear layout is a function from tuples of integers to tuples +// of integers. We specify some key values of the function, and then we can +// compute all the other values using the linearity rule. +// +// Here are the key things you can do with linear layout objects. +// +// 1. Given an LL, construct a new LL by modifying it or combining it with +// another LL. +// +// 2. "Apply" an LL, i.e. use it to map an input index to an output index. +// A function for this that uses LLVM-dialect MLIR as its input and output +// lives in TritonGPUToLLVM.h. +// +// 3. Convert an existing Triton layout (e.g. BlockedLayoutAttr) to an LL. +// These functions live in TritonGPU/LinearLayoutConversions.h. During +// TTGIR -> LLVM codegen, we convert Triton layouts to linear layouts and +// then apply them. In the future, we intend to remove the Triton layouts +// entirely. +// +// # Examples of linear layouts +// +// 1. The 1D identity layout. This maps L(x) = x. +// +// Recall that our bases are the values of L(x) where x is a power of two. +// So for e.g. an 8-element layout, we have L(1) = 1, L(2) = 2, L(4) = 4, and +// therefore our bases are [1, 2, 4]. +// +// 2. The 1D zeros layout. This maps L(x) = 0. +// +// For an 8-element layout, we have L(1) = L(2) = L(4) = 0, so our bases are +// [0, 0, 0]. +// +// 3. A 2D -> 2D identity layout. Our basis vectors are the values of L(x,0) +// and L(0,y) where x and y are powers of two. The bases are +// +// - L(0,1) = (0,1) +// - L(0,2) = (0,2) +// - L(1,0) = (1,0) +// - L(2,0) = (2,0). +// +// 4. A 2D -> 2D transpose layout. For a 4x4 layout, we have: +// +// - L(0,1) = (1,0) +// - L(0,2) = (2,0) +// - L(1,0) = (0,1) +// - L(2,0) = (0,2). +// +// 5. A 1D -> 1D "transpose" layout. Consider the 16-element layout that maps +// +// x = 0 1 2 3 4 5 6 7 8 9 A B C D E F +// L(x) = 0 4 8 C 1 5 9 D 2 6 A E 3 7 B F. +// +// The bases are [L(1), L(2), L(4), L(8)] = [4, 8, 1, 2]. You can also think +// of this as a rearrangement of the 1D identity layout [1, 2, 4, 8]. +// +// 6. A 2D -> 1D broadcasted layout. L(x,y) = x. For a 4x4 -> 4 layout, our +// bases are +// +// - L(0,1) = 0 +// - L(0,2) = 0 +// - L(1,0) = 1 +// - L(2,0) = 2. +// +// # Implementation notes +// +// ## Dimension order +// +// An LL's input and output dimensions have an order. This order only affects +// the reshapeIns/Outs operations, where the layout is logically flattened +// according to the dimension order and then chopped up again. +// +// ## Surjectivity +// +// We require that all output values are covered by some input value, i.e. the +// function L is surjective. But multiple input values can map to the same +// output value. This represents the idea that the same logical tensor element +// can be stored in multiple places in the hardware. +// +// ## Why map hardware loc -> tensor index and not the other way around? +// +// In Triton, a linear layout usually tells us which logical tensor value is +// stored at a particular place in the hardware. For example, an LL might map +// the tuple (thread-id, warp-id, block-id) to a 2D index into a tensor, (x,y), +// meaning that the register at (t,w,b) has value tensor[x,y]. Or it might map +// from a shared memory (offset, block) to a tensor index. +// +// It might seem more natural to go the other way around, from tensor index to +// place in the hardware. But a particular tensor[x,y] value might be stored in +// more than one place in the hardware, so if we went in this direction, the +// layout would no longer be a proper function. This would complicate +// everything else. +// +// # Optional mathematical background: Linear functions over GF(2) +// +// (You shouldn't need to understand this math to use linear layouts, but it +// helps with the implementation.) +// +// One way to define a linear function is to say it's any function F that can be +// written as +// +// L(a) = a1 * B1 + a2 * B2 + ... + aM * BM, +// +// where +// +// - a is a vector [a1...aM], and ai is a scalar in some field 𝔽 (for +// example, ai might be a real number), and +// - each Bj is a vector [b1j, b1j, ..., bNj] of N scalars in 𝔽. +// +// We can also write this as a matrix-vector product Ba, where +// +// - a is the column vector [a1, ..., aM] and +// +// - B is the matrix formed by concatenating the column vectors B1, ..., BM: +// +// | ↑ ↑ ↑ | +// B = | B1, B2, ..., BM| +// | ↓ ↓ ↓ | +// +// |b11, b12, ..., b1M| +// |b21, b22, ..., b2M| +// = | ↓ ↓ ↓ | +// |bN1, bN2, ..., bNM|. +// +// Usually when we do linear algebra, the field 𝔽 from which `ai` and `bij` are +// drawn is the real or complex numbers. But in linear layouts, we let 𝔽 be a +// different field: GF(2). +// +// GF(2) is the two-element field of bits. To define a field, I need to give +// you the set of elements and also addition and multiplication operations. For +// GF(2) the elements are simply {0,1}. We define addition as xor, and +// multiplication as binary `and`. +// +// Here's an example of a 4x4 matrix-vector multiply where the elements are in +// GF(2). I'm using ⊕ to represent GF(2)'s addition operation (i.e xor) and × +// to represent multiplication (i.e. binary `and`). +// +// | 1 0 0 0 | | 0 | | 1 | | 0 | | 0 | | 0 | +// | 0 1 1 0 | | 1 | = | 0 | × 0 ⊕ | 1 | × 1 ⊕ | 1 | × 1 ⊕ | 0 | × 0 +// | 0 0 1 1 | | 1 | | 0 | | 0 | | 1 | | 1 | +// | 0 0 1 1 | | 0 | | 0 | | 0 | | 1 | | 1 | +// +// | 0 | | 0 | +// = | 1 | ⊕ | 1 | +// | 0 | | 1 | +// | 0 | | 1 | +// +// | 0 | +// = | 0 |. +// | 1 | +// | 1 | +// +// This works, but it's cumbersome. It's more compact to think of the vector +// `a` as an M-bit integer, and each column Bi of the matrix B as an N-bit +// integer. Here's the same matrix-vector product written this way. +// +// = | 1 2 14 12 | × 6 +// = | 1 2 14 12 | × 0b0110 +// = (1 × 0) ⊕ (2 × 1) ⊕ (14 × 1) ⊕ (12 × 0) +// = 2 ⊕ 14 +// = 12. +// +// And we confirm that our answer of 12 is equal to the binary value 0b1100 we +// got before. +// +// Notice that the function F(a) is fully specified by the matrix B, and that +// the four columns of B tell us the values of F at power-of-two values for `a`, +// namely F(1), F(2), F(4), and F(8). In other words, we specify four results +// of F(x) (we call these the function's "basis vectors" or its "bases") and we +// can then compute any other value by xor'ing together subsets of the bases. +// +// In the case of a 1D -> 1D layout, the implementation of an LL is +// straightforward from the mathematical description. If the LL is +// higher-dimensional, we can "stack" the bit vectors to create 1D vectors. +// For example, if we have a 2D LL and we're given input tuple (0b0011, 0b1100), +// we can treat this like a 1D input 0b0011'1100 and then do the regular 1D LL +// computation. Similarly we can "unstack" the output from 1D to ND. +// +// The linearity rule presented earlier is perhaps misleading at this point. In +// the 1D view of things, we really only need +// +// L(x ⊕ y) = L(x) ⊕ L(y) (1D linearity rule), +// +// which is part of the definition of L being a linear function. The new 1D +// linearity rule plus stacking/unstacking is equivalent to the earlier +// N-dimensional linearity rule. +// +// That's all we need in order to define linear layouts mathematically! +// +// # Comaprison to Nvidia CuTe +// +// (Note, I'm not an expert on CuTe; this is my best understanding.) +// +// CuTe is a programmatic layout system that's part of Nvidia CUTLASS; see +// https://github.com/NVIDIA/cutlass/blob/629f465/media/docs/cute/00_quickstart.md +// +// LLs and CuTe solve similar problems. Before CuTe, CUTLASS v2 had many +// handcrafted layouts, "RowMajor", "VoltaTensorOpMultiplicandCongruous", etc, +// see https://www.youtube.com/watch?v=QLdUML5MCfE&t=574s. Each of these was a +// special case. CUTLASS v3 introduced CuTe layouts, which are programmable and +// subsume all of these special cases. The CUTLASS folks say this simplified +// CUTLASS, in the same way that we hope LLs will simplify Triton. +// +// Like CuTe layouts, LLs are also programmable and composible. But there are +// also some differences. +// +// - Dimensions in LLs are named; CuTe dimensions are numbered. +// - CuTe layouts can be nested; LLs cannot be. (Nesting doesn't give CuTe +// layouts additional power; any nested layout can be flattened.) +// - CuTe layouts support non-power-of-two shapes; LLs do not. In particular +// this means that LLs cannot represent padded layouts. +// - In CuTe, swizzling is a separate step applied after specifying a layout. +// In LLs, swizzling is part of the layout itself. +// - The structure of LLs allows us to programmatically search for layouts that +// satisfy certain requirements, for example a shared layout that doesn't +// have bank conflicts when read into a particular register layout. CuTe +// expects a human to choose the layout using their brain. +// - CuTe emits code that is in the critical path of your CPU and GPU programs, +// therefore it needs to be fast. It uses C++ template magic to specialize +// on known-sized dimensions, and so on. LLs themselves do not need to be +// fast; only the emitted `apply` code is on the critical path. +// - CuTe requires a CUDA compiler such as nvcc; LLs do not. +// +class LinearLayout { +private: + // bases[inDim][i] = L(0, ..., inDim=2^i, ..., 0). All other values of L are + // computed by xor'ing bases together, using the linearity rule. In addition: + // + // - Each inDim has the same set of outDims, in the same order. + // - The order of dims is minor-to-major, although this only affects reshape. + llvm::MapVector /*size=getNumOutDims()*/> + /*size=getInDimSizeLog2(inDim)*/> + bases; + + llvm::SetVector outDimNames; + +public: + using BasesT = decltype(bases); + + // The 0-dimensional layout that maps everything to 0. This is useful as a + // starting point when doing something like + // + // LinearLayout ret = LinearLayout::empty(); + // for (...) ret *= ...; + // return ret; + static LinearLayout empty() { return LinearLayout(BasesT{}, {}); } + + // Creates a 1D -> 1D layout that's the identity function, i.e. L(x) = x + // for x in [0, size). + static LinearLayout identity1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0 + // for x in [0, size). + static LinearLayout zeros1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a LinearLayout from a list of bases. These are interpreted + // according to the rules written for the member variable `bases`. + explicit LinearLayout(BasesT bases, ArrayRef outDimNames); + + // Construct a LinearLayout from an explicit list of bases. (This constructor + // is needed because llvm::MapVector does not have a constructor that accepts + // an initializer_list.) + // + // For example, given these bases + // + // L(in1=1, in2=0) = (out1=0, out2=1) + // L(in1=2, in2=0) = (out1=0, out2=2) + // L(in1=0, in2=1) = (out1=0, out2=4) + // L(in1=0, in2=2) = (out1=0, out2=8) + // L(in1=0, in2=4) = (out1=1, out2=1) + // + // we can use this constructor to build an equivalent LL: + // + // LinearLayout({ + // {"in1", {/*L(in1=1)=*/{0,1}, /*L(in1=2)=*/{0,2}}}, + // {"in2", {/*L(in2=1)=*/{0,4}, /*L(in2=2)=*/{0,8}, /*L(in2=4)=*/{1,1}}}, + // }, + // {"out1", "out2"}) + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames); + + const BasesT &getBases() const { return bases; } + + // Get the pos'th basis vector for the inDim -> outDim mapping. + // getBasis(inDim, pos) = L(0, ..., inDim = 2^pos, ..., 0). + ArrayRef getBasis(StringAttr inDim, int32_t pos) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + assert(pos < it->second.size()); + return it->second[pos]; + } + + int32_t getBasis(StringAttr inDim, int32_t pos, StringAttr outDim) const { + return getBasis(inDim, pos)[getOutDimIndex(outDim)]; + ; + } + + // These are in minor-to-major order, although if you don't flatten the dims + // (e.g. by reshaping) then the order doesn't really affect anything. + auto getInDimNames() const { return llvm::make_first_range(bases); } + ArrayRef getOutDimNames() const { + return outDimNames.getArrayRef(); + } + + // Gets the position that this outDim occupies in getOutDimNames(). Asserts + // if the dim is not present. + int32_t getOutDimIndex(StringAttr outDim) const; + + bool hasInDim(StringAttr inDim) const { return bases.contains(inDim); } + bool hasOutDim(StringAttr outDim) const { + return outDimNames.contains(outDim); + } + + int32_t getNumInDims() const { return bases.size(); } + int32_t getNumOutDims() const { return outDimNames.size(); } + + // Asserts if the dimension is not present. + int32_t getInDimSizeLog2(StringAttr inDim) const; + int32_t getInDimSize(StringAttr inDim) const { + return 1 << getInDimSizeLog2(inDim); + } + + // getOutDimSize(dim) == s means that there exists an input value that will + // produce each output value in [0,s). + // + // For example, if our bases are + // + // L(in0=1) = 1 + // L(in0=2) = 4 + // L(in1=1) = 2 + // L(in1=2) = 8 + // + // then the largest value we can produce is L(3,3) = 1 ⊕ 4 ⊕ 2 ⊕ 8 = 15 (and + // indeed we can produce all values in [0,16) by xor'ing subsets of the bases + // 1,2,4,8), so getOutDimSize(out_dim0) == 16. + // + // Asserts if the dimension is not present. + int32_t getOutDimSizeLog2(StringAttr outDim) const; + int32_t getOutDimSize(StringAttr outDim) const { + return 1 << getOutDimSizeLog2(outDim); + } + + // Reorders the in/out dimensions of the layout. This is mostly cosmetic + // (affecting e.g. the order of getIn/OutDimNames), but it also affects the + // behavior of reshape. + [[nodiscard]] LinearLayout + transposeIns(ArrayRef newInDimOrder) const; + [[nodiscard]] LinearLayout + transposeOuts(ArrayRef newOutDimOrder) const; + + // Creates a new layout which, roughly speaking, is equivalent to one where + // every element of the `outer` layout is replaced by a full instance of the + // `inner` layout. + // + // Examples: + // + // - empty() is the multiplicative identity: + // + // L * empty() == empty() * L == L. + // + // - Multiplying two identity1D layouts with disjoint in/out dimensions gives + // a 2D identity layout: + // + // identity1D(4, "i1", "o1") * identity1D(8, "i2", "o2") => + // L(i1,i2) = (i1,i2), + // + // with in-dims ("i1", "i2") and out-dims ("o1", "o2"), in that order. + // + // - If out-dims overlap, they are combined, as in the following examples. + // + // - identity1D(4, "i", "o") * identity1D(2, "i", "o") == + // identity1D(8, "i", "o") + // + // - identity1D(4, "i", "o") * zeros1D(2, "i", "o") => L(x) = x % 4 + // for x in [0,8). + // + // - zeros1D(2, "i", "o") * identity1D(4, "i", "o") => L(x) = x / 2 + // for x in [0,8). + // + // - identity1D(4, "i", "o1") * identity1D(8, "i", "o2") => + // L(x) = (x % 4, x / 4) for x in [0,32). + // + // Notice that this operation is not commutative. It's also not associative. + // TODO(jlebar): Can I modify the definition to make it associative? Pretty + // confusing if not. If I can't, add an example. + // + // Requires: Any in/out dimensions which are in both outer and inner appear in + // the same relative order. + friend LinearLayout operator*(LinearLayout inner, LinearLayout outer); + LinearLayout &operator*=(LinearLayout outer) { + *this = *this * outer; + return *this; + } + + // Computes and returns L(x, y, z). + // + // If you want to apply the layout to mlir Values instead of integers, that + // function lives in TritonGPUToLLVM/Utility.h. + SmallVector> + apply(ArrayRef> ins) const; + + // Creates a new layout which is equivalent to running this layout, then + // running `outer`. That is, + // + // - let this layout be L(x), and + // - let `outer` be O(x). + // - Then compose(outer) returns the layout (O∘L)(x), aka O(L(x)). + // + // Requires: The output dimensions of this layout equal the input dimensions + // of outer (order doesn't matter). + [[nodiscard]] LinearLayout compose(const LinearLayout &outer) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeIns( + // std::vector> + // newInDims) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeOuts( + // std::vector> + // newOutDims) const; + + std::string toString() const; + + friend bool operator==(LinearLayout lhs, LinearLayout rhs); + friend bool operator!=(LinearLayout lhs, LinearLayout rhs) { + return !(lhs == rhs); + } +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +} // namespace mlir::triton + +#endif diff --git a/third_party/iluvatar/include/triton/Tools/StrUtil.h b/third_party/iluvatar/include/triton/Tools/StrUtil.h new file mode 100644 index 000000000..8b59f7d2b --- /dev/null +++ b/third_party/iluvatar/include/triton/Tools/StrUtil.h @@ -0,0 +1,54 @@ +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::triton { + +// Better version of llvm::join. This one works when T is an integer or any +// other type which defines operator<<(raw_ostream). +template +std::string join(C &&container, llvm::StringRef sep = ", ") { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + s << elem; + } + return ret; +} + +// Joins a container of elements into a string, using `sep` as a separator. +// +// fn is called to transform each element of the container before it's added to +// the string. fn must have one of the following two signatures. +// +// - void fn(llvm::raw_ostream&, E), where E is the element type of the +// container, or +// - T fn(E), where T is a type which can be passed to +// raw_ostream::operator<<. +// +template +std::string join(C &&container, llvm::StringRef sep, Fn &&fn) { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + + if constexpr (std::is_invocable_v) { + static_assert( + std::is_void_v< + std::invoke_result_t>); + fn(s, elem); + } else { + s << fn(elem); + } + } + return ret; +} + +} // namespace mlir::triton diff --git a/third_party/iluvatar/include/triton/Tools/Sys/GetEnv.hpp b/third_party/iluvatar/include/triton/Tools/Sys/GetEnv.hpp new file mode 100644 index 000000000..e02426f95 --- /dev/null +++ b/third_party/iluvatar/include/triton/Tools/Sys/GetEnv.hpp @@ -0,0 +1,130 @@ +#ifndef TRITON_TOOLS_SYS_GETENV_HPP +#define TRITON_TOOLS_SYS_GETENV_HPP + +#include +#include +#include +#include +#include +#include + +#ifdef __ILUVATAR__ +#include +#include +namespace fs = std::filesystem; +#endif + +namespace mlir::triton { + +inline const std::set CACHE_INVALIDATING_ENV_VARS = { + // clang-format off + "AMDGCN_ENABLE_DUMP", + "DISABLE_FAST_REDUCTION", + "DISABLE_LLVM_OPT", + "DISABLE_MMA_V3", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", + "LLVM_ENABLE_TIMING", + "MLIR_ENABLE_DIAGNOSTICS", + "MLIR_ENABLE_DUMP", + "MLIR_ENABLE_TIMING", + "TRITON_DISABLE_LINE_INFO", + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", + "TRITON_ENABLE_LLVM_DEBUG", + "TRITON_LLVM_DEBUG_ONLY", + "USE_TTGIR_LOC", + "NVPTX_ENABLE_DUMP", +#ifdef __ILUVATAR__ + "ILUIR_ENABLE_DUMP", +#endif + // clang-format on +}; + +inline const std::set CACHE_NEUTRAL_ENV_VARS = { + "TRITON_REPRODUCER_PATH", +}; + +namespace tools { + +inline void assertIsRecognized(const std::string &env) { + bool is_invalidating = CACHE_INVALIDATING_ENV_VARS.find(env.c_str()) != + CACHE_INVALIDATING_ENV_VARS.end(); + bool is_neutral = + CACHE_NEUTRAL_ENV_VARS.find(env.c_str()) != CACHE_NEUTRAL_ENV_VARS.end(); + std::string errmsg = env + "is not recognized. " + "Please add it to triton/tools/sys/getenv.hpp"; + assert((is_invalidating || is_neutral) && errmsg.c_str()); +} + +inline std::string getStrEnv(const std::string &env) { + assertIsRecognized(env); + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return ""; + std::string result(cstr); + return result; +} + +// return value of a cache-invalidating boolean environment variable +inline bool getBoolEnv(const std::string &env) { + assertIsRecognized(env); + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return str == "on" || str == "true" || str == "1"; +} + +inline std::optional isEnvValueBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (str == "on" || str == "true" || str == "1") + return true; + if (str == "off" || str == "false" || str == "0") + return false; + return std::nullopt; +} + +#ifdef __ILUVATAR__ +static fs::path &getCudaPath(void) { + static fs::path cuda_path = [] { + void *handle = dlopen("libnvrtc.so", RTLD_LAZY); + if (!handle) { + std::fprintf(stderr, "%s\n", dlerror()); + exit(EXIT_FAILURE); + } + void *pfunc = dlsym(handle, "nvrtcCompileProgram"); + Dl_info info; + if (dladdr(pfunc, &info) == 0) { + std::fprintf(stderr, "Failed to get symbol information: %s\n", dlerror()); + exit(EXIT_FAILURE); + } + return fs::path(info.dli_fname).parent_path().parent_path(); + }(); + return cuda_path; +} + +static fs::path &getLinkerPath(void) { + static fs::path linker_path = [] { + fs::path cuda_path = getCudaPath(); + fs::path linker_path1 = cuda_path / "bin/ld.lld"; + fs::path linker_path2 = cuda_path / "../bin/ld.lld"; + if (!fs::exists(linker_path1)) { + if (fs::exists(linker_path2)) { + linker_path1 = linker_path2; + } else { + fprintf(stderr, "iluvatar linker not found in %s and %s\n", + linker_path1.c_str(), linker_path2.c_str()); + exit(EXIT_FAILURE); + } + } + return linker_path1; + }(); + return linker_path; +} +#endif + +} // namespace tools +} // namespace mlir::triton + +#endif diff --git a/third_party/iluvatar/lib/Analysis/Alias.cpp b/third_party/iluvatar/lib/Analysis/Alias.cpp new file mode 100644 index 000000000..dde554319 --- /dev/null +++ b/third_party/iluvatar/lib/Analysis/Alias.cpp @@ -0,0 +1,64 @@ +#include "triton/Analysis/Alias.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + +AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { + if (lhs == rhs) + return lhs; + AliasInfo ret; + for (auto value : lhs.allocs) { + ret.insert(value); + } + for (auto value : rhs.allocs) { + ret.insert(value); + } + return ret; +} + +void SharedMemoryAliasAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + AliasInfo aliasInfo; + bool pessimistic = true; + // These ops may allocate a new shared memory buffer. + auto result = op->getResult(0); + + // Only LocalAllocOp creates a new buffer. + if (isa(op)) { + aliasInfo.insert(result); + pessimistic = false; + } else if (isa(op)) { + // extract_slice %src + // trans %src + aliasInfo = AliasInfo(operands[0]->getValue()); + pessimistic = false; + } else { + assert(!isa(result.getType()) && + "unknown operation creating memory descriptor"); + } + + if (pessimistic) { + return setAllToEntryStates(results); + } + // Join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(aliasInfo)); +} + +AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { + // TODO: implement + return AliasResult::MayAlias; +} + +ModRefResult SharedMemoryAliasAnalysis::getModRef(Operation *op, + Value location) { + // TODO: implement + return ModRefResult::getModAndRef(); +} + +} // namespace mlir diff --git a/third_party/iluvatar/lib/Analysis/Allocation.cpp b/third_party/iluvatar/lib/Analysis/Allocation.cpp new file mode 100644 index 000000000..b3cdc2da0 --- /dev/null +++ b/third_party/iluvatar/lib/Analysis/Allocation.cpp @@ -0,0 +1,678 @@ +#include "triton/Analysis/Allocation.h" + +#include +#include +#include + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" + +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getUniqueContigPerThread; +using ::mlir::triton::gpu::IluvatarMmaEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +namespace mlir { + +//===----------------------------------------------------------------------===// +// Shared Memory Allocation Analysis +//===----------------------------------------------------------------------===// +namespace triton { + +// Bitwidth of pointers +constexpr int kPtrBitWidth = 64; + +static std::pair, SmallVector> +getCvtOrder(Attribute srcLayout, Attribute dstLayout) { + // REBASE TODO: add IluvatarMmaEncodingAttr case? + auto srcMmaLayout = mlir::dyn_cast(srcLayout); + auto srcDotLayout = mlir::dyn_cast(srcLayout); + auto dstMmaLayout = mlir::dyn_cast(dstLayout); + auto dstDotLayout = mlir::dyn_cast(dstLayout); + + assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && + "mma -> mma layout conversion is only supported on Ampere"); + + // mma or dot layout does not have an order, so the order depends on the + // layout of the other operand. + auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout) + : getOrder(srcLayout); + auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout) + : getOrder(dstLayout); + + return {inOrd, outOrd}; +} + +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (shouldUseDistSmem(srcLayout, dstLayout)) { + // TODO: padding to avoid bank conflicts + return convertType(getShapePerCTA(srcTy)); + } + + if (isMfmaToDotShortcut(srcTy, dstTy)) + return {}; + + // MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem + if (auto srcMmaLayout = mlir::dyn_cast(srcLayout)) { + if (mlir::isa(dstLayout)) { + if (isMmaToDotShortcut(srcTy, dstTy)) { + return {}; + } + } else if (auto dstMmaLayout = + mlir::dyn_cast(dstLayout)) { + if (isMmaToMmaShortcut(srcTy, dstTy)) { + return {}; + } + } + } + if (auto srcMmaLayout = mlir::dyn_cast(srcLayout)) { + if (mlir::isa(dstLayout)) { + if (isMmaToDotShortcut(srcTy, dstTy)) { + return {}; + } else if (isMmaToDotSlowShortcut(srcTy, dstTy)) { + return getShapePerCTATile(srcMmaLayout); + } + } else if (auto dstMmaLayout = + mlir::dyn_cast(dstLayout)) { + if (isMmaToMmaShortcut(srcTy, dstTy)) { + return {}; + } + } + } + + if (auto srcSliceLayout = srcLayout.dyn_cast()) { + if (auto dstSliceLayout = dstLayout.dyn_cast()) { + if (srcSliceLayout.getParent().isa() && + dstSliceLayout.getParent().isa()) { + return {}; + } + } + } + + assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()"); + + auto srcShapePerCTA = getShapePerCTA(srcTy); + auto dstShapePerCTA = getShapePerCTA(dstTy); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); + + unsigned rank = dstTy.getRank(); + SmallVector repShape(rank); + for (unsigned d = 0; d < rank; ++d) { + repShape[d] = + std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), + std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); + } + return repShape; +} + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec) { + auto repShape = getRepShapeForCvtLayout(op); + if (repShape.empty()) + return repShape; + auto rank = repShape.size(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + assert(!isMfmaToDotShortcut(srcTy, dstTy)); + if (isMmaToDotSlowShortcut(srcTy, dstTy)) + return repShape; + + auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); + unsigned srcContigPerThread = + getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; + unsigned dstContigPerThread = + getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means + // that we cannot do vectorization. + unsigned innerDim = rank - 1; + inVec = outOrd[0] != innerDim ? 1 + : inOrd[0] != innerDim ? 1 + : srcContigPerThread; + outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; + + // For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the + // codegen. + if (auto mma = mlir::dyn_cast(srcLayout)) { + if (mma.getVersionMajor() == 1) { + inVec = srcContigPerThread; + } else if (mlir::isa(dstLayout)) { + // when storing from mma layout and loading in blocked layout vectorizing + // the load back gives better performance even if there is a + // transposition. + outVec = dstContigPerThread; + } + } + + if (rank <= 1) + return repShape; + // pad the last dimension + unsigned paddedDim = rank - 1; + if (auto dstBlockedLayout = mlir::dyn_cast(dstLayout)) { + paddedDim = dstBlockedLayout.getOrder()[0]; + } + unsigned pad = std::max(inVec, outVec); + if (mlir::dyn_cast(srcLayout) && + mlir::isa(dstLayout)) { + pad = 16; + } + repShape[paddedDim] += pad; + return repShape; +} + +// TODO: extend beyond scalars +SmallVector getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { + SmallVector smemShape; + if (isa(op.getPtr().getType())) { + // do nothing or just assert because shared memory is not used in tensor up + // to now + } else { + // need only bytes for scalar + // always vec = 1 and elemsPerThread = 1 for scalar? + smemShape.push_back(1); + } + return smemShape; +} + +SmallVector getScratchConfigForAtomicCAS(triton::AtomicCASOp op) { + return SmallVector{1}; +} + +class AllocationAnalysis { +public: + AllocationAnalysis(Operation *operation, + Allocation::FuncAllocMapT *funcAllocMap, + Allocation *allocation) + : operation(operation), funcAllocMap(funcAllocMap), + allocation(allocation) { + run(); + } + +private: + using BufferT = Allocation::BufferT; + + /// Value -> Liveness Range + /// Use MapVector to ensure determinism. + using BufferRangeMapT = llvm::MapVector>; + /// Nodes -> Nodes + using GraphT = DenseMap>; + + void run() { + getValuesAndSizes(); + resolveLiveness(); + computeOffsets(); + } + + /// Initializes explicitly defined shared memory values for a given operation. + void getExplicitValueSize(Operation *op) { + // Values returned from scf.yield will not be allocated even though they + // have the shared encoding. + // For example: %a = scf.if -> yield + // %a must be allocated elsewhere by other operations. + // FIXME(Keren): extract and insert are always alias for now + if (!maybeSharedAllocationOp(op)) + return; + + // XXX(Keren): Why this hard-coded alignment? + size_t kAlignment = 8; + for (Value result : op->getResults()) { + if (auto alloc = result.getDefiningOp()) { + // Bytes could be a different value once we support padding or other + // allocation policies. + auto allocType = alloc.getType(); + auto shapePerCTA = triton::gpu::getShapePerCTA(allocType); + auto bytes = product(shapePerCTA) * + allocType.getElementTypeBitWidth() / 8; + + // XXX(Keren): magic numbers 256 and 1024 + // benzh@maybe alignment should be passed in. + // Software swizzling calculates phase based on offset, while hardware + // swizzling do that based on physical address. Thus only by setting the + // alignment to 1024 can ensure the correctness.  + if (bytes > 256) + kAlignment = 1024; + allocation->addBuffer(result, bytes, + kAlignment); + } + } + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes, + unsigned alignment) { + if (bytes > 0) + allocation->addBuffer(op, bytes, alignment); + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes) { + if (bytes > 0) + allocation->addBuffer(op, bytes); + } + + /// Initializes temporary shared memory for a given operation. + void getScratchValueSize(Operation *op) { + const size_t scratchAlignment = 128; + if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + unsigned bytes = helper.getScratchSizeInBytes(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto scanOp = dyn_cast(op)) { + ScanLoweringHelper helper(scanOp); + unsigned bytes = helper.getScratchSizeInBytes(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto histogram = dyn_cast(op)) { + auto dstTy = histogram.getType(); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + auto bytes = std::max(dstTy.getNumElements(), threadsPerWarp) * + std::max(8, dstTy.getElementTypeBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.getSrc().getType(); + auto dstTy = cvtLayout.getType(); + auto srcEncoding = srcTy.getEncoding(); + auto dstEncoding = dstTy.getEncoding(); + if (mlir::isa(srcEncoding) || + mlir::isa(dstEncoding)) { + // Conversions from/to shared memory do not need scratch memory. + return; + } + // ConvertLayoutOp with both input/output non-shared_layout + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's + // also possible to realize it with other approaches in restricted + // conditions, such as warp-shuffle + unsigned inVec = 0; + unsigned outVec = 0; + auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec); + unsigned elems = 0; + if (!smemShape.empty()) + elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto bytes = + isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto atomicRMWOp = dyn_cast(op)) { + auto value = op->getOperand(0); + // only scalar requires scratch memory + // make it explicit for readability + if (dyn_cast(value.getType())) { + // nothing to do + } else { + auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto elemTy = + cast(value.getType()).getPointeeType(); + auto bytes = + isa(elemTy) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } else if (auto atomicCASOp = dyn_cast(op)) { + // only scalar requires scratch memory + // make it explicit for readability + auto value = op->getOperand(0); + if (dyn_cast(value.getType())) { + // nothing to do + } else { + auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto elemTy = + cast(value.getType()).getPointeeType(); + auto bytes = isa(elemTy) + ? elems * kPtrBitWidth / 8 + : elems * elemTy.getIntOrFloatBitWidth() / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } else if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto funcOp = dyn_cast(callable); + auto *funcAlloc = &(*funcAllocMap)[funcOp]; + auto bytes = funcAlloc->getSharedMemorySize(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } + + void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { + dataflow::Lattice *latticeElement = + analysis.getLatticeElement(value); + if (latticeElement) { + AliasInfo &info = latticeElement->getValue(); + if (!info.getAllocs().empty()) { + for (auto alloc : info.getAllocs()) { + allocation->addAlias(value, alloc); + } + } + } + } + + /// Extract all shared memory values and their sizes + void getValuesAndSizes() { + // Get the alloc values + operation->walk([&](Operation *op) { + getExplicitValueSize(op); + getScratchValueSize(op); + }); + // Get the alias values + std::unique_ptr solver = createDataFlowSolver(); + SharedMemoryAliasAnalysis *aliasAnalysis = + solver->load(); + if (failed(solver->initializeAndRun(operation))) { + // TODO: return error instead of bailing out.. + llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); + } + operation->walk([&](Operation *op) { + for (auto operand : op->getOperands()) { + getValueAlias(operand, *aliasAnalysis); + } + for (auto value : op->getResults()) { + getValueAlias(value, *aliasAnalysis); + } + }); + } + + /// Computes the liveness range of the allocated value. + /// Each buffer is allocated only once. + void resolveExplicitBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto valueBufferIter : allocation->valueBuffer) { + auto value = valueBufferIter.first; + auto *buffer = valueBufferIter.second; + bufferRange[buffer] = getLiveness(value); + } + } + + /// Extends the liveness range by unionizing the liveness range of the aliased + /// values because each allocated buffer could be an alias of others, if block + /// arguments are involved. + void resolveAliasBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto aliasBufferIter : allocation->aliasBuffer) { + auto value = aliasBufferIter.first; + auto buffers = aliasBufferIter.second; + auto range = getLiveness(value); + for (auto *buffer : buffers) { + auto minId = range.start(); + auto maxId = range.end(); + if (bufferRange.count(buffer)) { + // Extend the allocated buffer's range + minId = std::min(minId, bufferRange[buffer].start()); + maxId = std::max(maxId, bufferRange[buffer].end()); + } + bufferRange[buffer] = Interval(minId, maxId); + } + } + } + + /// Computes the liveness range of scratched buffers. + /// Some operations may have a temporary buffer that is not explicitly + /// allocated, but is used to store intermediate results. + void resolveScratchBufferLiveness( + const DenseMap &operationId) { + // Analyze liveness of scratch buffers and virtual buffers. + auto processScratchMemory = [&](const auto &container) { + for (auto opScratchIter : container) { + // Any scratch memory's live range is the current operation's live + // range. + auto *op = opScratchIter.first; + auto *buffer = opScratchIter.second; + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } + }; + processScratchMemory(allocation->opScratch); + processScratchMemory(allocation->opVirtual); + } + + /// Resolves liveness of all values involved under the root operation. + void resolveLiveness() { + // Assign an ID to each operation using post-order traversal. + // To achieve the correct liveness range, the parent operation's ID + // should be greater than each of its child operation's ID . + // Example: + // ... + // %5 = triton.convert_layout %4 + // %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) { + // %2 = triton.convert_layout %5 + // ... + // scf.yield %arg0 + // } + // For example, %5 is defined in the parent region and used in + // the child region, and is not passed as a block argument. + // %6 should should have an ID greater than its child operations, + // otherwise %5 liveness range ends before the child operation's liveness + // range ends. + DenseMap operationId; + operation->walk( + [&](Operation *op) { operationId[op] = operationId.size(); }); + + // Analyze liveness of explicit buffers + Liveness liveness(operation); + auto getValueLivenessRange = [&](Value value) { + auto liveOperations = liveness.resolveLiveness(value); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); + return Interval(minId, maxId); + }; + + resolveExplicitBufferLiveness(getValueLivenessRange); + resolveAliasBufferLiveness(getValueLivenessRange); + resolveScratchBufferLiveness(operationId); + } + + /// Computes the shared memory offsets for all related values. + /// Paper: Algorithms for Compile-Time Memory Optimization + /// (https://dl.acm.org/doi/pdf/10.5555/314500.315082) + void computeOffsets() { + SmallVector buffers; + for (auto bufferIter : bufferRange) { + buffers.emplace_back(bufferIter.first); + } + + calculateStarts(buffers); + + // NOTE: The original paper doesn't consider interference between + // the bumped ranges. Buffers that previously do not interfere with + // could interfere after offset bumping if their liveness ranges overlap. + // Therefore, we rerun the interference graph algorithm after bumping so + // that we regroup the buffers and color them again. Since we always + // increase the buffer offset and keep reducing conflicts, we will + // eventually reach a fixed point. + GraphT interference; + buildInterferenceGraph(buffers, interference); + do { + allocate(buffers, interference); + buildInterferenceGraph(buffers, interference); + } while (!interference.empty()); + } + + /// Computes the initial shared memory offsets. + void calculateStarts(const SmallVector &buffers) { + // v = values in shared memory + // t = triplet of (size, start, end) + // shared memory space + // - + // | *******t4 + // | /|\ v2 inserts t4, t5, and t6 + // | | + // | ******t5 ************t6 + // | ^^^^^v2^^^^^^ + // | | *********************t2 + // | \|/ v2 erases t1 + // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 + // |---------------------------------------------| liveness range + // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... + // If the available triple's range is less than a given buffer range, + // we won't know if there has been an overlap without using graph coloring. + // Start -> Liveness Range + using TripleMapT = std::multimap>; + TripleMapT tripleMap; + tripleMap.insert(std::make_pair(0, Interval())); + SmallVector xBuffers = buffers; + while (!xBuffers.empty()) { + auto tripleIt = tripleMap.begin(); + auto offset = tripleIt->first; + auto range = tripleIt->second; + tripleMap.erase(tripleIt); + auto bufferIt = + std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { + auto xRange = bufferRange[buffer]; + bool res = xRange.intersects(range); + for (auto val : tripleMap) + res = res && + !val.second.intersects(xRange); // only one buffer intersect + return res; + }); + if (bufferIt != xBuffers.end()) { + auto buffer = *bufferIt; + auto xSize = buffer->size; + auto xRange = bufferRange.lookup(buffer); + // TODO(Keren): A buffer's size shouldn't be determined here, have to + // clean it up + size_t alignOffset = buffer->setOffsetAligned(offset); + tripleMap.insert({alignOffset + xSize, + Interval{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); + // We could either insert (range.start, xRange.start) or (range.start, + // xRange.end), both are correct and determine the potential buffer + // offset, and the graph coloring algorithm will solve the interference, + // if any + if (range.start() < xRange.start()) + tripleMap.insert({offset, Interval{range.start(), xRange.end()}}); + if (xRange.end() < range.end()) + tripleMap.insert({offset, Interval{xRange.start(), range.end()}}); + xBuffers.erase(bufferIt); + } + } + } + + /// Builds a graph of all shared memory values. Edges are created between + /// shared memory values that are overlapping. + void buildInterferenceGraph(const SmallVector &buffers, + GraphT &interference) { + // Reset interference graph + interference.clear(); + for (auto x : buffers) { + for (auto y : buffers) { + if (x == y) + continue; + auto xStart = x->offset; + auto yStart = y->offset; + auto xSize = x->size; + auto ySize = y->size; + Interval xSizeRange = {xStart, xStart + xSize}; + Interval ySizeRange = {yStart, yStart + ySize}; + auto xOpRange = bufferRange.lookup(x); + auto yOpRange = bufferRange.lookup(y); + if (xOpRange.intersects(yOpRange) && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + } + } + } + + /// Finalizes shared memory offsets considering interference. + void allocate(const SmallVector &buffers, + const GraphT &interference) { + // Reset shared memory size + allocation->sharedMemorySize = 0; + // First-fit graph coloring + // Neighbors are nodes that interfere with each other. + // We color a node by finding the index of the first available + // non-neighboring node or the first neighboring node without any color. + // Nodes with the same color do not interfere with each other. + DenseMap colors; + for (auto value : buffers) { + colors[value] = (value == buffers[0]) ? 0 : -1; + } + SmallVector available(buffers.size()); + for (auto x : buffers) { + std::fill(available.begin(), available.end(), true); + for (auto y : interference.lookup(x)) { + int color = colors[y]; + if (color >= 0) { + available[color] = false; + } + } + auto it = std::find(available.begin(), available.end(), true); + colors[x] = std::distance(available.begin(), it); + } + // Finalize allocation + // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) + // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) + // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) + // TODO(Keren): We are wasting memory here. + // Nodes with color2 can actually start with 24. + for (auto x : buffers) { + size_t newOffset = 0; + for (auto y : interference.lookup(x)) { + newOffset = std::max(newOffset, y->offset + y->size); + } + if (colors.lookup(x) != 0) + x->setOffsetAligned(newOffset); + allocation->sharedMemorySize = + std::max(allocation->sharedMemorySize, x->offset + x->size); + } + } + +private: + Operation *operation; + Allocation::FuncAllocMapT *funcAllocMap; + Allocation *allocation; + BufferRangeMapT bufferRange; +}; + +} // namespace triton + +void Allocation::run(FuncAllocMapT &funcAllocMap) { + triton::AllocationAnalysis(getOperation(), &funcAllocMap, this); +} + +} // namespace mlir diff --git a/third_party/iluvatar/lib/Analysis/AxisInfo.cpp b/third_party/iluvatar/lib/Analysis/AxisInfo.cpp new file mode 100644 index 000000000..7a4671c55 --- /dev/null +++ b/third_party/iluvatar/lib/Analysis/AxisInfo.cpp @@ -0,0 +1,1318 @@ +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "axis-info" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton { +namespace { + +int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) { + // Base Case + if (a == 0) { + *x = 0; + *y = 1; + return b; + } + int64_t x1, y1; // To store results of recursive call + int64_t gcd = gcdImpl(b % a, a, &x1, &y1); + // Update x and y using results of + // recursive call + *x = y1 - (b / a) * x1; + *y = x1; + return gcd; +} + +int64_t gcd(int64_t a, int64_t b) { + if (a == 0) + return b; + if (b == 0) + return a; + int64_t x, y; + return gcdImpl(a, b, &x, &y); +} + +constexpr int log2Int(int64_t num) { + return (num > 1) ? 1 + log2Int(num / 2) : 0; +} + +// If lhs * rhs overflows, return max value possible value for the type +int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { + int64_t maxDivisor = highestPowOf2Divisor(0); + if (lhs > maxDivisor / rhs) + return maxDivisor; + return lhs * rhs; +} + +class AxisInfoVisitor { +public: + AxisInfoVisitor() = default; + virtual ~AxisInfoVisitor() = default; + + static bool isContiguousDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getContiguity(dim) == shape[dim]; + } + + static bool isConstantDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getConstancy(dim) == shape[dim]; + } + + virtual AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) = 0; + + virtual bool match(Operation *op) = 0; +}; + +// Base class for all operations +template class AxisInfoVisitorImpl : public AxisInfoVisitor { +public: + using AxisInfoVisitor::AxisInfoVisitor; + + AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) final { + return getAxisInfo(cast(op), operands); + } + + bool match(Operation *op) final { return isa(op); } + + virtual AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) = 0; +}; + +// Binary operations +template +class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + assert(operands.size() == 2 && "Expected two operands"); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); + for (auto d = 0; d < rank; ++d) { + if (constantValue.has_value()) { + contiguity.push_back(1); + constancy.push_back( + std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + highestPowOf2Divisor(constantValue.value())); + } else { + contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); + constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); + divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); + } + } + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +protected: + virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getDivisibility(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) { + return {}; + } +}; + +class AxisInfoVisitorList { +public: + template > + void append() { + (visitors.emplace_back(std::make_unique()), ...); + } + + AxisInfo apply(Operation *op, + ArrayRef *> operands) { + for (auto &visitor : visitors) + if (visitor->match(op)) + return visitor->getAxisInfo(op, operands); + return AxisInfo(); + } + +private: + std::vector> visitors; +}; + +class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + AxisInfoVisitorList visitors; + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, + lattice->join(AxisInfo::getPessimisticValueState(lattice->getPoint()))); + } + + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef *> argLattices, + unsigned firstIndex) override { + if (auto forOp = dyn_cast(op)) { + visitForOpInductionVar(forOp, argLattices); + } else { + setAllToEntryStates(argLattices.take_front(firstIndex)); + setAllToEntryStates(argLattices.drop_front( + firstIndex + successor.getSuccessorInputs().size())); + } + } + +public: + AxisInfoAnalysis(DataFlowSolver &solver); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + using FuncAxisInfoMapT = DenseMap; + + void visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; + void + visitForOpInductionVar(scf::ForOp op, + ArrayRef *> argLattices); +}; + +template +class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + return operands[0]->getValue(); + } +}; + +class MakeRangeOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::MakeRangeOp op, + ArrayRef *> operands) override { + auto start = op.getStart(); + auto end = op.getEnd(); + return AxisInfo(/*contiguity=*/{end - start}, + /*divisibility=*/{highestPowOf2Divisor(start)}, + /*constancy=*/{1}); + } +}; + +template +class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto intAttr = dyn_cast(op.getValue()); + auto boolAttr = dyn_cast(op.getValue()); + if (intAttr || boolAttr) { + int64_t value{}; + if (intAttr) + value = intAttr.getValue().getZExtValue(); + else + value = boolAttr.getValue() ? 1 : 0; + return AxisInfo(/*contiguity=*/{1}, + /*divisibility=*/{highestPowOf2Divisor(value)}, + /*constancy=*/{1}, + /*knownConstantValue=*/{value}); + } + // TODO: generalize to dense attr + auto splatAttr = dyn_cast(op.getValue()); + if (splatAttr && splatAttr.getElementType().isIntOrIndex()) { + int64_t value = splatAttr.template getSplatValue().getZExtValue(); + TensorType ty = cast(splatAttr.getType()); + return AxisInfo( + /*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1), + /*divisibility=*/ + AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)), + /*constancy=*/ + AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()), + /*knownConstantValue=*/{value}); + } + return AxisInfo(); + } +}; + +template +class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)), + gcd(lhs.getContiguity(dim), rhs.getConstancy(dim))); + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs) + // rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs) + // lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) * + // gcd(d_lhs, d_rhs) + auto rhsDivisibility = rhs.getDivisibility(dim); + if constexpr (std::is_same_v) { + // %ptr = addptr %lhs, %rhs + // is equivalent to + // %0 = mul %rhs, %elemSize + // %ptr = add %lhs, %0 + // The result will still be contiguous in terms of elements but not bytes + // For example: + // addptr [16] : !ptr, [0, 1, 2, 3] : i32 -> !ptr + // returns: + // [16, 20, 24, 28] : !ptr + // with element locations: + // [4, 5, 6, 7] + // It is "strided contiguous" with a divisilibity of 16 bytes + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + rhsDivisibility = multiplyDivisor(rhs.getDivisibility(dim), elemSize); + } + return gcd(lhs.getDivisibility(dim), rhsDivisibility); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + return {lhs.getConstantValue().value() + + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() - + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + auto rhsValue = rhs.getConstantValue().value() * elemSize; + return {lhs.getConstantValue().value() + rhsValue}; + } + } + return {}; + } +}; + +class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + // lhs * 1 = lhs + auto lhsContiguity = + rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1 + ? lhs.getContiguity(dim) + : 1; + // 1 * rhs = rhs + auto rhsContiguity = + lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1 + ? rhs.getContiguity(dim) + : 1; + return std::max(lhsContiguity, rhsContiguity); + } + + int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && + !(rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto rhsDivisibility = rhs.getDivisibility(dim); + if (rhs.getContiguity(dim) > 1 && + !(lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + rhsDivisibility = 1; + } + return multiplyDivisor(lhsDivisibility, rhsDivisibility); + } + + std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() * rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs / 1 = lhs + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? lhs.getContiguity(dim) + : 1; + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // Case 1: both lhs and rhs are constants. + auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + // Case 2: lhs contiguous, rhs constant. + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p), + // ..., (d_lhs * k + n) / (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // the minimal constancy is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual constancy. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + constancy = std::max(constancy, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return constancy; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // Case 1: lhs is 0 + if (lhs.getConstantValue().has_value() && + lhs.getConstantValue().value() == 0) + return lhs.getDivisibility(dim); + // Case 2: rhs is 1 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return lhs.getDivisibility(dim); + // otherwise: return 1 + return 1; + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() / rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getContiguity(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + int64_t contiguity = 1; + // lhs contiguous, rhs constant + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p), + // ..., (d_lhs * k + n) % (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // The minimal contiguity is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual contiguity. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + contiguity = std::max(contiguity, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return contiguity; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k'' + // rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p'' + // lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r + // r must be divisible by gcd(d_lhs, d_rhs) + return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)); + }; + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // lhs % 1 = 0 + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? shape[dim] + : gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() % rhs.getConstantValue().value()}; + else if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return {0}; + return {}; + } +}; + +class SplatOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::SplatOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + TensorType retTy = cast(_retTy); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(opInfo.getDivisibility(0)); + constancy.push_back(retTy.getShape()[d]); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::LoadOp op, + ArrayRef *> operands) override { + // If pointers and mask both have constancy properties, those properties + // will also extend to output. + AxisInfo ptrInfo = operands[0]->getValue(); + std::optional maskInfo; + if (operands.size() > 1) { + maskInfo = operands[1]->getValue(); + } + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + + for (int d = 0; d < ptrInfo.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(1); + constancy.push_back( + gcd(ptrInfo.getConstancy(d), + (maskInfo.has_value() && (d < maskInfo->getRank())) + ? maskInfo->getConstancy(d) + : 0)); + } + + return AxisInfo(contiguity, divisibility, constancy); + } +}; + +class ExpandDimsOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::ExpandDimsOp op, + ArrayRef *> operands) override { + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); + AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); + AxisInfo::DimVectorT constancy = opInfo.getConstancy(); + int64_t newDivisibility = 1; + if (opInfo.getConstantValue().has_value()) { + // The tensor is constant, same as ConstantOpAxisInfoVisitor + newDivisibility = highestPowOf2Divisor(opInfo.getConstantValue().value()); + } else if (opInfo.getRank()) { + // Otherwise, calculate the GCD as the new divisibility + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + newDivisibility = + opInfo.getContiguity(0) > 1 ? 1 : opInfo.getDivisibility(0); + for (int d = 1; d < opInfo.getRank(); ++d) { + newDivisibility = + gcd(newDivisibility, + opInfo.getContiguity(d) > 1 ? 1 : opInfo.getDivisibility(d)); + } + } + contiguity.insert(contiguity.begin() + op.getAxis(), 1); + divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility); + constancy.insert(constancy.begin() + op.getAxis(), 1); + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class BroadcastOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::BroadcastOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = cast(_retTy); + TensorType opTy = cast(_opTy); + ArrayRef retShape = retTy.getShape(); + ArrayRef opShape = opTy.getShape(); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); + divisibility.push_back(opInfo.getDivisibility(d)); + constancy.push_back(opShape[d] == 1 ? retShape[d] + : opInfo.getConstancy(d)); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +template +class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return AxisInfo(); + auto shape = resTy.getShape(); + short rank = resTy.getRank(); + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + for (short d = 0; d < rank; ++d) { + int64_t constHint = 1; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + constHint = lhsInfo.getConstancy(d); + constantValue = + compare(getPredicate(op), lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value()) + ? 1 + : 0; + } else { + // Case 1: lhs and rhs are both partial constants + constHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)); + if ((gtPredicate(getPredicate(op)) || lePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(lhsInfo, shape, d)) { + // Case 2: lhs all constant, rhs all contiguous + // NOTE: + // lhs: 4 4 4 4 + // rhs: 4 5 6 7 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs lt rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 1, 1, 1 + // lhs ge rhs: 1, 0, 0, 0 + // lhs gt rhs: 0, 0, 0, 0 + constHint = std::max(constHint, gcd(rhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } else if ((ltPredicate(getPredicate(op)) || + gePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(rhsInfo, shape, d)) { + // Case 3: lhs all contiguous, rhs all constant + // NOTE + // lhs: 4 5 6 7 + // rhs: 4 4 4 4 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 0, 0, 0 + // lhs lt rhs: 0, 0, 0, 0 + // lhs gt rhs: 0, 1, 1, 1 + // lhs ge rhs: 1, 1, 1, 1 + constHint = std::max(constHint, gcd(lhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } + } + + constancy.push_back(constHint); + divisibility.push_back(1); + contiguity.push_back(1); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +private: + static arith::CmpIPredicate getPredicate(arith::CmpIOp op) { + return op.getPredicate(); + } + + static bool gtPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sgt || + predicate == arith::CmpIPredicate::ugt; + } + + static bool gePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sge || + predicate == arith::CmpIPredicate::uge; + } + + static bool ltPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::slt || + predicate == arith::CmpIPredicate::ult; + } + + static bool lePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sle || + predicate == arith::CmpIPredicate::ule; + } + + static bool compare(arith::CmpIPredicate predicate, int64_t lhs, + int64_t rhs) { + switch (predicate) { + case arith::CmpIPredicate::eq: + return lhs == rhs; + case arith::CmpIPredicate::ne: + return lhs != rhs; + case arith::CmpIPredicate::slt: + return lhs < rhs; + case arith::CmpIPredicate::sle: + return lhs <= rhs; + case arith::CmpIPredicate::sgt: + return lhs > rhs; + case arith::CmpIPredicate::sge: + return lhs >= rhs; + case arith::CmpIPredicate::ult: + return (uint64_t)lhs < (uint64_t)rhs; + case arith::CmpIPredicate::ule: + return (uint64_t)lhs <= (uint64_t)rhs; + case arith::CmpIPredicate::ugt: + return (uint64_t)lhs > (uint64_t)rhs; + case arith::CmpIPredicate::uge: + return (uint64_t)lhs >= (uint64_t)rhs; + default: + break; + } + llvm_unreachable("unknown comparison predicate"); + } +}; + +template +class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto condConstancy = operands[0]->getValue().getConstancy(); + auto lhsInfo = operands[1]->getValue(); + auto rhsInfo = operands[2]->getValue(); + auto rank = lhsInfo.getRank(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + if (operands[0]->getValue().getConstantValue().has_value()) { + if (operands[0]->getValue().getConstantValue() == 0) { + contiguity = rhsInfo.getContiguity(); + divisibility = rhsInfo.getDivisibility(); + constancy = rhsInfo.getConstancy(); + constantValue = rhsInfo.getConstantValue(); + } else { + contiguity = lhsInfo.getContiguity(); + divisibility = lhsInfo.getDivisibility(); + constancy = lhsInfo.getConstancy(); + constantValue = lhsInfo.getConstantValue(); + } + } else { + // The condition can be either a tensor or i1. + // If i1 is used as the condition, the entire tensor of either + // lhs or rhs is selected. + bool i1Cond = isa(op.getOperand(0).getType()); + for (auto d = 0; d < rank; ++d) { + if (i1Cond) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } else { + constancy.push_back( + std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]), + gcd(rhsInfo.getConstancy(d), condConstancy[d]))); + contiguity.push_back( + std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]), + gcd(rhsInfo.getContiguity(d), condConstancy[d]))); + if (contiguity.back() == lhsInfo.getContiguity(d) && + contiguity.back() == rhsInfo.getContiguity(d)) { + // Contiguity not changed + divisibility.push_back( + gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + } else { + // Contiguity changed, we cannot use only divisibility. + // For example, the following example should have contiguity 2 and + // divisibility 2 + // [[0, 1], [4, 5]] + // [[16, 17, 18, 19]] + divisibility.push_back( + std::min(gcd(lhsInfo.getDivisibility(d), contiguity.back()), + gcd(rhsInfo.getDivisibility(d), contiguity.back()))); + } + } + } + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value() && + lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) + constantValue = lhsInfo.getConstantValue(); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } +}; + +template +class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() & + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() | + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() ^ + rhs.getConstantValue().value()}; + } + } + return {}; + } +}; + +class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto shift = rhs.getConstantValue().has_value() + ? rhs.getConstantValue().value() + : rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto numBits = log2Int(lhsDivisibility); + return multiplyDivisor(lhsDivisibility, 1 << shift); + } + + int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() << rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto shift = rhs.getConstantValue().has_value() + ? rhs.getConstantValue().value() + : rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + return std::max(1, lhsDivisibility / (1 << shift)); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + std::optional constantValue; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::max(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } else if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::min(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } + return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1), + /*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1), + /*knownConstancy=*/AxisInfo::DimVectorT(rank, 1), + /*constantValue=*/constantValue); + } else { + AxisInfo::DimVectorT contiguity, divisibility, constancy; + for (auto d = 0; d < rank; ++d) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } + return AxisInfo(contiguity, divisibility, constancy, std::nullopt); + } + } +}; + +//===----------------------------------------------------------------------===// +// AxisInfoAnalysis +//===----------------------------------------------------------------------===// + +AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) + : dataflow::SparseForwardDataFlowAnalysis>( + solver) { + // UnrealizedConversionCast: + // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is + // in the process of a PartialConversion, where UnrealizedConversionCast + // may exist + visitors.append, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor>(); + // TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp + // when scf.for supports integer induction variables + visitors.append(); + visitors.append, + ConstantOpAxisInfoVisitor>(); + visitors.append, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor>(); + visitors.append(); + visitors.append, + DivOpAxisInfoVisitor>(); + visitors.append, + RemOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append>(); + visitors.append, + LogicalOpAxisInfoVisitor, + LogicalOpAxisInfoVisitor>(); + visitors.append>(); + visitors.append, + ShROpAxisInfoVisitor>(); + visitors.append, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor>(); + visitors.append(); +} + +void AxisInfoAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + // TODO: For sure not the right way to do this + // but why is scf.if not initialized otherwise? + for (auto op : operands) + if (op->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op); + AxisInfo curr = visitors.apply(op, operands); + if (curr.getRank() == 0) + return setAllToEntryStates(results); + // override with hint + auto newContiguity = curr.getContiguity(); + auto newDivisibility = curr.getDivisibility(); + auto newConstancy = curr.getConstancy(); + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + curr = AxisInfo(newContiguity, newDivisibility, newConstancy, + curr.getConstantValue()); + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(curr)); +} + +void AxisInfoAnalysis::visitForOpInductionVar( + scf::ForOp op, ArrayRef *> argLattices) { + auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue(); + auto step = getLatticeElementFor(op, op.getStep())->getValue(); + + AxisInfo::DimVectorT knownContiguity(1, 1); + AxisInfo::DimVectorT knownDivisibility(1, 1); + AxisInfo::DimVectorT knownConstancy(1, 1); + knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0)); + auto inductionVar = + AxisInfo(knownContiguity, knownDivisibility, knownConstancy); + (void)argLattices[0]->join(inductionVar); +} + +} // anonymous namespace + +template +void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy) { + // liast of attributes that we care about + SmallVector> retVecs; + retVecs.push_back({contiguity, "tt.contiguity"}); + retVecs.push_back({divisibility, "tt.divisibility"}); + retVecs.push_back({constancy, "tt.constancy"}); + // initialize attributes one by one + for (auto [vec, attrName] : retVecs) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto int_attr = dyn_cast_or_null(attr)) + *vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue()); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + *vec = DimVectorT(vals.begin(), vals.end()); + } + } +} + +/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { + auto rank = 1; + if (TensorType ty = dyn_cast(value.getType())) + rank = ty.getRank(); + if (triton::PointerType ty = dyn_cast(value.getType())) + if (TensorType elemTy = dyn_cast(ty.getPointeeType())) + rank = elemTy.getRank(); + + DimVectorT knownContiguity(rank, 1); + DimVectorT knownDivisibility(rank, 1); + DimVectorT knownConstancy(rank, 1); + + BlockArgument blockArg = dyn_cast(value); + + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + // llvm codegen check alignment to generate vector load/store + // would be nice if this wasn't the case + else if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + } else if (Operation *op = value.getDefiningOp()) { + if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. + knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); + knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); + knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); + } + // Other operations are conservatively initialized with the lowest possible + // divisibility, contiguity, and constancy unless they have specified. + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + knownDivisibility = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + knownContiguity = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + knownConstancy = DimVectorT(vals.begin(), vals.end()); + } + } + + return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); +} + +/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { + // If one argument is not initialized, return the other. + if (lhs.getRank() == 0) + return rhs; + if (rhs.getRank() == 0) + return lhs; + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + for (auto d = 0; d < lhs.getRank(); ++d) { + contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); + divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); + } + std::optional constantValue; + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value() && + lhs.getConstantValue() == rhs.getConstantValue()) + constantValue = lhs.getConstantValue(); + return AxisInfo(contiguity, divisibility, constancy, constantValue); +} + +unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto layout = tensorTy.getEncoding(); + + // Here order should be ordered by contiguous first, so the first element + // should have the largest contiguous. + auto order = triton::gpu::getOrder(layout); + unsigned align = getPtrAlignment(ptr); + + auto uniqueContigPerThread = + triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + LDBG("getPtrContiguity uniqueContigPerThread = " << contiguity); + contiguity = std::min(align, contiguity); + + return contiguity; +} + +unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(ptr); + if (!axisInfo) + return 1; + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); + auto maxContig = axisInfo->getContiguity(order[0]); + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); + unsigned alignment = std::min(maxMultiple, maxContig); + LDBG("getPtrAlignment order[0] " + << order[0] << " maxMultipleBytes = " << maxMultipleBytes + << " maxContig = " << maxContig << " elemNumBits = " << elemNumBits + << " maxMultiple = " << maxMultiple << " alignment " << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { + auto tensorTy = dyn_cast(mask.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(mask); + if (!axisInfo) + return 1; + auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); + auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); + LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " + << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(funcOp))) + return; + auto *axisInfoMap = getFuncData(funcOp); + auto updateAxisInfoMap = [&](Value value) { + auto axisInfo = analysis->getLatticeElement(value)->getValue(); + AxisInfo curAxisInfo; + if (axisInfoMap->count(value)) { + curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value)); + } else { + curAxisInfo = axisInfo; + } + (*axisInfoMap)[value] = curAxisInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateAxisInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateAxisInfoMap(value); + } + }); +} + +void ModuleAxisInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *axisInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, int64_t prevValue) { + auto curValue = highestPowOf2Divisor(0); + if (callee.getArgAttrOfType(index, attrName)) { + curValue = + callee.getArgAttrOfType(index, attrName).getInt(); + } + auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64), + gcd(prevValue, curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto axisInfo = axisInfoMap->lookup(value); + assert(axisInfo.getRank() == 1 && "only scalar arguments are supported"); + setAttrFn("tt.contiguity", axisInfo.getContiguity(0)); + setAttrFn("tt.divisibility", axisInfo.getDivisibility(0)); + setAttrFn("tt.constancy", axisInfo.getConstancy(0)); + } +} + +} // namespace mlir::triton diff --git a/third_party/iluvatar/lib/Analysis/CMakeLists.txt b/third_party/iluvatar/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..12deb6143 --- /dev/null +++ b/third_party/iluvatar/lib/Analysis/CMakeLists.txt @@ -0,0 +1,18 @@ +add_triton_library(TritonAnalysis + AxisInfo.cpp + Allocation.cpp + Membar.cpp + Alias.cpp + Utility.cpp + + DEPENDS + TritonTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRLLVMDialect + TritonIR + TritonGPUIR + #TritonNvidiaGPUIR +) diff --git a/third_party/iluvatar/lib/Analysis/Membar.cpp b/third_party/iluvatar/lib/Analysis/Membar.cpp new file mode 100644 index 000000000..407a5ae15 --- /dev/null +++ b/third_party/iluvatar/lib/Analysis/Membar.cpp @@ -0,0 +1,178 @@ +#include "triton/Analysis/Membar.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include + +namespace mlir { + +void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) { + FunctionOpInterface funcOp = + dyn_cast(allocation->getOperation()); + OpBuilder builder(funcOp.getContext()); + resolve(funcOp, &funcBlockInfoMap, &builder); +} + +void MembarAnalysis::resolve(FunctionOpInterface funcOp, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + // Initialize the blockList + DenseMap inputBlockInfoMap; + DenseMap outputBlockInfoMap; + std::deque blockList; + funcOp.walk([&](Block *block) { + for (auto &op : block->getOperations()) { + // Check if the operation belongs to scf dialect, if so, we need to + // throw an error + if (op.getDialect()->getNamespace() == "scf") { + llvm::report_fatal_error( + "scf dialect is not supported in membar. Please lower it " + "to cf dialect first."); + return; + } + } + if (block->isEntryBlock()) + blockList.emplace_back(block); + }); + + // A fixed point algorithm + while (!blockList.empty()) { + auto *block = blockList.front(); + blockList.pop_front(); + // Make a copy of the inputblockInfo but not update + auto inputBlockInfo = inputBlockInfoMap[block]; + SmallVector successors; + for (auto &op : block->getOperations()) { + if (op.hasTrait()) { + visitTerminator(&op, successors); + } else { + update(&op, &inputBlockInfo, funcBlockInfoMap, builder); + } + } + // Get the reference because we want to update if it changed + if (outputBlockInfoMap.count(block) && + inputBlockInfo == outputBlockInfoMap[block]) { + // If we have seen the block before and the inputBlockInfo is the same as + // the outputBlockInfo, we skip the successors + continue; + } + // Update the current block + outputBlockInfoMap[block].join(inputBlockInfo); + // Update the successors + for (auto *successor : successors) { + inputBlockInfoMap[successor].join(outputBlockInfoMap[block]); + blockList.emplace_back(successor); + } + } + + // Update the final dangling buffers that haven't been synced + auto &funcBlockInfo = (*funcBlockInfoMap)[funcOp]; + funcOp.walk([&](Block *block) { + block->walk([&](triton::ReturnOp returnOp) { + funcBlockInfo.join(outputBlockInfoMap[block]); + }); + }); +} + +void MembarAnalysis::visitTerminator(Operation *op, + SmallVector &successors) { + if (auto branchInterface = dyn_cast(op)) { + Block *parentBlock = branchInterface->getBlock(); + successors.append(std::begin(parentBlock->getSuccessors()), + std::end(parentBlock->getSuccessors())); + return; + } + // Otherwise, it could be a return op + if (op->hasTrait()) + return; + llvm_unreachable("Unknown terminator encountered in membar analysis"); +} + +void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { + OpBuilder::InsertionGuard g(*builder); + auto barrierOp = builder->create(op->getLoc()); +} + +void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + if (isa(op)) { + // If the current op is a barrier, we sync previous reads and writes + blockInfo->sync(); + return; + } + + if (isa(op) && + !isa(op->getNextNode())) { + // If the current op is an async wait and the next op is not a barrier we + // insert a barrier op and sync + builder->setInsertionPointAfter(op); + insertBarrier(op, builder); + blockInfo->sync(); + return; + } + + BlockInfo curBlockInfo; + if (isa(op)) { + // Inter-function dependencies + auto callOpInterface = dyn_cast(op); + if (auto callee = + dyn_cast(callOpInterface.resolveCallable())) + curBlockInfo = funcBlockInfoMap->lookup(callee); + } else { + // Intra-function dependencies + if (auto memoryEffectOpInterface = dyn_cast(op)) { + // Explicit buffer + SmallVector> + effectInstances; + memoryEffectOpInterface.getEffects(effectInstances); + for (auto effectInstance : effectInstances) { + if (auto value = effectInstance.getValue()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) { + if (isa(effectInstance.getEffect())) + curBlockInfo.syncWriteIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + else if (isa(effectInstance.getEffect())) + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + } + } + } + // XXX(Keren): This is a hack as we cannot set side effects for dot ops, but + // on hopper they do have side effects. Need to clean it up + if (auto dotOp = dyn_cast(op)) { + for (auto value : dotOp.getOperands()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + } + // Scratch buffer is considered as both shared memory write & read + auto bufferId = allocation->getBufferId(op); + if (bufferId != Allocation::InvalidBufferId) { + curBlockInfo.syncWriteIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + + if (blockInfo->isIntersected(curBlockInfo)) { + builder->setInsertionPoint(op); + insertBarrier(op, builder); + blockInfo->sync(); + } + // Update the region info, even if barrier is inserted, we have to maintain + // the current op's read/write buffers. + blockInfo->join(curBlockInfo); +} +} // namespace mlir diff --git a/third_party/iluvatar/lib/Analysis/Utility.cpp b/third_party/iluvatar/lib/Analysis/Utility.cpp new file mode 100644 index 000000000..c5c36505e --- /dev/null +++ b/third_party/iluvatar/lib/Analysis/Utility.cpp @@ -0,0 +1,1049 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace mlir { +namespace { + +using namespace triton; +using namespace triton::gpu; + +int getParentAxis(Attribute layout, int axis) { + if (auto sliceEncoding = dyn_cast(layout)) { + axis = axis < sliceEncoding.getDim() ? axis : axis + 1; + return getParentAxis(sliceEncoding.getParent(), axis); + } + return axis; +} + +SmallVector getParentOrder(Attribute layout) { + if (auto sliceEncoding = mlir::dyn_cast(layout)) { + return getParentOrder(sliceEncoding.getParent()); + } + return getOrder(layout); +} + +} // namespace + +// TODO(jlebar): Move this class into namespace triton. +bool ReduceOpHelper::isReductionOnLayoutFastAxis() { + return getParentAxis(getSrcLayout(), axis) == + getParentOrder(getSrcLayout())[0]; +} + +SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { + auto srcLayout = getSrcLayout(); + auto order = getOrder(srcLayout); + auto it = std::find(order.begin(), order.end(), axis); + // delete the axis from order + order.erase(it); + // insert axis at the beginning of order + order.insert(order.begin(), axis); + return order; +} + +// Thread offset is the thread index offset of two adjacent threads on the +// reduction axis within the warp. +unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { + auto srcLayout = getSrcLayout(); + + // If the reduction axis is the fast axis of the parent layout + if (isReductionOnLayoutFastAxis()) { + return 1; + } + + unsigned threadOffset = 1; + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto parentLayout = sliceLayout.getParent(); + auto threadsPerWarp = getThreadsPerWarp(parentLayout); + threadOffset = threadsPerWarp[sliceLayout.getDim()]; + } else { + auto threadsPerWarp = getThreadsPerWarp(srcLayout); + auto order = getOrder(srcLayout); + for (unsigned i = 0; i < order.size(); i++) { + if (order[i] == axis) + break; + threadOffset *= threadsPerWarp[order[i]]; + } + } + return threadOffset; +} + +// Cases where distributed shared memory is not required in ConvertLayout: +// (1) numCTAs == 1 +// (2) numCTAs > 1 but srcCTALayout == dstCTALayout +// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented +// in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { + unsigned numCTAs = getNumCTAs(srcLayout); + assert(numCTAs == getNumCTAs(dstLayout) && + "Invalid layout conversion: the numbers of CTAs of src and dst " + "layouts are different"); + + // Case (1): Never use dsmem when numCTAs == 1 + if (numCTAs == 1) + return false; + + // Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not + // implemented yet + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + llvm::report_fatal_error("Layout conversion to be implemented"); + } + + // Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported + if (auto sliceLayout = mlir::dyn_cast(dstLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + return true; + } + + // The above two branches make sure that it is legal to call getCTALayout of + // srcLayout and dstLayout + + // Case (2): Do not use dsmem when srcCTALayout == dstCTALayout + auto srcCTALayout = getCTALayout(srcLayout); + auto dstCTALayout = getCTALayout(dstLayout); + if (srcCTALayout == dstCTALayout) + return false; + + // Dsmem access is required when srcCTALayout != dstCTALayout + return true; +} + +unsigned ReduceOpHelper::getInterWarpSize() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned sizeIntraWarps = getIntraWarpSize(); + return std::min(srcReduceDimSize / sizeIntraWarps, + getWarpsPerCTA(getSrcLayout())[axis]); +} + +unsigned ReduceOpHelper::getIntraWarpSize() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + return std::min(srcReduceDimSize, getThreadsPerWarp(getSrcLayout())[axis]); +} + +unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData(); + return std::min( + srcReduceDimSize / sizeIntraWarps, + getWarpsPerCTAWithUniqueData(getSrcLayout(), getSrcShape())[axis]); +} + +unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned elementPerThreads = + getUniqueContigPerThread(getSrcLayout(), getSrcShape())[axis]; + return std::min( + srcReduceDimSize / elementPerThreads, + getThreadsPerWarpWithUniqueData(getSrcLayout(), getSrcShape())[axis]); +} + +unsigned ReduceOpHelper::getThreadsReductionAxis() { + auto srcLayout = getSrcLayout(); + auto srcShape = getSrcShape(); + return getThreadsPerWarpWithUniqueData(srcLayout, srcShape)[axis] * + getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; +} + +bool ReduceOpHelper::isWarpSynchronous() { + auto srcLayout = getSrcLayout(); + auto srcShape = getSrcShape(); + return getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == 1; +} + +SmallVector ReduceOpHelper::getScratchConfig() { + SmallVector smemShape; + // that case doesn't need inter-warp communication + if (isWarpSynchronous()) + return {0, 0}; + + smemShape = convertType(getSrcShape()); + smemShape[axis] = getInterWarpSizeWithUniqueData(); + + return smemShape; +} + +unsigned ReduceOpHelper::getScratchSizeInBytes() { + auto smemShape = getScratchConfig(); + auto elems = product(smemShape); + + unsigned bytesPerElem = 0; + for (const auto &ty : srcElementTypes) { + bytesPerElem += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return bytesPerElem * elems; +} + +bool ReduceOpHelper::isReduceWithinCTA() { + auto axis = getAxis(); + auto srcLayout = getSrcLayout(); + auto CTASplitNum = getCTASplitNum(srcLayout); + assert(axis < CTASplitNum.size()); + return CTASplitNum[axis] == 1; +} + +bool ReduceOpHelper::isSupportedLayout() { + // Layout optimization passes such as PlanCTAPass and + // RemoveLayoutConversionPass should avoid cross-CTA reduction + if (!isReduceWithinCTA()) { + return false; + } + + auto srcLayout = getSrcLayout(); + if (isa(srcLayout)) { + return true; + } + if (auto mmaLayout = dyn_cast(srcLayout)) { + return mmaLayout.supportReduction(); + } + if (auto sliceLayout = dyn_cast(srcLayout)) { + return true; + } + return false; +} + +unsigned ScanLoweringHelper::getAxisNumElementsPerThread() { + return getEncoding().getSizePerThread()[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { + SmallVector sizePerThreads = getContigPerThread(getEncoding()); + sizePerThreads[getAxis()] = 1; + return product(sizePerThreads); +} + +Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); } + +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() { + return getThreadsPerWarp(getEncoding())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() { + return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + threadsPerWarp[getAxis()] = 1; + return product(threadsPerWarp); +} + +// Return the flat numbers of threads computing independent scan results. +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { + unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp(); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + warpsPerCTA[getAxis()] = 1; + unsigned numParallelWarpsPerCTA = product(warpsPerCTA); + return numParallelThreadsPerWarp * numParallelWarpsPerCTA; +} + +unsigned ScanLoweringHelper::getAxisNumWarps() { + return getWarpsPerCTA(getEncoding())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() { + return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumBlocks() { + auto sizePerThreads = getSizePerThread(getEncoding()); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + unsigned axis = getAxis(); + return ceil( + getShape()[axis], + (sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis])); +} + +unsigned ScanLoweringHelper::getNonAxisNumBlocks() { + auto sizePerThreads = getSizePerThread(getEncoding()); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + unsigned axis = getAxis(); + unsigned numBlocks = 1; + for (unsigned i = 0; i < sizePerThreads.size(); i++) { + if (i == axis) + continue; + numBlocks *= + ceil(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] * + warpsPerCTA[i])); + } + return numBlocks; +} + +bool ScanLoweringHelper::isSupported() { + // TODO: Support the following cases: + // 1. Scan on non-blocking encodings + if (!isa(getEncoding())) + return false; + return true; +} + +unsigned ScanLoweringHelper::getScratchSizeInElems() { + auto mod = scanOp->getParentOfType(); + unsigned numWarps = TritonGPUDialect::getNumWarps(mod); + unsigned numNonAxisElementsPerWarp = + getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread(); + unsigned numElements = numWarps * numNonAxisElementsPerWarp * + getAxisNumBlocks() * getNonAxisNumBlocks(); + return numElements; +} + +unsigned ScanLoweringHelper::getScratchSizeInBytes() { + unsigned axisNumWarps = getAxisNumWarpsWithUniqueData(); + if (axisNumWarps == 1) + return 0; + unsigned elementSizeInBytes = 0; + for (const auto &ty : srcElementTypes) { + elementSizeInBytes += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return elementSizeInBytes * getScratchSizeInElems(); +} + +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, + ArrayRef dstShape) { + SmallVector, SmallVector>> ret; + + if (srcShape.empty()) { + assert(dstShape.empty()); + return ret; + } + ret.push_back({}); + + int srcIdx = 0; + int dstIdx = 0; + int srcNElems = 1; + int dstNElems = 1; + while (srcIdx < srcShape.size() || dstIdx < dstShape.size()) { + if (srcNElems < dstNElems || // + (srcIdx < srcShape.size() && srcNElems == 1) || + (srcIdx < srcShape.size() && srcShape[srcIdx] == 1)) { + assert(srcIdx < srcShape.size()); + srcNElems *= srcShape[srcIdx]; + ret.back().first.push_back(srcIdx); + srcIdx++; + } else if (dstNElems < srcNElems || + (dstIdx < dstShape.size() && dstShape[dstIdx] == 1)) { + assert(dstIdx < dstShape.size()); + dstNElems *= dstShape[dstIdx]; + ret.back().second.push_back(dstIdx); + dstIdx++; + } else { + ret.push_back({}); + srcNElems = 1; + dstNElems = 1; + } + } + return ret; +} + +BlockedEncodingAttr ScanLoweringHelper::getEncoding() { + return cast(srcEncoding); +} + +unsigned ScanLoweringHelper::getAxisElementStride() { + auto order = getOrder(getEncoding()); + unsigned stride = 1; + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= getContigPerThread(getEncoding())[dim]; + } + llvm_unreachable("Axis not found in order"); +} + +unsigned ScanLoweringHelper::getAxisThreadStride() { + auto order = getOrder(getEncoding()); + unsigned stride = 1; + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= getEncoding().getThreadsPerWarp()[dim]; + } + llvm_unreachable("Axis not found in order"); +} + +unsigned ScanLoweringHelper::getAxisBlockStride() { + auto order = getOrder(getEncoding()); + unsigned stride = 1; + auto sizePerThreads = getSizePerThread(getEncoding()); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= ceil(getShape()[dim], sizePerThreads[dim] * + threadsPerWarp[dim] * + warpsPerCTA[dim]); + } + llvm_unreachable("Axis not found in order"); +} + +bool maybeSharedAllocationOp(Operation *op) { + // TODO(Keren): This function can be replaced by adding + // MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to + // query the memory effects of the op. + auto *dialect = op->getDialect(); + return dialect && + (dialect->getTypeID() == TypeID::get() || + dialect->getTypeID() == TypeID::get() || + dialect->getTypeID() == TypeID::get() || + dialect->getTypeID() == TypeID::get()); +} + +static bool supportMFMAGranularity(int m, int n, int k) { + // these limitations are dtype dependent, in future we may relax them + const static std::pair mfmaTypes[2] = {{32, 8}, {16, 16}}; + for (const auto &mfmaType : mfmaTypes) { + auto [granularityMN, granularityK] = mfmaType; + if (m % granularityMN != 0 || n % granularityMN != 0) + continue; + if (k % granularityK != 0) + continue; + return true; + } + return false; +} + +bool supportMFMATypes(Type a, Type b) { + if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth()) + return false; + + auto F8E4M3FNUZ = TypeID::get(); + auto F8E5M2FNUZ = TypeID::get(); + auto F16 = TypeID::get(); + auto BF16 = TypeID::get(); + auto F32 = TypeID::get(); + auto Int = TypeID::get(); + DenseSet> supportedTypes = { + {F32, F32}, + {F16, F16}, + {BF16, BF16}, + {F8E4M3FNUZ, F8E4M3FNUZ}, + {F8E4M3FNUZ, F8E5M2FNUZ}, + {F8E5M2FNUZ, F8E4M3FNUZ}, + {F8E5M2FNUZ, F8E5M2FNUZ}, + {Int, Int}}; + + if (!supportedTypes.contains({a.getTypeID(), b.getTypeID()})) + return false; + + if (a.isIntOrIndex() && a.getIntOrFloatBitWidth() != 8) + return false; + return true; +} + +bool supportMFMA(triton::DotOp op) { + auto aTy = cast(op.getA().getType()); + auto bTy = cast(op.getB().getType()); + + auto aElemTy = aTy.getElementType(); + auto bElemTy = bTy.getElementType(); + + if (!supportMFMATypes(aElemTy, bElemTy)) + return false; + + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + + auto rank = aShape.size(); + assert(bShape.size() == rank); + auto M = aShape[rank - 2]; + auto N = bShape[rank - 1]; + auto K = aShape[rank - 1]; + assert(K == bShape[rank - 2]); + if (!supportMFMAGranularity(M, N, K)) + return false; + + return true; +} + +static bool supportWMMAGranularity(int m, int n, int k) { + return m % 16 == 0 && n % 16 == 0 && k % 16 == 0; +} + +static bool supportWMMATypes(Type a, Type b, Type c, Type d) { + if (a != b || c != d) + return false; + auto aWidth = a.getIntOrFloatBitWidth(); + auto cWidth = c.getIntOrFloatBitWidth(); + if (a.isIntOrIndex()) { + if (!c.isIntOrIndex()) + return false; + bool aValid = aWidth <= 8; + bool cValid = cWidth <= 32; + return aValid && cValid; + } else if (isa(a) && isa(c)) { + if (a.isBF16()) + return c.isBF16() || c.isF32(); + if (a.isF16()) + return c.isF16() || c.isF32(); + return aWidth <= cWidth && aWidth <= 16; + } + return false; +} + +bool supportWMMA(triton::DotOp op) { + auto aTy = cast(op.getA().getType()); + auto bTy = cast(op.getB().getType()); + auto cTy = cast(op.getC().getType()); + auto dTy = cast(op.getResult().getType()); + + auto aElemTy = aTy.getElementType(); + auto bElemTy = bTy.getElementType(); + auto cElemTy = cTy.getElementType(); + auto dElemTy = dTy.getElementType(); + + if (!supportWMMATypes(aElemTy, bElemTy, cElemTy, dElemTy)) + return false; + + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + + auto rank = aShape.size(); + assert(bShape.size() == rank); + assert(aShape[rank - 1] == bShape[rank - 2]); + if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1], + aShape[rank - 1])) + return false; + + return true; +} + +bool supportMMA(triton::DotOp op, int version) { + // Refer to mma section for the data type supported by Volta and Hopper + // Tensor Core in + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + auto aElemTy = op.getA().getType().getElementType(); + auto bElemTy = op.getB().getType().getElementType(); + if (version == 3) { + if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) + return false; + auto retType = op.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + auto mod = op->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && + retShapePerCTA[rank - 1] % 8 == 0 && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || + aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || + aElemTy.isF32()))) { + return false; + } + // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. + if (op.getMaxNumImpreciseAcc() < 32 && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && + cast(op.getType()).getElementType().isF32()) { + return false; + } + } +#ifndef __ILUVATAR__ + if (aElemTy.isF32() && bElemTy.isF32()) { + return op.getInputPrecision() == InputPrecision::TF32 && version >= 2; + } +#else + auto retElemTy = + op.getResult().getType().cast().getElementType(); + if (retElemTy.isF16()) { + return false; + } +#endif + return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); +} + +bool supportMMA(Value value, int version) { + // Tell whether a DotOp support MMA by the operand type(either $a or $b). + // We cannot get both the operand types(in TypeConverter), here we assume the + // types of both the operands are identical here. +#if defined(__ILUVATAR__) + assert((version == 1 || version == 2) && + "Unexpected MMA layout version found"); + auto elemTy = cast(value.getType()).getElementType(); + return elemTy.isF16() || elemTy.isBF16() || elemTy.isF32() || + elemTy.isInteger(8); +#else + assert((version == 1 || version == 2 || version == 3) && + "Unexpected MMA layout version found"); + auto elemTy = cast(value.getType()).getElementType(); + // FP8 is not natively supported on all mma versions but it can always be + // promoted to fp16 therefore we can always support it. + bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || + elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + return isFP8 || elemTy.isF16() || elemTy.isBF16() || + (elemTy.isF32() && version >= 2) || + (elemTy.isInteger(8) && version >= 2); +#endif +} + +bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mfmaLayout = dyn_cast(srcLayout); + auto dotOperandLayout = dyn_cast(dstLayout); + if (mfmaLayout == nullptr || dotOperandLayout == nullptr) + return false; + // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is + // improved. In addition, we can enable this shortcut for regular MFMA + // layout when opIdx == 1. + return mfmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && + dotOperandLayout.getParent() == mfmaLayout && + (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && + (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); +} + +static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) { +#ifdef __ILUVATAR__ + auto src = dyn_cast(srcEncoding); + auto dst = dyn_cast(dstEncoding); + if (!src || !dst) + return false; + return src.getVersionMinor() == 0 && dst.getVersionMinor() > 0; +#else + auto src = dyn_cast(srcEncoding); + auto dst = dyn_cast(dstEncoding); + if (!src || !dst) + return false; + // when #mma = MmaEncoding + return src && dst && src.getVersionMajor() == 3 && + src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 && + dst.getWarpsPerCTA()[1] == 1; +#endif +} + +bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding()); +} + +// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mmaLayout = cast(srcLayout); + auto dotOperandLayout = cast(dstLayout); + int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); + auto ans = mmaLayout.getVersionMajor() == 3 && + dotOperandLayout.getOpIdx() == 0 && + isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) && + (elementTypeSize == 16 || elementTypeSize == 8); + return ans; +} + +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + // dot_op = #mma + // when #mma = MmaEncoding + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); +#ifdef __ILUVATAR__ + auto mmaLayout = mlir::cast(srcLayout); + auto dotOperandLayout = mlir::cast(dstLayout); + return mmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && + dotOperandLayout.getParent() == mmaLayout && + !srcTy.getElementType().isF32(); +#else + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) + return true; + auto mmaLayout = mlir::cast(srcLayout); + auto dotOperandLayout = mlir::cast(dstLayout); + return mmaLayout.getVersionMajor() == 2 && + mmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && + dotOperandLayout.getParent() == mmaLayout && + !srcTy.getElementType().isF32(); +#endif +} + +bool isMmaToDotSlowShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + if (!srcLayout.isa()) + return false; + auto mmaLayout = srcLayout.cast(); + if (!dstLayout.isa()) + return false; + auto dotOperandLayout = dstLayout.cast(); + auto dstParLayout = dotOperandLayout.getParent(); + if (!dstParLayout.isa()) + return false; + auto dstMmaLayout = + dstParLayout.dyn_cast(); + return !isMmaToDotShortcut(srcTy, dstTy) && + mmaLayout.getVersionMajor() == 1 && + dstMmaLayout.getVersionMajor() == 1 && + mmaLayout.getWarpsPerCTA()[0] == dstMmaLayout.getWarpsPerCTA()[0] && + dotOperandLayout.getOpIdx() == 0 && !srcTy.getElementType().isF32(); +} + +namespace { + +/// A data structure similar to SetVector but maintains +/// a deque instead of a vector to allow for efficient +/// push_back and pop_front operations. +/// Using SetVector doesn't suffice our needs because +/// it only pushes and pops from the back. +/// For example, if we have a queue like this: +/// 0->4 1->2->3 +/// ^-------- +/// where 3 depends on 4, once we pop 3, we found +/// 4 is not ready, so we check 2 and push 3 back +/// to the queue. +struct DFSSubgraphState { + DFSSubgraphState() : set(), deque() {} + DenseSet set; + std::deque deque; + + bool push_back(Operation *op) { + if (set.insert(op).second) { + deque.push_back(op); + return true; + } + return false; + } + + Operation *pop_front() { + Operation *op = deque.front(); + deque.pop_front(); + set.erase(op); + return op; + } + + bool empty() { return deque.empty(); } +}; + +/// DFS post-order implementation that maintains a global count to work across +/// multiple invocations, to help implement topological sort on multi-root DAGs. +/// We traverse all operations but only record the ones that appear in +/// `toSort` for the final result. +struct DFSState { + DFSState(const SetVector &set) : toSort(set), seen() {} + const SetVector &toSort; + SmallVector topologicalCounts; + DenseSet seen; + + /// We mark each op as ready if all its operands and parents ops are seen. If + /// an op is ready, we add it to the queue. Otherwise, we keep adding its + /// operands to the ancestors set. + /// We always want an op to be scheduled after all its parents to handle + /// correctly cases with scf operations. + void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph, + SmallVector &readyQueue) { + bool ready = true; + for (Value operand : op->getOperands()) { + auto def = operand.getDefiningOp(); + if (def && !seen.count(def)) { + subGraph.push_back(def); + ready = false; + } + } + Operation *parent = op->getParentOp(); + while (parent) { + if (!seen.count(parent)) { + subGraph.push_back(parent); + ready = false; + } + parent = parent->getParentOp(); + } + if (ready) + readyQueue.push_back(op); + } +}; + +void dfsPostorder(Operation *root, DFSState *state) { + DFSSubgraphState subGraph; + subGraph.push_back(root); + SmallVector ops; + while (!subGraph.empty()) { + // Nodes in the ready queue are ready to be processed. + // Meaning that either their operands are all seen or they have null + // operands. + SmallVector readyQueue; + auto *current = subGraph.pop_front(); + state->addToReadyQueue(current, subGraph, readyQueue); + while (!readyQueue.empty()) { + Operation *current = readyQueue.pop_back_val(); + if (!state->seen.insert(current).second) + continue; + ops.push_back(current); + for (Value result : current->getResults()) { + for (Operation *op : result.getUsers()) + state->addToReadyQueue(op, subGraph, readyQueue); + } + for (Region ®ion : current->getRegions()) { + for (Operation &op : region.getOps()) + state->addToReadyQueue(&op, subGraph, readyQueue); + } + } + } + + for (Operation *op : llvm::reverse(ops)) { + if (state->toSort.count(op) > 0) + state->topologicalCounts.push_back(op); + } +} + +} // namespace + +SetVector +multiRootTopologicalSort(const SetVector &toSort) { + if (toSort.empty()) { + return toSort; + } + + // Run from each root with global count and `seen` set. + DFSState state(toSort); + for (auto *s : toSort) { + assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); + dfsPostorder(s, &state); + } + + // Reorder and return. + SetVector res; + for (auto it = state.topologicalCounts.rbegin(), + eit = state.topologicalCounts.rend(); + it != eit; ++it) { + res.insert(*it); + } + return res; +} + +#ifdef __ILUVATAR__ +void getBackwardSliceImplCorex(Operation *op, + SetVector *backwardSlice, + TransitiveFilter filter, + bool omitBlockArguments) { + if (!op || op->hasTrait()) + return; + + // Evaluate whether we should keep this def. + // This is useful in particular to implement scoping; i.e. return the + // transitive backwardSlice in the current scope. + if (filter && !filter(op)) + return; + + for (const auto &en : llvm::enumerate(op->getOperands())) { + auto operand = en.value(); + if (auto *definingOp = operand.getDefiningOp()) { + if (backwardSlice->count(definingOp) == 0) + getBackwardSliceImplCorex(definingOp, backwardSlice, filter, + omitBlockArguments); + } else if (auto blockArg = operand.dyn_cast()) { + if (omitBlockArguments) + continue; + + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + // TODO: determine whether we want to recurse backward into the other + // blocks of parentOp, which are not technically backward unless they flow + // into us. For now, just bail. + if (parentOp && backwardSlice->count(parentOp) == 0) { + assert(parentOp->getNumRegions() == 1 && + parentOp->getRegion(0).getBlocks().size() == 1); + getBackwardSliceImplCorex(parentOp, backwardSlice, filter, + omitBlockArguments); + } + } else { + llvm_unreachable("No definingOp and not a block argument."); + } + } + + backwardSlice->insert(op); +} + +void getBackwardSliceCorex(Operation *op, SetVector *backwardSlice, + TransitiveFilter filter, bool omitBlockArguments) { + getBackwardSliceImplCorex(op, backwardSlice, filter, omitBlockArguments); + + // Don't insert the top level operation, we just queried on it and don't + // want it in the results. + backwardSlice->remove(op); +} +#endif + +SetVector multiRootGetSlice(Operation *op, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter) { + SetVector slice; + slice.insert(op); + + unsigned currentIndex = 0; + SetVector backwardSlice; + SetVector forwardSlice; + while (currentIndex != slice.size()) { + auto *currentOp = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentOp. + backwardSlice.clear(); + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = backwardFilter; +#ifdef __ILUVATAR__ + getBackwardSliceCorex(currentOp, &backwardSlice, opt.filter, + opt.omitBlockArguments); +#elif + getBackwardSlice(currentOp, &backwardSlice, opt); +#endif + slice.insert(backwardSlice.begin(), backwardSlice.end()); + + // Compute and insert the forwardSlice starting from currentOp. + forwardSlice.clear(); + getForwardSlice(currentOp, &forwardSlice, forwardFilter); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + ++currentIndex; + } + return multiRootTopologicalSort(slice); +} + +namespace { +// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis +// interacts with constant propagation, but SparseConstantPropagation +// doesn't seem to be sufficient. +class ConstantAnalysis : public DataFlowAnalysis { +public: + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override { + WalkResult result = top->walk([&](Operation *op) { + if (failed(visit(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + + LogicalResult visit(ProgramPoint point) override { + Operation *op = point.get(); + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + auto *constant = getOrCreate>( + op->getResult(0)); + propagateIfChanged(constant, constant->join(dataflow::ConstantValue( + value, op->getDialect()))); + return success(); + } + // Dead code analysis requires every operands has initialized ConstantValue + // state before it is visited. + // https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322 + // That's why we need to set all operands to unknown constants. + setAllToUnknownConstants(op->getResults()); + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) + setAllToUnknownConstants(block.getArguments()); + } + return success(); + } + +private: + /// Set all given values as not constants. + void setAllToUnknownConstants(ValueRange values) { + dataflow::ConstantValue unknownConstant(nullptr, nullptr); + for (Value value : values) { + auto *constant = + getOrCreate>(value); + propagateIfChanged(constant, constant->join(unknownConstant)); + } + } +}; +} // namespace + +std::unique_ptr createDataFlowSolver() { + auto solver = std::make_unique(); + solver->load(); + solver->load(); + return solver; +} + +static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { + + if (auto makeTensorPtrOp = dyn_cast(op)) { + return makeTensorPtrOp; + } + + if (auto advanceOp = dyn_cast(op)) { + return getMakeTensorPtrOp(advanceOp.getPtr()); + } + + if (auto branch = dyn_cast(op)) { + auto idx = cast(v).getResultNumber(); + llvm::SmallVector yieldOps; + op->walk([&](Operation *op) { + if (auto yieldOp = dyn_cast(op)) + yieldOps.push_back(yieldOp); + }); + + // benzh@ if multi yields, all yields operand should come from same arg. + Value newValue = yieldOps[0].getOperands()[idx]; + return getMakeTensorPtrOp(newValue); + } + + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +MakeTensorPtrOp getMakeTensorPtrOp(Value v) { + using BranchOps = llvm::SetVector>; + llvm::DenseMap blockToCFOps; + auto moduleOp = + v.getParentBlock()->getParentOp()->getParentOfType(); + + moduleOp.walk([&](Operation *op) { + if (auto br = dyn_cast(op)) { + Block *block = br.getDest(); + blockToCFOps[block].insert({op, -1}); + } + if (auto condBr = dyn_cast(op)) { + Block *blockT = condBr.getTrueDest(); + Block *blockF = condBr.getFalseDest(); + blockToCFOps[blockT].insert({condBr, 1}); + blockToCFOps[blockF].insert({condBr, 0}); + } + }); + + if (Operation *definingOp = v.getDefiningOp()) + return getMakeTensorPtrOpImpl(definingOp, v); + + // If there is no defining op, v must be a BlockArgument. + BlockArgument arg = cast(v); + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + + if (auto forOp = dyn_cast(argOwner)) + return getMakeTensorPtrOp( + forOp.getOperand(argNum + forOp.getNumControlOperands() - 1)); + if (auto funcOp = dyn_cast(argOwner)) { + Block *block = arg.getOwner(); + Operation *op; + int tOrF; + std::tie(op, tOrF) = blockToCFOps[block][0]; + if (auto br = dyn_cast(op)) + return getMakeTensorPtrOp(br.getDestOperands()[argNum]); + if (auto condBr = dyn_cast(op)) + return getMakeTensorPtrOp(tOrF ? condBr.getTrueDestOperands()[argNum] + : condBr.getFalseDestOperands()[argNum]); + return getMakeTensorPtrOp(argOwner->getOperand(argNum)); + } + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +} // namespace mlir diff --git a/third_party/iluvatar/lib/CMakeLists.txt b/third_party/iluvatar/lib/CMakeLists.txt new file mode 100644 index 000000000..e4e8a7ce1 --- /dev/null +++ b/third_party/iluvatar/lib/CMakeLists.txt @@ -0,0 +1,8 @@ +add_compile_options("-Wno-deprecated-declarations") +add_compile_options("-Wno-error=deprecated-declarations") + +add_subdirectory(Analysis) +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) +add_subdirectory(Tools) diff --git a/third_party/iluvatar/lib/Conversion/CMakeLists.txt b/third_party/iluvatar/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..143a4375a --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonToTritonGPU) +add_subdirectory(TritonGPUToLLVM) diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp new file mode 100644 index 000000000..aae9faf0e --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -0,0 +1,69 @@ +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_ALLOCATESHAREDMEMORY +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +struct AllocateSharedMemory + : public mlir::triton::impl::AllocateSharedMemoryBase< + AllocateSharedMemory> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + ModuleAllocation allocation(mod); + + mod.walk([&](FunctionOpInterface funcOp) { + funcOp.walk([&](Operation *op) { + auto *funcAllocation = allocation.getFuncData(funcOp); + auto oBufferId = funcAllocation->getBufferId(op); + int offset = -1; + if (oBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(oBufferId); + else if (op->getNumResults() == 1) { + Value value = op->getResult(0); + auto vBufferId = funcAllocation->getBufferId(value); + if (vBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(vBufferId); + } + if (offset == -1) + return; + op->setAttr("allocation.offset", + IntegerAttr::get(IntegerType::get(ctx, 32), offset)); + }); + }); + mod->setAttr("triton_gpu.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + allocation.getSharedMemorySize())); + } +}; + +} // namespace + +namespace mlir { + +namespace triton { + +namespace gpu { + +std::unique_ptr> createAllocateSharedMemoryPass() { + return std::make_unique(); +} + +} // namespace gpu + +} // namespace triton + +} // namespace mlir diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp new file mode 100644 index 000000000..a3f55f1e7 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -0,0 +1,80 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; + +struct AssertOpConversion : public ConvertOpToLLVMPattern { + explicit AssertOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter); + auto elemTy = elems[0].getType(); + Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0); + for (auto elem : elems) { + if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) { + condition = + or_(condition, + icmp_eq(elem, rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)))); + } else { + assert(false && "Unsupported type for assert"); + return failure(); + } + } + llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(), + adaptor.getFunc(), adaptor.getLine(), rewriter); + rewriter.eraseOp(op); + return success(); + } + // op: the op at which the assert is inserted. Unlike printf, we need to + // know about the op to split the block. + void llAssert(Operation *op, Value condition, StringRef message, + StringRef file, StringRef func, int line, + ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + auto ctx = rewriter.getContext(); + auto loc = op->getLoc(); + // #block1 + // if (condition) { + // #block2 + // __assertfail(message); + // } + // #block3 + Block *prevBlock = op->getBlock(); + + Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToStart(ifBlock); + targetInfo.assertFail(rewriter, loc, message, file, func, line); + + // Split a block after the call. + Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(ifBlock); + rewriter.create(loc, thenBlock); + rewriter.setInsertionPointToEnd(prevBlock); + rewriter.create(loc, condition, ifBlock, thenBlock); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateAssertOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..a2a4af24e --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,38 @@ +add_triton_library(TritonGPUToLLVM + ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp + DotOpToLLVM/FMA.cpp + TypeConverter.cpp + Utility.cpp + ElementwiseOpToLLVM.cpp + MemoryOpToLLVM.cpp + AssertOpToLLVM.cpp + ViewOpToLLVM.cpp + MakeRangeOpToLLVM.cpp + HistogramOpToLLVM.cpp + AllocateSharedMemory.cpp + ReduceOpToLLVM.cpp + ScanOpToLLVM.cpp + ConvertLayoutOpToLLVM.cpp + ControlFlowOpToLLVM.cpp + FuncOpToLLVM.cpp + SPMDOpToLLVM.cpp + DecomposeUnsupportedConversions.cpp + PrintOpToLLVM.cpp + + DEPENDS + TritonGPUConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRGPUDialect + MLIRGPUToNVVMTransforms + MLIRGPUToROCDLTransforms + MLIRGPUTransforms + TritonAnalysis + TritonIR + TritonGPUIR + TritonGPUTransforms + #TritonNvidiaGPUTransforms + # NVGPUIR +) diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp new file mode 100644 index 000000000..9765d7bf0 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -0,0 +1,141 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + if (funcOp->hasAttr("nvvm.kernel")) { + // A GPU kernel + if (op.getNumOperands() > 0) { + return rewriter.notifyMatchFailure( + op, "Kernel functions do not support return with operands"); + } + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + } else { + // A device function + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + auto loc = op.getLoc(); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = insert_val(packedResultsTy, packedResults, it.value(), + it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the last argument of the caller, which is the current stack pointer + // of shared memory and append it to the operands of the callOp. + auto loc = callOp.getLoc(); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + if (!caller->hasAttr("allocation.offset")) { + auto base = LLVM::getStackPointer(rewriter, caller); + promotedOperands.push_back(base); + return promotedOperands; + } + promotedOperands.push_back( + LLVM::getSharedMemoryBase(callOp->getLoc(), rewriter, callOp)); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } +}; + +} // namespace + +void mlir::triton::populateControlFlowOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 000000000..94894ceb1 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,324 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +using mlir::isLayoutMmaV1; +using mlir::LLVM::getMultiDimOffset; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getStridesFromShapeAndOrder; +using mlir::LLVM::getWrappedMultiDimOffset; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::isaDistributedLayout; +using ::mlir::triton::gpu::SharedEncodingAttr; + +namespace { + +struct LocalLoadOpConversion + : public ConvertOpToLLVMPattern { +public: + LocalLoadOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + // TODO: do we need to check if src is shared ? + if (isa(srcLayout) && isaDistributedLayout(dstLayout)) { + return lowerSharedToDistributed(op, adaptor, getTypeConverter(), + rewriter); + } + if (isa(dstLayout) && + isa( + cast(dstLayout).getParent())) { + return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter); + } + return failure(); + } + +private: + LogicalResult + lowerSharedToDotOpFMA(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + RankedTensorType dstTy = op.getType(); + Attribute dstLayout = dstTy.getEncoding(); + auto dotLayout = cast(dstLayout); + auto blockedLayout = cast( + cast(dstLayout).getParent()); + auto thread = getThreadId(rewriter, loc); + Value res = SharedToDotOperandFMA::convertLayout( + dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, + thread, loc, getTypeConverter(), rewriter); + rewriter.replaceOp(op, res); + return success(); + } + LogicalResult + lowerSharedToDistributed(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + auto dstShape = dstTy.getShape(); + assert(dstShape.size() <= 2 && + "Unexpected rank of ConvertLayout(shared->blocked)"); + auto srcSharedLayout = cast(srcTy.getEncoding()); + auto dstLayout = dstTy.getEncoding(); + auto inOrd = getOrder(srcSharedLayout); + + auto smemObj = getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), + typeConverter->convertType(srcTy.getElementType()), rewriter); + auto elemTy = typeConverter->convertType(dstTy.getElementType()); + + auto srcStrides = + getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter); + + SmallVector outVals = + loadSharedToDistributed(op.getResult(), op.getSrc(), smemObj, elemTy, + loc, rewriter, targetInfo); + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct ConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { +public: + ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isSupported(srcLayout, dstLayout)) { + return lowerDistributedToDistributed(op, adaptor, rewriter); + } + return failure(); + } + +private: + bool isSupported(Attribute srcLayout, Attribute dstLayout) const { + return isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout) && + !isLayoutMmaV1(srcLayout) && !isLayoutMmaV1(dstLayout); + } + // shared memory rd/st for blocked or mma layout with data padding + void processReplica(Location loc, ConversionPatternRewriter &rewriter, + bool stNotRd, RankedTensorType type, + ArrayRef numCTAsEachRep, + ArrayRef multiDimRepId, unsigned vec, + ArrayRef paddedRepShape, + ArrayRef origRepShape, + ArrayRef outOrd, SmallVector &vals, + Value smemBase) const { + auto accumNumCTAsEachRep = product(numCTAsEachRep); + auto layout = type.getEncoding(); + auto rank = type.getRank(); + auto sizePerThread = getSizePerThread(layout); + auto accumSizePerThread = product(sizePerThread); + SmallVector numCTATiles(rank); + auto shapePerCTATile = getShapePerCTATile(layout); + auto shapePerCTA = getShapePerCTA(layout, type.getShape()); + auto order = getOrder(layout); + for (unsigned d = 0; d < rank; ++d) { + numCTATiles[d] = ceil(shapePerCTA[d], shapePerCTATile[d]); + } + auto elemTy = type.getElementType(); + bool isInt1 = elemTy.isInteger(1); + bool isPtr = isa(elemTy); + auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy); + if (isInt1) + elemTy = IntegerType::get(elemTy.getContext(), 8); + else if (isPtr) + elemTy = IntegerType::get(elemTy.getContext(), 64); + + auto llvmElemTy = getTypeConverter()->convertType(elemTy); + + for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { + auto multiDimCTAInRepId = + getMultiDimIndex(ctaId, numCTAsEachRep, order); + SmallVector multiDimCTAId(rank); + for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) { + auto d = it.index(); + multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); + } + + auto linearCTAId = + getLinearIndex(multiDimCTAId, numCTATiles, order); + // TODO: This is actually redundant index calculation, we should + // consider of caching the index calculation result in case + // of performance issue observed. + for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { + SmallVector multiDimOffset = + getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, + multiDimCTAInRepId, shapePerCTATile); + SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( + rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, + shapePerCTA); + Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, + paddedRepShape, outOrd); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); + auto vecTy = vec_ty(llvmElemTy, vec); + ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); + if (stNotRd) { + Value valVec = undef(vecTy); + for (unsigned v = 0; v < vec; ++v) { + auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v]; + if (isInt1) + currVal = zext(llvmElemTy, currVal); + else if (isPtr) + currVal = ptrtoint(llvmElemTy, currVal); + valVec = insert_element(vecTy, valVec, currVal, i32_val(v)); + } + store(valVec, ptr); + } else { + Value valVec = load(vecTy, ptr); + for (unsigned v = 0; v < vec; ++v) { + Value currVal = extract_element(llvmElemTy, valVec, i32_val(v)); + if (isInt1) + currVal = icmp_ne(currVal, + rewriter.create( + loc, i8_ty, rewriter.getI8IntegerAttr(0))); + else if (isPtr) + currVal = inttoptr(llvmElemTyOrig, currVal); + vals[elemId + linearCTAId * accumSizePerThread + v] = currVal; + } + } + } + } + } + // blocked/mma -> blocked/mma. + // Data padding in shared memory to avoid bank conflict. + LogicalResult + lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto typeConverter = getTypeConverter(); + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (product(srcTy.getShape()) == 1) { + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals(getTotalElemsPerThread(dstTy), inVals[0]); + Value result = + packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + return success(); + } + + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + smemBase = bitcast(smemBase, elemPtrTy); + auto shape = dstTy.getShape(); + unsigned rank = dstTy.getRank(); + SmallVector numReplicates(rank); + SmallVector inNumCTAsEachRep(rank); + SmallVector outNumCTAsEachRep(rank); + SmallVector inNumCTAs(rank); + SmallVector outNumCTAs(rank); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); + auto shapePerCTA = getShapePerCTA(srcLayout, shape); + + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = + std::min(shapePerCTA[d], srcShapePerCTATile[d]); + unsigned outPerCTA = + std::min(shapePerCTA[d], dstShapePerCTATile[d]); + unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); + numReplicates[d] = ceil(shapePerCTA[d], maxPerCTA); + inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; + outNumCTAsEachRep[d] = maxPerCTA / outPerCTA; + assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0); + inNumCTAs[d] = ceil(shapePerCTA[d], inPerCTA); + outNumCTAs[d] = ceil(shapePerCTA[d], outPerCTA); + } + // Potentially we need to store for multiple CTAs in this replication + auto accumNumReplicates = product(numReplicates); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + unsigned inVec = 0; + unsigned outVec = 0; + auto origRepShape = getRepShapeForCvtLayout(op); + auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); + + unsigned outElems = getTotalElemsPerThread(dstTy); + auto outOrd = getOrder(dstLayout); + SmallVector outVals(outElems); + + for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { + auto multiDimRepId = + getMultiDimIndex(repId, numReplicates, outOrd); + if (repId != 0) { + barrier(); + } + auto successful = targetInfo.processReplicaUsingStMatrix( + rewriter, loc, smemBase, vals, srcTy, + getTypeConverter()->convertType(srcTy.getElementType()), + paddedRepShape, origRepShape, outOrd, accumNumReplicates); + if (!successful) { + processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, + multiDimRepId, inVec, paddedRepShape, origRepShape, + outOrd, vals, smemBase); + } + barrier(); + processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, + multiDimRepId, outVec, paddedRepShape, origRepShape, + outOrd, outVals, smemBase); + } + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp new file mode 100644 index 000000000..b7bd5fbc3 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -0,0 +1,234 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using ValueTable = std::map, Value>; +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getStridesFromShapeAndOrder; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::isaDistributedLayout; +using ::mlir::triton::gpu::SharedEncodingAttr; + +SmallVector +getThreadIds(Value threadId, ArrayRef shapePerCTATile, + ArrayRef sizePerThread, ArrayRef order, + ConversionPatternRewriter &rewriter, Location loc) { + int dim = order.size(); + SmallVector threadIds(dim); + for (unsigned k = 0; k < dim - 1; k++) { + Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]); + Value rem = urem(threadId, dimK); + threadId = udiv(threadId, dimK); + threadIds[order[k]] = rem; + } + Value dimK = i32_val(shapePerCTATile[order[dim - 1]]); + threadIds[order[dim - 1]] = urem(threadId, dimK); + return threadIds; +} + +// Get shapePerCTATile for M or N axis. +int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto shapePerCTATile = getShapePerCTATile(layout); + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + return isM ? mShapePerCTATile : nShapePerCTATile; +} + +// Get sizePerThread for M or N axis. +int getSizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto sizePerThread = getSizePerThread(layout); + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + return isM ? mSizePerThread : nSizePerThread; +} + +Value getStructFromValueTable(ArrayRef vals, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, + Type elemTy) { + SmallVector elemTypes(vals.size(), elemTy); + SmallVector elems; + elems.reserve(vals.size()); + for (auto &val : vals) { + elems.push_back(val); + } + MLIRContext *ctx = elemTy.getContext(); + Type structTy = struct_ty(elemTypes); + return packLLElements(loc, typeConverter, elems, rewriter, structTy); +} + +ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, + int sizePerThread, + ConversionPatternRewriter &rewriter, + Location loc, + const LLVMTypeConverter *typeConverter, + Type type) { + ValueTable res; + auto elems = unpackLLElements(loc, val, rewriter); + int index = 0; + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTA) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + res[{m + mm, k}] = elems[index++]; + } + } + return res; +} + +Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, + Location loc, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto aTensorTy = cast(A.getType()); + auto aLayout = cast(aTensorTy.getEncoding()); + auto aShapePerCTA = getShapePerCTA(aTensorTy); + + auto aOrder = aLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isARow = aOrder[0] == 1; + + auto aSmem = getSharedMemoryObjectFromStruct( + loc, llA, typeConverter->convertType(aTensorTy.getElementType()), + rewriter); + Value strideAM = aSmem.strides[0]; + Value strideAK = aSmem.strides[1]; + Value strideA0 = isARow ? strideAK : strideAM; + Value strideA1 = isARow ? strideAM : strideAK; + int aNumPtr = 8; + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; + + auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value mContig = i32_val(sizePerThread[order[1]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); + Value threadIdM = threadIds[0]; + + Value offA0 = isARow ? _0 : mul(threadIdM, mContig); + Value offA1 = isARow ? mul(threadIdM, mContig) : _0; + SmallVector aOff(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) { + aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1)); + } + auto elemTy = typeConverter->convertType(aTensorTy.getElementType()); + + Type ptrTy = ptr_ty(rewriter.getContext(), 3); + SmallVector aPtrs(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) + aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); + + SmallVector vas; + + int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/); + int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/); + + for (unsigned k = 0; k < K; ++k) + for (unsigned m = 0; m < M; m += mShapePerCTATile) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) { + Value offset = + add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); + Value pa = gep(ptrTy, elemTy, aPtrs[0], offset); + Value va = load(elemTy, pa); + vas.emplace_back(va); + } + + return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy); +} + +Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, + Location loc, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto bTensorTy = cast(B.getType()); + auto bLayout = cast(bTensorTy.getEncoding()); + auto bShapePerCTA = getShapePerCTA(bTensorTy); + + auto bOrder = bLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isBRow = bOrder[0] == 1; + + auto bSmem = getSharedMemoryObjectFromStruct( + loc, llB, typeConverter->convertType(bTensorTy.getElementType()), + rewriter); + Value strideBN = bSmem.strides[1]; + Value strideBK = bSmem.strides[0]; + Value strideB0 = isBRow ? strideBN : strideBK; + Value strideB1 = isBRow ? strideBK : strideBN; + int bNumPtr = 8; + int K = bShapePerCTA[0]; + int N = bShapePerCTA[1]; + + auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value nContig = i32_val(sizePerThread[order[0]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); + Value threadIdN = threadIds[1]; + + Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; + Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); + SmallVector bOff(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) { + bOff[i] = add(mul(offB0, strideB0), mul(offB1, strideB1)); + } + auto elemTy = typeConverter->convertType(bTensorTy.getElementType()); + + Type ptrTy = ptr_ty(rewriter.getContext(), 3); + SmallVector bPtrs(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) + bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); + + SmallVector vbs; + + int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/); + int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/); + + for (unsigned k = 0; k < K; ++k) + for (unsigned n = 0; n < N; n += nShapePerCTATile) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + Value offset = + add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); + Value pb = gep(ptrTy, elemTy, bPtrs[0], offset); + Value vb = load(elemTy, pb); + vbs.emplace_back(vb); + } + + return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy); +} + +namespace SharedToDotOperandFMA { +Value convertLayout(int opIdx, Value val, Value llVal, + BlockedEncodingAttr dLayout, Value thread, Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + if (opIdx == 0) + return loadAFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); + else + return loadBFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); +} +} // namespace SharedToDotOperandFMA diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp new file mode 100644 index 000000000..bbbb5749c --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -0,0 +1,118 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +static void addAttrs(Operation *op, ArrayRef attrs) { + for (const NamedAttribute attr : attrs) + op->setAttr(attr.getName(), attr.getValue()); +} + +} // namespace + +namespace mlir::triton::gpu { + +void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) { + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + module.walk([&](triton::SplatOp splatOp) -> void { + auto dstType = cast(splatOp.getType()); + auto shared = + dyn_cast(dstType.getEncoding()); + if (shared) { + OpBuilder builder(splatOp); + SmallVector sizePerThread(dstType.getRank(), 1); + auto newType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::BlockedEncodingAttr::get( + module.getContext(), dstType.getShape(), sizePerThread, + getOrder(shared), numWarps, threadsPerWarp, numCTAs)); + auto newSplat = builder.create(splatOp.getLoc(), newType, + splatOp.getSrc()); + auto newConvert = builder.create( + splatOp.getLoc(), dstType, newSplat.getResult()); + splatOp.replaceAllUsesWith(newConvert.getResult()); + splatOp.erase(); + } + }); +} + +template +void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, + ShortcutFn shortcutFn) { + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + + module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcMma = dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (srcMma && dstDotOp && !shortcutFn(srcType, dstType)) { + auto tmpType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::BlockedEncodingAttr::get( + module.getContext(), srcType.getShape(), getSizePerThread(srcMma), + getOrder(srcMma), numWarps, threadsPerWarp, numCTAs)); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + addAttrs(tmp, cvtOp->getAttrs()); + auto newConvert = builder.create( + cvtOp.getLoc(), dstType, tmp); + addAttrs(newConvert, cvtOp->getAttrs()); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); +} + +template void decomposeTensorCoreToDotLayoutConversion< + triton::gpu::NvidiaMmaEncodingAttr>(ModuleOp, ShortcutFn); +template void + decomposeTensorCoreToDotLayoutConversion( + ModuleOp, ShortcutFn); +template void decomposeTensorCoreToDotLayoutConversion< + triton::gpu::IluvatarMmaEncodingAttr>(ModuleOp, ShortcutFn); + +void decomposeBlockedToDotLayoutConversion(ModuleOp module) { + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcBlocked = + dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (srcBlocked && dstDotOp) { + auto tmpType = MemDescType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get( + module.getContext(), dstDotOp, srcType.getShape(), + srcBlocked.getOrder(), srcBlocked.getCTALayout(), + srcType.getElementType())); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + addAttrs(tmp, cvtOp->getAttrs()); + auto newConvert = builder.create(cvtOp.getLoc(), + dstType, tmp); + addAttrs(newConvert, cvtOp->getAttrs()); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 000000000..afb5bf01d --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,102 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; + +using ValueTableFMA = std::map, Value>; + +static ValueTableFMA +getValueTableFromStructFMA(Value val, int K, int n0, int shapePerCTATile, + int sizePerThread, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Type type) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + int index = 0; + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTATile) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + res[{m + mm, k}] = elems[index++]; + } + } + return res; +} + +LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto B = op.getB(); + auto C = op.getC(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto bTensorTy = cast(B.getType()); + auto dTensorTy = cast(D.getType()); + + auto aShapePerCTA = getShapePerCTA(aTensorTy); + auto bShapePerCTA = getShapePerCTA(bTensorTy); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + auto order = dLayout.getOrder(); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = getSizePerThread(dLayout); + auto shapePerCTATile = getShapePerCTATile(dLayout); + + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; + int N = bShapePerCTA[1]; + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + + auto has = + getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread, + rewriter, loc, typeConverter, aTensorTy); + auto hbs = + getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread, + rewriter, loc, typeConverter, bTensorTy); + + SmallVector ret = cc; + bool isCRow = order[0] == 1; + + for (unsigned k = 0; k < K; k++) { + for (unsigned m = 0; m < M; m += mShapePerCTATile) + for (unsigned n = 0; n < N; n += nShapePerCTATile) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + int mIdx = m / mShapePerCTATile * mSizePerThread + mm; + int nIdx = n / nShapePerCTATile * nSizePerThread + nn; + + int z = isCRow + ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx + : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; + ret[z] = rewriter.create(loc, has[{m + mm, k}], + hbs[{n + nn, k}], ret[z]); + } + } + + auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 000000000..7b4b05d78 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,850 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir::triton::gpu; + +namespace mlir::triton::gpu { + +Type getElementType(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) + return tensorType.getElementType(); + return type; +} +// MMA encoding has a different order depending on the element's bit width; +// reorder if we're in this case. +SmallVector reorderValues(const SmallVector &values, Type inType, + Type ouType) { +#ifdef __ILUVATAR__ + return values; +#endif + auto inTensorTy = dyn_cast(inType); + auto ouTensorTy = dyn_cast(ouType); + if (!inTensorTy || !ouTensorTy) + return values; + auto inEncoding = dyn_cast(inTensorTy.getEncoding()); + auto ouEncoding = dyn_cast(ouTensorTy.getEncoding()); + assert(inEncoding == ouEncoding); + if (!inEncoding) + return values; + // If the parent of the dot operand is in block encoding, we don't need to + // reorder elements + auto parentEncoding = dyn_cast(ouEncoding.getParent()); + if (!parentEncoding) + return values; + size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); + size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); + auto ouEltTy = ouTensorTy.getElementType(); + if (inBitWidth == ouBitWidth) + return values; + if (inBitWidth == 16 && ouBitWidth == 32) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 8) { + ret.push_back(values[i]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + } + return ret; + } + if (inBitWidth == 8 && ouBitWidth == 16) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 16) { + ret.push_back(values[i + 0]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 8]); + ret.push_back(values[i + 9]); + ret.push_back(values[i + 10]); + ret.push_back(values[i + 11]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + ret.push_back(values[i + 12]); + ret.push_back(values[i + 13]); + ret.push_back(values[i + 14]); + ret.push_back(values[i + 15]); + } + return ret; + } + llvm_unreachable("unimplemented code path"); +} + +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { +#ifdef __ILUVATAR__ + return inValues; +#endif + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { +#ifdef __ILUVATAR__ + return inValues; +#endif + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); + auto vecType = vec_ty(eltType, vecWidth); + for (int i = 0; i < inValues.size(); i += vecWidth) { + Value vec = undef(vecType); + for (int j = 0; j < vecWidth; j++) { + vec = insert_element(vec, inValues[i + j], i32_val(j)); + } + outValues.push_back(bitcast(vec, i32_ty)); + } + return outValues; +} + +int getNumElementsPerThreads(Type type, + const LLVMTypeConverter *typeConverter) { + int numElemsPerThread = 1; + auto tensorTy = dyn_cast(type); + if (!tensorTy) + return numElemsPerThread; + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (structType) { + numElemsPerThread = structType.getBody().size(); + } + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return numElemsPerThread; + auto eltType = tensorTy.getElementType(); + assert(eltType.getIntOrFloatBitWidth() <= 32 && + "Only support element type with bit width <= 32 in dot operand mma " + "layout"); + // dot operand data are packed into i32 elements so use the following formula + // to get the number of elements per thread. + return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread; +} + +} // namespace mlir::triton::gpu + +namespace { +struct AddPtrOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = op.getType(); + auto typeConverter = getTypeConverter(); + auto resultTensorTy = dyn_cast(resultTy); + if (resultTensorTy) { + // auto ptrs = unpackLLElements(loc, adaptor.getPtr(), rewriter); + unsigned elems = getTotalElemsPerThread(resultTy); + Type elemTy = typeConverter->convertType( + cast(resultTensorTy.getElementType()).getPointeeType()); + Type ptrTy = typeConverter->convertType(resultTensorTy.getElementType()); + auto ptrs = unpackLLElements(loc, adaptor.getPtr(), rewriter); + auto offsets = unpackLLElements(loc, adaptor.getOffset(), rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = gep(ptrTy, elemTy, ptrs[i], offsets[i]); + } + Value view = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, view); + } else { + assert(isa(resultTy)); + auto resultPtrTy = typeConverter->convertType(resultTy); + auto resultElemTy = typeConverter->convertType( + cast(resultTy).getPointeeType()); + Value result = + gep(resultPtrTy, resultElemTy, adaptor.getPtr(), adaptor.getOffset()); + rewriter.replaceOp(op, result); + } + return success(); + } +}; + +struct CmpIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, + MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create( + loc, elemTy, ArithCmpIPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::ICmpPredicate + ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__) \ + case arith::CmpIPredicate::item__: \ + return LLVM::ICmpPredicate::item__ + + __PRED_ENUM(eq); + __PRED_ENUM(ne); + __PRED_ENUM(sgt); + __PRED_ENUM(sge); + __PRED_ENUM(slt); + __PRED_ENUM(sle); + __PRED_ENUM(ugt); + __PRED_ENUM(uge); + __PRED_ENUM(ult); + __PRED_ENUM(ule); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpIPredicate"); + } +}; + +struct CmpFOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + static SmallVector + createDestOps(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + MultipleOperandsRange operands, Location loc) { + return {rewriter.create( + loc, elemTy, ArithCmpFPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::FCmpPredicate + ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__, item1__) \ + case arith::CmpFPredicate::item__: \ + return LLVM::FCmpPredicate::item1__ + + __PRED_ENUM(OEQ, oeq); + __PRED_ENUM(ONE, one); + __PRED_ENUM(OGT, ogt); + __PRED_ENUM(OGE, oge); + __PRED_ENUM(OLT, olt); + __PRED_ENUM(OLE, ole); + __PRED_ENUM(ORD, ord); + __PRED_ENUM(UEQ, ueq); + __PRED_ENUM(UGT, ugt); + __PRED_ENUM(UGE, uge); + __PRED_ENUM(ULT, ult); + __PRED_ENUM(ULE, ule); + __PRED_ENUM(UNE, une); + __PRED_ENUM(UNO, uno); + __PRED_ENUM(AlwaysTrue, _true); + __PRED_ENUM(AlwaysFalse, _false); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpFPredicate"); + } +}; + +struct MulhiUIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + explicit MulhiUIOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(MulhiUIOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + + Type resultElementTy = getElementTypeOrSelf(op.getResult().getType()); + assert(resultElementTy.isInteger(32) || resultElementTy.isInteger(64)); + + auto funcName = targetInfo.getMulhiFuncName(resultElementTy); + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +struct ExternElementwiseOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + SmallVector createDestOps(ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + StringRef funcName = op.getSymbol(); + if (funcName.empty()) + llvm::errs() << "ExternElementwiseOpConversion"; + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath()); + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } +}; + +template +struct ElementwiseOpConversion + : public ElementwiseOpConversionBase< + SourceOp, ElementwiseOpConversion> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using OpAdaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0], + adaptor.getAttributes().getValue())}; + } +}; + +struct ElementwiseInlineAsmOpConversion + : public ConvertOpToLLVMPattern { + using Base = ConvertOpToLLVMPattern; + + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + // If operand size is smaller than 32 bits, pack in groups of 32 bits. + SmallVector packOperands(ElementwiseInlineAsmOp op, + MultipleOperandsRange operands, + ConversionPatternRewriter &rewriter, + Location loc) const { + SmallVector packedOperands; + unsigned numPackedElements = op.getPackedElement(); + for (int i = 0, e = op.getNumOperands(); i < e; i++) { + Type elemTy = getElementType(op.getOperand(i)); + unsigned bitWidth = + elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 64; + unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1; + numElementPerReg = std::min(numElementPerReg, numPackedElements); + for (int j = 0; j < numPackedElements; j += numElementPerReg) { + if (numElementPerReg == 1) { + packedOperands.push_back(operands[j][i]); + continue; + } + Type t = + vec_ty(getTypeConverter()->convertType(elemTy), numElementPerReg); + Value packed = undef(t); + for (int k = 0; k < numElementPerReg; k++) { + packed = insert_element(packed, operands[j + k][i], i32_val(k)); + } + packedOperands.push_back(packed); + } + } + return packedOperands; + } + + SmallVector> + createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + MultipleOperandsRange operands, Location loc) const { + auto ctx = op->getContext(); + + if (operands.size() % op.getPackedElement() != 0) + llvm::report_fatal_error("Inline asm op has more packed elements than " + "number of elements per thread."); + + // Pack elems smaller than 32 bits into 32-bit registers. + SmallVector packedOperands = + packOperands(op, operands, rewriter, loc); + + // Types returned by the LLVM asm op. If there's more than one, they'll be + // wrapped in a struct. + SmallVector asmRetTypes; + for (auto result : op.getResult()) { + auto ty = getTypeConverter()->convertType(getElementType(result)); + + // Pack return elements into 32-bits. + unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64; + unsigned numElemsPerReg = + std::min(bitWidth < 32 ? 32 / bitWidth : 1, op.getPackedElement()); + assert(op.getPackedElement() % numElemsPerReg == 0); + if (numElemsPerReg > 1) { + ty = vec_ty(ty, numElemsPerReg); + } + for (unsigned i = 0; i < op.getPackedElement() / numElemsPerReg; i++) { + asmRetTypes.push_back(ty); + } + } + Type asmRetType = + asmRetTypes.size() > 1 ? struct_ty(asmRetTypes) : asmRetTypes[0]; + + Value asmResults = + rewriter + .create( + loc, asmRetType, + /*operands=*/packedOperands, + /*asm_string=*/op.getAsmString(), + /*constraints=*/op.getConstraints(), + /*has_side_effects=*/!op.getPure(), + /*is_align_stack=*/false, + /*asm_dialect=*/ + LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT), + /*operand_attrs=*/ArrayAttr()) + ->getResult(0); + + // asmResults is a flat struct; pack its values into + // [return_value][op.getPackedElement()]. + SmallVector> ret(op->getNumResults()); + for (int i = 0; i < op->getNumResults(); i++) { + int structIdx = 0; + for (int j = 0; j < op.getPackedElement(); j++) { + Value val; + if (asmRetTypes.size() > 1) { + val = + extract_val(asmResults, i * op.getPackedElement() + structIdx++); + } else { + val = asmResults; + } + if (auto vectorTy = dyn_cast(val.getType())) { + for (int k = 0; k < vectorTy.getNumElements(); k++) { + ret[i].push_back(extract_element(val, i32_val(k))); + } + j += vectorTy.getNumElements() - 1; + } else { + ret[i].push_back(val); + } + } + } + return ret; + } + + LogicalResult + matchAndRewrite(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // Layout is unpackedOperands[operand][elem]. + SmallVector> unpackedOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + unpackedOperands.push_back( + unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter())); + } + + int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), + getTypeConverter()); + + // These are checked by the verifier, so we don't need to raise a nice + // error. + assert(all_of(unpackedOperands, [&](auto &operands) { + return operands.size() == numElemsPerThread; + })); + if (numElemsPerThread % op.getPackedElement() != 0) { + // Pad with the undef for each operand to have a multiple of + // op.getPackedElement() elements. + int numPaddedValue = + op.getPackedElement() - numElemsPerThread % op.getPackedElement(); + for (auto &operands : unpackedOperands) { + for (int i = 0; i < numPaddedValue; i++) { + operands.push_back(undef(operands[0].getType())); + } + } + } + + // Run the inline asm op on each block of elements. + // + // Layout is unpackedResults[result_idx][elem]. + // + // This loop always runs at least once, even when the asm has no input + // elements. + SmallVector> unpackedResults(op->getNumResults()); + for (unsigned i = 0; i < numElemsPerThread; i += op.getPackedElement()) { + // Block of elements to process with one call to the inline asm. This is + // ordered opposite `unpackedResults`: The outer dim is + // op.getPackedElement(), and the inner dim is the operand. + SmallVector> block(op.getPackedElement()); + for (auto &os : unpackedOperands) { + for (int j = 0; j < op.getPackedElement(); j++) { + block[j].push_back(os[i + j]); + } + } + auto cur = createDestOps(op, adaptor, rewriter, block, loc); + assert(cur.size() == unpackedResults.size()); + for (unsigned j = 0; j < cur.size(); j++) { + unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), + cur[j].end()); + } + } + for (auto &results : unpackedResults) { + results.resize(numElemsPerThread); + } + // Reorder and pack the results. + SmallVector outs; + for (int i = 0; i < unpackedResults.size(); i++) { + // We reordered all the inputs so they match operand 0. Reorder the + // outputs accordingly. + if (op->getNumOperands() > 0) { + unpackedResults[i] = reorderValues( + unpackedResults[i], /*inType=*/op->getOperand(0).getType(), + /*ouType=*/op->getResult(i).getType()); + } + auto packed = packI32(unpackedResults[i], op->getResult(i).getType(), + rewriter, loc, getTypeConverter()); + outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter, + op->getResult(i).getType())); + } + + rewriter.replaceOp(op, outs); + return success(); + } +}; + +struct AbsIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0][0], + /*is_int_min_poison=*/false)}; + } +}; + +struct AbsFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (llvm::isa(elemTy)) { + // Mask out the sign bit + auto num_bits = + getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); + assert(num_bits <= 16); + auto mask = (1u << (num_bits - 1u)) - 1u; + auto maskAttr = rewriter.getIntegerAttr(elemTy, mask); + auto maskConst = rewriter.create(loc, maskAttr); + return {and_(operands[0][0], maskConst)}; + } + + return {rewriter.create(loc, elemTy, operands[0][0])}; + } +}; +/// The lowering of index_cast becomes an integer conversion since index +/// becomes an integer. If the bit width of the source and target integer +/// types is the same, just erase the cast. If the target type is wider, +/// sign-extend the value, otherwise truncate it. +struct IndexCastOpLowering + : public ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto inElemTy = + this->getTypeConverter()->convertType(getElementType(op.getIn())); + unsigned targetBits = elemTy.getIntOrFloatBitWidth(); + unsigned sourceBits = inElemTy.getIntOrFloatBitWidth(); + + if (targetBits == sourceBits) + return {operands[0][0]}; + if (targetBits < sourceBits) + return {rewriter.replaceOpWithNewOp(op, elemTy, + operands[0][0])}; + return { + rewriter.replaceOpWithNewOp(op, elemTy, operands[0][0])}; + } +}; + +struct SelectOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + std::array llvmOperands; + if (operands[0].size() == 2) { + // Case of scalar condition with tensor operands. + assert(op.getCondition().getType().isInteger(1)); + llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]}; + } else { + llvmOperands = {operands[0][0], operands[0][1], operands[0][2]}; + } + return {rewriter.create( + loc, llvmOperands[1].getType(), llvmOperands, + adaptor.getAttributes().getValue())}; + } +}; +template +struct MinMaxFOpConversion + : ElementwiseOpConversionBase> { + using Base = ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static_assert(std::is_same::value || + std::is_same::value, + "OpTy must be arith::MinimumFOp or arith::MaximumFOp"); + + // Choose the destination op based on the OpTy. + using DestOpNanProp = + typename std::conditional::value, + LLVM::MinimumOp, LLVM::MaximumOp>::type; + using DestOpNoNanProp = + typename std::conditional::value, + LLVM::MinNumOp, LLVM::MaxNumOp>::type; + + explicit MinMaxFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + bool hwNanPropagationSupported, + PatternBenefit benefit = 1) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + hwNanPropagationSupported(hwNanPropagationSupported) {} + + SmallVector createDestOps(OpTy op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (hwNanPropagationSupported) { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + // Handle workaround for NaN propagation, i.e. software emulation of NaN + // propagation. If any of the operands is NaN, return NaN. + auto lhs = operands[0][0]; + auto rhs = operands[0][1]; + auto lhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, lhs, lhs); + auto rhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, rhs, rhs); + auto isNan = rewriter.create(loc, lhsIsNan, rhsIsNan); + auto nonNanRes = rewriter.create(loc, elemTy, lhs, rhs); + + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + +private: + bool hwNanPropagationSupported; +}; + +struct ClampFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit ClampFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // Clip pattern not found, use min/max. + if (op.getPropagateNan() == PropagateNan::ALL) { + if (targetInfo.supportMaximumMinimum()) { + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + // On pre-80 compute capability, we need to handle NaN propagation + // manually. We need to check only the first operand for clamp. + auto lhs = operands[0][0]; + auto isNan = rewriter.create(loc, LLVM::FCmpPredicate::une, + lhs, lhs); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + auto nonNanRes = rewriter.create(loc, v, operands[0][2]); + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + + // No NaN propagation. + assert(op.getPropagateNan() == PropagateNan::NONE); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMinMaxFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool hwNanPropagationSupported, + PatternBenefit benefit) { + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); +} + +void mlir::triton::populateClampFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); +} + +void mlir::triton::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { +#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) + POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) + POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) + POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) + POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) + POPULATE_UNARY_OP(math::FloorOp, math::FloorOp) + POPULATE_UNARY_OP(math::CeilOp, math::CeilOp) + POPULATE_UNARY_OP(math::LogOp, math::LogOp) + POPULATE_UNARY_OP(math::Log2Op, math::Log2Op) + POPULATE_UNARY_OP(math::CosOp, math::CosOp) + POPULATE_UNARY_OP(math::SinOp, math::SinOp) + POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) + POPULATE_UNARY_OP(math::RsqrtOp, math::RsqrtOp) + POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) + POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op) + POPULATE_UNARY_OP(math::ErfOp, math::ErfOp) + POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) + POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) + POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) +#undef POPULATE_UNARY_OP + +#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - + POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * + POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp) + POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp) + POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % + POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) + POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + // fmin (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MinNumFOp, LLVM::MinNumOp) + // fmax (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MaxNumFOp, LLVM::MaxNumOp) + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax + POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin + POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax +#undef POPULATE_BINARY_OP + + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); + patterns.add(typeConverter, axisInfoAnalysis, + benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000..47f40ebec --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,118 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, int numWarps, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + // 1. Modify the function type to add the new argument. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(ptrTy); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add the new argument. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + // 3. Add a new argument to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(ptrTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = funcOp; + if (!LLVM::isKernel(funcOp)) + amendedFuncOp = amendFuncOp(funcOp, rewriter); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + amendedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) { + return failure(); + } + + auto ctx = funcOp->getContext(); + + if (LLVM::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr("nvvm.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + newFuncOp.setLinkage(LLVM::Linkage::External); + } else { + // The noinline attribute will be used by the LLVM codegen to prevent + // inlining. + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + newFuncOp.setPassthroughAttr( + ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + rewriter.eraseOp(amendedFuncOp); + newFuncOp.setLinkage(LLVM::Linkage::Internal); + } + // Set an attribute for maxntidx, it could be used in latter LLVM codegen + // for `nvvm.annotation` metadata. + newFuncOp->setAttr("nvvm.maxntid", + rewriter.getDenseI32ArrayAttr(32 * numWarps)); + rewriter.eraseOp(funcOp); + return success(); + } + +private: + int numWarps{0}; +}; + +} // namespace + +void mlir::triton::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit) { + patterns.add(typeConverter, numWarps, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp new file mode 100644 index 000000000..acf940b3e --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp @@ -0,0 +1,212 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +static int log2Int(int64_t num) { return (num > 1) ? 1 + log2Int(num / 2) : 0; } + +// Compute a histogram within a warp. This uses an algorithm by @apgoucher +// that does the following: +// Create a ballot for each bit of the bin index (there +// are only log2(num_bins) of these) and then apply bitwise operations to get +// the indicator functions for the bins owned by this particular thread, and +// only popcount those. +static SmallVector computeWarpLevelHistogram( + Location loc, RankedTensorType srcType, SmallVector &srcValues, + int numBins, int numThreadPerWarp, Value threadId, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) { + assert(numBins % numThreadPerWarp == 0 && + "numBins must be divisible by numThreadPerWarp"); + Value zero = i32_val(0); + int numBits = log2Int(numBins); + int numBitsLaneId = log2Int(numThreadPerWarp); + unsigned numElementsPerThreads = triton::gpu::getTotalElemsPerThread(srcType); + unsigned numThreadWithUniqueData = + triton::gpu::getThreadsPerWarpWithUniqueData(srcType.getEncoding(), + srcType.getShape())[0]; + // The histogram is distributed across threads, each thread owns `numBins / + // numThreadPerWarp` bins. + SmallVector warpLevelHistogram(numBins / numThreadPerWarp, zero); + for (int i = 0; i < numElementsPerThreads; ++i) { + Value value = srcValues[i]; + SmallVector ballotBits; + for (int j = 0; j < numBits; ++j) { + Value bitSet = and_(value, i32_val(1 << j)); + Value cmp = icmp_ne(bitSet, zero); + Value bit = + targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), cmp); + ballotBits.push_back(bit); + } + uint64_t fullMaskValue = + numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF; + Value fullMask = int_val(numThreadPerWarp, fullMaskValue); + Value mask = fullMask; + // If not all threads have unique data, mask out the redundant ones. + if (numThreadWithUniqueData < numThreadPerWarp) { + mask = int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1); + } + for (int i = 0; i < numBitsLaneId; i++) { + Value updateMask = select(icmp_ne(and_(threadId, i32_val(1 << i)), zero), + int_val(numThreadPerWarp, 0), fullMask); + mask = + and_(mask, xor_(ballotBits[i + numBits - numBitsLaneId], updateMask)); + } + // at this point, 'mask' tells you which elements are in a bin owned by this + // thread. + for (int k = 0; k < warpLevelHistogram.size(); k++) { + Value binMask = mask; + for (int j = 0; j < numBits - numBitsLaneId; j++) { + Value updateMask = + int_val(numThreadPerWarp, ((k & (1 << j)) ? 0 : fullMaskValue)); + binMask = and_(binMask, xor_(ballotBits[j], updateMask)); + } + // at this point, 'bin_mask' tells you which elements are in the kth bin + // owned by this thread. + Value bitCount = rewriter.create( + loc, int_ty(numThreadPerWarp), binMask); + if (numThreadPerWarp > 32) + bitCount = trunc(i32_ty, bitCount); + warpLevelHistogram[k] = add(warpLevelHistogram[k], bitCount); + } + } + return warpLevelHistogram; +} + +static void atomicAdd(Value ptr, Value val, Location loc, + ConversionPatternRewriter &rewriter) { + rewriter.create(loc, LLVM::AtomicBinOp::add, ptr, val, + LLVM::AtomicOrdering::monotonic); +} + +static SmallVector computeCrossWarpHistogram( + Location loc, ConversionPatternRewriter &rewriter, RankedTensorType srcType, + Value baseSharedMemPtr, const SmallVector &warpLevelHistogram, + int numBins, int numThreadPerWarp, const SmallVector &indices, + Value threadId, int numWarps) { + SmallVector histogramValues; + unsigned numWarpsWithUniqueData = + mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(), + srcType.getShape())[0]; + Value laneId = and_(threadId, i32_val(numThreadPerWarp - 1)); + // Initialize the shared memory with zeros. + int64_t numElementPerThread = + ceil(numBins, numThreadPerWarp * numWarps); + for (int i = 0; i < numElementPerThread; ++i) { + Value offset = add(threadId, i32_val((i * numWarps * numThreadPerWarp))); + offset = urem(offset, i32_val(numBins)); + Value sharedMemPtr = + gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + store(i32_val(0), sharedMemPtr); + } + barrier(); + Block *afterAtomics = nullptr; + // If some warps have replicated data we need to skip those warps when + // accumulating. + if (numWarpsWithUniqueData < numWarps) { + Block *currentBlock = rewriter.getInsertionBlock(); + afterAtomics = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *atomicBlock = rewriter.createBlock(afterAtomics); + rewriter.setInsertionPointToEnd(currentBlock); + Value cond = + icmp_ult(threadId, i32_val(numWarpsWithUniqueData * numThreadPerWarp)); + rewriter.create(loc, cond, atomicBlock, afterAtomics); + rewriter.setInsertionPointToStart(atomicBlock); + } + // Apply atomic add to update the histogram in shared memory. + for (int i = 0; i < warpLevelHistogram.size(); ++i) { + Value warpLevelHistogramValue = warpLevelHistogram[i]; + Value offset = + add(mul(laneId, i32_val(warpLevelHistogram.size())), i32_val(i)); + Value sharedMemPtr = + gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + atomicAdd(sharedMemPtr, warpLevelHistogramValue, loc, rewriter); + } + if (afterAtomics) { + rewriter.create(loc, afterAtomics); + rewriter.setInsertionPointToStart(afterAtomics); + } + barrier(); + // load the histogram to register with the right layout. + for (Value index : indices) { + Value sharedMemPtr = + gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index); + Value val = load(i32_ty, sharedMemPtr); + histogramValues.push_back(val); + } + return histogramValues; +} + +namespace { +struct HistogramOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + explicit HistogramOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + SmallVector srcValues = unpackLLElements(loc, input, rewriter); + int numBins = op.getType().getDimSize(0); + auto mod = op->getParentOfType(); + int numThreadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + assert(numThreadsPerWarp == 32 || + numThreadsPerWarp == 64 && + "Only supports 32 or 64 threads per warp"); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + // Pad out the bins so that we have at least one bin per thread within a + // warp. + numBins = std::max(numBins, numThreadsPerWarp); + Value threadId = getThreadId(rewriter, loc); + auto srcType = op.getSrc().getType(); + // First compute a warp local histogram based on values owned by each warps. + SmallVector warpLevelHistogram = computeWarpLevelHistogram( + loc, srcType, srcValues, numBins, numThreadsPerWarp, threadId, rewriter, + targetInfo); + + // Then use atomic to update the histogram in shared memory. + // TODO: we could skip this for cases with num_warps=1 as long as we can + // generate the right layout. Currently the warp level histogram generates + // data in the default blocked layout. + Value baseSharedMemPtr = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto dstType = op.getType(); + Attribute dstEncoding = dstType.getEncoding(); + auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding, + dstType, true); + SmallVector innerDimIndices; + for (int i = 0; i < indices.size(); ++i) + innerDimIndices.push_back(indices[i][0]); + SmallVector histogramValue = computeCrossWarpHistogram( + loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins, + numThreadsPerWarp, innerDimIndices, threadId, numWarps); + + Value results = packLLElements(loc, typeConverter, histogramValue, rewriter, + op.getType()); + rewriter.replaceOp(op, results); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateHistogramOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp new file mode 100644 index 000000000..43120c791 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp @@ -0,0 +1,53 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +struct MakeRangeOpConversion + : public ConvertOpToLLVMPattern { + MakeRangeOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + RankedTensorType ty = op.getType(); + auto shape = ty.getShape(); + auto layout = ty.getEncoding(); + auto elemTy = ty.getElementType(); + assert(elemTy.isInteger(32)); + Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart()); + auto idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, true); + unsigned elems = idxs.size(); + SmallVector retVals(elems); + // TODO: slice layout has more elements than expected. + // Unexpected behavior for make range, but generally OK when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. + for (const auto &multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = add(multiDim.value()[0], start); + } + auto typeConverter = getTypeConverter(); + Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); + rewriter.replaceOp(op, result); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMakeRangeOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000..12ab6684c --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,145 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// blocked -> shared. +// Swizzling in shared memory to avoid bank conflict. Normally used for +// A/B operands of dots. +void lowerDistributedToShared(Location loc, Value src, Value dst, + Value adaptorSrc, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) { + auto srcTy = cast(src.getType()); + auto dstTy = cast(dst.getType()); + auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy); + auto srcLayout = srcTy.getEncoding(); + auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); + assert(srcTy.getShape().size() <= 2 || + (srcTy.getShape().size() == 3 && outOrd[2] == 0) && + "Unexpected rank of ConvertLayout(blocked->shared)"); + auto elemTy = typeConverter->convertType(srcTy.getElementType()); + + auto smemBase = smemObj.getBase(); + int32_t elemSize = elemTy.getIntOrFloatBitWidth(); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + auto dstStrides = smemObj.getStrides(); + auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); + storeDistributedToShared(src, inVals, dstStrides, dst, smemBase, elemTy, loc, + rewriter, targetInfo); +} + +struct LocalAllocOpConversion + : public ConvertOpToLLVMPattern { + LocalAllocOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + auto sharedLayout = + cast(resultTy.getEncoding()); + auto order = sharedLayout.getOrder(); + // Workaround for 3D tensors + // TODO: we need to modify the pipeline pass to give a proper shared + // encoding to 3D tensors + SmallVector newOrder; + if (resultTy.getShape().size() != order.size()) { + for (auto i = 0; i < order.size(); ++i) + newOrder.push_back(order[i] + 1); + newOrder.push_back(0); + } else { + newOrder = SmallVector(order.begin(), order.end()); + } + + auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); + auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); + auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, + newOrder, loc, rewriter); + // If there is an initial tensor, store it into the shared memory. + if (op.getSrc()) { + lowerDistributedToShared(loc, op.getSrc(), op.getResult(), + adaptor.getSrc(), smemObj, typeConverter, + rewriter, targetInfo); + } + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::LocalDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct LocalStoreOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + + LocalStoreOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value memDescVal = op.getDst(); + auto llvmElemTy = + getTypeConverter()->convertType(op.getDst().getType().getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), + adaptor.getSrc(), smemObj, getTypeConverter(), + rewriter, targetInfo); + rewriter.eraseOp(op); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp new file mode 100644 index 000000000..32c7835c2 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp @@ -0,0 +1,243 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace { + +// The input print op contains: +// - a "prefix" (string) specified by the user, and +// - one or more "operands" (tensors). +// +// For each operand, we print all of the values contained in this GPU thread, +// one per line, along with the index of the value in its tensor. +struct PrintOpConversion : public ConvertOpToLLVMPattern { + explicit PrintOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + auto getPid = [&](int axis) { + return targetInfo.programId(rewriter, loc, + op->getParentOfType(), axis); + }; + std::array pid = {getPid(0), getPid(1), getPid(2)}; + + // Simple printf of a string without any tensors. + if (op.getNumOperands() == 0) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(pid[0]) << ", " + << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" + << op.getPrefix(); + llPrintf(formatStr, {pid[0], pid[1], pid[2]}, rewriter); + rewriter.eraseOp(op); + return success(); + } + + for (size_t i = 0; i < op.getNumOperands(); i++) { + // Elements of the tensor that are resident in this GPU thread. + auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); + + // Get the indices of `elems` within the tensor. Note that if `elems` + // has an "interesting" layout, then these will not be in any + // particularly nice order. + + // Extract the shape of the tensor being printed and use it to figure + // out how many digits we need for each of the dimensions. + SmallVector dimWidths; + SmallVector> indices; + if (auto rankedTy = + dyn_cast(op.getOperand(i).getType())) { + indices = emitIndices(loc, rewriter, targetInfo, rankedTy.getEncoding(), + rankedTy, true); + for (int64_t dim : rankedTy.getShape()) { + if (dim > 0) { + dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); + } else { + dimWidths.push_back(0); + } + } + } else { + // We're printing a scalar. + assert(elems.size() == 1); + indices.push_back({}); + } + + if (!elems.empty()) { + printTensor(op.getPrefix(), /*operand=*/i, + /*numOperands=*/op.getNumOperands(), elems, pid, indices, + dimWidths, op.getHex(), rewriter); + } + } + rewriter.eraseOp(op); + return success(); + } + + void printTensor(StringRef prefixStr, size_t operand, size_t numOperands, + ArrayRef elems, std::array pid, + ArrayRef> indices, + ArrayRef dimWidths, bool hex, + ConversionPatternRewriter &rewriter) const { + assert(!elems.empty()); + assert(elems.size() == indices.size()); + assert(dimWidths.size() == indices.front().size()); + + size_t rank = dimWidths.size(); + + // Format is: + // pid (, , ) idx (, , ...) (operand ) + // where we leave off "(operand )" if there's only one operand. + // + // The Python wrapper munges `prefix` so that it prints nicely (e.g. starts + // with " " and ends with ": "). + + Value formatStrValue; + int formatStrByteCount = 0; + for (int i = 0; i < elems.size(); i++) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + + // nvptx printf can only accept 32 args; if we pass more than that, it + // will print garbage for the trailing args. + constexpr int kMaxPrintfOperands = 32; + SmallVector printfOperands; + + // TODO(jlebar): We really should pad the pid, but because the max pid is + // not known at compile-time, this would require nontrivial device-side + // work. + os << "pid ("; + for (int j = 0; j < pid.size(); j++) { + if (j != 0) { + os << ", "; + } + os << getFormatSubstr(pid[j]); + printfOperands.push_back(pid[j]); + } + os << ") "; + + // If `rank` is large enough, we could end up exceeding + // kMaxPrintfOperands. In that case, just truncate the index. + // (Subtract 2 because we're going to add two operands after the index.) + int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2; + + os << "idx ("; + const auto &index = indices[i]; + for (size_t dim = 0; dim < index.size(); dim++) { + if (dim != 0) { + os << ", "; + } + if (dim == maxAllowedRank) { + os << "... (truncated)"; + break; + } + os << getFormatSubstr(index[dim], /*hex=*/false, + /*width=*/dimWidths[dim]); + printfOperands.push_back(index[dim]); + } + os << ")" << prefixStr; + + if (numOperands > 1) { + os << "(operand " << operand << ") "; + } + + auto elem = elems[i]; + os << getFormatSubstr(elem, hex); + printfOperands.push_back(elem); + + // It's the same format string each iteration, but it's a lot easier if we + // construct the format string at the same time as we populate + // printfOperands. But we don't want to create BLOCK_SIZE duplicate + // strings, so we cache the Value. + if (i == 0) { + formatStrValue = + llPrintf(formatStr, printfOperands, rewriter, &formatStrByteCount); + } else { + targetInfo.printf(rewriter, formatStrValue, formatStrByteCount, + printfOperands); + } + } + } + + std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt) const { + Type type = value.getType(); + if (isa(type)) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = + "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } else if (hex) { + prefix += "0"; + prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isSignedInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "lli"; + else + return prefix + "i"; + } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "llu"; + else + return prefix + "u"; + } + assert(false && "not supported type"); + return ""; + } + + // Returns a Value for the format string, which you can reuse. Writes the byte + // count for the string to |formatStrByteCount| if not null. + Value llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter, + int *formatStrByteCount = nullptr) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); + if (formatStrByteCount) + *formatStrByteCount = msgNewline.size_in_bytes(); + return msgValue; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populatePrintOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp new file mode 100644 index 000000000..c18586602 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -0,0 +1,436 @@ +#include "ReduceScanCommon.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +namespace { +struct ReduceOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + ReduceOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ReduceOpHelper helper(op); + assert(helper.isSupportedLayout() && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + // Then reduce across threads within a warp. + reduceWithinWarps(helper, accs, rewriter); + + if (helper.isWarpSynchronous()) { + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + + // Compute a shared memory base per operand. + auto smemShape = helper.getScratchConfig(); + + SmallVector smemBases = + getSmemBases(op, product(smemShape), rewriter); + + storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); + + sync(rewriter, loc, op); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + accumulatePartialReductions(helper, smemBases, rewriter); + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier in case the layouts are accepted. + sync(rewriter, loc, op); + + // set output values + loadReductionAndPackResult(helper, smemShape, smemBases, rewriter); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; + + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + SmallVector &acc, ValueRange cur, bool isFirst) const { + if (isFirst) { + acc = SmallVector(cur.begin(), cur.end()); + return; + } + + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(combineOp, &parent.front()); + auto &newReduce = parent.front(); + auto returnOp = dyn_cast(newReduce.getTerminator()); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), + combineArgs); + + auto results = returnOp.getResult(); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; + } + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + } + + SmallVector> + unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; + } + + void sync(ConversionPatternRewriter &rewriter, Location loc, + triton::ReduceOp op) const { + barrier(); + } + + // Reduce along op axis for elements that are in the same thread. The + // accumulated value is stored in accs. + void reduceWithinThreads( + ReduceOpHelper &helper, SmallVector> &srcValues, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + RankedTensorType operandType = op.getInputTypes()[0]; + // Assumes offsets don't actually depend on type + SmallVector> offsets = + emitOffsetForLayout(helper.getSrcLayout(), operandType); + + // Thread X might hold the same input value in two registers. Get the + // indices in `offsets` that hold unique values, and only accumualte over + // those. + llvm::MapVector, int> uniqueOffsets; + for (int i = 0; i < offsets.size(); ++i) { + uniqueOffsets.insert({offsets[i], i}); + } + + unsigned srcElems = getTotalElemsPerThread(operandType); + auto *combineOp = &op.getCombineOp(); + auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo, + helper.getSrcLayout(), operandType, true); + // reduce within threads + for (const auto &[_, i] : uniqueOffsets) { + SmallVector key = offsets[i]; + key[op.getAxis()] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); + if (isFirst) + indices[key] = srcIndices[i]; + } + } + + // Apply warp reduction across the given number of contiguous lanes using op + // region and the accumulator values as source. + void warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, unsigned interleave) const { + auto success = + targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce); + if (success) + return; + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { + SmallVector shfl(acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); + } + accumulate(rewriter, op.getCombineOp(), acc, shfl, false); + } + } + + // Reduce across threads within each warp. + void + reduceWithinWarps(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + unsigned threadOffsetOnReductionAxis = + helper.getThreadOffsetOnReductionAxis(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = accs[key]; + warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps, + threadOffsetOnReductionAxis); + } + } + + // Pack the accumulator values and replace the reduce op with the result. + void packResults(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + unsigned axis = op.getAxis(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + SmallVector> resultOffset = + emitOffsetForLayout(resultLayout, resultTy); + SmallVector resultVals; + for (int j = 0; j < resultElems; j++) { + auto key = resultOffset[j]; + key.insert(key.begin() + axis, 0); + resultVals.push_back(accs[key][i]); + } + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else + results[i] = accs.begin()->second[i]; + } + rewriter.replaceOp(op, results); + } + + SmallVector + getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc, + ConversionPatternRewriter &rewriter) const { + auto srcLayout = helper.getSrcLayout(); + auto srcShape = helper.getSrcShape(); + auto order = triton::gpu::getWarpOrder(srcLayout); + SmallVector multiDimWarpId; + + // 2x2 warps with slice dim = 0, warpId = 2 ends up writing at the same + // address as warpId = 0 since the warpsPerCTA is [1, 2], need to figure out + // a way to properly delinearize warpId in the slice case + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout); + auto parentOrder = triton::gpu::getWarpOrder(parentLayout); + multiDimWarpId = + delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder); + multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim()); + } else { + SmallVector warpsPerCTA = + triton::gpu::getWarpsPerCTA(srcLayout); + warpsPerCTA[helper.getAxis()] = triton::gpu::getWarpsPerCTAWithUniqueData( + srcLayout, srcShape)[helper.getAxis()]; + multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); + } + return multiDimWarpId; + } + + void storeWarpReduceToSharedMemory( + ReduceOpHelper &helper, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + Value threadId = getThreadId(rewriter, loc); + auto srcLayout = helper.getSrcLayout(); + Value warpSize = i32_val(triton::gpu::getWarpSize(srcLayout)); + Value warpId = udiv(threadId, warpSize); + Value laneId = urem(threadId, warpSize); + auto srcShape = helper.getSrcShape(); + unsigned axis = op.getAxis(); + auto smemShape = helper.getScratchConfig(); + + auto threadsPerWarp = + triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); + auto order = getOrder(srcLayout); + SmallVector multiDimLaneId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + Value laneIdAxis = multiDimLaneId[axis]; + Value zero = i32_val(0); + Value laneZero = icmp_eq(laneIdAxis, zero); + + SmallVector multiDimWarpId = + getMultiDimWarpId(helper, warpId, loc, rewriter); + Value warpIdAxis = multiDimWarpId[axis]; + + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = it.second; + + SmallVector writeIdx = indices[key]; + writeIdx[axis] = warpIdAxis; + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShape, smemOrder); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], writeOffset); + targetInfo.storeShared(rewriter, loc, writePtr, acc[i], laneZero); + } + } + } + + // Load the reduction of each warp and accumulate them to a final value and + // store back to shared memory. + void accumulatePartialReductions(ReduceOpHelper &helper, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + auto srcLayout = helper.getSrcLayout(); + auto smemShape = helper.getScratchConfig(); + unsigned elems = product(smemShape); + unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); + Location loc = op.getLoc(); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(srcLayout)); + Value laneId = urem(threadId, warpSize); + Value zero = i32_val(0); + + auto mod = op.getOperation()->getParentOfType(); + unsigned numThreads = + product(triton::gpu::getWarpsPerCTA(srcLayout)) * + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); + Value readOffset = threadId; + for (unsigned round = 0; round < elemsPerThread; ++round) { + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], readOffset); + acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, + threadIsNeeded); + } + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); + // only the first thread in each sizeInterWarps is writing + Value writeOffset = readOffset; + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + writePtrs[i] = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], writeOffset); + } + + Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); + Value laneIdModSizeInterWarpsIsZero = + icmp_eq(laneIdModSizeInterWarps, zero); + Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + targetInfo.storeShared(rewriter, loc, writePtrs[i], acc[i], pred); + } + + if (round != elemsPerThread - 1) { + readOffset = add(readOffset, i32_val(numThreads)); + } + } + } + + // Load the final reduction from shared memory and replace the reduce result + // with it. + void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector smemShape, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto srcLayout = helper.getSrcLayout(); + auto axis = op.getAxis(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + // nd-tensor where n >= 1 + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, targetInfo, + resultLayout, resultTy, true); + auto resultShape = resultTy.getShape(); + auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); + for (size_t resultIdx = 0, resultDim = resultShape.size(); + resultIdx < resultDim; ++resultIdx) { + auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1; + if (resultCTATile[resultIdx] > smemShape[smemIdx] || + resultShape[resultIdx] > smemShape[smemIdx]) { + // When srcShape smaller then src sizePerThread, only srcShape + // elements is accumulated in smem. Modulo smemShape effectively + // replicates srcShape elements to src sizePerThread. + readIdx[smemIdx] = + urem(readIdx[smemIdx], i32_val(smemShape[smemIdx])); + } + } + Value readOffset = + linearize(rewriter, loc, readIdx, smemShape, smemOrder); + Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], readOffset); + resultVals[j] = load(elemTy, readPtr); + } + + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(elemTy, smemBases[i]); + } + } + rewriter.replaceOp(op, results); + } +}; +} // namespace + +void mlir::triton::populateReduceOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h new file mode 100644 index 000000000..09d11ba38 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -0,0 +1,84 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H + +// TODO: refactor so that it doesn't fail if Allocation.h +// is included after utility.h (due to conflict in `store` macro +// and +#include "triton/Analysis/Allocation.h" + +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +// +#include "mlir/IR/TypeUtilities.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include +#include + +#define DEBUG_TYPE "ttgpu_to_llvm" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::SharedMemoryObject; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::CTALayoutAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; +// namespace ttng = ::mlir::triton::nvidia_gpu; + +namespace mlir::triton { +class ReduceOp; +class ScanOp; +} // namespace mlir::triton + +template +class ConvertTritonGPUReduceScanToLLVMPattern + : public ConvertOpToLLVMPattern { +public: + // Make sure the class is only instantiated with Reduce and Scan + static_assert(std::is_same_v || + std::is_same_v); + + using ConvertOpToLLVMPattern::getTypeConverter; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + // Return the pointee type of the shared memory pointer for operand i. + Type getElementType(SourceOp op, int i) const { + auto ty = op.getInputTypes()[i].getElementType(); + return getTypeConverter()->convertType(ty); + } + + // Helper to compute the smem bases in both reductions and scans + SmallVector getSmemBases(SourceOp op, unsigned elems, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + // indices will store the index of the op operands in descending order + // of their bitwidths + std::vector indices(op.getNumOperands()); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) { + return op.getElementTypes()[i].getIntOrFloatBitWidth() > + op.getElementTypes()[j].getIntOrFloatBitWidth(); + }); + // Assign base index to each operand in their order in indices + std::map indexToBase; + indexToBase[indices[0]] = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + indexToBase[indices[i]] = gep( + ptr_ty(rewriter.getContext(), 3), getElementType(op, indices[i - 1]), + indexToBase[indices[i - 1]], i32_val(elems)); + } + // smemBases[k] is the base pointer for the k-th operand + SmallVector smemBases(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemBases[i] = indexToBase[i]; + } + return smemBases; + } +}; + +#endif diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000..972fc5592 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,38 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetProgramIdOpConversion + : public ConvertOpToLLVMPattern { + explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value programId = targetInfo.programId(rewriter, op->getLoc(), + op->getParentOfType(), + op.getAxisAsInt()); + rewriter.replaceOp(op, programId); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp new file mode 100644 index 000000000..675bf5a34 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -0,0 +1,589 @@ +#include + +#include "ReduceScanCommon.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +// apply combine region to acc and cur and accumulate it into acc +// TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce. +// Deduplicate +static SmallVector accumulate(ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur) { + // Allows for passing an unitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(combineOp, &parent.front()); + auto &newScan = parent.front(); + auto returnOp = dyn_cast(newScan.getTerminator()); + + SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(), + combineArgs); + SmallVector results; + llvm::transform(returnOp.getResult(), std::back_inserter(results), + [&](Value res) { return rewriter.getRemappedValue(res); }); + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +// Scan a contiguous elements within a thread and update `srcValues` in place. +static void +scanThreadContiguousElements(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper) { + // Depending on layout contiguous elements along axis dim may not be + // contiguous in srcValues. Keep track of what elements belong to the same + // chunk of contiguous elements. + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned numChunks = srcValues.size() / scanElementsPerThreads; + unsigned stride = helper.getAxisElementStride(); + SmallVector> accs(numChunks); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + // Change this into emitOffsetForLayout? + unsigned accIndex = (srcIndex % stride) + + ((srcIndex / stride) / scanElementsPerThreads) * stride; + + accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex], + srcValues[srcIndex]); + srcValues[srcIndex] = accs[accIndex]; + } +} + +// Apply a scan across threads of the warp for the last element of each +// contiguous group of elements. +static void warpScan(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value laneIdAxis) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Reduce within warps. + SmallVector acc = srcValues[srcIndex]; + for (unsigned i = 1; i <= scanDim / 2; i <<= 1) { + SmallVector shfl(acc.size()); + for (unsigned j = 0; j < acc.size(); ++j) { + shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); + } + SmallVector tempAcc = + accumulate(rewriter, helper.getCombineOp(), shfl, acc); + Value mask = icmp_slt(laneIdAxis, i32_val(i)); + for (unsigned j = 0; j < acc.size(); ++j) { + acc[j] = select(mask, acc[j], tempAcc[j]); + } + } + srcValues[srcIndex] = acc; + } +} + +// For each set of contiguous elements within a thread we store the partial +// reduction into shared memory. Each parallel scan and each warp will store its +// own partial reductions. The shared memory is organized as follow: +// ----------------------------------------------------------------- +// chunk 0: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +// chunk 1: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +static void storeWarpAccumulator(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId, SmallVector smemBases, + SmallVector smemTypes, + Value parallelLaneId, + const TargetInfoBase &targetInfo) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned chunkId = 0; + unsigned elementStride = helper.getAxisElementStride(); + + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + auto lastElement = srcValues[srcIndex]; + Value mask = icmp_eq(laneId, i32_val(scanDim - 1)); + Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane))); + index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps)); + for (unsigned i = 0; i < lastElement.size(); ++i) { + Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), smemTypes[i], + smemBases[i], index); + targetInfo.storeShared(rewriter, loc, writePtr, lastElement[i], mask); + } + chunkId++; + } +} + +// Read the partial reductions from shared memory from each chunk of contiguous +// elements for each warp and parallel scan. Then combine the partial reduction +// with the right elements. Within a given contiguous element chunk we update +// all the elements by accumulating the value from the last element of the +// reduced value from the previous lane. +static void AddPartialReduce(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, + SmallVector smemBases, + SmallVector smemTypes, Value warpId, + Value laneIdAxis, Value parallelLaneId) { + Location loc = helper.getLoc(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); + Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); + Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + struct Accumulator { + SmallVector acc; + SmallVector maskedAcc; + }; + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Accumulate the partial reduction from shared memory. Decide which + // accumulator to combine based on whether the elements belong to the same + // dimension along axis. + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + Accumulator &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + for (unsigned i = 0; i < axisNumWarps; ++i) { + Value index = add(parallelLaneId, i32_val(numParallelLane * + (i + chunkId * axisNumWarps))); + SmallVector partialReduce(helper.getNumOperands()); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + auto elemTy = smemTypes[j]; + Value ptr = + gep(ptr_ty(rewriter.getContext(), 3), elemTy, smemBases[j], index); + partialReduce[j] = load(elemTy, ptr); + } + + if (accumulator.acc.size() == 0) { + accumulator.acc = partialReduce; + accumulator.maskedAcc = partialReduce; + continue; + } + accumulator.acc = accumulate(rewriter, helper.getCombineOp(), + accumulator.acc, partialReduce); + Value mask = icmp_slt(warpId, i32_val(i + 1)); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + accumulator.maskedAcc[j] = + select(mask, accumulator.maskedAcc[j], accumulator.acc[j]); + } + } + auto temp = accumulate(rewriter, helper.getCombineOp(), + accumulator.maskedAcc, srcValues[srcIndex]); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + auto val = srcValues[srcIndex]; + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + temp[i] = select(maskFirstWarp, val[i], temp[i]); + } + } + srcValues[srcIndex] = temp; + // Update the rest of the contiguous elements. + SmallVector lastElement(helper.getNumOperands()); + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); + lastElement[i] = select(maskFirstLane, accumulator.maskedAcc[i], elem); + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + laneValue[j] = + select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], laneValue[j]); + } + } + srcValues[srcIndex - i * elementStride] = laneValue; + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + accumulator.maskedAcc = accumulator.acc; + chunkId++; + } +} + +static void AddPartialReduceOneWarp(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value warpId, + Value laneIdAxis, Value laneIdLast) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); + Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); + Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector> accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + auto &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + if (axisBlockId == 0) // First chunk and first block + accumulator = srcValues[srcIndex]; + else + srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(), + accumulator, srcValues[srcIndex]); + // Update the rest of the contiguous elements. + auto lastElement = srcValues[srcIndex]; + if (scanDim > 1) { + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + lastElement[i] = targetInfo.shuffleUp( + rewriter, loc, srcValues[srcIndex][i], threadStride); + lastElement[i] = select(maskFirstLane, accumulator[i], lastElement[i]); + if (numScanBlocks > 1) + // Update accumulator with the value from the last lane. + accumulator[i] = targetInfo.shuffleIdx( + rewriter, loc, srcValues[srcIndex][i], laneIdLast); + } + } else if (numScanBlocks > 1) { + accumulator = srcValues[srcIndex]; + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + if (axisBlockId == 0) { + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + // For the first warp and first chunk we don't have anything to + // accumulate. + laneValue[j] = + select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], laneValue[j]); + } + } + srcValues[srcIndex - i * elementStride] = laneValue; + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + chunkId++; + } +} + +namespace { +struct ScanOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + using ConvertTritonGPUReduceScanToLLVMPattern< + triton::ScanOp>::ConvertTritonGPUReduceScanToLLVMPattern; + explicit ScanOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (succeeded(emitFastScan(op, adaptor, rewriter))) + return success(); + return failure(); + } + +private: + const TargetInfoBase &targetInfo; + SmallVector getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const; + SmallVector getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const; + std::tuple + getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const; + LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; +}; + +SmallVector +ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + return delinearize(rewriter, loc, laneId, threadsPerWarp, order); +} + +SmallVector +ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); + return delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); +} + +// Break up the threadId into lane and warp id along the scan dimension and +// compute a flat id for the parallel dimensions. +std::tuple +ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); + SmallVector multiDimLaneId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); + + Value laneIdAxis = multiDimLaneId[axis]; + Value warpIdAxis = multiDimWarpId[axis]; + + multiDimLaneId[axis] = i32_val(0); + threadsPerWarp[axis] = 1; + Value laneIdParallel = + linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order); + multiDimWarpId[axis] = i32_val(0); + warpsPerCTA[axis] = 1; + Value warpIdParallel = + linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, warpOrder); + Value flatIdParallel = + add(laneIdParallel, + mul(warpIdParallel, i32_val(helper.getNonAxisNumThreadsPerWarp()))); + return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel); +} + +SmallVector> +unpackInputs(Location loc, triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter) { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; +} + +// Flip the srcValues. Both reverses the chunks and reverses the lanes. +// Lane reversal is done with a butterfly shuffle flip (divide and flip). +SmallVector> +flipSrcValues(Location loc, triton::ScanOp op, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + SmallVector> srcValues, int iWarpSize) { + SmallVector> values(srcValues.size()); + for (int i = 0; i < srcValues.size(); ++i) { + int revIndex = srcValues.size() - i - 1; + for (unsigned j = 0; j < op.getNumOperands(); ++j) { + for (unsigned k = iWarpSize / 2; k >= 1; k = k / 2) { + srcValues[revIndex][j] = + targetInfo.shuffleXor(rewriter, loc, srcValues[revIndex][j], k); + } + values[i].push_back(srcValues[revIndex][j]); + } + } + return values; +} + +// Lowering using warp shuffle operations to do warp level scan. +LogicalResult +ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ScanLoweringHelper helper(op); + auto loc = helper.getLoc(); + if (!helper.isSupported()) + return failure(); + + Value threadId = getThreadId(rewriter, loc); + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = i32_val(iWarpSize); + Value warpId = udiv(threadId, warpSize); + Value laneId = urem(threadId, warpSize); + + auto [laneIdAxis, warpIdAxis, flatIdParallel] = + getDelinearizedIds(rewriter, helper, laneId, warpId); + auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps)); + auto srcValues = + unpackInputs(loc, op, adaptor, rewriter, *getTypeConverter()); + + // For the reverse option we apply flip(scan(flip()) in + // order to avoid having a separate code path in the reverse direction. + // We do this by 1) reversing chunks, 2) reversing lanes, 3) reversing + // warp ids and then undoing this below. + // (Note: Tried pretty hard to get shflDownSync to work but I ended up + // having to add a lot of the complex cross warp code (if rev switch + // first/last etc). Reverse first seems more maintainable.) + if (op.getReverse()) { + warpIdAxis = sub(i32_val(axisNumWarps - 1), warpIdAxis); + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + // Scan contiguous elements in a thread and update `srcValues`. + scanThreadContiguousElements(srcValues, rewriter, helper); + // Apply warp level scan to the last element of each chunk of contiguous + // elements. + warpScan(srcValues, rewriter, targetInfo, helper, laneIdAxis); + + if (axisNumWarps > 1) { + // Slow path for the case where there are multiple warps with unique data on + // the axis. + auto elems = helper.getScratchSizeInElems(); + SmallVector smemBases = getSmemBases(op, elems, rewriter); + SmallVector smemTypes(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemTypes[i] = getElementType(op, i); + } + + // Store the partial reducing for each warp into shared memory. + storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, + smemBases, smemTypes, flatIdParallel, targetInfo); + barrier(); + // Read back the partial reduction of each warp and accumulate them based on + // warpId. Then update each chunk of contiguous elements by adding the + // accumulated value from the previous lane. + AddPartialReduce(srcValues, rewriter, targetInfo, helper, smemBases, + smemTypes, warpIdAxis, laneIdAxis, flatIdParallel); + } else if (srcValues.size() > 1) { + // Fast path for the case where there is only one warp with unique data on + // the axis. + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId); + multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding()); + auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, + triton::gpu::getOrder(helper.getEncoding())); + AddPartialReduceOneWarp(srcValues, rewriter, targetInfo, helper, warpIdAxis, + laneIdAxis, laneIdLast); + } // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do. + + auto transpose = [](const SmallVector> &v) { + assert(v.size() > 0 && v[0].size() > 0); + auto ret = SmallVector>(v[0].size(), + SmallVector(v.size())); + for (int i = 0; i < v.size(); ++i) { + for (int j = 0; j < v[0].size(); ++j) { + ret[j][i] = v[i][j]; + } + } + return ret; + }; + + SmallVector results(op.getNumOperands()); + if (op.getReverse()) { + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + auto valuesTransposed = transpose(srcValues); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto resultTy = dyn_cast(op.getResult()[i].getType()); + results[i] = packLLElements(loc, getTypeConverter(), valuesTransposed[i], + rewriter, resultTy); + } + rewriter.replaceOp(op, results); + return success(); +} +} // namespace + +void mlir::triton::populateScanOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000..a663d247a --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -0,0 +1,149 @@ +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::IluvatarMmaEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + addConversion([&](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type); + }); + addConversion([&](MemDescType type) -> std::optional { + return convertMemDescType(type); + }); + addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional { + return convertAsyncToken(type); + }); + addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); + addConversion([&](mlir::Float8E5M2Type type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); + addConversion([&](mlir::Float8E5M2FNUZType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); +#ifndef __ILUVATAR__ + // Internally store bfloat16 as int16 + addConversion([&](BFloat16Type type) -> std::optional { + return IntegerType::get(type.getContext(), 16); + }); +#endif +} + +Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (isa(pointeeType)) { + auto rankedTensorType = cast(pointeeType); + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto eleType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + SmallVector types; + // offsets + for (size_t i = 0; i < shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 32)); + // shapes, strides + for (size_t i = 0; i < 2 * shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 64)); + + types.push_back(LLVM::LLVMPointerType::get(ctx, type.getAddressSpace())); + + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); +} + +Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( + TensorOrMemDesc type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + Type elemTy = convertType(type.getElementType()); + auto dotOpLayout = mlir::dyn_cast(layout); + if (!dotOpLayout) + return elemTy; + if (auto iluvatarmmaParent = + mlir::dyn_cast(dotOpLayout.getParent())) { + if (iluvatarmmaParent.isVolta()) { + int bitwidth = elemTy.getIntOrFloatBitWidth(); + if (bitwidth == 8) + return vec_ty(elemTy, 8); + return vec_ty(elemTy, 4); + } + } + auto mmaParent = + mlir::dyn_cast(dotOpLayout.getParent()); + if (!mmaParent || mmaParent.isHopper()) + return elemTy; + int bitwidth = elemTy.getIntOrFloatBitWidth(); + assert(bitwidth <= 32); + return IntegerType::get(ctx, 32); +} + +Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + SmallVector shape(type.getShape().begin(), type.getShape().end()); + Type eltType = getElementTypeForStruct(cast(type)); + + if (auto shared_layout = mlir::dyn_cast(layout)) { + SmallVector types; + // base ptr + auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); + types.push_back(ptrType); + // shape dims + auto rank = type.getRank(); + // offsets + strides + for (auto i = 0; i < rank * 2; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + + unsigned numElementsPerThread = getTotalElemsPerThread(type); + SmallVector types(numElementsPerThread, eltType); + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertMemDescType(MemDescType type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + SmallVector shape(type.getShape().begin(), type.getShape().end()); + SmallVector types; + // base ptr + auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); + types.push_back(ptrType); + // shape dims + auto rank = type.getShape().size(); + // offsets + strides + for (auto i = 0; i < rank * 2; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertAsyncToken( + triton::gpu::AsyncTokenType type) { + return IntegerType::get(type.getContext(), 32); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/Utility.cpp new file mode 100644 index 000000000..b65259974 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -0,0 +1,664 @@ +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "python/src/plugin.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "llvm/ADT/STLExtras.h" + +namespace SharedToDotOperandMMAv1 { + +using CoordTy = SmallVector; +using ValueTable = std::map, std::pair>; + +using getMNCoordsFunc = SmallVector (*)( + Value, Location, ConversionPatternRewriter &, ArrayRef, + const IluvatarMmaEncodingAttr &, ArrayRef, int, int, bool); +DEFINE_LOAD_FUNC(getMNCoords) + +static SmallVector +getMNCoords(Value thread, Location loc, ConversionPatternRewriter &rewriter, + ArrayRef wpt, const NvidiaMmaEncodingAttr &mmaLayout, + ArrayRef shape, bool isARow, bool isBRow, bool isAVec4, + bool isBVec4) { + static constexpr std::array fpw{{2, 2, 1}}; + + auto *ctx = thread.getContext(); + Value _1 = i32_val(1); + Value _2 = i32_val(2); + Value _4 = i32_val(4); + Value _16 = i32_val(16); + Value _32 = i32_val(32); + Value _fpw0 = i32_val(fpw[0]); + Value _fpw1 = i32_val(fpw[1]); + + // A info + auto aRep = mmaLayout.getMMAv1Rep(0); + auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); + // B info + auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); + auto bRep = mmaLayout.getMMAv1Rep(1); + + SmallVector rep({aRep[0], bRep[1]}); + SmallVector spw({aSpw[0], bSpw[1]}); + SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); + + Value lane = urem(thread, _32); + Value warp = udiv(thread, _32); + + Value warp0 = urem(warp, i32_val(wpt[0])); + Value warp12 = udiv(warp, i32_val(wpt[0])); + Value warp1 = urem(warp12, i32_val(wpt[1])); + + // warp offset + Value offWarpM = mul(warp0, i32_val(spw[0])); + Value offWarpN = mul(warp1, i32_val(spw[1])); + // quad offset + Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0); + Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1); + // pair offset + Value offPairM = udiv(urem(lane, _16), _4); + offPairM = urem(offPairM, _fpw0); + offPairM = mul(offPairM, _4); + Value offPairN = udiv(urem(lane, _16), _4); + offPairN = udiv(offPairN, _fpw0); + offPairN = urem(offPairN, _fpw1); + offPairN = mul(offPairN, _4); + + // sclare + offPairM = mul(offPairM, i32_val(rep[0] / 2)); + offQuadM = mul(offQuadM, i32_val(rep[0] / 2)); + offPairN = mul(offPairN, i32_val(rep[1] / 2)); + offQuadN = mul(offQuadN, i32_val(rep[1] / 2)); + + // quad pair offset + Value offLaneM = add(offPairM, offQuadM); + Value offLaneN = add(offPairN, offQuadN); + // a, b offset + Value offsetAM = add(offWarpM, offLaneM); + Value offsetBN = add(offWarpN, offLaneN); + // m indices + Value offsetCM = add(and_(lane, _1), offsetAM); + SmallVector idxM; + for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0]) + for (unsigned mm = 0; mm < rep[0]; ++mm) + idxM.push_back(add(offsetCM, i32_val(m + mm * 2))); + + // n indices + Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN))); + SmallVector idxN; + for (int n = 0; n < shape[1]; n += shapePerCTA[1]) { + for (int nn = 0; nn < rep[1]; ++nn) { + idxN.push_back(add( + offsetCN, i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]))); + idxN.push_back( + add(offsetCN, + i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1))); + } + } + + SmallVector> axes({idxM, idxN}); + + // product the axis M and axis N to get coords, ported from + // generator::init_idx method from triton2.0 + + // TODO[Superjomn]: check the order. + SmallVector coords; + for (Value x1 : axes[1]) { // N + for (Value x0 : axes[0]) { // M + SmallVector idx(2); + idx[0] = x0; // M + idx[1] = x1; // N + coords.push_back(std::move(idx)); + } + } + + return coords; // {M,N} in row-major +} +} // namespace SharedToDotOperandMMAv1 + +namespace mlir { + +namespace triton::gpu { +Type getFunctionType(Type resultType, ValueRange operands) { + SmallVector operandTypes(operands.getTypes()); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); +} + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, + Operation *op, StringRef funcName, + Type funcType, + StringRef libname /*= ""*/, + StringRef libpath /*= ""*/) { + using LLVM::LLVMFuncOp; + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return cast(*funcOp); + + Operation *parent = op; + if (!isa(op)) + parent = op->getParentOfType(); + OpBuilder b(parent); + auto ret = b.create(op->getLoc(), funcName, funcType); + ret.getOperation()->setAttr("libname", + StringAttr::get(op->getContext(), libname)); + ret.getOperation()->setAttr("libpath", + StringAttr::get(op->getContext(), libpath)); + return ret; +} +} // namespace triton::gpu + +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices) { + assert(layout.getNumInDims() == indices.size()); + for (auto [inDimName, idx] : indices) { + assert(layout.hasInDim(inDimName) && "Invalid inDimName"); + } + + // This function can emit a lot of MLIR code, which ultimately makes + // compilation slow. (We think this shouldn't be the case -- it's not *that* + // much code -- but we're not clear on how to fix the slowness, which happens + // in the bowels of MLIR.) + // + // As a result we go through some contortions to avoid emitting code where + // possible. + + // Manually constant-fold the layout where possible. + SmallVector> constantIns; + for (auto [inDimName, idx] : indices) { + if (auto constant = dyn_cast(idx.getDefiningOp())) { + constantIns.push_back( + {inDimName, constant.getValue().cast().getInt()}); + } else { + constantIns.push_back({inDimName, 0}); + } + } + SmallVector constantComponent = + llvm::to_vector(llvm::make_second_range(layout.apply(constantIns))); + + Value zero = i32_val(0); + SmallVector> outIndices; + for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) { + if (constantComponent[i] == 0) + outIndices.push_back({outDimName, zero}); + else + outIndices.push_back({outDimName, i32_val(constantComponent[i])}); + } + + for (auto [inDimName, idx] : indices) { + if (isa(idx.getDefiningOp())) { + continue; + } + + int nBits = layout.getInDimSizeLog2(inDimName); + for (int i = 0; i < nBits; i++) { + Value bit = and_(idx, i32_val(1 << i)); + Value bit_is_zero = icmp_eq(bit, zero); + for (auto &[outDimName, outIdx] : outIndices) { + int32_t basis = layout.getBasis(inDimName, i, outDimName); + if (basis == 0) + continue; + outIdx = xor_(outIdx, select(bit_is_zero, zero, i32_val(basis))); + } + } + } + + return outIndices; +} + +std::optional>> +emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { + MLIRContext *ctx = rewriter.getContext(); + auto shape = type.getShape(); + + std::optional ll = triton::gpu::toLinearLayout(shape, layout); + if (!ll.has_value()) { + return std::nullopt; + } + + // TODO(jlebar): We could add strong typing if we wanted; for now this is + // "stringly typed". + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(ll->getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + unsigned rank = shape.size(); + SmallVector> ret; + for (unsigned reg = 0; reg < ll->getInDimSize(str_attr("register")); reg++) { + auto idxs = applyLinearLayout(loc, rewriter, *ll, + {{kRegister, i32_val(reg)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}}); + assert(idxs.size() == rank); + for (unsigned k = 0; k < rank; ++k) { + assert(idxs[k].first == str_attr("dim" + std::to_string(k))); + } + ret.push_back(llvm::to_vector(llvm::make_second_range(idxs))); + } + + return ret; +} + +namespace LLVM { +using namespace mlir::triton; +using mlir::triton::gpu::getOrder; +using mlir::triton::gpu::getSizePerThread; + +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) { + auto i32ty = rewriter.getIntegerType(32); + return rewriter.create(loc, i32ty, + IntegerAttr::get(i32ty, v)); +} + +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v) { + auto i64ty = rewriter.getIntegerType(64); + return rewriter.create(loc, i64ty, + IntegerAttr::get(i64ty, v)); +} + +Value createConstantF16(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f16Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF16FloatAttr(v)); +} + +Value createConstantF32(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f32Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF32FloatAttr(v)); +} + +Value createConstantF64(Location loc, OpBuilder &rewriter, double v) { + auto type = type::f64Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF64FloatAttr(v)); +} + +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type) { + if (!isa(type)) { + llvm::report_fatal_error("Creating NaN constant for non-float type!"); + } + return rewriter.create( + loc, type, APFloat::getNaN(cast(type).getFloatSemantics())); +} + +// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + TypeConverter *converter, int64_t value) { + Type ty = converter->convertType(builder.getIndexType()); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value) { + Type ty = builder.getIntegerType(width); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +SharedMemoryObject +getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, + ConversionPatternRewriter &rewriter) { + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector elems(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + elems[i] = extract_val(type, llvmStruct, i); + } + + auto rank = (elems.size() - 1) / 2; + return {/*base=*/elems[0], + /*baseElemType=*/elemTy, + /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, + /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; +} + +SmallVector getStridesFromShapeAndOrder(ArrayRef shape, + ArrayRef order, + Location loc, + RewriterBase &rewriter) { + auto rank = shape.size(); + SmallVector strides(rank); + int64_t stride = 1; + for (auto idx : order) { + strides[idx] = i32_val(stride); + stride *= shape[idx]; + } + return strides; +} + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + SmallVector reorderedMultiDim(rank); + if (auto constantOp = linear.getDefiningOp()) { + unsigned intVal = mlir::cast(constantOp.getValue()) + .getValue() + .getSExtValue(); + reorderedMultiDim = delinearize(rewriter, loc, intVal, reordered); + } else { + reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); + } + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + unsigned remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + unsigned dimSize = en.value(); + multiDim[en.index()] = i32_val(remained % dimSize); + remained = remained / dimSize; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + Value remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + Value dimSize = i32_val(en.value()); + multiDim[en.index()] = urem(remained, dimSize); + remained = udiv(remained, dimSize); + } + return multiDim; +} + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape, + ArrayRef order) { + return linearize(rewriter, loc, applyPermutation(multiDim, order), + applyPermutation(shape, order)); +} + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = i32_val(0); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = i32_val(dimShape); + linear = add(mul(linear, dimSize), dim); + } + } + return linear; +} + +Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, + StringRef key, StringRef content) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto ctx = moduleOp.getContext(); + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (key + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<64> contentStr(content); + size_t contentSize = contentStr.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize); + +#ifndef __ILUVATAR__ + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + UnknownLoc::get(ctx), globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(contentStr)); + } + + Value zero = i32_val(0); + Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); + Value globalPtr = rewriter.create( + UnknownLoc::get(ctx), globalPtrType, global.getSymName()); + Value stringStart = + gep(ptr_ty(ctx), i8_ty, globalPtr, SmallVector({zero})); +#else + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + UnknownLoc::get(ctx), globalType, + /*isConstant=*/true, LLVM::Linkage::Private, stringConstName, + rewriter.getStringAttr(contentStr), 1, 4); + } + + Value zero = i32_val(0); + Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); + Value globalPtr = rewriter.create( + UnknownLoc::get(ctx), globalPtrType, global.getSymName()); + Value localPtr = addrspacecast(ptr_ty(ctx), globalPtr); + Value stringStart = + gep(ptr_ty(ctx), i8_ty, localPtr, SmallVector({zero})); +#endif + return stringStart; +} + +SmallVector getMultiDimOffset(Attribute layout, Location loc, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + unsigned elemId, RankedTensorType type, + ArrayRef multiDimCTAInRepId, + ArrayRef shapePerCTATile, + bool isTrans, bool stNotRd) { + auto shape = type.getShape(); + unsigned rank = shape.size(); + if (auto blockedLayout = dyn_cast(layout)) { + auto multiDimOffsetFirstElem = emitBaseIndexForLayout( + loc, rewriter, targetInfo, blockedLayout, type, false); + SmallVector multiDimOffset(rank); + SmallVector multiDimElemId = getMultiDimIndex( + elemId, getSizePerThread(layout), getOrder(layout)); + for (unsigned d = 0; d < rank; ++d) { + multiDimOffset[d] = + add(multiDimOffsetFirstElem[d], + i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] + + multiDimElemId[d])); + } + return multiDimOffset; + } + if (auto sliceLayout = mlir::dyn_cast(layout)) { + unsigned dim = sliceLayout.getDim(); + auto parentEncoding = sliceLayout.getParent(); + auto parentSizePerThread = getSizePerThread(parentEncoding); + auto parentShape = sliceLayout.paddedShape(shape); + auto parentTy = RankedTensorType::get(parentShape, type.getElementType(), + parentEncoding); + auto offsets = emitOffsetForLayout(layout, type); + auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy); + SmallVector idxs; + for (SmallVector off : offsets) { + off.insert(off.begin() + dim, 0); + auto it = std::find(parentOffset.begin(), parentOffset.end(), off); + idxs.push_back(std::distance(parentOffset.begin(), it)); + } + auto multiDimOffsetParent = getMultiDimOffset( + parentEncoding, loc, rewriter, targetInfo, idxs[elemId], parentTy, + sliceLayout.paddedShape(multiDimCTAInRepId), + sliceLayout.paddedShape(shapePerCTATile)); + SmallVector multiDimOffset(rank); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d == dim) + continue; + unsigned slicedD = d < dim ? d : (d - 1); + multiDimOffset[slicedD] = multiDimOffsetParent[d]; + } + return multiDimOffset; + } + if (auto mmaLayout = mlir::dyn_cast(layout)) { + assert(rank == 2 || + (rank == 3 && mmaLayout.isAmpere()) && "Unexpected rank"); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + auto instrShape = mmaLayout.getInstrShape(); + SmallVector mmaColIdx(2); + SmallVector mmaRowIdx(2); + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + // TODO: fix the bug in MMAEncodingAttr document + SmallVector multiDimWarpId(2); + auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); + auto warpOrder = triton::gpu::getWarpOrder(mmaLayout); + multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); + Value _1 = i32_val(1); + Value _2 = i32_val(2); + Value _4 = i32_val(4); + Value _8 = i32_val(8); + Value _16 = i32_val(16); + if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { + multiDimWarpId[rank - 1] = urem( + multiDimWarpId[rank - 1], + i32_val(ceil(shapePerCTA[rank - 1], instrShape[rank - 1]))); + multiDimWarpId[rank - 2] = urem( + multiDimWarpId[rank - 2], + i32_val(ceil(shapePerCTA[rank - 2], instrShape[rank - 2]))); + + Value mmaGrpId = udiv(laneId, _4); + Value mmaGrpIdP8 = add(mmaGrpId, _8); + Value mmaThreadIdInGrp = urem(laneId, _4); + Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2); + Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1); + Value rowWarpOffset = + mul(multiDimWarpId[rank - 2], i32_val(instrShape[rank - 2])); + mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset); + mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset); + Value colWarpOffset = + mul(multiDimWarpId[rank - 1], i32_val(instrShape[rank - 1])); + mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset); + mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset); + } else if (mmaLayout.isVolta()) { + // Volta doesn't follow the pattern here. + } else { + llvm_unreachable("Unexpected MMALayout version"); + } + + SmallVector multiDimOffset(rank); + if (mmaLayout.isHopper()) { + unsigned elemIdRem4 = elemId % 4; + unsigned nGrpId = elemId / 4; + multiDimOffset[0] = elemIdRem4 < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[1] = elemIdRem4 % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[1] = add(multiDimOffset[1], i32_val(8 * nGrpId)); + multiDimOffset[0] = add(multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * + shapePerCTATile[0])); + multiDimOffset[1] = add(multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * + shapePerCTATile[1])); + } else if (mmaLayout.isAmpere()) { + if (rank == 3) + multiDimOffset[0] = + add(multiDimWarpId[0], + i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0])); + multiDimOffset[rank - 2] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[rank - 1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[rank - 2] = + add(multiDimOffset[rank - 2], i32_val(multiDimCTAInRepId[rank - 2] * + shapePerCTATile[rank - 2])); + multiDimOffset[rank - 1] = + add(multiDimOffset[rank - 1], i32_val(multiDimCTAInRepId[rank - 1] * + shapePerCTATile[rank - 1])); + } else if (mmaLayout.isVolta()) { + auto [isARow, isBRow, isAVec4, isBVec4, _] = + mmaLayout.decodeVoltaLayoutStates(); + auto coords = SharedToDotOperandMMAv1::getMNCoords( + threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape, + isARow, isBRow, isAVec4, isBVec4); + return coords[elemId]; + } else { + llvm_unreachable("Unexpected MMALayout version"); + } + return multiDimOffset; + } + if (auto mmaLayout = mlir::dyn_cast(layout)) { + assert(rank == 2 && "Unexpected rank"); + SmallVector multiDimOffset(rank); + Value threadId = getThreadId(rewriter, loc); + if (mmaLayout.isVolta()) { + int bitwidth = type.getElementType().getIntOrFloatBitWidth(); + int elemVecSize = stNotRd ? (32 / bitwidth) : 1; + static auto func = SharedToDotOperandMMAv1::load_getMNCoords_func( + "iluvatar", "getMNCoords"); + auto coords = func(threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(), + mmaLayout, shape, bitwidth, elemVecSize, isTrans); + return coords[elemId]; + } else { + llvm_unreachable("Unexpected MMALayout version"); + } + } + if (isa(layout)) { + auto multiDimBase = + emitBaseIndexForLayout(loc, rewriter, targetInfo, layout, type, false); + SmallVector> offsets; + assert(rank == 2); + SmallVector multiDimOffset(rank); + if (auto mfmaLayout = dyn_cast(layout)) { + emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0], + multiDimCTAInRepId[1]); + } else if (auto wmmaLayout = dyn_cast(layout)) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0], + multiDimCTAInRepId[1]); + } + multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0])); + multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1])); + return multiDimOffset; + } + llvm_unreachable("unexpected layout in getMultiDimOffset"); +} + +SmallVector getWrappedMultiDimOffset( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDimOffset, ArrayRef shape, + SmallVector shapePerCTATile, SmallVector shapePerCTA) { + unsigned rank = shape.size(); + SmallVector multiDimOffsetWrapped(rank); + for (unsigned d = 0; d < rank; ++d) { + if (shapePerCTATile[d] > shapePerCTA[d]) + multiDimOffsetWrapped[d] = urem(multiDimOffset[d], i32_val(shape[d])); + else + multiDimOffsetWrapped[d] = multiDimOffset[d]; + } + return multiDimOffsetWrapped; +} + +} // namespace LLVM +} // namespace mlir diff --git a/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp new file mode 100644 index 000000000..45c0bde61 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -0,0 +1,399 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefsPlugin.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +namespace { +struct SplatOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a + // LLVM::StructType value. + // + // @elemType: the element type in operand. + // @resType: the return type of the Splat-like op. + // @constVal: a LLVM::ConstantOp or other scalar value. + static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Location loc) { + auto tensorTy = cast(resType); + // Check the converted type for the tensor as depending on the encoding the + // converter may pick different element types. + auto srcType = typeConverter->convertType(tensorTy); + if (auto structTy = dyn_cast(srcType)) + srcType = structTy.getBody()[0]; + // If the type sizes don't match we need to pack constants. + if (srcType.isIntOrFloat() && constVal.getType().getIntOrFloatBitWidth() != + srcType.getIntOrFloatBitWidth()) { + unsigned cstBitWidth = constVal.getType().getIntOrFloatBitWidth(); + unsigned srcBitWidth = srcType.getIntOrFloatBitWidth(); + assert(cstBitWidth <= srcBitWidth && srcBitWidth % cstBitWidth == 0); + unsigned ratio = srcBitWidth / cstBitWidth; + Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth); + VectorType vecType = VectorType::get(ratio, intTy); + Value intCst = bitcast(constVal, intTy); + Value vec = undef(vecType); + for (unsigned i = 0; i < ratio; ++i) + vec = insert_element(vecType, vec, intCst, int_val(32, i)); + constVal = vec; + } + auto llSrc = bitcast(constVal, srcType); + size_t elemsPerThread = getTotalElemsPerThread(tensorTy); + llvm::SmallVector elems(elemsPerThread, llSrc); + return packLLElements(loc, typeConverter, elems, rewriter, resType); + } + LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto src = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src, + typeConverter, rewriter, loc); + rewriter.replaceOp(op, {llStruct}); + return success(); + } +}; +// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), +// the logic is the same as triton::SplatOp, so the underlying implementation +// is reused. +struct ArithConstantSplatOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto value = op.getValue(); + if (!mlir::dyn_cast(value)) + return failure(); + auto loc = op->getLoc(); + LLVM::ConstantOp arithConstantOp; + auto values = mlir::dyn_cast(op.getValue()); + auto elemType = values.getElementType(); + Attribute val; + if (type::isFloat(elemType)) { + val = values.getValues()[0]; + } else if (type::isInt(elemType)) { + val = values.getValues()[0]; + } else { + llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: " + << value.getType() << "\n"; + return failure(); + } + auto constOp = rewriter.create(loc, elemType, val); + auto typeConverter = getTypeConverter(); + auto llStruct = SplatOpConversion::convertSplatLikeOp( + elemType, op.getType(), constOp, typeConverter, rewriter, loc); + rewriter.replaceOp(op, llStruct); + return success(); + } +}; +struct CatOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename CatOp::Adaptor; + explicit CatOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + unsigned elems = getTotalElemsPerThread(resultTy); + auto typeConverter = getTypeConverter(); + Type elemTy = typeConverter->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + // unpack input values + auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter); + auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter); + // concatenate (and potentially reorder) values + SmallVector retVals; + for (Value v : lhsVals) + retVals.push_back(v); + for (Value v : rhsVals) + retVals.push_back(v); + // pack and replace + Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct JoinOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename JoinOp::Adaptor; + explicit JoinOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The op has a blocked encoding. + // - The last dimension (the one we're joining) is also the most minor + // dimension. + // - The input and output encodings are the same, except the output has + // 2 elements per thread in the last dim. + // + // With these invariants, join is trivial: We just return the i'th element + // from lhs, followed by the i'th elem from rhs. + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + SmallVector lhsVals = + unpackLLElements(loc, adaptor.getLhs(), rewriter); + SmallVector rhsVals = + unpackLLElements(loc, adaptor.getRhs(), rewriter); + assert(lhsVals.size() == rhsVals.size()); + SmallVector joinedVals; + for (int i = 0; i < lhsVals.size(); i++) { + joinedVals.push_back(lhsVals[i]); + joinedVals.push_back(rhsVals[i]); + } + Value ret = + packLLElements(loc, typeConverter, joinedVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct SplitOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename SplitOp::Adaptor; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The op has a blocked encoding. + // - The last dimension (the one we're spliting) is also the most minor + // dimension, and has sizePerThread=2. + // + // With these invariants, split is trivial: Every other value goes into + // return value 0, and every other goes into return value 1. + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + SmallVector srcVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(srcVals.size() % 2 == 0); + SmallVector outLhsVals; + SmallVector outRhsVals; + for (int i = 0; i < srcVals.size(); i += 2) { + outLhsVals.push_back(srcVals[i]); + outRhsVals.push_back(srcVals[i + 1]); + } + auto resultTy = cast(op.getResult(0).getType()); + Value retLhs = + packLLElements(loc, typeConverter, outLhsVals, rewriter, resultTy); + Value retRhs = + packLLElements(loc, typeConverter, outRhsVals, rewriter, resultTy); + rewriter.replaceOp(op, {retLhs, retRhs}); + return success(); + } +}; +struct ReshapeOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ReshapeOp::Adaptor; + explicit ReshapeOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) { + return emitOptionalError(loc, + "expensive view not supported on reshape op"); + } + auto resultTy = cast(op.getType()); + auto srcTy = cast(op.getSrc().getType()); + auto typeConverter = getTypeConverter(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, typeConverter, vals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ExpandDimsOp::Adaptor; + explicit ExpandDimsOpConversion( + LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(op.getType()); + auto srcLayout = dyn_cast(srcTy.getEncoding()); + if (!srcLayout) { + return emitOptionalError( + loc, "ExpandDimsOp only supports SliceEncodingAttr as its input"); + } + auto resultLayout = resultTy.getEncoding(); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + offset.erase(offset.begin() + srcLayout.getDim()); + resultVals.push_back(srcValues.at(offset)); + } + Value ret = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct TransOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + if (auto enc = dyn_cast(resultTy.getEncoding())) { + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.base, srcSmemObj.baseElemType, + /*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()), + /*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder())); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } else if (auto enc = mlir::dyn_cast( + resultTy.getEncoding())) { + // If the dst encoding is blocked, then TransOp::inferReturnTypes + // ensures that: + // - the src encoding is also blocked, and + // - the translation from src to dst is just a "renaming" of the + // registers, i.e. each thread has exactly the same values. + // Thus the transpose op simply returns the same values it got. + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, this->getTypeConverter(), vals, rewriter, + resultTy); + rewriter.replaceOp(op, ret); + return success(); + } + return emitOptionalError(loc, "unsupported encoding for TransOp"); + } +}; +struct BroadcastOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Following the order of indices in the legacy code, a broadcast of: + // [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)] + // => + // [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)] + // + // logically maps to a broadcast within a thread's scope: + // [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1), + // 1,spt(k+1)..spt(n-1)] + // => + // [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)] + // + // regardless of the order of the layout + // + Location loc = op->getLoc(); + Value src = adaptor.getSrc(); + Value result = op.getResult(); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(result.getType()); + auto srcLayout = srcTy.getEncoding(); + auto resultLayout = resultTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultShape = resultTy.getShape(); + unsigned rank = srcTy.getRank(); + auto typeConverter = getTypeConverter(); + assert(rank == resultTy.getRank()); + auto order = triton::gpu::getOrder(srcLayout); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + SmallVector srcVals = unpackLLElements(loc, src, rewriter); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + for (size_t j = 0; j < srcShape.size(); j++) + if (srcShape[j] == 1) + offset[j] = 0; + resultVals.push_back(srcValues.at(offset)); + } + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct MemDescSubviewOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // %dst = extract_slice %src[%offsets] + Location loc = op->getLoc(); + auto srcTy = op.getSrc().getType(); + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + + // newBase = base + offset + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + SmallVector opOffsetVals = op.getOffsets(); + size_t destRank = op.getResult().getType().getRank(); + SmallVector offsetVals; + SmallVector strides; + int rankReduced = srcTy.getRank() - destRank; + for (int i = rankReduced; i < opOffsetVals.size(); i++) { + strides.push_back(smemObj.strides[i]); + offsetVals.push_back(opOffsetVals[i]); + } + // Compute the offset based on the original strides of the shared memory + // object + auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + smemObj = + SharedMemoryObject(gep(elemPtrTy, llvmElemTy, smemObj.base, offset), + llvmElemTy, strides, offsetVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; +} // namespace + +void mlir::triton::populateViewOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/iluvatar/lib/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..1b629ba16 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonToTritonGPU + TritonGPUConversion.cpp + TritonToTritonGPUPass.cpp + + DEPENDS + TritonConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + TritonIR + TritonGPUIR + TritonGPUTransforms +) diff --git a/third_party/iluvatar/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/third_party/iluvatar/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp new file mode 100644 index 000000000..34fb89954 --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -0,0 +1,123 @@ +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +#include +#include + +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton::gpu; + +// +// TypeConverter +// +TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + int numWarps, int threadsPerWarp, + int numCTAs) + : context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp), + numCTAs(numCTAs) { + addConversion([](Type type) { return type; }); + + // Add encoding for tensor + addConversion([this](RankedTensorType tensorType) -> RankedTensorType { + // types with encoding are already in the right format + // TODO: check for layout encodings more specifically + if (tensorType.getEncoding()) + return tensorType; + ArrayRef shape = tensorType.getShape(); + triton::gpu::BlockedEncodingAttr encoding = + getDefaultBlockedEncoding(this->context, shape, this->numWarps, + this->threadsPerWarp, this->numCTAs); + return RankedTensorType::get(shape, tensorType.getElementType(), encoding); + }); + + // Add encoding for tensor pointer + addConversion([this](triton::PointerType ptrType) -> triton::PointerType { + // Check whether tensor pointer `tt.ptr>` + auto pointeeTensorType = + dyn_cast(ptrType.getPointeeType()); + if (pointeeTensorType == nullptr) + return ptrType; + + // Add layout into the tensor + auto convertedTensorType = convertType(pointeeTensorType); + return triton::PointerType::get(convertedTensorType, + ptrType.getAddressSpace()); + }); + + // + // Materializations + // + // This will be called when (newArgType != origArgType) + // This will create newArg, and map(origArg, newArg) + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); + return std::nullopt; + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonGPU Conversion"); + return std::nullopt; + }); + + // This will be called when (desiredType != newOperandType) + // where, desiredType = typeConverter->convertType(origType) + // NOTE: only for remapped values. + addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + auto cast = + builder.create(loc, tensorType, inputs); + return std::optional(cast.getResult()); + }); +} + +// +// TritonGPUConversion +// +TritonGPUConversionTarget::TritonGPUConversionTarget( + MLIRContext &context, TritonGPUTypeConverter &typeConverter) + : ConversionTarget(context) { + // TODO: we should also verify ops of TritonGPUDialect + addLegalDialect(); + + // Some ops from SCF are illegal + addIllegalOp(); + + addDynamicallyLegalDialect([&](Operation *op) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; + }); + + // We have requirements for the data layouts + addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { + Attribute aEncoding = + cast(dotOp.getA().getType()).getEncoding(); + Attribute bEncoding = + cast(dotOp.getB().getType()).getEncoding(); + if (aEncoding && isa(aEncoding) && + bEncoding && isa(bEncoding)) + return true; + return false; + }); +} diff --git a/third_party/iluvatar/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/third_party/iluvatar/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp new file mode 100644 index 000000000..32de13fab --- /dev/null +++ b/third_party/iluvatar/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -0,0 +1,826 @@ +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "llvm/ADT/APSInt.h" +#include + +#define GEN_PASS_CLASSES +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// pass named attrs (e.g., tt.contiguity) from Triton to Triton +static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { + for (const NamedAttribute attr : dictAttrs.getValue()) + if (!op->hasAttr(attr.getName())) + op->setAttr(attr.getName(), attr.getValue()); +} + +template struct GenericOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + + return success(); + } +}; + +class ArithConstantPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = getTypeConverter()->convertType(op.getType()); + auto retShapedType = cast(retType); + auto value = dyn_cast(adaptor.getValue()); + if (dyn_cast(retShapedType)) { + assert(value); + if (value.getElementType().isInteger(1) && value.isSplat()) + // Workaround until https://reviews.llvm.org/D133743 is included. + value = + DenseElementsAttr::get(retShapedType, value.getSplatValue()); + else + // This is a hack. We just want to add encoding + value = value.reshape(retShapedType); + } + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retShapedType, value), + adaptor.getAttributes()); + return success(); + } +}; + +void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + // -------------- + // Add legality and rewrite pattern rules for operations + // from the Arith dialect. The basic premise is that + // Arith operations require both inputs to have the same + // non-null encoding + // -------------- + MLIRContext *context = patterns.getContext(); + // TODO: there's probably a better way to avoid adding all ops one-by-one + patterns.add< + ArithConstantPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, // NegFOp + // Floating point + GenericOpPattern, GenericOpPattern, + // MaxMin + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + // Floating point + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + // Cmp + GenericOpPattern, GenericOpPattern, + // Select + GenericOpPattern, + // Cast Ops + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern>(typeConverter, context); +} + +void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern>( + typeConverter, context); +} + +// +// Triton patterns +// +struct TritonExpandDimsPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Type retType = op.getType()); + RankedTensorType argType = + cast(adaptor.getSrc().getType()); + Attribute _argEncoding = argType.getEncoding(); + if (!_argEncoding) + return failure(); + auto argEncoding = cast(_argEncoding); + // return shape + auto retShape = argType.getShape().vec(); + retShape.insert(retShape.begin() + op.getAxis(), 1); + // return encoding + auto retSizePerThread = argEncoding.getSizePerThread(); + retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1); + auto retThreadsPerWarp = argEncoding.getThreadsPerWarp(); + retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1); + auto retWarpsPerCTA = argEncoding.getWarpsPerCTA(); + retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1); + SmallVector retOrder(retShape.size()); + std::iota(retOrder.begin(), retOrder.end(), 0); + + auto argCTALayout = argEncoding.getCTALayout(); + auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis()); + auto retCTASplitNum = + insertOne(argCTALayout.getCTASplitNum(), op.getAxis()); + auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis()); + auto retCTALayout = triton::gpu::CTALayoutAttr::get( + getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder); + + SmallVector smeCTA(retShape.size()); + triton::gpu::BlockedEncodingAttr retEncoding = + triton::gpu::BlockedEncodingAttr::get( + getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA, + retOrder, retCTALayout, false, smeCTA); + // convert operand to slice of return type + Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( + getContext(), op.getAxis(), retEncoding, false); + RankedTensorType newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), newArgEncoding); + // construct new op + auto newSrc = rewriter.create( + op.getLoc(), newArgType, adaptor.getSrc()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newSrc, adaptor.getAxis()), + adaptor.getAttributes()); + return success(); + } + +private: + template + SmallVector insertOne(ArrayRef vec, unsigned axis) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + axis, 1); + return res; + } + + // Example: order = [ 0, 2, 1, 3], dim = 2 + // resOrder = [2, 0, 3, 1, 4] + SmallVector insertOrder(ArrayRef order, + unsigned axis) const { + SmallVector resOrder(order.begin(), order.end()); + for (unsigned i = 0; i < resOrder.size(); ++i) + if (resOrder[i] >= axis) + ++resOrder[i]; + resOrder.insert(resOrder.begin(), axis); + return resOrder; + } +}; + +struct TritonDotPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType origType = op.getType(); + auto origShape = origType.getShape(); + auto typeConverter = getTypeConverter(); + int numWarps = typeConverter->getNumWarps(); + int threadsPerWarp = typeConverter->getThreadsPerWarp(); + int numCTAs = typeConverter->getNumCTAs(); + auto rank = origShape.size(); + SmallVector retSizePerThread(rank, 1); + auto numElements = product(origShape); + if (numElements / (numWarps * threadsPerWarp) >= 4) { + retSizePerThread[rank - 1] = 2; + retSizePerThread[rank - 2] = 2; + } + if (numElements / (numWarps * threadsPerWarp) >= 16) { + retSizePerThread[rank - 1] = 4; + retSizePerThread[rank - 2] = 4; + } + SmallVector retOrder(rank); + for (unsigned i = 0; i < rank; ++i) + retOrder[i] = rank - 1 - i; + Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( + getContext(), origShape, retSizePerThread, retOrder, numWarps, + threadsPerWarp, numCTAs); + RankedTensorType retType = + RankedTensorType::get(origShape, origType.getElementType(), dEncoding); + // a & b must be of smem layout + auto aType = cast(adaptor.getA().getType()); + auto bType = cast(adaptor.getB().getType()); + Type aEltType = aType.getElementType(); + Type bEltType = bType.getElementType(); + Attribute aEncoding = aType.getEncoding(); + Attribute bEncoding = bType.getEncoding(); + if (!aEncoding || !bEncoding) + return failure(); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + Value c = adaptor.getC(); + if (!mlir::isa(aEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 0, dEncoding, aEltType); + auto dstType = + RankedTensorType::get(aType.getShape(), aEltType, encoding); + a = rewriter.create(a.getLoc(), dstType, a); + } + if (!mlir::isa(bEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 1, dEncoding, bEltType); + auto dstType = + RankedTensorType::get(bType.getShape(), bEltType, encoding); + b = rewriter.create(b.getLoc(), dstType, b); + } + c = rewriter.create(c.getLoc(), retType, c); + + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, a, b, c, adaptor.getInputPrecision(), + adaptor.getMaxNumImpreciseAcc()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The cat op satisfy two conditions: + // 1. output.numel = lhs.numel + rhs.numel + // 2. output.total_elems_per_thread = + // next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread) + // For now, this behaves like generic, but this + // will evolve when we add support for `can_reorder=False`. + auto retType = cast( + this->getTypeConverter()->convertType(op.getType())); + auto retEncoding = + cast(retType.getEncoding()); + auto lhsType = adaptor.getLhs().getType(); + auto rhsType = adaptor.getRhs().getType(); + auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType); + auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType); + auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType); + auto retShape = retType.getShape(); + auto retOrder = retEncoding.getOrder(); + auto retSizePerThread = retEncoding.getSizePerThread(); + auto retThreadsPerWarp = retEncoding.getThreadsPerWarp(); + auto retWarpsPerCTA = retEncoding.getWarpsPerCTA(); + // Get new retSizePerThread if ret elems per thread is not enough. + // We have to round it up to the next power of 2 due to triton's tensor size + // constraint. + auto newRetTotalElemsPerThread = + nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread); + auto newRetSizePerThread = retSizePerThread; + newRetSizePerThread[retOrder[0]] *= + newRetTotalElemsPerThread / retTotalElemsPerThread; + SmallVector smeCTA(retShape.size()); + triton::gpu::BlockedEncodingAttr newRetEncoding = + triton::gpu::BlockedEncodingAttr::get( + getContext(), newRetSizePerThread, retThreadsPerWarp, + retWarpsPerCTA, retOrder, retEncoding.getCTALayout(), false, + smeCTA); + auto newRetType = RankedTensorType::get(retShape, retType.getElementType(), + newRetEncoding); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newRetType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonJoinOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Simply rely on type inference for this op. (Notably, GenericOpPattern + // does not do this, instead it assigns the default layout to the ins and + // outs.) + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, adaptor.getLhs(), adaptor.getRhs()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonSplitOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = dyn_cast(srcTy.getEncoding()); + int rank = srcEnc.getOrder().size(); + auto typeConverter = getTypeConverter(); + + // The operand to split must have: + // - a blocked layout, with + // - sizePerThread = 2 in the last dimension, + // - threadsPerWarp, warpsPerCTA, and CTAsPerCGA = 1 in the last dim, and + // - the last dimension minor. + // If that's not the case, add a convert before the split. + if (!srcEnc || srcEnc.getSizePerThread().back() != 2 || + srcEnc.getOrder().front() != rank - 1) { + // If we take the default encoding for the op's result (i.e. post-split) + // and add 1 to the end of each dim, that gives us what we want. Other + // than making a legal src encoding, our choice of layout doesn't matter; + // it'll get fixed by RemoveLayoutConversions. + auto defaultEnc = getDefaultBlockedEncoding( + getContext(), + cast(op.getResult(0).getType()).getShape(), + typeConverter->getNumWarps(), typeConverter->getThreadsPerWarp(), + typeConverter->getNumCTAs()); + + auto append = [&](ArrayRef vals, unsigned val) { + SmallVector res(vals); + res.push_back(val); + return res; + }; + auto prepend = [&](ArrayRef vals, unsigned val) { + SmallVector res; + res.push_back(val); + res.append(vals.begin(), vals.end()); + return res; + }; + + srcEnc = BlockedEncodingAttr::get( + getContext(), append(defaultEnc.getSizePerThread(), 2), + append(defaultEnc.getThreadsPerWarp(), 1), + append(defaultEnc.getWarpsPerCTA(), 1), + prepend(defaultEnc.getOrder(), rank - 1), + CTALayoutAttr::get(getContext(), + append(defaultEnc.getCTAsPerCGA(), 1), + append(defaultEnc.getCTASplitNum(), 1), + prepend(defaultEnc.getCTAOrder(), rank - 1)), + defaultEnc.getLoadType(), defaultEnc.getSmeWarpsPerCTA()); + srcTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), + srcEnc); + src = rewriter.create(op.getLoc(), srcTy, src); + } + + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonTransPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = srcTy.getEncoding(); + if (!srcEnc) + return failure(); + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src, op.getOrder()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonBroadcastPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // This creates a tensor with the new shape but the argument's layout + LogicalResult + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(adaptor.getSrc().getType()); + auto srcEncoding = srcType.getEncoding(); + if (!srcEncoding) + return failure(); + Type retType = RankedTensorType::get( + op.getType().getShape(), op.getType().getElementType(), srcEncoding); + // Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReduce = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis(), + adaptor.getNoWarpReduce()); + addNamedAttrs(newReduce, adaptor.getAttributes()); + + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newReduce.getResult()); + return success(); + } +}; + +struct TritonScanPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newScan = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis(), op.getReverse()); + addNamedAttrs(newScan, adaptor.getAttributes()); + + auto &newCombineOp = newScan.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newScan.getResult()); + return success(); + } +}; + +class TritonFuncOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getName(), op.getFunctionType()); + addNamedAttrs(newOp, adaptor.getAttributes()); + rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), + newOp.getBody().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter))) + return failure(); + + return success(); + } +}; + +class TritonCallOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getCallee(), op.getResultTypes(), adaptor.getOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + return success(); + } +}; + +class TritonReturnOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, unsigned numCTAs) { + MLIRContext *context = patterns.getContext(); + patterns.insert< // TODO: view should have custom pattern that views the + // layout + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + TritonBroadcastPattern, GenericOpPattern, + TritonCatPattern, TritonJoinOpPattern, TritonSplitOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, TritonReducePattern, + GenericOpPattern, TritonScanPattern, + GenericOpPattern, + GenericOpPattern, TritonExpandDimsPattern, + TritonTransPattern, TritonDotPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, TritonFuncOpPattern>(typeConverter, + context); +} + +// +// SCF patterns +// +// This is borrowed from ConvertForOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +struct SCFForPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + // Ref: ConvertForOpTypes + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Now, update all the types. + + // Convert the types of block arguments within the given region. This + // replaces each block with a new block containing the updated signature. + // The entry block may have a special conversion if `entryConversion` is + // provided. On success, the new entry block to the region is returned for + // convenience. Otherwise, failure is returned. + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *getTypeConverter()))) { + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + // Change the clone to use the updated operands. We could have cloned with + // a IRMapping, but this seems a bit more direct. + newOp->setOperands(adaptor.getOperands()); + // Update the result types to the new converted types. + SmallVector newResultTypes; + for (Type type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + + rewriter.replaceOp(op, newOp.getResults()); + + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFWhilePattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + auto newOp = rewriter.create(op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +class SCFConditionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add, SCFForPattern, SCFIfPattern, + SCFWhilePattern, SCFConditionPattern>(typeConverter, context); +} + +// CF + +class CFBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getSuccessor(), adaptor.getOperands()); + if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +class CFCondBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + + if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(), + *converter))) + return failure(); + if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +void populateCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(typeConverter, context); +} +// + +class ConvertTritonToTritonGPU + : public ConvertTritonToTritonGPUBase { +public: + ConvertTritonToTritonGPU() = default; + // constructor with some parameters set explicitly. + ConvertTritonToTritonGPU(const std::string &target, int numWarps, + int threadsPerWarp, int numCTAs) { + this->numWarps = numWarps; + this->threadsPerWarp = threadsPerWarp; + this->numCTAs = numCTAs; + this->target = target; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + // type converter + TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp, + numCTAs); + TritonGPUConversionTarget target(*context, typeConverter); + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + populateArithPatternsAndLegality(typeConverter, patterns, target); + populateMathPatternsAndLegality(typeConverter, patterns, target); + populateTritonPatterns(typeConverter, patterns, numCTAs); + // TODO: can we use + // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? + populateSCFPatterns(typeConverter, patterns); + populateCFPatterns(typeConverter, patterns); + + auto inti = llvm::APSInt(32, false); + auto i32_ty = IntegerType::get(mod->getContext(), 32); + + mod->setAttr( + AttrNumWarpsName, + IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue()))); + mod->setAttr( + AttrNumThreadsPerWarp, + IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue()))); + + mod->setAttr(AttrNumCTAsName, + IntegerAttr::get(i32_ty, llvm::APInt(32, numCTAs.getValue()))); + + if (this->target.getValue().empty()) { + mod.emitError("expected target specification to attach to the module op"); + return signalPassFailure(); + } + mod->setAttr(AttrTargetName, + StringAttr::get(context, this->target.getValue())); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + + // update layouts + // broadcast src => multicast, dst => broadcasted + // if (failed(target.refineLayouts(mod, numWarps))) + // return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass(const std::string &target, + int numWarps, + int threadsPerWarp, + int numCTAs) { + return std::make_unique<::ConvertTritonToTritonGPU>(target, numWarps, + threadsPerWarp, numCTAs); +} + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass() { + return std::make_unique<::ConvertTritonToTritonGPU>(); +} diff --git a/third_party/iluvatar/lib/Dialect/CMakeLists.txt b/third_party/iluvatar/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..27cb65ce5 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) diff --git a/third_party/iluvatar/lib/Dialect/Triton/CMakeLists.txt b/third_party/iluvatar/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/iluvatar/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/iluvatar/lib/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..752daa7ff --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonIR + Dialect.cpp + Ops.cpp + Traits.cpp + Types.cpp + + DEPENDS + TritonTableGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/iluvatar/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/iluvatar/lib/Dialect/Triton/IR/Dialect.cpp new file mode 100644 index 000000000..8f46e8ca8 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/IR/Dialect.cpp @@ -0,0 +1,138 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/DialectImplementation.h" + +#include "mlir/Transforms/InliningUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" +#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; + +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TritonInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + auto funcOp = dyn_cast(callable); + if (!funcOp) + return true; + if (funcOp->hasAttr("noinline")) + return !funcOp->getAttrOfType("noinline").getValue(); + return true; + } + + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + IRMapping &) const final { + return true; + } + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create(op->getLoc(), newDest, + returnOp.getOperands()); + op->erase(); + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only return needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } +}; + +struct TensorModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getRank(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementTypeBitWidth(); + } +}; + +struct MemDescModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getShape().size(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementType().getIntOrFloatBitWidth(); + } +}; + +} // namespace + +void TritonDialect::initialize() { + registerTypes(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + >(); + + // We can also add interface here. + addInterfaces(); + + RankedTensorType::attachInterface(*getContext()); + MemDescType::attachInterface(*getContext()); +} + +Operation *TritonDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/iluvatar/lib/Dialect/Triton/IR/Ops.cpp b/third_party/iluvatar/lib/Dialect/Triton/IR/Ops.cpp new file mode 100644 index 000000000..f654a5e22 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/IR/Ops.cpp @@ -0,0 +1,1109 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +namespace mlir { +namespace triton { + +// Parser & printer for assembly forms +ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { + // Parse operands + SmallVector allOperands; + + SMLoc allOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(allOperands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()) + return failure(); + + // Operand types + SmallVector operandTypes; + + // Parse `optional(type(ptr)) -> type(result)` + Type ptrType, resultType; + if (parser.parseType(resultType)) + return failure(); + if (parser.parseOptionalArrow().succeeded()) { + ptrType = resultType; + if (parser.parseType(resultType)) + return failure(); + operandTypes.push_back(ptrType); + result.addTypes(resultType); + } else { + operandTypes.push_back(getPointerTypeSameShape(resultType)); + result.addTypes(resultType); + } + + // Determine `mask` and `other` + int hasMask = 0, hasOther = 0; + if (allOperands.size() == 2) { + operandTypes.push_back(getI1SameShape(resultType)); + hasMask = 1; + } + if (allOperands.size() == 3) { + operandTypes.push_back(getI1SameShape(resultType)); + operandTypes.push_back(resultType); + hasMask = 1; + hasOther = 1; + } + // Determine `inputStride` + int hasStride = 0; + if (allOperands.size() == 4) { + operandTypes.push_back( + IntegerType::get(parser.getBuilder().getContext(), 32)); + operandTypes.push_back( + IntegerType::get(parser.getBuilder().getContext(), 32)); // placeHolder0 + operandTypes.push_back( + IntegerType::get(parser.getBuilder().getContext(), 32)); // placeHolder1 + hasStride = 1; + } + + if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, + result.operands)) + return failure(); + + // Deduce `operandSegmentSizes` from the number of the operands + auto operandSegmentSizesAttrName = + LoadOp::getOperandSegmentSizesAttrName(result.name); + result.addAttribute( + operandSegmentSizesAttrName, + parser.getBuilder().getDenseI32ArrayAttr( + {1, hasMask, hasOther, hasStride, hasStride, hasStride})); + + return success(); +} + +void LoadOp::print(OpAsmPrinter &printer) { + printer << " "; + printer << getOperation()->getOperands(); + + // `operandSegmentSizes` can be deduced, so we don't print it. + printer.printOptionalAttrDict(getOperation()->getAttrs(), + {getOperandSegmentSizesAttrName()}); + + // `type(ptr) -> type(result)` + printer << " : "; + // `type(ptr)` is optional during parsing, we only print for tensor pointers + if (isTensorPointerType(getPtr().getType())) { + printer.printStrippedAttrOrType(getPtr().getType()); + printer << " -> "; + } + printer.printStrippedAttrOrType(getResult().getType()); +} + +void LoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getPtr(), + triton::GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace triton +} // namespace mlir + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { + +//-- LoadOp -- +static Type getLoadOpResultType(OpBuilder &builder, Type ptrType) { + auto ptrTensorType = ptrType.dyn_cast(); + if (!ptrTensorType) + return ptrType.cast().getPointeeType(); + auto shape = ptrTensorType.getShape(); + Type elementType = + ptrTensorType.getElementType().cast().getPointeeType(); + return RankedTensorType::get(shape, elementType); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + // Operands + state.addOperands(ptr); + if (mask) { + state.addOperands(mask); + if (other) { + state.addOperands(other); + } + } + + // Attributes + state.addAttribute(getOperandSegmentSizesAttrName(state.name), + builder.getDenseI32ArrayAttr( + {1, (mask ? 1 : 0), (other ? 1 : 0), 0, 0, 0})); + state.addAttribute( + getBoundaryCheckAttrName(state.name), + DenseI32ArrayAttr::get(builder.getContext(), boundaryCheck)); + if (padding.has_value()) { + state.addAttribute( + getPaddingAttrName(state.name), + PaddingOptionAttr::get(builder.getContext(), padding.value())); + } + state.addAttribute(getCacheAttrName(state.name), + CacheModifierAttr::get(builder.getContext(), cache)); + state.addAttribute(getEvictAttrName(state.name), + EvictionPolicyAttr::get(builder.getContext(), evict)); + state.addAttribute(getIsVolatileAttrName(state.name), + builder.getBoolAttr(isVolatile)); + + // Result type + Type resultType = getLoadOpResultType(builder, ptr.getType()); + state.addTypes({resultType}); +} + +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { + CanonicalizeMaskedLoadPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), + loadOp.getInputStride(), loadOp.getInputStride(), + loadOp.getInputStride()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return success(); + } +}; + +void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- StoreOp -- +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + /*boundaryCheck=*/{}, cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, Value mask, CacheModifier cache, + EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, + cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, ArrayRef boundaryCheck, + CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + builder.getDenseI32ArrayAttr(boundaryCheck), cache, + evict); +} + +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern : public OpRewritePattern { + CanonicalizeMaskedStorePattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return success(); + } +}; + +void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- TransOp -- +OpFoldResult TransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + return getSrc(); + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + return {}; +} + +LogicalResult TransOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input + auto argTy = cast(operands[0].getType()); + auto order = properties.as()->order.asArrayRef(); + SmallVector retShape = applyPermutation(argTy.getShape(), order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, order, retEncoding) + .failed()) { + return failure(); + } + } + if (isa(argTy)) { + inferredReturnTypes.push_back( + MemDescType::get(retShape, retEltTy, retEncoding)); + } else { + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +LogicalResult TransOp::verify() { + // Check that the op's `order` attribute is a permutation of the right length. + auto srcTy = getSrc().getType(); + + ArrayRef order = getOrder(); + if (order.size() != srcTy.getRank()) { + return emitError("order must have the same size as the rank of the " + "operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +//-- DotOp -- +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = dyn_cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult DotOp::verify() { + auto aTy = getA().getType(); + auto bTy = getB().getType(); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding && !bEncoding) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + Dialect &dialect = aEncoding.getDialect(); + auto interface = cast(&dialect); + return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, + bEncoding); +} + +//-- MakeRangeOp -- +OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { + // make_range(start, start + 1) -> constant(start) + if (adaptor.getStart() + 1 == adaptor.getEnd()) { + auto shapedType = cast(getType()); + return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); + } + return {}; +} + +LogicalResult MakeRangeOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start > end) { + return this->emitOpError() << "start must be less than or equal to end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start != ty.getShape()[0]) { + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must match size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32)) { + return this->emitOpError() << "returned tensor must have i32 elements"; + } + return success(); +} + +//-- ReduceOp -- +static LogicalResult +inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy, + int axis, bool noWarpReduce, + SmallVectorImpl &inferredReturnTypes) { + auto retShape = argTy.getShape().vec(); + retShape.erase(retShape.begin() + axis); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(retEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = + dyn_cast(&dialect); + if (inferLayoutInterface + ->inferReduceOpEncoding(argEncoding, axis, noWarpReduce, + retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return failure(); + } + } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +void ReduceOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis, bool noWarpReduce) { + SmallVector inferredReturnTypes; + for (unsigned i = 0; i < operands.size(); ++i) { + auto argTy = cast(operands[i].getType()); + auto retEltTy = argTy.getElementType(); + (void)inferReduceReturnShape(argTy, retEltTy, axis, noWarpReduce, + inferredReturnTypes); + } + + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis, + noWarpReduce); +} + +LogicalResult ReduceOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + bool noWarpReduce = prop->noWarpReduce.getValue(); + for (auto arg : operands) { + auto argTy = cast(arg.getType()); + auto retEltTy = argTy.getElementType(); + if (inferReduceReturnShape(argTy, retEltTy, axis, noWarpReduce, + inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +// Helpers for Reductions and Scans +template LogicalResult verifyReduceScan(Op &op) { + if (op.getOperands().empty()) { + return op.emitOpError() << "must have at least 1 operand"; + } + if (op.getNumOperands() != op.getNumResults()) { + return op.emitOpError() << "must have the same number of inputs as outputs"; + } + + auto getElementType = [](Type ty) { + if (auto tensorType = dyn_cast(ty)) { + return tensorType.getElementType(); + } + return ty; + }; + + for (auto [opElemTy, resTy] : + llvm::zip(op.getElementTypes(), op.getResultTypes())) { + if (opElemTy != getElementType(resTy)) { + return op.emitOpError() << "operand types and result types must agree"; + } + } + return success(); +} + +template +static LogicalResult verifyRegionsImpl(Op &op) { + auto argElementTypes = op.getElementTypes(); + const auto &operands = op.getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *op.getBody(); + if (block.getNumArguments() != numArgs) { + return op.emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; + if (blockArgTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return op.emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return op.emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return success(); +} + +static llvm::SmallVector +getInputTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcTys; + srcTys.reserve(operands.size()); + for (const auto &ty : operands.getTypes()) { + srcTys.push_back(cast(ty)); + } + return srcTys; +} + +static llvm::SmallVector +getElementTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(operands.size()); + for (const auto &op : operands) { + srcElemTys.push_back(cast(op.getType()).getElementType()); + } + return srcElemTys; +} + +LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ReduceOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ReduceOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + +//-- ScanOp -- +void ScanOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis, bool reverse) { + SmallVector inferredReturnTypes; + state.addAttribute("reverse", builder.getBoolAttr(reverse)); + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + ScanOp::build(builder, state, inferredReturnTypes, operands, axis, reverse); +} + +LogicalResult +ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + return success(); +} + +LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ScanOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ScanOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ScanOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } + +//-- SplatOp -- +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getSrc(); + if (!value) + return {}; + auto shapedType = cast(getType()); + auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); + return ret; +} + +//-- ExpandDimsOp -- +LogicalResult ExpandDimsOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + auto arg = operands[0]; + auto argTy = cast(arg.getType()); + auto retShape = argTy.getShape().vec(); + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + retShape.insert(retShape.begin() + axis, 1); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc) + .failed()) + return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return success(); +} + +LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + // expand_dims(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) + // + // On its own this doesn't do much, but consider + // broadcast(expand_dims(broadcast)) + // -> broadcast(broadcast(expand_dims)) + // -> broadcast(expand_dims) + if (auto broadcast = dyn_cast(definingOp)) { + auto src = broadcast.getSrc(); + auto srcTy = src.getType(); + SmallVector newExpandShape(srcTy.getShape()); + newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); + + // Infer the encoding of the new expand op, if encodings are present. + Attribute newExpandEnc; + if (auto srcEnc = srcTy.getEncoding()) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc, + op.getLoc()) + .failed()) { + return emitOptionalError(op.getLoc(), + "failed to infer layout for ExpandDimsOp"); + } + } + + auto newExpandTy = RankedTensorType::get( + newExpandShape, srcTy.getElementType(), newExpandEnc); + auto newExpand = rewriter.create(op.getLoc(), newExpandTy, + src, op.getAxis()); + auto newBroadcast = rewriter.create( + broadcast.getLoc(), op.getType(), newExpand.getResult()); + rewriter.replaceOp(op, {newBroadcast.getResult()}); + return success(); + } + + return failure(); +} + +template +static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { + if (!value) + return {}; + + auto shapedType = cast(op.getType()); + if (auto denseElemsAttr = dyn_cast(value)) { + if (denseElemsAttr.isSplat()) { + return denseElemsAttr.resizeSplat(shapedType); + } else { + return denseElemsAttr.reshape(shapedType); + } + } + return {}; +} + +OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +//-- ReshapeOp -- +template +LogicalResult canonicalizeViewOrBroadcast(OpType op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + + // view(view) -> view + if (auto parentView = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, TypeRange({op.getType()}), + parentView->getOperands(), + parentView->getAttrs()); + return success(); + } + + // view(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + + return failure(); +} + +LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { + if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + return failure(); + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +LogicalResult ReshapeOp::verify() { + auto dstTy = getType(); + auto srcTy = getSrc().getType(); + if (getType().getNumElements() != srcTy.getNumElements()) { + return emitError( + "number of src and dst elements of reshape must be the same"); + } + + Attribute srcEnc = srcTy.getEncoding(); + Attribute dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("Op requires that either (a) src and dst both have " + "encodings, or (b) neither does."); + } + + if (srcEnc && !getAllowReorder()) { + Attribute inferredDstEnc; + if (cast(&srcEnc.getDialect()) + ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, + dstTy.getShape(), inferredDstEnc, + getLoc()) + .failed()) { + return emitError("This reshape is impossible without reordering, but " + "reordering is not allowed. Try choosing a different " + "encoding for the input tensor (or allow reordering)."); + } + if (inferredDstEnc != dstEnc) { + return emitError("Expected result encoding ") + << inferredDstEnc << " but was " << dstEnc; + } + } + + return success(); +} + +//-- FpToFpOp -- +LogicalResult FpToFpOp::verify() { + auto dstType = getType().getElementType(); + auto srcType = getSrc().getType().getElementType(); + if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && + (!getRounding().has_value())) { + return emitError("Rounding mode is required for FP downcast"); + } + return success(); +} + +//-- BroadcastOp -- +LogicalResult BroadcastOp::canonicalize(BroadcastOp op, + PatternRewriter &rewriter) { + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + auto value = adaptor.getSrc(); + if (!value) + return {}; + + if (auto denseElemsAttr = dyn_cast(value)) { + auto shapedType = cast(getType()); + return denseElemsAttr.resizeSplat(shapedType); + } + return {}; +} + +//-- MakeTensorPtrOp -- +void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ValueRange offsets, ArrayRef tensorShape, + ArrayRef order) { + // Get pointer type from `base` + auto pointerType = cast(base.getType()); + assert(pointerType != nullptr); + + // Build type `tt.ptr>` + auto tensorType = RankedTensorType::get( + SmallVector(tensorShape.begin(), tensorShape.end()), + pointerType.getPointeeType()); + auto result = PointerType::get(tensorType, 1); + + return build(builder, state, result, base, shape, strides, offsets, + builder.getDenseI32ArrayAttr(order)); +} + +// The following ops, including `call`, `func`, and `return` are copied and +// modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp +// We could revert it back once MLIR has a better inliner interface. +//-- FuncOp -- +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &printer) { + function_interface_impl::printFunctionOp( + printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +// -- CallOp -- +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this).getProperties().callee; + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +// -- ReturnOp -- +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); + + return success(); +} + +// -- JoinOp -- +LogicalResult +JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 2); + assert(operands[0].getType() == operands[1].getType()); + assert(isa(operands[0].getType())); + assert(isa(operands[1].getType())); + + Value lhs = operands[0]; + Value rhs = operands[1]; + auto srcTy = cast(lhs.getType()); + + SmallVector retShape(srcTy.getShape()); + retShape.push_back(2); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferJoinOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); + return success(); +} + +// -- SplitOp -- +LogicalResult SplitOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 1); + assert(isa(operands[0].getType())); + + Value src = operands[0]; + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + + if (srcShape.empty() || srcShape.back() != 2) { + return emitOptionalError(location, + "last dimension of input tensor must be 2"); + } + ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferSplitOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); + inferredReturnTypes.push_back(retTy); + inferredReturnTypes.push_back(retTy); + return success(); +} + +// -- ElementwiseInlineAsmOp -- +void ElementwiseInlineAsmOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +LogicalResult ElementwiseInlineAsmOp::verify() { + if (getNumOperands() >= 1) { + auto tensorType = dyn_cast(getOperand(0).getType()); + size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; + if (numInputElems % this->getPackedElement() != 0) { + return emitError("number of input elements ") + << numInputElems + << " must be a multiple of the op's packed_element attribute, " + << getPackedElement(); + } + } + return success(); +} + +// -- ExternElementwiseOp -- +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/Triton/IR/Traits.cpp b/third_party/iluvatar/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 000000000..19729aee5 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,239 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +#include + +#include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +static LogicalResult verifySameEncoding(Type typeA, Type typeB, + bool allowTensorPointerType) { + // TODO(Keren): the allowTensorPointerType argument is a hack to allow. + // The type checking code is kind of a mess with the current design. + auto getEncoding = [=](Type type) -> Attribute { + Attribute ret; + if (auto tensorType = dyn_cast(type)) { + ret = tensorType.getEncoding(); + } + if (!allowTensorPointerType) { + assert(!triton::isTensorPointerType(type)); + } + return ret; + }; + auto encodingA = getEncoding(typeA); + auto encodingB = getEncoding(typeB); + if (!encodingA || !encodingB) + return success(); + return encodingA == encodingB ? success() : failure(); +} + +LogicalResult +OpTrait::impl::verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) + return op->emitOpError() << "requires the same encoding for all operands"; + + return success(); +} + +LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( + Operation *op, bool allowTensorPointerType) { + if (op->getNumOperands() == 0) + return success(); + + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto resultType : op->getResultTypes()) + if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) + return op->emitOpError() + << "requires the same encoding for all operands and results"; + + return verifySameOperandsEncoding(op, allowTensorPointerType); +} + +LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + return success(); +} + +// Check that the Triton layouts on op's operands and return types are valid. +// For example, we check that the number of warps per block in a Triton GPU +// blocked layout matches that of its module. +// +// It's a little weird to check these properties of a layout only when the +// layout is used in an op, since most of the properties don't actually depend +// on the op. They do depend on the *module*, though, and a layout is attached +// to a module only by virtue of being used in one of the module's ops. +LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { + auto module = op->getParentOfType(); + auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { + // Only ranked tensors can have layouts. + auto rankedTy = dyn_cast(val.getType()); + if (!rankedTy) + return success(); + + mlir::Attribute layout = rankedTy.getEncoding(); + if (!layout) + return success(); + + if (isa(layout)) + return makeErr() << "Shared layout is not allowed on tensor type."; + // TODO(jlebar): Currently this only checks blocked layouts, but other + // layouts also have invariants! + + // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. + if (auto blocked = dyn_cast(layout)) { + // A different verifier should have checked that the layout itself is + // valid, including that threads-per-warp has the same rank as + // warps-per-block etc. + auto layoutRank = blocked.getThreadsPerWarp().size(); + if (layoutRank != rankedTy.getRank()) { + return makeErr() << layout << ".\nLayout has rank " << layoutRank + << ", but the tensor it's attached to has rank " + << rankedTy.getRank() << "."; + } + + int moduleThreadsPerWarp = + ttg::TritonGPUDialect::getThreadsPerWarp(module); + int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); + if (layoutThreadsPerWarp != moduleThreadsPerWarp) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutThreadsPerWarp + << " threads per warp, but the module specifies " + << moduleThreadsPerWarp << " threads per warp."; + } + + int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module); + int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); + if (layoutWarpsPerCTA != moduleWarpsPerCTA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutWarpsPerCTA + << " warps per CTA, but the module specifies " + << moduleWarpsPerCTA << " warps per CTA."; + } + + if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { + int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module); + int64_t layoutCTAsPerCGA = + product(blocked.getCTALayout().getCTAsPerCGA()); + if (layoutCTAsPerCGA != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutCTAsPerCGA + << " CTAs per CGA, but the module specifies " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + } + } + + return success(); + }; + + for (size_t i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + auto err = checkLayout(operand, [&]() { + // Stringify the operand using `printAsOperand`. This prints e.g. "%42" + // rather than the full definition. + std::string operandStr; + llvm::raw_string_ostream os(operandStr); + // If we don't assume verified, dump() will recursively call this + // function! + operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); + + return op->emitError("Operand ") + << i << " (" << operand << ") has an invalid layout: "; + }); + if (!err.succeeded()) + return err; + } + + for (size_t i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + auto err = checkLayout(result, [&]() { + if (op->getNumResults() == 1) { + return op->emitError("Result has an invalid layout: "); + } else { + return op->emitError("Result ") << i << " has an invalid layout: "; + } + }); + if (!err.succeeded()) + return err; + } + + return success(); +} + +static ArrayRef getTypeShape(Type type) { + auto rankedType = dyn_cast(type); + if (auto ptrType = dyn_cast(type)) + rankedType = dyn_cast(ptrType.getPointeeType()); + return rankedType ? rankedType.getShape() : ArrayRef(); +} + +LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() << "requires the same shape for all operands"; + + return success(); +} + +LogicalResult +OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : op->getResultTypes()) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + + return verifySameLoadStoreOperandsShape(op); +} diff --git a/third_party/iluvatar/lib/Dialect/Triton/IR/Types.cpp b/third_party/iluvatar/lib/Dialect/Triton/IR/Types.cpp new file mode 100644 index 000000000..0e1df5b74 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/IR/Types.cpp @@ -0,0 +1,171 @@ +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void TritonDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + >(); +} + +Type PointerType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + int addressSpace = 1; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseInteger(addressSpace)) + return Type(); + } + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType, addressSpace); +} + +void PointerType::print(AsmPrinter &printer) const { + if (getAddressSpace() == 1) { + printer << "<" << getPointeeType() << ">"; + } else { + printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; + } +} + +static constexpr llvm::StringRef kMutableMemory = "mutable"; + +Type MemDescType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false)) + return Type(); + + // Parse the element type. + Type elementType; + if (parser.parseType(elementType)) + return Type(); + + Attribute encoding; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(encoding)) + return Type(); + } + bool mutableMemory = false; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOptionalKeyword(kMutableMemory)) + return Type(); + mutableMemory = true; + } + if (parser.parseGreater()) + return Type(); + + return MemDescType::get(parser.getContext(), dimensions, elementType, + encoding, mutableMemory); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + for (auto dim : getShape()) + printer << dim << "x"; + printer << getElementType(); + if (getEncoding()) + printer << ", " << getEncoding(); + if (getMutableMemory()) + printer << ", " << kMutableMemory; + printer << ">"; +} + +namespace mlir { + +namespace triton { + +unsigned getPointeeBitWidth(Type type) { + auto pointeeType = getPointeeType(type); + if (auto tensorTy = dyn_cast(pointeeType)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + return pointeeType.getIntOrFloatBitWidth(); +} + +Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i1Type, + tensorTy.getEncoding()); + return i1Type; +} + +Type getPointeeType(Type type) { + if (auto tensorTy = dyn_cast(type)) { + // Tensor of pointers + auto shape = tensorTy.getShape(); + auto ptrType = dyn_cast(tensorTy.getElementType()); + Type pointeeType = ptrType.getPointeeType(); + return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return type; +} + +Type getI32SameShape(Type type) { + auto i32Type = IntegerType::get(type.getContext(), 32); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i32Type, + tensorTy.getEncoding()); + return i32Type; +} + +Type getPointerTypeSameShape(Type type) { + if (auto tensorTy = dyn_cast(type)) { + Type elementType = tensorTy.getElementType(); + auto shape = tensorTy.getShape(); + PointerType ptrType = PointerType::get(elementType, 1); + return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding()); + } else { + return PointerType::get(type, 1); + } +} + +Type getPointerType(Type type) { return PointerType::get(type, 1); } + +bool isTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + return isa(ptrType.getPointeeType()); + return false; +} + +bool isTensorOrTensorPointerType(Type type) { + return isa(type) || isTensorPointerType(type); +} + +Type getElementTypeOfTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) + return tensorTy.getElementType(); + return {}; +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/iluvatar/lib/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..298398750 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,18 @@ +set(LLVM_TARGET_DEFINITIONS Combine.td) +mlir_tablegen(TritonCombine.inc -gen-rewriters) +add_public_tablegen_target(TritonCombineIncGen) + +add_triton_library(TritonTransforms + Combine.cpp + ReorderBroadcast.cpp + RewriteTensorPointer.cpp + + DEPENDS + TritonTransformsIncGen + TritonCombineIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTransformUtils + TritonIR +) diff --git a/third_party/iluvatar/lib/Dialect/Triton/Transforms/Combine.cpp b/third_party/iluvatar/lib/Dialect/Triton/Transforms/Combine.cpp new file mode 100644 index 000000000..f905b521b --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/Transforms/Combine.cpp @@ -0,0 +1,257 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace mlir::triton { +namespace { + +bool isZero(Value val) { + if (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())) + return true; + // broadcast(constant_0) + if (auto bc = val.getDefiningOp()) { + if (matchPattern(bc.getSrc(), m_Zero()) || + matchPattern(bc.getSrc(), m_AnyZeroFloat())) + return true; + } + return false; +} + +bool isBroadcastConstantCombinable(Attribute value) { + if (auto denseValue = dyn_cast(value)) { + return denseValue.isSplat(); + } + return isa(value); +} + +DenseElementsAttr getConstantValue(Builder &builder, Attribute value, + Value bcast_res) { + auto resType = cast(bcast_res.getType()); + DenseElementsAttr res; + if (auto denseValue = dyn_cast(value)) { + res = + DenseElementsAttr::get(resType, denseValue.getSplatValue()); + } else { + res = DenseElementsAttr::get(resType, value); + } + return res; +} + +bool isAddPtrOffsetCombinable(Value first, Value second) { + auto GetConstantIntValue = [](Value val) -> std::optional { + DenseElementsAttr constAttr; + auto defOp = val.getDefiningOp(); + if (defOp) { + if (auto splatOp = llvm::dyn_cast(defOp)) + val = splatOp.getSrc(); + else if (matchPattern(defOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto attr = constAttr.getSplatValue(); + // Check IntegerAttr + if (auto intAttr = dyn_cast_or_null(attr)) + return intAttr.getValue(); + } + } + + // Check constant value. + llvm::APInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal; + + return std::nullopt; + }; + + if (first.getType() == second.getType()) { + // Whether bitwidth of element type is equal to pointer + if (getElementTypeOrSelf(first.getType()).getIntOrFloatBitWidth() == 64) + return true; + + // first + second does not overflow + auto firstVal = GetConstantIntValue(first); + auto secondVal = GetConstantIntValue(second); + if (firstVal && secondVal) { + bool overflow = false; + auto resVal = firstVal->sadd_ov(*secondVal, overflow); + return !overflow; + } + } + return false; +} + +// TODO(csigg): remove after next LLVM integrate. +using FastMathFlags = arith::FastMathFlags; + +#include "TritonCombine.inc" + +// select(cond, load(ptrs, splat(cond), ???), other) +// => load(ptrs, splat(cond), other) +class CombineSelectMaskedLoadPattern : public RewritePattern { +public: + CombineSelectMaskedLoadPattern(MLIRContext *context) + : RewritePattern(arith::SelectOp::getOperationName(), 3, context, + {LoadOp::getOperationName()}) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto selectOp = llvm::dyn_cast(op); + if (!selectOp) + return failure(); + + Value trueValue = selectOp.getTrueValue(); + Value falseValue = selectOp.getFalseValue(); + Value condSelect = selectOp.getCondition(); + + auto *loadOpCandidate = trueValue.getDefiningOp(); + auto loadOp = llvm::dyn_cast_or_null(loadOpCandidate); + if (!loadOp) + return failure(); + + Value mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto *splatOpCandidate = mask.getDefiningOp(); + auto splatOp = llvm::dyn_cast_or_null(splatOpCandidate); + if (!splatOp) + return failure(); + + auto splatCond = splatOp.getSrc(); + if (splatCond != condSelect) + return failure(); + + rewriter.replaceOpWithNewOp( + op, loadOp.getPtr(), loadOp.getMask(), /*other=*/falseValue, + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), + loadOp.getInputStride(), loadOp.getInputStride(), + loadOp.getInputStride()); + return success(); + } +}; + +// sum(x[:, :, None] * y[None, :, :], 1) +// -> dot(x, y) +class CombineBroadcastMulReducePattern : public RewritePattern { +private: + static bool isAddF32(const Operation *op) { + if (auto addf = dyn_cast_or_null(op)) + return addf.getType().getIntOrFloatBitWidth() <= 32; + return false; + } + + static SmallVector getEqualIndices(ArrayRef x, + ArrayRef y) { + SmallVector res; + for (int i = 0; i < x.size(); ++i) + if (x[i] == y[i]) + res.push_back(i); + return res; + } + +public: + CombineBroadcastMulReducePattern(MLIRContext *context) + : RewritePattern(ReduceOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto reduceOp = llvm::dyn_cast(op); + if (!reduceOp) + return failure(); + // only support reduce with simple addition + Region &combineOp = reduceOp.getCombineOp(); + bool isReduceAdd = combineOp.hasOneBlock() && + combineOp.front().getOperations().size() == 2 && + isAddF32(&*combineOp.front().getOperations().begin()); + if (!isReduceAdd) + return failure(); + // operand of reduce has to be mul + auto mulOp = llvm::dyn_cast_or_null( + reduceOp.getOperand(0).getDefiningOp()); + if (!mulOp) + return failure(); + // mul operand has to be broadcast + auto broadcastLhsOp = llvm::dyn_cast_or_null( + mulOp.getOperand(0).getDefiningOp()); + if (!broadcastLhsOp) + return failure(); + auto broadcastRhsOp = llvm::dyn_cast_or_null( + mulOp.getOperand(1).getDefiningOp()); + if (!broadcastRhsOp) + return failure(); + // broadcast operand is expand dims + auto expandLhsOp = llvm::dyn_cast_or_null( + broadcastLhsOp.getSrc().getDefiningOp()); + if (!expandLhsOp) + return failure(); + auto expandRhsOp = llvm::dyn_cast_or_null( + broadcastRhsOp.getSrc().getDefiningOp()); + if (!expandRhsOp) + return failure(); + // get not-broadcast dimensions + int expandLhsAxis = expandLhsOp.getAxis(); + int expandRhsAxis = expandRhsOp.getAxis(); + if (expandLhsAxis != 2 || expandRhsAxis != 0) + return failure(); + auto broadcastLhsShape = + cast(broadcastLhsOp.getType()).getShape(); + auto broadcastRhsShape = + cast(broadcastLhsOp.getType()).getShape(); + if (broadcastLhsShape[2] < 16 || broadcastRhsShape[0] < 16) + return failure(); + Type newAccType = RankedTensorType::get( + {broadcastLhsShape[0], broadcastRhsShape[2]}, + cast(broadcastLhsOp.getSrc().getType()).getElementType()); + rewriter.setInsertionPoint(op); + auto newAcc = rewriter.create( + op->getLoc(), newAccType, + rewriter.create(op->getLoc(), + rewriter.getF32FloatAttr(0))); + rewriter.replaceOpWithNewOp(op, expandLhsOp.getSrc(), + expandRhsOp.getSrc(), newAcc, + InputPrecision::TF32, 0); + return success(); + } +}; + +class CombineOpsPass : public TritonCombineOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + // Dot Add %{ + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + // %} + patterns.add(context); + // patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // anonymous namespace + +std::unique_ptr createCombineOpsPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/third_party/iluvatar/lib/Dialect/Triton/Transforms/Combine.td b/third_party/iluvatar/lib/Dialect/Triton/Transforms/Combine.td new file mode 100644 index 000000000..2776ad5b0 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/Transforms/Combine.td @@ -0,0 +1,54 @@ +#ifndef TRITON_PATTERNS +#define TRITON_PATTERNS + +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "triton/Dialect/Triton/IR/TritonOps.td" +include "mlir/IR/PatternBase.td" + + +// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) + +// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +def CombineDotAddIPattern : Pat< + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFPattern : Pat< + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +def CombineDotAddIRevPattern : Pat< + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFRevPattern : Pat< + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1)) +// Note: leave (sub %c0, %c0) canceling to ArithDialect +// (ref: ArithCanonicalization.td) +// defvar DefOverflow = ConstantEnumCase; +// def CombineAddPtrPattern : Pat< +// (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), +// (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)), +// [(Constraint> $idx0, $idx1)]>; + +// broadcast(cst) => cst +def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; +def CombineBroadcastConstantPattern : Pat< + (TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)), + (Arith_ConstantOp (getConstantValue $value, $bcast_res), (location $bcast_res)), + [(Constraint> $value)]>; + +#endif diff --git a/third_party/iluvatar/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/third_party/iluvatar/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp new file mode 100644 index 000000000..43479a3d9 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -0,0 +1,247 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +// TODO(jlebar): Move this and all other generatede code into namespace +// mlir::triton. +#define GEN_PASS_DEF_TRITONREORDERBROADCAST +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace mlir::triton { +namespace { + +Operation *cloneWithNewArgsAndResultTypes(PatternRewriter &rewriter, + Operation *op, ValueRange newOperands, + TypeRange newTypes) { + OperationState newElementwiseState(op->getLoc(), op->getName()); + newElementwiseState.addOperands(newOperands); + newElementwiseState.addTypes(newTypes); + newElementwiseState.addAttributes(op->getAttrs()); + return rewriter.create(newElementwiseState); +} + +bool isSplat(Operation *op) { + if (auto splatOp = llvm::dyn_cast(op)) { + return true; + } + DenseElementsAttr constAttr; + return (matchPattern(op, m_Constant(&constAttr)) && constAttr.isSplat()); +} + +// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) +struct MoveSplatAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveSplatAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult match(Operation *op) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + for (auto operand : op->getOperands()) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + + if (!isSplat(definingOp)) { + return failure(); + } + } + return success(op->getNumOperands() > 0); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto operands = op->getOperands(); + + llvm::SmallVector scalarOperands(operands.size()); + for (unsigned iOp = 0; iOp < operands.size(); ++iOp) { + auto definingOp = operands[iOp].getDefiningOp(); + + DenseElementsAttr constAttr; + if (auto splatOp = llvm::dyn_cast(definingOp)) { + scalarOperands[iOp] = splatOp.getSrc(); + } else if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto value = constAttr.getSplatValue(); + scalarOperands[iOp] = arith::ConstantOp::materialize( + rewriter, value, constAttr.getElementType(), loc); + } else { + llvm_unreachable("Expected a splat"); + } + } + + auto resultTypes = op->getResultTypes(); + llvm::SmallVector scalarResultTys; + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + scalarResultTys.push_back(elemTy); + } + + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, scalarOperands, + scalarResultTys); + + for (unsigned iRes = 0; iRes < resultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + } +}; + +// elementwise(broadcast(a)) => broadcast(elementwise(a)) +// This also generalizes to multiple arguments when the rest are splat-like +// Not handled: multiple broadcasted arguments +struct MoveBroadcastAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveBroadcastAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult match(Operation *op) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + auto operands = op->getOperands(); + bool seenBroadcast = false; + ArrayRef srcShape; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) { + return failure(); + } + auto getSrcShape = [](BroadcastOp b) { + return b.getSrc().getType().getShape(); + }; + if (auto broadcastOp = llvm::dyn_cast(definingOp)) { + if (!seenBroadcast) { + seenBroadcast = true; + srcShape = getSrcShape(broadcastOp); + } else if (srcShape != getSrcShape(broadcastOp)) { + // If the broadcast have different types we cannot re-order. + return failure(); + } + } else if (!isSplat(definingOp)) { + // Not splat or broadcast + return failure(); + } + } + return success(seenBroadcast); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Find broadcast op + auto operands = op->getOperands(); + BroadcastOp broadcastOp; + for (auto operand : operands) { + broadcastOp = operand.getDefiningOp(); + if (broadcastOp) { + break; + } + } + + auto srcTy = broadcastOp.getSrc().getType(); + auto srcShape = srcTy.getShape(); + auto srcEncoding = srcTy.getEncoding(); + + // Reshape operands to match srcShape + llvm::SmallVector newOperands; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (auto broadcastSrcOp = llvm::dyn_cast(definingOp)) { + newOperands.push_back(broadcastSrcOp.getSrc()); + continue; + } + auto elemTy = + dyn_cast(operand.getType()).getElementType(); + auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding); + if (auto splatOp = llvm::dyn_cast(definingOp)) { + auto newSplat = rewriter.create(loc, newTy, splatOp.getSrc()); + newOperands.push_back(newSplat); + continue; + } + DenseElementsAttr constAttr; + if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto scalarValue = constAttr.getSplatValue(); + auto splatValue = SplatElementsAttr::get(newTy, scalarValue); + auto newConstant = + rewriter.create(loc, newTy, splatValue); + newOperands.push_back(newConstant); + continue; + } + llvm_unreachable("Expected broadcast or splat"); + } + + // Reshape results to match srcShape + llvm::SmallVector newResultTypes; + auto resultTypes = op->getResultTypes(); + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + newResultTypes.push_back( + RankedTensorType::get(srcShape, elemTy, srcEncoding)); + } + + // Create new op and broadcast results + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, newOperands, + newResultTypes); + for (unsigned iRes = 0; iRes < newResultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + } +}; + +template +class CanonicalizePattern : public OpRewritePattern { +public: + explicit CanonicalizePattern(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + return OpType::canonicalize(op, rewriter); + } +}; + +class ReorderBroadcastPass + : public ::impl::TritonReorderBroadcastBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + patterns.add>(context); + patterns.add>(context); + // elementwise(broadcast(a)) => broadcast(elementwise(a)) + patterns.add(context); + // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr createReorderBroadcastPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/third_party/iluvatar/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/third_party/iluvatar/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp new file mode 100644 index 000000000..27b22b99e --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -0,0 +1,601 @@ +#include +#include + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +/// An additional struct to record the meta information of operations +/// with tensor pointers +struct RewritedInfo { +private: + Value base; + SmallVector shape; + SmallVector strides; + SmallVector offsets; + ArrayRef tensorShape; + ArrayRef order; + + // A cache to avoid generating the same offset with range + DenseMap cachedOffsetWithRange; + +public: + RewritedInfo() = default; + + RewritedInfo(const RewritedInfo &other) = default; + + RewritedInfo(Value base, const SmallVector &shape, + const SmallVector &strides, + const SmallVector &offsets, + const ArrayRef &tensorShape, + const ArrayRef &order) + : base(base), shape(shape), strides(strides), offsets(offsets), + tensorShape(tensorShape), order(order) { + assert(shape.size() == strides.size() && shape.size() == offsets.size() && + shape.size() == tensorShape.size() && shape.size() == order.size()); + } + + unsigned int length() const { return shape.size(); } + + Value getOffset(unsigned i) { return offsets[i]; } + + SmallVector getOffsets() { return offsets; } + +#if defined(__ILUVATAR__) + Value getContiguousStride() { + if (strides.size() == 2) + return strides[order[1]]; + return NULL; + } +#endif + + void setOffset(unsigned i, Value newOffset) { + offsets[i] = newOffset; + cachedOffsetWithRange.clear(); + } + + void setOffsets(const SmallVector &newOffsets) { + offsets = newOffsets; + cachedOffsetWithRange.clear(); + } + + Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + unsigned i) { + if (cachedOffsetWithRange.count(i)) + return cachedOffsetWithRange[i]; + + // Add range + auto indexI32RowType = + RankedTensorType::get({tensorShape[i]}, builder.getI32Type()); + auto indexRowType = + RankedTensorType::get({tensorShape[i]}, builder.getI64Type()); + Value splatOffset = + builder.create(loc, indexRowType, offsets[i]); + Value range = builder.create(loc, indexI32RowType, 0, + tensorShape[i]); + Value i64Range = builder.create(loc, indexRowType, range); + + // Expand dimensions + Value expandedResult = + builder.create(loc, splatOffset, i64Range); + for (int j = 0; j < tensorShape.size(); ++j) { + if (j == i) + continue; + expandedResult = + builder.create(loc, expandedResult, j); + } + + return cachedOffsetWithRange[i] = expandedResult; + } + + Value generatePtr(OpBuilder &builder, const Location &loc) { + assert(tensorShape.size() == offsets.size() && + tensorShape.size() == strides.size()); + auto indexTensorType = + RankedTensorType::get(tensorShape, builder.getI64Type()); + auto ptrType = cast(base.getType()); + auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType); + + // Generate offsets per dimension + Value ptr = builder.create(loc, ptrTensorType, base); + for (unsigned i = 0; i < tensorShape.size(); ++i) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = builder.create( + loc, offsetWithRange.getType(), strides[i]); + Value offsetWithStride = + builder.create(loc, offsetWithRange, splatStride); + Value broadcasted = builder.create( + loc, indexTensorType, offsetWithStride); + + // Add to the pointer + ptr = builder.create(loc, ptrTensorType, ptr, + broadcasted); + } + + return ptr; + } + + Value generateMask(OpBuilder &builder, const Location &loc, + const std::optional> &boundaryCheck) { + if (!boundaryCheck.has_value()) + return {}; + + // Generate mask per dimension + auto maskTensorType = + RankedTensorType::get(tensorShape, builder.getI1Type()); + Value mask; + for (auto i : boundaryCheck.value()) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // Compare with lower bound + Value lowerBound = builder.create( + loc, 0, builder.getI64Type()); + Value splatLowerBound = builder.create( + loc, offsetWithRange.getType(), lowerBound); + Value cmpLower = builder.create( + loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound); + + // Compare with upper bound + Value splatUpperBound = builder.create( + loc, offsetWithRange.getType(), shape[i]); + Value cmpUpper = builder.create( + loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound); + + // And and broadcast + Value andResult = builder.create(loc, cmpLower, cmpUpper); + Value broadcasted = + builder.create(loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = builder.create(loc, mask, broadcasted); + } + } + + return mask; + } + + Value generateOther(OpBuilder &builder, const Location &loc, + const std::optional &padding) { + if (!padding.has_value()) + return Value(); + + // Create element attribute + auto elementType = + cast(base.getType()).getPointeeType(); + auto otherTensorType = RankedTensorType::get(tensorShape, elementType); + + // Set zero padding value + TypedAttr attr = + elementType.isIntOrIndex() + ? cast(builder.getIntegerAttr(elementType, 0)) + : cast(builder.getFloatAttr(elementType, 0)); + + // Float NaN padding case + if (padding.value() == triton::PaddingOption::PAD_NAN) { + assert(!elementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + cast(attr).getValue().getSemantics()); + attr = builder.getFloatAttr(elementType, apNaN); + } + + // Create tensor + Value constant = builder.create(loc, attr); + return builder.create(loc, otherTensorType, constant); + } +}; + +} // namespace + +// TODO: this pass relies on assumptions of how block pointers are created and +// on pattern matches that walks the SSA links to find the base/strides. This is +// very fragile and to solve we should expose convert Ptr of tensor to a +// structure containins all values and not only offsets. +class RewriteTensorPointerPass + : public TritonRewriteTensorPointerBase { +private: + DenseMap rewritedInfo; + +public: + static bool needRewrite(Operation *op) { + return std::any_of(op->getOperands().begin(), op->getOperands().end(), + [](Value operand) { + return triton::isTensorPointerType(operand.getType()); + }); + } + + static SmallVector + generateNewOperands(const SmallVector &oldOperands, unsigned index, + const SmallVector &newValues) { + assert(index < oldOperands.size()); + SmallVector newOperands; + for (int i = 0; i < index; ++i) + newOperands.push_back(oldOperands[i]); + for (auto value : newValues) + newOperands.push_back(value); + for (auto i = index + 1; i < oldOperands.size(); ++i) + newOperands.push_back(oldOperands[i]); + return newOperands; + } + + Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, + triton::MakeTensorPtrOp op, + std::stack &eraser) { + // Save info for later use + auto ptrType = cast(op.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + + // Cast I32 offsets into I64 + SmallVector i64Offsets; + for (auto offset : op.getOffsets()) { + auto i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), offset); + i64Offsets.push_back(i64Offset); + } + + // Save information + rewritedInfo[op.getResult()] = + RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets, + tensorType.getShape(), op.getOrderAttr()); + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteAdvanceOp(OpBuilder &builder, triton::AdvanceOp op, + std::stack &eraser) { + // Get info from previous results + assert(rewritedInfo.count(op.getPtr())); + auto info = rewritedInfo[op.getPtr()]; + + // Calculate new offsets + assert(info.length() == op.getOffsets().size()); + SmallVector newOffsets; + for (int i = 0; i < info.length(); ++i) { + Value i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), op.getOffsets()[i]); + Value newOffset = builder.create( + op.getLoc(), info.getOffset(i), i64Offset); + newOffsets.push_back(newOffset); + } + + // Save info for later use + info.setOffsets(newOffsets); + rewritedInfo[op.getResult()] = info; + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op, + std::stack &eraser) { + assert(isa(op) || isa(op)); + + // We only have to rewrite load/stores with tensor pointers + auto ptr = op->getOperand(0); + if (!triton::isTensorPointerType(ptr.getType())) + return nullptr; + + // Get info from previous results + assert(rewritedInfo.count(ptr)); + auto info = rewritedInfo[ptr]; + + // Load/store with tensor pointers implicitly will check the bound while + // accessing memory, so we should set `mask` and `other` (according to the + // padding). Also note that load with tensor pointers do not have `mask` and + // `other` while building IR from Python AST + std::optional> boundaryCheck; + if (auto loadOp = dyn_cast(op)) { + assert(!loadOp.getMask() && !loadOp.getOther()); + boundaryCheck = loadOp.getBoundaryCheck(); + } else if (auto storeOp = dyn_cast(op)) { + assert(!storeOp.getMask()); + boundaryCheck = storeOp.getBoundaryCheck(); + } + + // Generate new `ptr`, `mask` and `other` + auto newPtr = info.generatePtr(builder, op->getLoc()); + auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); + Value newOther; + if (auto loadOp = dyn_cast(op)) + newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding()); + + // Create a new operation + if (auto loadOp = dyn_cast(op)) { +#if defined(__ILUVATAR__) + Value newResult; + Value resStride = info.getContiguousStride(); + if (!newMask && !newOther && resStride) { + Value matStride = builder.create( + loadOp.getLoc(), builder.getI32Type(), resStride); + newResult = builder.create( + loadOp.getLoc(), loadOp.getResult().getType(), newPtr, newMask, + newOther, loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(), + matStride, matStride, matStride); + } else { + newResult = builder.create( + loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + } + op->getResult(0).replaceAllUsesWith(newResult); +#else + auto newResult = builder.create( + loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + op->getResult(0).replaceAllUsesWith(newResult); +#endif + } else if (auto storeOp = dyn_cast(op)) { + builder.create(storeOp.getLoc(), newPtr, + storeOp.getValue(), newMask, + storeOp.getCache(), storeOp.getEvict()); + } + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteIfOp(OpBuilder &builder, scf::IfOp op, + std::stack &eraser) { + auto thenYieldOp = op.thenYield(); + assert(op.getNumResults() == thenYieldOp.getNumOperands()); + SmallVector results = thenYieldOp.getOperands(); + + // get new result types + SmallVector newRetTypes; + bool needRewrite = false; + for (unsigned i = 0; i < results.size(); ++i) { + if (!triton::isTensorPointerType(results[i].getType())) { + newRetTypes.push_back(results[i].getType()); + continue; + } + needRewrite = true; + auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + newRetTypes.push_back(builder.getI64Type()); + } + } + if (!needRewrite) + return op; + // create and clone new IfOp + bool hasElse = !op.getElseRegion().empty(); + scf::IfOp newOp = builder.create(op.getLoc(), newRetTypes, + op.getCondition(), hasElse); + IRMapping mapping; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + mapping.map(op->getOperand(i), newOp->getOperand(i)); + } + auto rematerialize = [&](Block *block) { + for (Operation &opInIf : block->getOperations()) { + auto newOp = builder.clone(opInIf, mapping); + } + }; + builder.setInsertionPointToStart(newOp.thenBlock()); + rematerialize(op.thenBlock()); + if (hasElse) { + builder.setInsertionPointToStart(newOp.elseBlock()); + rematerialize(op.elseBlock()); + } + + // update rewritedInfo + unsigned oldResIdx = 0, newResIdx = 0; + while (oldResIdx < results.size()) { + if (!triton::isTensorPointerType(results[oldResIdx].getType())) { + oldResIdx++; + newResIdx++; + } else { + auto makeTensorPtrOp = getMakeTensorPtrOp(results[oldResIdx]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + info.setOffset(j, newOp->getResult(newResIdx++)); + } + rewritedInfo[op.getResult(oldResIdx)] = info; + oldResIdx++; + } + } + + eraser.push(op); + return newOp; + } + + Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, + std::stack &eraser) { + // Generate new iteration operands and set rewrited information + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; + ++i, ++oldI) { + if (!triton::isTensorPointerType(newIterOperands[i].getType())) + continue; + + // Expand the tensor pointer into offsets + assert(rewritedInfo.count(newIterOperands[i])); + auto info = rewritedInfo[newIterOperands[i]]; + newIterOperands = + generateNewOperands(newIterOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + + // Rebuild the loop type + auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), + op.getUpperBound(), op.getStep(), + newIterOperands); + + // Create value mapping. Note that for tensor pointers, we use identity + // mapping. It may refer to a value in the old loop, but we will rewrite it + // later + IRMapping mapping; + for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; + ++i, ++oldI) { + auto oldRegionIterArg = op.getRegionIterArg(oldI); + if (triton::isTensorPointerType(oldRegionIterArg.getType())) { + // Pass rewrited info inside + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + mapping.map(oldRegionIterArg, oldRegionIterArg); + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getRegionIterArg(i + j)); + rewritedInfo[oldRegionIterArg] = info; + i += info.length() - 1; + } else { + mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i)); + } + } + mapping.map(op.getInductionVar(), newForOp.getInductionVar()); + + // Clone body + builder.setInsertionPointToStart(newForOp.getBody()); + for (auto &opInFor : *op.getBody()) { + auto *newOp = builder.clone(opInFor, mapping); + for (unsigned i = 0; i < opInFor.getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + + // Replace later usages + assert(op.getNumResults() == op.getInitArgs().size()); + for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { + auto oldResult = op.getResult(oldI); + if (triton::isTensorPointerType(oldResult.getType())) { + // Pack new offsets into rewrited info + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getResult(i + j)); + i += info.length() - 1; + rewritedInfo[oldResult] = info; + } else { + oldResult.replaceAllUsesWith(newForOp.getResult(i)); + } + } + + // Erase later + eraser.push(op); + return newForOp; + } + + Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op, + std::stack &eraser) { + // Replace tensor pointers with offsets + SmallVector newOperands = op->getOperands(); + for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { + if (!triton::isTensorPointerType(newOperands[i].getType())) + continue; + + assert(rewritedInfo.count(newOperands[i])); + auto info = rewritedInfo[newOperands[i]]; + newOperands = generateNewOperands(newOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + op->setOperands(newOperands); + + // No need to erase + return nullptr; + } + + Operation *rewriteOp(Operation *op, std::stack &eraser) { + OpBuilder builder(op); + + // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers + // Rewriting functions return the next operation to visit, if there is no + // next one, simply return `nullptr` + std::pair rewrited; + if (auto makeTensorPtrOp = dyn_cast(op)) { + return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser); + } else if (auto advanceOp = dyn_cast(op)) { + return rewriteAdvanceOp(builder, advanceOp, eraser); + } else if (isa(op) || isa(op)) { + return rewriteLoadStoreOp(builder, op, eraser); + } else if (op->getDialect()->getNamespace() == "scf" || + op->getDialect()->getNamespace() == "cf") { + if (auto ifOp = dyn_cast(op)) { + return rewriteIfOp(builder, ifOp, eraser); + } + if (!needRewrite(op)) + return op; + + if (auto forOp = dyn_cast(op)) { + return rewriteForOp(builder, forOp, eraser); + } else if (auto yieldOp = dyn_cast(op)) { + return rewriteYieldOp(builder, yieldOp, eraser); + } else { + llvm_unreachable("Currently we only support tensor pointer usages " + "inside a `scf::ForOp` or `scf::IfOp`, others such as " + "`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` " + "are not supported yet"); + } + } + + // Otherwise return the original one + return op; + } + + void visitOperation(Operation *op, std::stack &eraser) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + // We need an extra copy because erasing operations may break the + // iterator behavior + SmallVector blockCopy; + for (auto &nestedOp : block) + blockCopy.push_back(&nestedOp); + + // Rewrite and recursively visit + for (auto &nestedOp : blockCopy) { + if (auto newOp = rewriteOp(nestedOp, eraser)) + visitOperation(newOp, eraser); + } + } + } + } + + void runOnOperation() override { + // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because + // MLIR does not support one-multiple value mapping. For example, if we use + // `ConversionPatternRewriter`, we can not make a type converter, which + // converts `ptr` into multiple types `ptr<>, int64, int64, ...` + // (containing the base/offsets/strides...). What we can do is to convert + // `ptr` into a single type `Tuple, int64, int64, ...>`. But + // in this way, we also have to define `PackTuple` and `UnpackTuple` + // operations and make a canonicalization pass to optimize, which is much + // So here we recursively build the IR, to be specific, we have to rewrite + // `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, `tt.store`, + // `scf.for` (tensor pointer usages may be in a loop fashion) + std::stack eraser; + visitOperation(getOperation(), eraser); + + // The operation could not be erased during visit, because they may have + // later usages, so we erase after visit + rewritedInfo.clear(); + while (!eraser.empty()) { + auto op = eraser.top(); + eraser.pop(); + op->erase(); + } + } +}; + +std::unique_ptr triton::createRewriteTensorPointerPass() { + return std::make_unique(); +} diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/CMakeLists.txt b/third_party/iluvatar/lib/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/iluvatar/lib/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..b5dcdb5ea --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(TritonGPUIR + Dialect.cpp + LinearLayoutConversions.cpp + Types.cpp + + DEPENDS + TritonGPUTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRGPUDialect + TritonIR + TritonTools +) diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/IR/Dialect.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/IR/Dialect.cpp new file mode 100644 index 000000000..f191f3580 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -0,0 +1,3352 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefsPlugin.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/TypeSwitch.h" + +// Include TableGen'erated code +#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Utility +namespace mlir { +namespace triton { + +static Type getI1SameShapeFromTensorOrTensorPtr(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorType = dyn_cast(type)) { + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + Type pointeeType = ptrType.getPointeeType(); + if (auto tensorType = dyn_cast(pointeeType)) { + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); + } + } + return Type(); +} + +namespace gpu { + +// TODO: Inheritance of layout attributes +// so that all distributed layouts implement +// these utilities + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, + Type eltTy) { + if (auto tritonGPUAttr = mlir::dyn_cast(layout)) { + return tritonGPUAttr.getTotalElemsPerThread(shape, eltTy); + } else { + llvm::report_fatal_error("getTotalElemsPerThread not implemented"); + return 0; + } +} + +SmallVector getElemsPerThread(Attribute layout, + ArrayRef shape, Type eltTy) { + if (auto tritonGPUAttr = mlir::dyn_cast(layout)) { + return tritonGPUAttr.getElemsPerThread(shape, eltTy); + } else { + llvm::report_fatal_error("getElemsPerThread not implemented"); + return SmallVector(); + } +} + +SmallVector getElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return SmallVector(1, 1); + auto tensorType = cast(type); + return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape(), + tensorType.getElementType()); +} + +unsigned getTotalElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return 1; + auto tensorType = cast(type); + return getTotalElemsPerThread(tensorType.getEncoding(), tensorType.getShape(), + tensorType.getElementType()); +} + +SmallVector getThreadsPerWarp(Attribute layout) { + if (auto distributedLayout = dyn_cast(layout)) { + return distributedLayout.getThreadsPerWarp(); + } else { + llvm::report_fatal_error("getThreadsPerWarp not implemented"); + return SmallVector(); + } +} + +unsigned getWarpSize(Attribute layout) { + unsigned size = 1; + auto threadsPerWarp = getThreadsPerWarp(layout); + for (auto e : threadsPerWarp) { + size *= e; + } + return size; +} + +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape) { + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentThreadsPerWarp = + getThreadsPerWarpWithUniqueData(parentLayout, parentShape); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + return threadsPerWarp; + } + auto threadsPerWarp = getThreadsPerWarp(layout); + assert(threadsPerWarp.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < threadsPerWarp.size(); i++) { + threadsPerWarp[i] = std::min(threadsPerWarp[i], tensorShape[i]); + } + + return threadsPerWarp; +} + +SmallVector getWarpsPerCTA(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getWarpsPerCTA(); + } + + llvm::report_fatal_error("getWarpsPerCTA not implemented"); + return SmallVector(); +} + +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape) { + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentWarpsPerCTA = + getWarpsPerCTAWithUniqueData(parentLayout, parentShape); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); + return warpsPerCTA; + } + auto warpsPerCTA = getWarpsPerCTA(layout); + assert(warpsPerCTA.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < warpsPerCTA.size(); i++) { + auto sizePerWarp = + getSizePerThread(layout)[i] * getThreadsPerWarp(layout)[i]; + auto maxWarpsPerDim = ceil(tensorShape[i], sizePerWarp); + warpsPerCTA[i] = std::min(warpsPerCTA[i], maxWarpsPerDim); + } + + return warpsPerCTA; +} + +SmallVector getSizePerThread(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getSizePerThread(); + } else { + llvm::report_fatal_error("getSizePerThread not implemented"); + return {}; + } +} + +SmallVector getContigPerThread(Attribute layout) { + if (auto distributedLayout = dyn_cast(layout)) { + return distributedLayout.getContigPerThread(); + } else { + llvm::report_fatal_error("getContigPerThread not implemented"); + return {}; + } +} + +SmallVector getUniqueContigPerThread(Attribute layout, + ArrayRef shape) { + // If slice layout, call recursively on parent layout, and drop + // sliced dim + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(shape); + auto parentUniqueContigPerThread = + getUniqueContigPerThread(parentLayout, parentShape); + parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() + + sliceLayout.getDim()); + return parentUniqueContigPerThread; + } + // Base case + auto rank = shape.size(); + SmallVector ret(rank); + auto contigPerThread = getContigPerThread(layout); + assert(contigPerThread.size() == rank && "Unexpected contigPerThread size"); + for (int d = 0; d < rank; ++d) { + ret[d] = std::min(shape[d], contigPerThread[d]); + } + return ret; +} + +SmallVector getShapePerCTATile(Attribute layout, + ArrayRef tensorShape) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getShapePerCTATile(tensorShape); + } else { + llvm::report_fatal_error("getShapePerCTATile not implemented"); + return SmallVector(); + } +} + +bool isExpensiveView(Type srcType, Type dstType) { + return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); +} + +/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. + * Erase dim and decrease all values larger than dim by 1. + * Example: order = [0, 2, 4, 3, 1], dim = 2 + * resOrder = [0, 3, 2, 1] + */ +static SmallVector eraseOrder(ArrayRef order, + unsigned dim) { + unsigned rank = order.size(); + assert(dim < rank && "Invalid dim to erase"); + SmallVector resOrder; + for (unsigned i : order) + if (i < dim) + resOrder.push_back(i); + else if (i > dim) + resOrder.push_back(i - 1); + return resOrder; +} + +SmallVector getWarpOrder(Attribute layout) { + auto order = getOrder(layout); + if (auto mmaLayout = dyn_cast(layout)) { + if (mmaLayout.isHopper()) { + // Hopper MMA instructions force a warp order of [0, 1]. See docs: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 + auto it = std::find(order.begin(), order.end(), 0); + order.erase(it); + order.insert(order.begin(), 0); + } + } + return order; +} + +SmallVector getOrder(Attribute layout) { + if (auto blockedLayout = dyn_cast(layout)) { + return SmallVector(blockedLayout.getOrder().begin(), + blockedLayout.getOrder().end()); + } else if (auto mmaLayout = dyn_cast(layout)) { + auto distributedLayout = cast(layout); + auto rank = distributedLayout.getWarpsPerCTA().size(); + SmallVector order(rank); + for (auto i = 0; i < rank; ++i) + order[i] = rank - 1 - i; + if (auto mfmaLayout = dyn_cast(layout)) { + if (mfmaLayout.getIsTransposed()) { + std::swap(order[rank - 2], order[rank - 1]); + } + } + return order; + } else if (auto dotLayout = dyn_cast(layout)) { + auto rank = getWarpsPerCTA(dotLayout.getParent()).size(); + SmallVector order(rank); + for (auto i = 0; i < rank; ++i) + order[i] = rank - 1 - i; + return order; + } else if (auto sliceLayout = dyn_cast(layout)) { + SmallVector parentOrder = getOrder(sliceLayout.getParent()); + unsigned dim = sliceLayout.getDim(); + SmallVector order; + for (unsigned d : parentOrder) { + if (d == dim) + continue; + else if (d > dim) + order.push_back(d - 1); + else + order.push_back(d); + } + return order; + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + return SmallVector(sharedLayout.getOrder().begin(), + sharedLayout.getOrder().end()); + } else { + llvm::report_fatal_error("Unimplemented usage of getOrder"); + } + return {}; +}; + +CTALayoutAttr getCTALayout(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return CTALayoutAttr::get( + layout.getContext(), getCTAsPerCGA(distributedLayout), + getCTASplitNum(distributedLayout), getCTAOrder(distributedLayout)); + } else if (auto sharedLayout = mlir::dyn_cast(layout)) + return sharedLayout.getCTALayout(); + else + llvm::report_fatal_error("Unimplemented usage of getCTALayout"); + return {}; +} + +SmallVector getCTAsPerCGA(Attribute layout) { + ArrayRef ref; + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getCTAsPerCGA(); + else if (mlir::isa(layout)) + return {1, 1}; + else if (auto sharedLayout = mlir::dyn_cast(layout)) + ref = sharedLayout.getCTALayout().getCTAsPerCGA(); + else + llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); + return SmallVector(ref.begin(), ref.end()); +} + +SmallVector getCTASplitNum(Attribute layout) { + SmallVector res; + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getCTASplitNum(); + } else if (mlir::isa(layout)) { + res.resize(2); + res[0] = res[1] = 1; + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + res.assign(sharedLayout.getCTALayout().getCTASplitNum().begin(), + sharedLayout.getCTALayout().getCTASplitNum().end()); + } else { + assert(false && "Unimplemented usage of getCTASplitNum"); + } + return res; +} + +SmallVector getCTAOrder(Attribute layout) { + SmallVector res; + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + res = distributedLayout.getCTAOrder(); + } else if (mlir::isa(layout)) { + return {0, 1}; + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + res = SmallVector(sharedLayout.getCTALayout().getCTAOrder()); + } else { + llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); + } + return res; +} + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape) { + unsigned rank = shape.size(); + SmallVector shapePerCTA(rank); + for (unsigned i = 0; i < rank; ++i) { + // This wrapping rule must be consistent with emitCTAOffsetForLayout + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + shapePerCTA[i] = shape[i] / splitNum; + } + return shapePerCTA; +} + +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { + if (auto sharedLayout = mlir::dyn_cast(layout)) { + // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. + // The first dim of shape is numStages. This is a work around, otherwise too + // many places would have to be modified in pipeline pass. Maybe we need to + // refactor this logic in the future. + auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); + if (shape.size() == CTASplitNum.size() + 1) { + auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); + res.insert(res.begin(), shape.front()); + return res; + } + } + return getShapePerCTA(getCTASplitNum(layout), shape); +} + +SmallVector getShapePerCTA(Type type) { + auto tensorType = cast(type); + return getShapePerCTA(tensorType.getEncoding(), tensorType.getShape()); +} + +unsigned getNumWarpsPerCTA(Attribute layout) { + SmallVector warpsPerCTA; + if (auto blockedLayout = dyn_cast(layout)) + warpsPerCTA = blockedLayout.getWarpsPerCTA(); + else if (auto sliceLayout = dyn_cast(layout)) + return getNumWarpsPerCTA(sliceLayout.getParent()); + else if (auto mmaLayout = dyn_cast(layout)) { + // Use the distributed layout interface to get the number of warps per CTA. + auto distributedLayout = cast(layout); + warpsPerCTA = distributedLayout.getWarpsPerCTA(); + } else if (auto mfmaLayout = dyn_cast(layout)) + warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + else if (auto wmmaLayout = dyn_cast(layout)) + warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + else if (auto dotLayout = dyn_cast(layout)) + return getNumWarpsPerCTA(dotLayout.getParent()); + else if (auto sharedLayout = dyn_cast(layout)) + llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); + else + llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA"); + return product(warpsPerCTA); +} + +unsigned getNumCTAs(Attribute layout) { + return product(getCTAsPerCGA(layout)); +} + +bool isaDistributedLayout(Attribute layout) { + return isa(layout); +} + +template bool hasEncoding(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) { + auto encoding = tensorType.getEncoding(); + return encoding && isa(encoding); + } + return false; +} + +bool hasDotOperandEncoding(Value value) { + return hasEncoding(value); +} + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { + // If the new elements per thread is less than the old one, we will need to do + // convert encoding that goes through shared memory anyway. So we consider it + // as expensive. + RankedTensorType tensorTy = cat.getType(); + auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); + auto shape = tensorTy.getShape(); + auto elemTy = tensorTy.getElementType(); + auto newTotalElemsPerThread = + gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy); + return newTotalElemsPerThread < totalElemsPerThread; +} + +bool isMmaConvertLayout(Operation *op) { + // Match cvt(#mma(version_minor = 0) -> #mma(version_minor > 0)) + // The later is for storing dot result. + if (auto convertOp = dyn_cast(op)) { + auto srcType = convertOp.getOperand().getType().cast(); + auto dstType = convertOp.getResult().getType().cast(); + if (!srcType || !dstType) + return false; + auto srcMmaEnc = dyn_cast(srcType.getEncoding()); + auto dstMmaEnc = dyn_cast(dstType.getEncoding()); + if (!srcMmaEnc || !dstMmaEnc) + return false; + return srcMmaEnc.getVersionMinor() == 0 && dstMmaEnc.getVersionMinor() > 0; + } + return false; +} + +bool isSliceMmaConvertLayout(Operation *op, bool srcNoWarpReduce, + bool dstNoWarpReduce) { + // Match cvt(slice<{parent=#mma, noWarpReduce=srcNoWarpReduce}> + // -> slice<{parent=#mma, noWarpReduce=dstNoWarpReduce}>) + if (auto convertOp = dyn_cast(op)) { + auto srcType = convertOp.getOperand().getType().cast(); + auto dstType = convertOp.getResult().getType().cast(); + if (!srcType || !dstType) + return false; + auto srcLayout = + dyn_cast(srcType.getEncoding()); + auto dstLayout = + dyn_cast(dstType.getEncoding()); + if (!srcLayout || !dstLayout) + return false; + auto srcMmaLayout = + srcLayout.getParent().dyn_cast(); + auto dstMmaLayout = + dstLayout.getParent().dyn_cast(); + if (!srcMmaLayout || !dstMmaLayout) + return false; + return srcLayout.getNoWarpReduce() == srcNoWarpReduce && + dstLayout.getNoWarpReduce() == dstNoWarpReduce; + } + return false; +} + +LogicalResult CTALayoutAttr::verify( + function_ref emitError, ArrayRef CTAsPerCGA, + ArrayRef CTASplitNum, ArrayRef CTAOrder) { + if (CTAsPerCGA.size() != CTASplitNum.size() || + CTASplitNum.size() != CTAOrder.size()) { + return emitError() << "CTAsPerCGA, CTASplitNum, and CTAOrder must all have " + "the same rank."; + } + + if (!isPermutationOfIota(CTAOrder)) { + return emitError() + << "CTAOrder must be a permutation of 0..(rank-1), but was [" + << CTAOrder << "]"; + } + + return success(); +} + +LogicalResult BlockedEncodingAttr::verify( + function_ref emitError, + ArrayRef sizePerThread, ArrayRef threadsPerWarp, + ArrayRef warpsPerCTA, ArrayRef order, + CTALayoutAttr CTALayout, unsigned loadType, + ArrayRef smeWarpsPerCTA) { + if (sizePerThread.size() != threadsPerWarp.size() || + threadsPerWarp.size() != warpsPerCTA.size() || + warpsPerCTA.size() != order.size()) { + return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and " + "order must all have the same rank."; + } + + // Empty CTALayout is allowed, but if it's present its rank must match the + // BlockedEncodingAttr's rank. + if (CTALayout.getCTASplitNum().size() != 0 && + sizePerThread.size() != CTALayout.getCTASplitNum().size()) { + return emitError() << "BlockedEncodingAttr and CTALayout's fields must " + "have the same rank."; + } + if (!isPermutationOfIota(order)) { + return emitError() + << "order must be a permutation of 0..(rank-1), but was [" << order + << "]"; + } + return success(); +} + +// 1 element per thread +// order = reverse(arange(rank)) +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs) { + int rank = shape.size(); + llvm::SmallVector order(rank); + std::iota(order.begin(), order.end(), 0); + std::reverse(order.begin(), order.end()); + llvm::SmallVector sizePerThread(rank, 1); + triton::gpu::BlockedEncodingAttr encoding = + triton::gpu::BlockedEncodingAttr::get(context, shape, sizePerThread, + order, numWarps, threadsPerWarp, + numCTAs); + return encoding; +} + +} // namespace gpu +} // namespace triton +} // namespace mlir + +static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, + unsigned &value, StringRef desc) { + auto intAttr = mlir::dyn_cast(attr); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer type in ") + << desc; + return failure(); + } + if (intAttr.getType().isSignedInteger()) { + int64_t attrVal = intAttr.getSInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else if (intAttr.getType().isSignlessInteger()) { + int64_t attrVal = intAttr.getInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else { + value = intAttr.getUInt(); + } + return success(); +} + +static LogicalResult parseBoolAttrValue(AsmParser &parser, + const NamedAttribute &attr, bool &value, + StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr.getValue()); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + +// parse an array of integers +static LogicalResult parseIntArrayAttr(AsmParser &parser, + const NamedAttribute &attr, + SmallVector &res, + StringRef desc) { + auto arrayAttr = mlir::dyn_cast(attr.getValue()); + if (!arrayAttr) { + parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; + return failure(); + } + for (Attribute i : arrayAttr) { + unsigned value; + if (parseIntAttrValue(parser, i, value, desc).failed()) + return failure(); + res.push_back(value); + } + return success(); +}; + +static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, + unsigned &value, StringRef desc) { + return parseIntAttrValue(parser, attr.getValue(), value, desc); +}; + +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr, value, desc); +}; + +// Print the CTALayout if it's not equal to the default. +static void maybePrintCTALayout(mlir::MLIRContext *context, + mlir::AsmPrinter &printer, CTALayoutAttr layout, + unsigned rank) { + if (layout != CTALayoutAttr::getDefault(context, rank)) { + printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]" + << ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]" + << ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]"; + } +} + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" + +SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) { + return SliceEncodingAttr::get(getContext(), axis, *this, false); +} +SmallVector +BlockedEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + auto sizePerThread = getSizePerThread(); + auto warpsPerCTA = getWarpsPerCTA(); + auto threadsPerWarp = getThreadsPerWarp(); + auto shapePerCTA = getShapePerCTA(*this, shape); + assert(rank == sizePerThread.size() && + "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); + SmallVector elemsPerThread(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]; + elemsPerThread[i] = ceil(shapePerCTA[i], t) * sizePerThread[i]; + } + return elemsPerThread; +} +unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. +// But we need to have a consistent interface with e.g. SliceEncodingAttr, which +// computes some of these fields. +SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector BlockedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector BlockedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector BlockedEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector BlockedEncodingAttr::getWarpOrder() const { + return SmallVector(getOrder()); +} +SmallVector BlockedEncodingAttr::getThreadsPerWarp() const { + return SmallVector(getThreadsPerWarp__()); +} +SmallVector BlockedEncodingAttr::getThreadOrder() const { + return SmallVector(getOrder()); +} +SmallVector BlockedEncodingAttr::getSizePerThread() const { + return SmallVector(getSizePerThread__()); +} +SmallVector +BlockedEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + SmallVector shape; + for (unsigned d = 0, n = getOrder().size(); d < n; ++d) + shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] * + getWarpsPerCTA()[d]); + return shape; +} + +template +SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { + size_t rank = shape.size(); + unsigned dim = getDim(); + SmallVector retShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) + retShape[d] = shape[d]; + else if (d == dim) + retShape[d] = 1; + else + retShape[d] = shape[d - 1]; + } + return retShape; +} +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; + +SmallVector +SliceEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + auto parent = getParent(); + auto parentElemsPerThread = + ::getElemsPerThread(parent, paddedShape(shape), eltTy); + parentElemsPerThread.erase(parentElemsPerThread.begin() + getDim()); + return parentElemsPerThread; +} +unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} +SmallVector SliceEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + res.erase(res.begin() + getDim()); + return res; +} +SmallVector SliceEncodingAttr::getCTAOrder() const { + auto parentCTAOrder = ::getCTAOrder(getParent()); + return eraseOrder(parentCTAOrder, getDim()); +} +SmallVector SliceEncodingAttr::getCTAsPerCGA() const { + auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent()); + if (parentCTAsPerCGA[getDim()] == 1) { + parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim()); + return parentCTAsPerCGA; + } + /* For getCTAsPerCGA of a slice layout, we have two choices: + * (1) Return CTAsPerCGA of its parent. This is not a perfect solution + * because the rank of the returned CTAsPerCGA does not match the rank of + * tensorShape. + * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a + * perfect solution because the product of the returned CTAsPerCGA might not + * match numCTAs. + * To avoid introducing inconsistencies to the shape and + * layout system, the usage of directly getting CTAsPerCGA of a slice layout + * in which the sliced dim is not 1 is banned. You should always consider + * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) + * in the branch where layout is an instance of SliceEncodingAttr. This is + * inconvenient but safe. + */ + llvm::report_fatal_error( + "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); +} +SmallVector SliceEncodingAttr::getWarpsPerCTA() const { + auto parent = getParent(); + auto parentWarpsPerCTA = ::getWarpsPerCTA(parent); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + getDim()); + int32_t nextDim = getDim() < warpsPerCTA.size() ? getDim() : getDim() - 1; + warpsPerCTA[nextDim] *= parentWarpsPerCTA[getDim()]; + return warpsPerCTA; +} +SmallVector SliceEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector SliceEncodingAttr::getThreadsPerWarp() const { + auto parent = getParent(); + auto parentThreadsPerWarp = ::getThreadsPerWarp(parent); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + getDim()); + int32_t nextDim = getDim() < threadsPerWarp.size() ? getDim() : getDim() - 1; + threadsPerWarp[nextDim] *= parentThreadsPerWarp[getDim()]; + return threadsPerWarp; +} +SmallVector SliceEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector SliceEncodingAttr::getSizePerThread() const { + auto sizePerThread = ::getSizePerThread(getParent()); + sizePerThread.erase(sizePerThread.begin() + getDim()); + return sizePerThread; +} +SmallVector +SliceEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + SmallVector shape = ::getShapePerCTATile(getParent(), tensorShape); + shape.erase(shape.begin() + getDim()); + return shape; +} + +// + +SmallVector +AMDMfmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + assert((rank == 2 || rank == 3) && "Unexpected rank of mfma layout"); + + SmallVector elemsPerThread(rank); + auto nonKDim = getMDim(); + auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16); + if (rank == 3) + elemsPerThread[0] = ceil(shape[0], getWarpsPerCTA()[0]); + if (getIsTransposed()) { + unsigned elemsCol = + ceil(shape[rank - 1], nonKDim * getWarpsPerCTA()[rank - 1]) * + elemsPerThreadPerTile; + unsigned elemsRow = + ceil(shape[rank - 2], nonKDim * getWarpsPerCTA()[rank - 2]); + elemsPerThread[rank - 2] = elemsRow; + elemsPerThread[rank - 1] = elemsCol; + } else { + unsigned elemsCol = + ceil(shape[rank - 1], nonKDim * getWarpsPerCTA()[rank - 1]); + unsigned elemsRow = + ceil(shape[rank - 2], nonKDim * getWarpsPerCTA()[rank - 2]) * + elemsPerThreadPerTile; + elemsPerThread[rank - 2] = elemsRow; + elemsPerThread[rank - 1] = elemsCol; + } + return elemsPerThread; +} + +unsigned AMDMfmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// + +SmallVector +AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + assert((rank == 2 || rank == 3) && "Unexpected rank of wmma layout"); + + SmallVector elemsPerThread(rank); + auto mnkDim = getMNKDimPerWMMAInstr(); + auto elemsPerThreadPerTile = getSizePerThread(); + auto warpsPerCTA = getWarpsPerCTA(); + + if (rank == 3) + elemsPerThread[0] = ceil(shape[0], getWarpsPerCTA()[0]); + elemsPerThread[rank - 2] = + ceil(shape[rank - 2], mnkDim[0] * warpsPerCTA[rank - 2]) * + elemsPerThreadPerTile[rank - 2]; + elemsPerThread[rank - 1] = + ceil(shape[rank - 1], mnkDim[1] * warpsPerCTA[rank - 1]) * + elemsPerThreadPerTile[rank - 1]; + return elemsPerThread; +} + +unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// + +SmallVector +NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + assert(rank == 2 || + (rank == 3 && isAmpere()) && "Unexpected rank of mma layout"); + assert((isVolta() || isAmpere() || isHopper()) && + "For NvidiaMmaEncodingAttr only version 1~3 is supported"); + + auto shapePerCTA = getShapePerCTA(getCTALayout().getCTASplitNum(), shape); + + SmallVector elemsPerThread(rank); + if (isVolta()) { + auto [isARow, isBRow, isAVec4, isBVec4, id] = decodeVoltaLayoutStates(); + static constexpr std::array fpw{{2, 2}}; + unsigned packSize0 = (isARow || isAVec4) ? 1 : 2; + unsigned packSize1 = (isBRow && !isBVec4) ? 2 : 1; + unsigned repM = 2 * packSize0; + unsigned repN = 2 * packSize1; + unsigned spwM = fpw[0] * 4 * repM; + unsigned spwN = fpw[1] * 4 * repN; + unsigned wptM = getWarpsPerCTA()[0]; + unsigned wptN = getWarpsPerCTA()[1]; + unsigned resM = repM * std::max(1, shapePerCTA[0] / (spwM * wptM)); + unsigned resN = 2 * repN * std::max(1, shapePerCTA[1] / (spwN * wptN)); + elemsPerThread[0] = resM; + elemsPerThread[1] = resN; + } else if (isAmpere()) { + unsigned elemsRow = + ceil(shapePerCTA[rank - 2], 16 * getWarpsPerCTA()[rank - 2]) * + 2; + unsigned elemsCol = + ceil(shapePerCTA[rank - 1], 8 * getWarpsPerCTA()[rank - 1]) * + 2; + if (rank == 3) + elemsPerThread[0] = ceil(shapePerCTA[0], getWarpsPerCTA()[0]); + elemsPerThread[rank - 2] = elemsRow; + elemsPerThread[rank - 1] = elemsCol; + } else if (isHopper()) { + auto wpt = getWarpsPerCTA(); + auto instrMNK = getInstrShape(); + int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); + elemsPerThread[0] = 2 * repM; + elemsPerThread[1] = (instrMNK[1] / 4) * repN; + } else { + llvm_unreachable("Unexpected mma version"); + } + + return elemsPerThread; +} + +unsigned NvidiaMmaEncodingAttr::getElemsPerThreadOfOperand( + int opIdx, ArrayRef shape) const { + size_t rank = shape.size(); + assert(rank == 2 && "Unexpected rank of mma layout"); + auto shapePerCTA = getShapePerCTA(*this, shape); + int res = 0; + if (isVolta()) { + llvm_unreachable( + "getElemsPerThreadOfOperand() not supported for version 1"); + } else if (isAmpere()) { + llvm_unreachable( + "getElemsPerThreadOfOperand() not supported for version 2"); + } else if (isHopper()) { + auto wpt = getWarpsPerCTA(); + auto instrMNK = getInstrShape(); + if (opIdx == 0) { + int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repK = ceil(shapePerCTA[1], instrMNK[2]); + return 8 * repM * repK; + + } else if (opIdx == 1) { + int repK = ceil(shapePerCTA[0], instrMNK[2]); + int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); + // benzh@ here need more check + return 4 * std::max(instrMNK[1] / 32, 1) * repK * repN; + } + } + return res; +} + +unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// + +using getElemsPerThreadFunc = SmallVector (*)( + const IluvatarMmaEncodingAttr *, ArrayRef, Type); + +getElemsPerThreadFunc load_get_elems_per_thread_func(const char *backend_name, + const char *func_name) { + void *symbol = load_backend_symbol(backend_name, func_name); + return reinterpret_cast(symbol); +} + +SmallVector +IluvatarMmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + static auto getElemsPerThread = + load_get_elems_per_thread_func("iluvatar", "getElemsPerThread"); + return getElemsPerThread(this, shape, eltTy); +} + +using getTotalElemsPerThreadFunc = unsigned (*)(const IluvatarMmaEncodingAttr *, + ArrayRef, Type); + +getTotalElemsPerThreadFunc +load_get_total_elems_per_thread_func(const char *backend_name, + const char *func_name) { + void *symbol = load_backend_symbol(backend_name, func_name); + return reinterpret_cast(symbol); +} + +unsigned +IluvatarMmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + static auto func = load_get_total_elems_per_thread_func( + "iluvatar", "getTotalElemsPerThread"); + return func(this, shape, eltTy); +} + +// + +SmallVector +SharedEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return SmallVector(); +} +unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return 0; +} + +SmallVector +DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for dot operand"); + return SmallVector(); +} + +unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + if (auto mmaParent = mlir::dyn_cast(getParent())) { + return mmaParent.getTotalElemsPerThreadForOperands(shape, eltTy, + getKWidth(), getOpIdx()); + } + if (auto blockedLayout = mlir::dyn_cast(getParent())) { + auto shapePerCTA = getShapePerCTA(*this, shape); + auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); + auto order = blockedLayout.getOrder(); + auto sizePerThread = ::getSizePerThread(blockedLayout); + + int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; + int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; + + bool isM = getOpIdx() == 0; + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int sizePerThreadMN = isM ? mSizePerThread : nSizePerThread; + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int shapePerCTAMNTile = isM ? mShapePerCTATile : nShapePerCTATile; + + return K * std::max(otherDim / shapePerCTAMNTile, 1) * sizePerThreadMN; + } + llvm_unreachable("unknown dot operand parent layout"); + return 0; +} +SmallVector DotOperandEncodingAttr::getCTAsPerCGA() const { + return ::getCTAsPerCGA(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTAOrder() const { + return ::getCTAOrder(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + auto rank = res.size(); + assert(rank == 2 || rank == 3 && "Invalid dotLayout"); + + // Do not split CTA in K dimension + getOpIdx() == 0 ? res[rank - 1] = 1 : res[rank - 2] = 1; + return res; +} +SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto distributedLayout = + mlir::dyn_cast(parentLayout)) { + return distributedLayout.getWarpsPerCTA(); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-DistributedEncodingAttr parent not " + "supported yet"); + } +} +SmallVector DotOperandEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector DotOperandEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector DotOperandEncodingAttr::getShapePerCTATile( + ArrayRef tensorShape) const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { + return parentMmaLayout.getShapePerCTATileForDotOperands(tensorShape, + getOpIdx()); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " + "supported yet"); + } +} + +LogicalResult DotOperandEncodingAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + unsigned opIdx, Attribute parent, unsigned kWidth, unsigned useSme) { + if (opIdx != 0 && opIdx != 1) { + return emitError() + << "triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: " + << opIdx; + } + if (!parent) { + return emitError() << "triton_gpu.dot_op parent paramenter cannot be null"; + } + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0 && !parentAttr.isAmpere()) + return emitError() << "triton_gpu.dot_op kWidth parameter can only be " + "non-zero for Ampere MMA parent"; + if (kWidth == 0 && parentAttr.isAmpere()) + return emitError() + << "triton_gpu.dot_op kWidth parameter is mandatory for " + "Ampere MMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + // TODO: remove this condition if new values are supported + if (kWidth != 16) + return emitError() << "triton_gpu.dot_op kWidth parameter supports " + "only 16 for WMMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth == 0) + return emitError() + << "triton_gpu.dot_op kWidth parameter is mandatory for " + "MFMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() + << "triton_gpu.dot_op kWidth parameter is not supported " + "when the parent is a blocked layout"; + return success(); + } + + return emitError() << "triton_gpu.dot_op unexpected parent layout: " + << parent; +} + +//===----------------------------------------------------------------------===// +// Blocked Encoding +//===----------------------------------------------------------------------===// + +static std::optional getCTALayoutOrError( + AsmParser &parser, std::optional> CTAsPerCGA, + std::optional> CTASplitNum, + std::optional> CTAOrder, unsigned rank) { + if (CTAsPerCGA && CTASplitNum && CTAOrder) { + return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum, + *CTAOrder); + } + if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) { + return CTALayoutAttr::getDefault(parser.getContext(), rank); + } + parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder " + "must all be present or all be absent"); + return std::nullopt; +} + +Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector sizePerThread; + SmallVector threadsPerWarp; + SmallVector warpsPerCTA; + SmallVector order; + unsigned loadType = 0; + SmallVector smeWarpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "sizePerThread") { + if (parseIntArrayAttr(parser, attr, sizePerThread, + "number of elements per thread") + .failed()) + return {}; + } else if (attr.getName() == "threadsPerWarp") { + if (parseIntArrayAttr(parser, attr, threadsPerWarp, + "number of threads per warp") + .failed()) + return {}; + } else if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, + "number of warps per CTA") + .failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else if (attr.getName() == "loadType") { + loadType = attr.getValue().cast().getInt(); + } else if (attr.getName() == "smeWarpsPerCTA") { + if (parseIntArrayAttr(parser, attr, smeWarpsPerCTA, "smeWarpsPerCTA") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/sizePerThread.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order, + *CTALayout, loadType, smeWarpsPerCTA); +} + +void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "sizePerThread = [" << ArrayRef(getSizePerThread()) << "]" + << ", threadsPerWarp = [" << ArrayRef(getThreadsPerWarp()) << "]" + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" + << ", order = [" << getOrder() << "]" + << ", loadType = " << getLoadType() << ", smeWarpsPerCTA = [" + << getSmeWarpsPerCTA() << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getSizePerThread().size()); + + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// MMA encoding +//===----------------------------------------------------------------------===// + +Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + SmallVector instrShape; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CTALayout, + instrShape); +} + +void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + + printer << ", instrShape = [" << getInstrShape() << "]}>"; +} + +//===----------------------------------------------------------------------===// +// MFMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + SmallVector instrShape; + bool isTransposed; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) + return {}; + } + if (attr.getName() == "isTransposed") { + if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, + instrShape[0], instrShape[1], isTransposed, *CTALayout); +} + +void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() // + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" // + << ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" // + << ", isTransposed = " << getIsTransposed(); + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + printer << "}>"; +} + +LogicalResult +AMDMfmaEncodingAttr::verify(function_ref emitError, + unsigned versionMajor, unsigned versionMinor, + llvm::ArrayRef warpsPerCTA, + unsigned mDim, unsigned nDim, bool isTransposed, + mlir::triton::gpu::CTALayoutAttr) { + if (!(versionMajor >= 0 && versionMajor <= 3)) { + return emitError() << "major version must be in the [0, 3] range"; + } + if (versionMinor != 0) { + return emitError() << "minor version must be 0"; + } + if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) { + return emitError() + << "(M, N) cases other than (32, 32) or (16, 16) unimplemented"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// WMMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), + warpsPerCTA, *CTALayout); +} + +void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// Iluvatar MMA encoding +//===----------------------------------------------------------------------===// + +Attribute IluvatarMmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + SmallVector instrShape; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CTALayout, + instrShape); +} + +void IluvatarMmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + + printer << ", instrShape = [" << getInstrShape() << "]}>"; +} + +//===----------------------------------------------------------------------===// +// Sliced Encoding +//===----------------------------------------------------------------------===// + +Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + unsigned dim = mlir::cast(attrs.get("dim")).getInt(); + bool noWarpReduce = attrs.get("noWarpReduce").cast().getValue(); + Attribute parent = attrs.get("parent"); + return parser.getChecked(parser.getContext(), dim, parent, + noWarpReduce); +} + +void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "dim = " << getDim() << ", " + << "noWarpReduce = " << getNoWarpReduce() << ", " + << "parent = " << getParent() << "}>"; +} + +//===----------------------------------------------------------------------===// +// Shared encoding +//===----------------------------------------------------------------------===// + +Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned vec = 0; + unsigned perPhase = 0; + unsigned maxPhase = 0; + SmallVector order; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + bool hasLeadingOffset = false; + bool useTcu = false; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "vec") { + if (parseUInt(parser, attr, vec, "vec").failed()) + return {}; + } else if (attr.getName() == "perPhase") { + if (parseUInt(parser, attr, perPhase, "perPhase").failed()) + return {}; + } else if (attr.getName() == "maxPhase") { + if (parseUInt(parser, attr, maxPhase, "maxPhase").failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else if (attr.getName() == "hasLeadingOffset") { + if (parseBool(parser, attr, hasLeadingOffset, "hasLeadingOffset") + .failed()) + return {}; + } else if (attr.getName() == "useTcu") { + if (parseBool(parser, attr, useTcu, "useTcu").failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/order.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), vec, perPhase, maxPhase, order, *CTALayout, + hasLeadingOffset, useTcu); +} + +void SharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "vec = " << getVec() // + << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() // + << ", order = [" << getOrder() << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getOrder().size()); + printer << ", hasLeadingOffset = " << getHasLeadingOffset() + << ", useTcu = " << getUseTcu() << "}>"; +} + +//===----------------------------------------------------------------------===// +// Mfma encoding +//===----------------------------------------------------------------------===// +// TODO: there is a lot of common code with MmaEncoding here + +SmallVector +AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + shapePerCTATile[rank - 1] *= getMDim(); + shapePerCTATile[rank - 2] *= getNDim(); + return shapePerCTATile; +} + +SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDMfmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDMfmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector AMDMfmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector AMDMfmaEncodingAttr::getThreadsPerWarp() const { + unsigned rows, cols; + auto rank = ::getOrder(*this).size(); + SmallVector res(rank, 1); + if (getMDim() == 32) { + cols = 2; + rows = 32; + } else { + assert(getMDim() == 16); + cols = 4; + rows = 16; + } + if (getIsTransposed()) { + res[rank - 1] = cols; + res[rank - 2] = rows; + } else { + res[rank - 1] = rows; + res[rank - 2] = cols; + } + return res; +} + +SmallVector AMDMfmaEncodingAttr::getSizePerThread() const { + unsigned rows, cols; + auto rank = ::getOrder(*this).size(); + SmallVector res(rank, 1); + if (getMDim() == 32) { + rows = 16; + cols = 1; + } else if (getMDim() == 16) { + rows = 4; + cols = 1; + } else + llvm_unreachable("Unexpected mfma non-k dim"); + + if (getIsTransposed()) { + res[rank - 1] = rows; + res[rank - 2] = cols; + } else { + res[rank - 1] = cols; + res[rank - 2] = rows; + } + return res; +} + +SmallVector +AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { + unsigned mDim = getMDim(); + unsigned nDim = getNDim(); + assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + constexpr int waveSize = 64; // MFMA is used on wave64 architectures only + int kGroups = -1; + if (mDim == nDim) + kGroups = waveSize / mDim; + if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) + kGroups = 1; + int64_t kDim = kWidth * kGroups; + if (opIdx == 0) + return {mDim, kDim}; + else + assert(opIdx == 1); + return {kDim, nDim}; +} + +SmallVector +AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, + int kWidth, int opIdx) const { + auto operandTileShape = getMFMAInstrShapeForOperands(kWidth, opIdx); + auto rank = operandShape.size(); + auto warpsPerCTA = getWarpsPerCTA(); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto rep = getMFMARepForOperands(shape, kWidth, opIdx); + return product(rep) * kWidth; +} + +SmallVector +AMDMfmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + if (opIdx == 0) { + return {4, 1}; + } else if (opIdx == 1) { + return {1, 4}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; + } +} + +SmallVector +AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, + int opIdx) const { + assert(getMDim() == 32 || getMDim() == 16); + auto parentShapePerCTATile = getShapePerCTATile(shape); + auto rank = parentShapePerCTATile.size(); + if (opIdx == 0) { + if (rank == 2) + return {parentShapePerCTATile[rank - 2], 32}; + else + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32}; + } else if (opIdx == 1) { + if (rank == 2) + return {32, parentShapePerCTATile[rank - 1]}; + else + return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } + llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1"); +} + +SmallVector +AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + + auto mnkDim = getMNKDimPerWMMAInstr(); + shapePerCTATile[rank - 2] *= mnkDim[0]; + shapePerCTATile[rank - 1] *= mnkDim[1]; + return shapePerCTATile; +} +SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDWmmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDWmmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector AMDWmmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector AMDWmmaEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector AMDWmmaEncodingAttr::getThreadsPerWarp() const { + auto rank = getWarpsPerCTA().size(); + SmallVector threads(rank, 1); + auto mnkInstr = getMNKDimPerWMMAInstr(); + threads[rank - 2] = mnkInstr[0] / getSizePerThread()[rank - 2]; + threads[rank - 1] = mnkInstr[1] / getSizePerThread()[rank - 1]; + return threads; +} + +SmallVector AMDWmmaEncodingAttr::getSizePerThread() const { + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); + sizePerThread[rank - 2] = 8; + sizePerThread[rank - 1] = 1; + return sizePerThread; +} +SmallVector +AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); + if (opIdx == 0) { + sizePerThread[rank - 2] = 1; + sizePerThread[rank - 1] = 16; + } else if (opIdx == 1) { + sizePerThread[rank - 2] = 16; + sizePerThread[rank - 1] = 1; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } + return sizePerThread; +} + +SmallVector +AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, + int opIdx) const { + auto parentShapePerCTA = getShapePerCTATile(shape); + auto rank = shape.size(); + assert(rank = 2); + if (opIdx == 0) { + return {parentShapePerCTA[0], static_cast(shape[1])}; + } else if (opIdx == 1) { + return {static_cast(shape[0]), parentShapePerCTA[1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } +} + +unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx); + return product(rep) * kWidth; +} + +SmallVector +AMDWmmaEncodingAttr::getWMMAElemsPerInstrForOperands() const { + return {16, 16}; +} + +SmallVector +AMDWmmaEncodingAttr::getWMMARepForOperands(ArrayRef operandShape, + Type elemType, int kWidth, + int opIdx) const { + auto operandTileShape = getWMMAElemsPerInstrForOperands(); + assert(operandTileShape.size() == 2); + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = operandShape.size(); + assert(rank == 2 || rank == 3); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +SmallVector AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr() { + // TODO: move magic numbers out of the code + return {16, 16, 16}; +} + +//===----------------------------------------------------------------------===// +// Mma encoding +//===----------------------------------------------------------------------===// + +bool NvidiaMmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; } + +bool NvidiaMmaEncodingAttr::isTuring() const { + return getVersionMajor() == 2 && getVersionMinor() == 1; +} + +bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } + +bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } + +SmallVector NvidiaMmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector NvidiaMmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector NvidiaMmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector NvidiaMmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector NvidiaMmaEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { + auto rank = getWarpsPerCTA().size(); + SmallVector res(rank, 1); + if (isVolta()) { + res[rank - 2] = 4; + res[rank - 1] = 8; + return res; + } + if (isAmpere()) { + res[rank - 2] = 8; + res[rank - 1] = 4; + return res; + } + if (isHopper()) { + res[rank - 2] = 8; + res[rank - 1] = 4; + return res; + } + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for unknown Mma version "); +} +SmallVector NvidiaMmaEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { + auto rank = ::getOrder(*this).size(); + SmallVector res(rank, 1); + if (isAmpere()) { + res[rank - 2] = 2; + res[rank - 1] = 2; + return res; + } + if (isVolta()) { + res[rank - 2] = 1; + res[rank - 1] = 2; + return res; + } + if (isHopper()) { + auto instrShape = getInstrShape(); + // WGMMA instructions have an order of [0, 1] with 4 warps, each with 8 + // unique thread ids (32 in a warp group) per column. It is 1 warp wide with + // 4 unique thread ids in the row. So the size per thread is the instruction + // size divided by the number of unique thread ids. + return SmallVector{instrShape[0] * 4 / 32, instrShape[1] / 4}; + } + llvm_unreachable("Unexpected mma version"); +} + +SmallVector +NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + if (isAmpere()) { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), + warpsPerCTA.end()); + shapePerCTATile[rank - 1] *= 8; + shapePerCTATile[rank - 2] *= 16; + return shapePerCTATile; + } + if (isVolta()) { + assert(!tensorShape.empty() && "Volta needs the tensorShape"); + if (tensorShape.size() == 1) // must be SliceEncoding + return {static_cast(tensorShape[0]), + static_cast(tensorShape[0])}; + return {static_cast(tensorShape[0]), + static_cast(tensorShape[1])}; + } + if (isHopper()) { + auto instrShape = getInstrShape(); + return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]}; + } + llvm::report_fatal_error("Unexpected MMA layout version found"); +} + +// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor +std::tuple +NvidiaMmaEncodingAttr::decodeVoltaLayoutStates() const { + unsigned versionMinor = getVersionMinor(); + bool isARow = versionMinor & (1 << 0); + bool isBRow = versionMinor & (1 << 1); + bool isAVec4 = versionMinor & (1 << 2); + bool isBVec4 = versionMinor & (1 << 3); + + int id = 0; + for (int i = numBitsToHoldMmaV1ID - 1; i >= 0; --i) + id = (id << 1) + static_cast(versionMinor & (1 << (4 + i))); + + return std::make_tuple(isARow, isBRow, isAVec4, isBVec4, id); +} + +bool NvidiaMmaEncodingAttr::getMMAv1IsRow(int opIdx) const { + auto [isARow, isBRow, _0, _1, _2] = decodeVoltaLayoutStates(); + return opIdx == 0 ? isARow : isBRow; +} +bool NvidiaMmaEncodingAttr::getMMAv1IsVec4(int opIdx) const { + auto [_0, _1, isAVec4, isBVec4, _2] = decodeVoltaLayoutStates(); + return opIdx == 0 ? isAVec4 : isBVec4; +} +int NvidiaMmaEncodingAttr::getMMAv1NumOuter(ArrayRef shape, + int opIdx) const { + auto spw = getMMAv1ShapePerWarp(opIdx); + auto rep = getMMAv1Rep(opIdx); + auto warpsPerCTA = getWarpsPerCTA(); + if (opIdx == 0) { + return rep[0] * shape[0] / (spw[0] * warpsPerCTA[0]); + } else { + return rep[1] * shape[1] / (spw[1] * warpsPerCTA[1]); + } +} +SmallVector NvidiaMmaEncodingAttr::getMMAv1Rep(int opIdx) const { + auto [isARow, isBRow, isAVec4, isBVec4, _] = decodeVoltaLayoutStates(); + // A + if (opIdx == 0) { + int packSize = (isARow || isAVec4) ? 1 : 2; + return {2 * packSize, 0, 1}; + } + // B + else { + int packSize = (isBRow && !isBVec4) ? 2 : 1; + return {0, 2 * packSize, 1}; + } +} +SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { + auto rep = getMMAv1Rep(opIdx); + if (opIdx == 0) { + return {8 * rep[0], 0, 1}; + } else { + return {0, 8 * rep[1], 1}; + } +} +int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { + return 2 * getMMAv1Rep(opIdx)[opIdx]; +} +SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, + int bitwidth, + int opIdx) const { + auto rank = shape.size(); + auto warpsPerCTA = getWarpsPerCTA(); + SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; + int numRepBatch = + rank == 3 + ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) + : 1; + assert(isAmpere()); + + if (opIdx == 0) + return {numRepBatch, + std::max(1, shape[rank - 2] / + (shapePerWarp[1] * warpsPerCTA[rank - 2])), + std::max(1, shape[rank - 1] / shapePerWarp[3])}; + else { + assert(opIdx == 1); + return {numRepBatch, + std::max(1, shape[rank - 2] / shapePerWarp[3]), + std::max(1, shape[rank - 1] / (shapePerWarp[2] * + warpsPerCTA[rank - 1]))}; + } +} +unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto shapePerCTA = getShapePerCTA(*this, shape); + int warpsPerCTAM = getWarpsPerCTA()[0]; + int warpsPerCTAN = getWarpsPerCTA()[1]; + // H100 + if (isHopper()) { + return getTotalElemsPerThread(shape, eltTy); + } + // A100 + if (isAmpere()) { + auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); + if (opIdx == 0) + return 4 * rep[0] * rep[1] * rep[2]; + if (opIdx == 1) + return 4 * rep[0] * rep[1] * std::max(rep[2] / 2, 1); + } + // V100 + if (isVolta()) { + bool isRow = getMMAv1IsRow(opIdx); + bool isVec4 = getMMAv1IsVec4(opIdx); + if (opIdx == 0) { + int packSizeM = (isRow || isVec4) ? 1 : 2; + int repM = 2 * packSizeM; + int spwM = 2 * 4 * repM; + int numM = getMMAv1NumOuter(shape, opIdx); + int NK = shape[1]; + int vec = 2 * repM; + // Here we mimic the logic in loadA, the result cannot be calculated + // directly. + llvm::DenseSet> visited; + auto ld = [&](int m, int k) { + visited.insert({m, k}); + if (vec > 4) { + if (isRow) + visited.insert({m, k + 4}); + else + visited.insert({m + 1, k}); + } + }; + for (unsigned k = 0; k < NK; k += 4) + for (unsigned m = 0; m < numM / 2; ++m) + if (!visited.count({m, k})) + ld(m, k); + return visited.size() * 2; + } + if (opIdx == 1) { + int packSizeN = (isRow && !isVec4) ? 2 : 1; + int repN = 2 * packSizeN; + int spwN = 2 * 4 * repN; + int numN = getMMAv1NumOuter(shape, opIdx); + int vec = 2 * repN; + + int NK = shape[0]; + // Here we mimic the logic in loadA, the result cannot be calculated + // directly. + llvm::DenseSet> visited; + int elemsPerLd = vec > 4 ? 4 : 2; + auto ld = [&](int n, int k) { + visited.insert({n, k}); + if (vec > 4) { + if (isRow) + visited.insert({n + 1, k}); + else + visited.insert({n, k + 4}); + } + }; + + for (unsigned k = 0; k < NK; k += 4) + for (unsigned n = 0; n < numN / 2; ++n) { + if (!visited.count({n, k})) + ld(n, k); + } + + return visited.size() * 2; + } + } + llvm_unreachable("unknown mma layout"); +} +SmallVector +NvidiaMmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, + int opIdx) const { + assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); + auto parentShapePerCTATile = getShapePerCTATile(shape); + auto rank = parentShapePerCTATile.size(); + if (opIdx == 0) { + if (rank == 2) + return {parentShapePerCTATile[rank - 2], 16}; + else + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 16}; + } else if (opIdx == 1) { + if (rank == 2) + return {16, parentShapePerCTATile[rank - 1]}; + else + return {parentShapePerCTATile[0], 16, parentShapePerCTATile[rank - 1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } +} +SmallVector +NvidiaMmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); + if (opIdx == 0) { + return {2, 4}; + } else if (opIdx == 1) { + return {4, 1}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; + } +} + +//===----------------------------------------------------------------------===// +// Iluvatar Mma encoding +//===----------------------------------------------------------------------===// + +using isVoltaFunc = bool (*)(const IluvatarMmaEncodingAttr *); +DEFINE_LOAD_FUNC(isVolta) +bool IluvatarMmaEncodingAttr::isVolta() const { + auto isVolta = load_isVolta_func("iluvatar", "isVolta"); + return isVolta(this); +} + +#define DEFINE_GET_SMALLVECTOR_FUNC(var_name) \ + using var_name##Func = \ + SmallVector (*)(const IluvatarMmaEncodingAttr *); \ + DEFINE_LOAD_FUNC(var_name) \ + SmallVector IluvatarMmaEncodingAttr::var_name() const { \ + static auto func = load_##var_name##_func("iluvatar", "" #var_name); \ + return func(this); \ + } + +DEFINE_GET_SMALLVECTOR_FUNC(getCTAsPerCGA) +DEFINE_GET_SMALLVECTOR_FUNC(getCTAOrder) +DEFINE_GET_SMALLVECTOR_FUNC(getCTASplitNum) +DEFINE_GET_SMALLVECTOR_FUNC(getWarpsPerCTA) +DEFINE_GET_SMALLVECTOR_FUNC(getWarpOrder) +DEFINE_GET_SMALLVECTOR_FUNC(getThreadsPerWarp) +DEFINE_GET_SMALLVECTOR_FUNC(getThreadOrder) +DEFINE_GET_SMALLVECTOR_FUNC(getSizePerThread) + +using getShapePerCTATileFunc = SmallVector (*)( + const IluvatarMmaEncodingAttr *, ArrayRef); +DEFINE_LOAD_FUNC(getShapePerCTATile) + +SmallVector IluvatarMmaEncodingAttr::getShapePerCTATile( + ArrayRef tensorShape) const { + auto func = load_getShapePerCTATile_func("iluvatar", "getShapePerCTATile"); + return func(this, tensorShape); +} + +using getTCUShapePerWarpFunc = + SmallVector (*)(const IluvatarMmaEncodingAttr *, int); +DEFINE_LOAD_FUNC(getTCUShapePerWarp) + +SmallVector +IluvatarMmaEncodingAttr::getTCUShapePerWarp(int bitwidth) const { + auto func = load_getTCUShapePerWarp_func("iluvatar", "getTCUShapePerWarp"); + return func(this, bitwidth); +} + +using getTCUShapePerCTAFunc = + SmallVector (*)(const IluvatarMmaEncodingAttr *, int); +DEFINE_LOAD_FUNC(getTCUShapePerCTA) + +SmallVector +IluvatarMmaEncodingAttr::getTCUShapePerCTA(int bitwidth) const { + auto func = load_getTCUShapePerCTA_func("iluvatar", "getTCUShapePerCTA"); + return func(this, bitwidth); +} + +using getTCURepFunc = SmallVector (*)(const IluvatarMmaEncodingAttr *, + ArrayRef, int, int); +DEFINE_LOAD_FUNC(getTCURep) + +SmallVector IluvatarMmaEncodingAttr::getTCURep(ArrayRef shape, + int bitwidth, + int opIdx) const { + auto func = load_getTCURep_func("iluvatar", "getTCURep"); + return func(this, shape, bitwidth, opIdx); +} + +using getTotalElemsPerThreadForOperandsFunc = unsigned (*)( + const IluvatarMmaEncodingAttr *, ArrayRef, Type, int, int); +DEFINE_LOAD_FUNC(getTotalElemsPerThreadForOperands) + +unsigned IluvatarMmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto func = load_getTotalElemsPerThreadForOperands_func( + "iluvatar", "getTotalElemsPerThreadForOperands"); + return func(this, shape, eltTy, kWidth, opIdx); +} + +using getShapePerCTATileForDotOperandsFunc = SmallVector (*)( + const IluvatarMmaEncodingAttr *, ArrayRef, int); +DEFINE_LOAD_FUNC(getShapePerCTATileForDotOperands) + +SmallVector IluvatarMmaEncodingAttr::getShapePerCTATileForDotOperands( + ArrayRef shape, int opIdx) const { + auto func = load_getShapePerCTATileForDotOperands_func( + "iluvatar", "getShapePerCTATileForDotOperands"); + return func(this, shape, opIdx); +} + +using getSizePerThreadForOperandsFunc = + SmallVector (*)(const IluvatarMmaEncodingAttr *, unsigned); +DEFINE_LOAD_FUNC(getSizePerThreadForOperands) + +SmallVector +IluvatarMmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + auto func = load_getSizePerThreadForOperands_func( + "iluvatar", "getSizePerThreadForOperands"); + return func(this, opIdx); +} + +//===----------------------------------------------------------------------===// +// DotOperand Encoding +//===----------------------------------------------------------------------===// +SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for DotOperandEncodingAttr"); +} +SmallVector DotOperandEncodingAttr::getSizePerThread() const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { + return parentMmaLayout.getSizePerThreadForOperands(getOpIdx()); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " + "supported yet"); + return {}; + } +} + +//===----------------------------------------------------------------------===// +// AsyncCopyGlobalToLocalOp +//===----------------------------------------------------------------------===// + +ParseResult AsyncCopyGlobalToLocalOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector allOperands; + Type srcType, dstType, strideType; + SMLoc allOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(allOperands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseCustomTypeWithFallback(srcType)) + return failure(); + + SmallVector operandTypes; + int hasStride = 0; + if (parser.parseOptionalComma().succeeded()) { + if (allOperands.size() >= 6) { + hasStride = 1; + if (parser.parseType(strideType)) + return failure(); + } + } + if (parser.parseArrow() || parser.parseCustomTypeWithFallback(dstType)) + return failure(); + result.addTypes(dstType); + operandTypes.push_back(srcType); // src + operandTypes.push_back(dstType); // result + + int hasMask = 0, hasOther = 0; + if (allOperands.size() >= 3 && allOperands.size() != 5) { + operandTypes.push_back( + triton::getI1SameShapeFromTensorOrTensorPtr(srcType)); // mask + hasMask = 1; + } + if (allOperands.size() == 4) { + operandTypes.push_back(triton::getPointeeType(srcType)); // other + hasOther = 1; + } + if (allOperands.size() >= 5) { + operandTypes.push_back(strideType); // inputStride + operandTypes.push_back(strideType); // placeHolder0 + operandTypes.push_back(strideType); // placeHolder1 + } + if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc, + result.operands)) + return failure(); + + // Deduce operand_segment_sizes from the number of the operands. + auto operandSegmentSizesAttrName = + AsyncCopyGlobalToLocalOp::getOperandSegmentSizesAttrName(result.name); + result.addAttribute( + operandSegmentSizesAttrName, + parser.getBuilder().getDenseI32ArrayAttr( + {1, 1, hasMask, hasOther, hasStride, hasStride, hasStride})); + return success(); +} + +void AsyncCopyGlobalToLocalOp::print(OpAsmPrinter &printer) { + printer << " "; + printer << getOperation()->getOperands(); + // "operand_segment_sizes" can be deduced, so we don't print it. + printer.printOptionalAttrDict(getOperation()->getAttrs(), + {getOperandSegmentSizesAttrName()}); + printer << " : "; + if (getInputStride()) { + printer.printStrippedAttrOrType( + llvm::ArrayRef{getSrc().getType(), getInputStride().getType()}); + } else + printer.printStrippedAttrOrType(getSrc().getType()); + printer << " -> "; + printer.printStrippedAttrOrType(getResult().getType()); +} + +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// + +class TritonGPUOpAsmInterface : public OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (auto mmaAttr = mlir::dyn_cast(attr)) { + os << "mma"; + return AliasResult::FinalAlias; + } else if (auto sharedAttr = mlir::dyn_cast(attr)) { + os << "shared"; + return AliasResult::FinalAlias; + } else if (auto blockedAttr = mlir::dyn_cast(attr)) { + os << "blocked"; + return AliasResult::FinalAlias; + } /* else if (auto sliceAttr = dyn_cast(attr)) { + os << "slice"; + return AliasResult::FinalAlias; + } */ + return OpAsmDialectInterface::getAlias(attr, os); + } +}; + +struct TritonGPUInferLayoutInterface + : public triton::DialectInferLayoutInterface { + using DialectInferLayoutInterface::DialectInferLayoutInterface; + + LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + bool noWarpReduce, + Attribute &resultEncoding) const override { + resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis, + operandEncoding, noWarpReduce); + return success(); + } + + // Infer the encoding of a tt.trans(x) given the encoding of x. + // + // Our goal is to choose an encoding so that the trans is a "nop". For + // example, in a blocked encoding, the same GPU threads hold the same + // elements, they're just "renamed" -- what was element [i,j] of the tensor is + // now element [j,i], but that element is held by the same GPU thread. + // + // For most properties of the encoding, we let + // outputEnc.prop = inputEnc.prop * trans.order, + // where `x * y` means we apply permutation y to x. + // + // This works because prop[i] tells you something about the i'th dimension of + // the tensor. (For example, sizePerThread[2] == 4 means that one GPU thread + // contains 4 elements along dim 2 of the tensor.) The transpose reorders the + // dimensions according to the perm trans.order, so we achieve our goal of + // having a "nop" transpose by reordering the values in the prop the same way. + // + // The big exception to this is the encoding's `order`. + // + // An encoding's order is a list of dimensions, from fastest moving (most + // minor) to slowest moving. Thus enc.order[i] does not tell you something + // about the i'th dimension of the tensor, and it would be disasterously + // incorrect to do enc.order * trans.order. + // + // But! If we invert enc.order, it *does* meet this criterion. For example, + // if enc.order = [2,0,1], inverse(enc.order) = [1,2,0]. If you stare at it, + // you'll see that inverse(enc.order)[i] == j means that dimension i is the + // j'th most minor. Therefore we can safely permute *this* by trans.order. + // + // Thus we have + // + // outputEnc.order = inverse(inverse(inputEnc.order) * trans.order) + // = inverse(trans.order) * inputEnc.order. + // + LogicalResult inferTransOpEncoding(Attribute operandEncoding, + ArrayRef order, // trans order + Attribute &resultEncoding) const override { + // Note: inferFooOpEncoding should not crash if given invalid inputs, which + // happens when someone creates invalid IR. If we return failure() on + // error, then MLIR will generate a helpful error message. + + auto invOrder = inversePermutation(order); + SmallVector invOrderUnsigned(invOrder.begin(), invOrder.end()); + + auto permuteCTALayout = + [&](const CTALayoutAttr &layout) -> FailureOr { + auto n = order.size(); + if (layout.getCTAsPerCGA().size() != n || + layout.getCTASplitNum().size() != n || + layout.getCTAOrder().size() != n) { + return failure(); + } + + return CTALayoutAttr::get( + getDialect()->getContext(), + applyPermutation(layout.getCTAsPerCGA(), order), + applyPermutation(layout.getCTASplitNum(), order), + applyPermutation(invOrderUnsigned, layout.getCTAOrder())); + }; + + if (auto enc = mlir::dyn_cast(operandEncoding)) { + if (enc.getOrder().size() != order.size()) { + return failure(); + } + FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + if (failed(ctaLayout)) { + return failure(); + } + resultEncoding = SharedEncodingAttr::get( + getDialect()->getContext(), enc.getVec(), enc.getPerPhase(), + enc.getMaxPhase(), applyPermutation(invOrderUnsigned, enc.getOrder()), + *ctaLayout, enc.getHasLeadingOffset(), enc.getUseTcu()); + return success(); + } + + if (auto enc = mlir::dyn_cast(operandEncoding)) { + auto n = order.size(); + if (enc.getSizePerThread().size() != n || + enc.getThreadsPerWarp().size() != n || + enc.getWarpsPerCTA().size() != n || enc.getOrder().size() != n) { + return failure(); + } + FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + if (failed(ctaLayout)) { + return failure(); + } + resultEncoding = BlockedEncodingAttr::get( + getDialect()->getContext(), + applyPermutation(enc.getSizePerThread(), order), + applyPermutation(enc.getThreadsPerWarp(), order), + applyPermutation(enc.getWarpsPerCTA(), order), + applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout, + enc.getLoadType(), enc.getSmeWarpsPerCTA()); + return success(); + } + + return failure(); // unhandled encoding + } + + LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const override { + auto sliceEncoding = mlir::dyn_cast(operandEncoding); + if (!sliceEncoding) + return emitOptionalError( + location, "ExpandDimsOp operand encoding must be SliceEncodingAttr"); + if (sliceEncoding.getDim() != axis) + return emitOptionalError( + location, "Incompatible slice dimension for ExpandDimsOp operand"); + resultEncoding = sliceEncoding.getParent(); + return success(); + } + + LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const override { + auto mmaRetEncoding = mlir::dyn_cast(retEncoding); + if (mmaRetEncoding && mmaRetEncoding.isHopper()) { + auto dotOpEnc = mlir::dyn_cast(operandEncoding); + if (!mlir::isa(operandEncoding) && + !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && + mlir::isa(dotOpEnc.getParent()))) { + return emitOptionalError( + location, "unexpected operand layout for NvidiaMmaEncodingAttr v3"); + } + } else if (auto dotOpEnc = + mlir::dyn_cast(operandEncoding)) { + if (opIdx != dotOpEnc.getOpIdx()) + return emitOptionalError(location, "Wrong opIdx"); + if (retEncoding != dotOpEnc.getParent()) + return emitOptionalError(location, "Incompatible parent encoding"); + } else + return emitOptionalError( + location, "Dot's a/b's encoding should be of DotOperandEncodingAttr"); + return success(); + } + + LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const override { + auto aEncoding = + mlir::dyn_cast(operandEncodingA); + auto bEncoding = + mlir::dyn_cast(operandEncodingB); + if (!aEncoding && !bEncoding) + return mlir::success(); + auto mmaAEncoding = + mlir::dyn_cast_or_null(aEncoding.getParent()); + if (mmaAEncoding && mmaAEncoding.isHopper()) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return op->emitError("mismatching encoding between A and B operands"); + if (aEncoding.getKWidth() != bEncoding.getKWidth()) + return op->emitError("mismatching kWidth between A and B operands"); + return success(); + } + + // Given a src shape + encoding and a dst shape, our goal is to compute a dst + // encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z] + // contains elements [a,b,c,d] before the reshape, it contains those same + // elements after the reshape, they're just "renamed". + // + // A dst encoding that satisfies this property does not exist for all inputs. + // Here are some positive and negative examples. + // + // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so + // dim 1 is the fastest-changing in the dst, but the src has the opposite + // order. + // - OK: 2x2x32 order=[1,0,2] -> 4x32. We choose dst order [0,1]. + // What's important is that the 2x2 dimensions appear in major-to-minor + // order. + // - NOT OK: 32x32 sizePerThread=[2,2] -> 1024. Thread 0 in the src + // contains elements [(0,0), (0,1), (1,0), and (1,1)]. We cannot express + // this with an encoding based on the dst shape. + // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will + // contain the same elements as before. + // + // Users of this function require that it is symmetrical: if + // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => + // srcEnc. + LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + auto src = mlir::dyn_cast(srcEnc); + if (!src) { + return emitOptionalError( + loc, "Non-reordering reshape only supports BlockedEncoding"); + } + + // Nop reshape; we can always infer an encoding. + if (srcShape == dstShape) { + dstEnc = srcEnc; + return success(); + } + + // default -> default encoding is always a nop. + auto context = srcEnc.getContext(); + int32_t numWarps = product(src.getWarpsPerCTA()); + int32_t threadsPerWarp = product(src.getThreadsPerWarp()); + int32_t numCTAs = product(src.getCTALayout().getCTAsPerCGA()); + if (srcEnc == getDefaultBlockedEncoding(context, srcShape, numWarps, + threadsPerWarp, numCTAs)) { + dstEnc = getDefaultBlockedEncoding(context, dstShape, numWarps, + threadsPerWarp, numCTAs); + return success(); + } + + // Feature flag to disable this routine while it's relatively new. + // TODO(jlebar): Remove this once we're confident in the code. + if (triton::tools::getBoolEnv( + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE")) { + return failure(); + } + + // Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA + // should be like the other fields in blocked encoding, but I'm not sure how + // to handle CTASplitNum. + if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || + !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { + return emitOptionalError( + loc, "Non-reordering reshape does not currently support multi-CTA " + "layouts other than the default layout."); + } + + // Cowardly refuse to handle encodings where shape[dim] is not divisible by + // sizePerThread[dim], threadsPerWarp[dim], and warpsPerCTA[dim]. (We make + // an exception if the block is larger than the shape.) + auto checkDivisibility = [&](StringRef name, ArrayRef subblock) { + for (int dim = 0; dim < srcShape.size(); dim++) { + if (srcShape[dim] >= subblock[dim] && + srcShape[dim] % subblock[dim] != 0) { + return emitOptionalError(loc, + "Can't do a non-reordering reshape because " + "the size of dimension ", + dim, " (", srcShape[dim], ")", + " is not divisible by ", name, "[", dim, "]", + " = ", subblock[dim]); + } + } + return success(); + }; + if (!succeeded( + checkDivisibility("sizePerThread", src.getSizePerThread())) || + !succeeded( + checkDivisibility("threadsPerWarp", src.getThreadsPerWarp())) || + !succeeded(checkDivisibility("warpsPerCTA", src.getWarpsPerCTA()))) { + return failure(); + } + + SmallVector, SmallVector>> decomp = + getReshapeDecomposition(srcShape, dstShape); + + // enc.order[i] == j means that dimension j is the enc.order[i]'th most + // minor. But what we usually want is the inverse: inverse(enc.order)[i] = j + // means that dimension i is the j'th most minor (larger means more major). + auto srcInvOrder = inversePermutation(src.getOrder()); + + // If src dims [a,b,c] are to be merged, then they must be consecutive in + // physical order, with `a` being the most major. + for (const auto &[srcDims, dstDims] : decomp) { + if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { + return emitOptionalError(loc, + "Cannot do a non-reordering reshape given " + "this src encoding order. Dimensions [", + join(srcDims), + "] must be physically consecutive."); + } + } + + // If src dims [a,b,c] are to be merged, then `c` must fill up sizePerThread + // / threadsPerWarp / blocksPerCTA before `b` can have any non-1 values. + // Examples: + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,2,2]. + // The total sizePerThread for dim 2 is 2, which is less than dim 2's + // size of 4. Therefore dim 1 cannot have non-1 sizePerThread. + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4]. + // Dim 2's sizePerThread covers its whole size, so dim 1 is allowed to + // have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[2,1,4]. + // Dim 1's sizePerThread does not cover its whole size, so dim 0 is not + // allowed to have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,1,2], + // threadsPerWarp=[1,2,1]. + // Dim 2 has 2 elems per thread and 1 thread per warp. 2*1 is less than + // dim 2's size. Therefore dim 1 must have threadsPerWarp=1. + // + // In addition, the encoding's block can be larger than the shape, but only + // in the most-major dimension of each decomposed chunk, and only after + // we've "used up" the more minor dims. Examples: + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4], threadsPerWarp=[16,2,1], + // warpsPerCTA=[4,1,1]. + // The whole size of dims 0 and 1 are covered by sizePerThread * + // threadsPerWarp. Therefore dim 2 is allowed to have threadsPerWarp and + // warpsPerCTA larger than its size. + for (const auto &[srcDims, dstDims] : decomp) { + auto shapeRemaining = gather(srcShape, srcDims); + auto checkSubblock = [&, srcDims = srcDims](ArrayRef subblock) { + // Iterate minor-to-major (i==0 is most major). + for (int i = srcDims.size() - 1; i >= 0; i--) { + int dim = srcDims[i]; + if (subblock[dim] == 1) { + continue; + } + + // Check that more-minor dims all have 1 in shapeRemaining. + for (int j = i + 1; j < srcDims.size(); j++) { + if (shapeRemaining[j] != 1) { + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Must use " + "up sizePerThread / threadsPerWarp / warpsPerCTA for " + "more-minor dimensions before more major-dims can use them."); + } + } + + if (shapeRemaining[i] >= subblock[dim]) { + assert(shapeRemaining[i] % subblock[dim] == 0); // checked earlier + shapeRemaining[i] /= subblock[dim]; + } else { + shapeRemaining[i] = 0; + } + + // Is the block larger than the shape in this dimension? This is OK + // only if we're the most-major dimension of the chunk and in all + // future chunks, only this most-major dim has a non-1 size. + if (shapeRemaining[i] == 0 && i != 0) { + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Block " + "size in dimension ", + dim, + " is larger than the shape that dimension, but this is only " + "allowed for the most-major dimension of a reshape chunk"); + } + } + return success(); + }; + if (!succeeded(checkSubblock(src.getSizePerThread())) || + !succeeded(checkSubblock(src.getThreadsPerWarp())) || + !succeeded(checkSubblock(src.getWarpsPerCTA()))) { + return failure(); + } + } + + // Given e.g. src.getSizePerThread(), computeSubblockSize computes e.g. + // dst.getSizePerThread(). This should be called for each of sizePerThread, + // threadsPerWarp, and warpsPerCTA, in that order. + SmallVector dstShapeRemaining(dstShape); + auto computeSubblockSize = [&](ArrayRef srcSubblock, + SmallVector &dstSubblock, + StringRef fieldName) -> LogicalResult { + // The dst subblock is "filled up" greedily starting with the most minor + // dim. When we're done, we are left with a smaller shape, of size + // dstShape / dstSubblock, which we store in dstShapeRemaining and use for + // the next call to computeSubblockSize. + dstSubblock.resize(dstShape.size()); + for (const auto &[srcDims, dstDims] : decomp) { + int64_t subblockRemaining = product(gather(srcSubblock, srcDims)); + for (int i = dstDims.size() - 1; i >= 0; i--) { + auto &val = dstSubblock[dstDims[i]]; + auto &shapeRemaining = dstShapeRemaining[dstDims[i]]; + val = std::min(subblockRemaining, shapeRemaining); + + assert(shapeRemaining % val == 0); // Checked earlier. + subblockRemaining /= val; + shapeRemaining /= val; + } + + // If there are any elems remaining in the subblock, it must be because + // the block is larger than the shape. This excess goes into the + // most-major dim of the subblock. + dstSubblock[dstDims[0]] *= subblockRemaining; + } + return success(); + }; + + SmallVector dstSizePerThread; + SmallVector dstThreadsPerWarp; + SmallVector dstWarpsPerCTA; + if (!succeeded(computeSubblockSize(src.getSizePerThread(), dstSizePerThread, + "sizePerThread")) || + !succeeded(computeSubblockSize(src.getThreadsPerWarp(), + dstThreadsPerWarp, "threadsPerWarp")) || + !succeeded(computeSubblockSize(src.getWarpsPerCTA(), dstWarpsPerCTA, + "warpsPerCTA"))) { + return failure(); + } + + // Since we know that each set of srcDims is consecutive, we can + // meaningfully sort decomp by the physical order of the src dimensions, + // major-to-minor. This will also be the order of the dst dimensions. + llvm::sort(decomp, [&](const auto &a, const auto &b) { + const auto &[srcDimsA, dstDimsA] = a; + const auto &[srcDimsB, dstDimsB] = b; + return srcInvOrder[srcDimsA.front()] < srcInvOrder[srcDimsB.front()]; + }); + + // Compute the dst order. Make the dimensions appear in the same order as + // their corresponding src dimensions. + SmallVector dstInvOrder(dstShape.size()); + int i = 0; + for (const auto &[srcDims, dstDims] : decomp) { + for (auto dim : reverse(dstDims)) { + dstInvOrder[dim] = i++; + } + } + auto dstOrder = inversePermutation(dstInvOrder); + + // CTALayout can be all 1's because we bailed on multi-CTA layouts above. + auto CTALayout = CTALayoutAttr::get( + src.getContext(), + /*CTAsPerCGA=*/SmallVector(dstShape.size(), 1), + /*CTASplitNum=*/SmallVector(dstShape.size(), 1), + /*CTAOrder=*/llvm::to_vector(llvm::seq(dstShape.size()))); + + unsigned loadType = src.getLoadType(); + ArrayRef smeWarpsPerCTA = src.getSmeWarpsPerCTA(); + + dstEnc = BlockedEncodingAttr::get( + src.getContext(), dstSizePerThread, dstThreadsPerWarp, dstWarpsPerCTA, + dstOrder, CTALayout, loadType, smeWarpsPerCTA); + + return success(); + } + + LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const override { + auto enc = mlir::dyn_cast(srcEnc); + if (!enc) { + return emitOptionalError(loc, + "JoinOp can only operate on BlockedEncoding"); + } + + // JoinOp takes two tensors of shape AxBxC and generates a tensor of shape + // AxBxCx2. The encoding is the same as the input, but with 2 elems per + // thread in the new dimension. The new dimension is most-minor. + auto append = [](ArrayRef vals, int val) { + SmallVector ret(vals); + ret.push_back(val); + return ret; + }; + auto appendMinorDim = [](ArrayRef order) { + SmallVector ret(order); + ret.insert(ret.begin(), ret.size()); + return ret; + }; + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + append(enc.getSizePerThread(), 2), // + append(enc.getThreadsPerWarp(), 1), // + append(enc.getWarpsPerCTA(), 1), // + appendMinorDim(enc.getOrder()), // + CTALayoutAttr::get(enc.getContext(), // + append(enc.getCTAsPerCGA(), 1), + append(enc.getCTASplitNum(), 1), + appendMinorDim(enc.getCTAOrder())), + enc.getLoadType(), // + enc.getSmeWarpsPerCTA()); + return success(); + } + + LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const override { + auto enc = mlir::dyn_cast(srcEnc); + if (!enc) { + return emitOptionalError(loc, + "SplitOp can only operate on BlockedEncoding"); + } + + // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of + // shape AxBxC. The input must have 2 elements per thread in the last + // dimension, which must be most-minor. The result encoding is the same as + // the input, but with the last dimension removed. + if (enc.getSizePerThread().back() != 2) { + return emitOptionalError(loc, + "SplitOp requires 2 elements per thread in the " + "last dimension of the input"); + } + if (enc.getThreadsPerWarp().back() != 1 || + enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, "SplitOp requires threadsPerWarp, warpsPerCTA, " + "and CTAsPerCGA = 1 for the last dimension of the input"); + } + if (enc.getOrder().front() != enc.getOrder().size() - 1) { + return emitOptionalError( + loc, "SplitOp requires the last dimension to be most-minor in order"); + } + if (enc.getCTALayout().getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, + "SplitOp requires the last dimension to be most-minor in CTAOrder"); + } + + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + ArrayRef(enc.getSizePerThread()).drop_back(1), + ArrayRef(enc.getThreadsPerWarp()).drop_back(1), + ArrayRef(enc.getWarpsPerCTA()).drop_back(1), + ArrayRef(enc.getOrder()).drop_front(1), + CTALayoutAttr::get(enc.getContext(), // + ArrayRef(enc.getCTAsPerCGA()).drop_back(1), + ArrayRef(enc.getCTASplitNum()).drop_back(1), + ArrayRef(enc.getCTAOrder()).drop_front(1)), + enc.getLoadType(), enc.getSmeWarpsPerCTA()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Canonicalizer +//===----------------------------------------------------------------------===// + +// reshape(cvt) -> reshape +struct CanonicalizeConvertFromReshape + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + if (isExpensiveView(convert.getSrc().getType(), op.getType())) + return failure(); + if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getAllowReorder()); + return mlir::success(); + } +}; + +// histogram(cvt) -> histogram +struct CanonicalizeConvertFromHistogram + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::HistogramOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getSrc()); + return mlir::success(); + } +}; + +// alloc(cvt) -> alloc +struct CanonicalizeConvertFromAlloc + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, + PatternRewriter &rewriter) const override { + if (!op.getSrc()) + return failure(); + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getSrc()); + return mlir::success(); + } +}; + +// local_store(cvt) -> local_store +struct CanonicalizeConvertFromLocalStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp(op, convert.getSrc(), + op.getDst()); + return mlir::success(); + } +}; + +struct CanonicalizeConvertFromConvert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ConvertLayoutOp op, + PatternRewriter &rewriter) const override { + // Convert to the same layout is redundant. + if (op->getResultTypes() == op->getOperandTypes()) { + rewriter.replaceOp(op, op->getOperands()); + return success(); + } + + // We don't handle conversions to DotOperandEncodingAttr. This is a + // heuristic to accommodate fused attention. + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); + if (mlir::isa(dstType.getEncoding()) && + (mlir::isa(srcType.getEncoding()) || + mlir::isa(srcType.getEncoding()))) + return failure(); + + // for hopper MMAv3 + if (mlir::isa(dstType.getEncoding()) && + mlir::isa(srcType.getEncoding()) && + llvm::any_of(op.getResult().getUsers(), + [](Operation *dot) { return isa(dot); })) { + return failure(); + } + + Operation *arg = op.getSrc().getDefiningOp(); + if (!arg) + return failure(); + + // cvt(reshape) -> reshape + if (auto reshape = dyn_cast(arg)) { + if (!reshape.getAllowReorder() || + reshape.getEfficientLayout().has_value() || + isExpensiveView(reshape.getSrc().getType(), op.getType())) + return failure(); + + // In TritonGPUToLLVM phase, ViewOp is converted to unpacking and packing + // operations, which requires the element type to match between unpacking + // and packing. However, part of values with dot operand encoding will be + // packed/unpacked as i32 elements instead of the underlying element type. + // To avoid errors, skip this folding when either the operand or result + // of view has a dot operand encoding. + if (hasDotOperandEncoding(op->getOperand(0)) || + hasDotOperandEncoding(op->getResult(0))) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + reshape.getResult(), + reshape.getAllowReorder()); + return success(); + } + + // cvt(histogram) -> histogram + if (auto histogram = dyn_cast(arg)) { + // For histogram ops the input and output layouts are independent, so we + // can always fold convert into the histogram op. + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + histogram.getSrc()); + return success(); + } + + // cvt(local_load) -> local_load. + if (auto sharedLoad = dyn_cast(arg)) { + // Shared_load can load to any layout so we can always fold convert into + // it. + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + sharedLoad.getSrc()); + return success(); + } + + // cvt(cat) -> cat + if (auto cat = dyn_cast(arg)) { + if (isExpensiveCat(cat, op.getType().getEncoding())) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + cat.getOperands()); + return success(); + } + + // cvt(cvt(x, type1), type2) -> cvt(x, type2) + if (auto cvt = dyn_cast(arg)) { + auto srcType = op.getSrc().getType(); +#ifdef __ILUVATAR__ + if (triton::gpu::isMmaConvertLayout(cvt)) + return failure(); + if (triton::gpu::isSliceMmaConvertLayout(arg, true, false)) + return mlir::failure(); +#endif + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), cvt.getSrc()); + return success(); + } + + // cvt(type1, splat(type2, x)) -> splat(type1, x) + if (auto splat = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + splat.getSrc()); + return success(); + } + + // cvt(type1, make_range(type2, x)) -> make_range(type1, x) + if (auto range = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), range.getStart(), range.getEnd()); + return success(); + } + + // cvt(type, constant) -> constant + if (auto cst = llvm::dyn_cast(arg)) + if (auto ret = dyn_cast(cst.getValue())) { + auto ty = cast(op->getResultTypes().front()); + auto newRet = + SplatElementsAttr::get(ty, ret.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newRet); + return success(); + } + return failure(); + } +}; + +void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); +} + +// LocalAllocOp +void LocalAllocOp::getEffects( + SmallVectorImpl> + &effects) { + Operation *op = getOperation(); + // If allocation is immutable, mark it as no side effect allow things like + // CSE, DCE to work in early compiler passes. + // After the memory offset is computed, we attach the true side effect to the + // op. + if (!getType().getMutableMemory() && !op->hasAttr("allocation.offset")) + return; + effects.emplace_back(MemoryEffects::Allocate::get(), + mlir::triton::gpu::SharedMemory::get()); + if (getSrc()) + effects.emplace_back(MemoryEffects::Write::get(), getResult(), + mlir::triton::gpu::SharedMemory::get()); +} + +// LocalLoadOp +void LocalLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + mlir::triton::gpu::SharedMemory::get()); +} + +// LocalStoreOp +void LocalStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), getDst(), + mlir::triton::gpu::SharedMemory::get()); +} + +// AsyncCopyGlobalToLocalOp +void AsyncCopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), getResult(), + mlir::triton::gpu::SharedMemory::get()); +} + +LogicalResult MemDescSubviewOp::verify() { + auto srcTy = getSrc().getType(); + auto dstTy = getType(); + + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match desc element type"); + } + if (getOffsets().size() != srcTy.getRank()) { + return emitError("offsets must have the same rank as input"); + } + if (srcTy.getRank() < dstTy.getRank()) { + return emitError("result rank must be less than or equal to input rank"); + } + auto rankDiff = srcTy.getRank() - dstTy.getRank(); + for (int i = 0; i < dstTy.getRank(); i++) { + if (dstTy.getDimSize(i) > srcTy.getDimSize(i + rankDiff)) { + return emitError( + "result shape cannot be larger than input shape at dimension ") + << i; + } + } + + auto srcEnc = srcTy.getEncoding(); + auto dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("src and result must both have or not have an encoding"); + } + + if (!isa(srcEnc)) { + return emitError("src encoding must be SharedEncodingAttr"); + } + if (!isa(dstEnc)) { + return emitError("result encoding must be SharedEncodingAttr"); + } + + // TODO(jlebar): Currently we generate illegal encodings, so we can't add a + // verifier for them. In particular, we use the same encoding for the src and + // dst of a subview op, when the subview removes a dimension. That generates + // an illegal shared encoding (because the size of `order` doesn't match the + // rank of the tensor), but it's not checked anywhere, and we believe the + // resulting code ultimately works. + + return success(); +} + +void TritonGPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc" + >(); + addInterfaces(); + addInterfaces(); +} + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" + +// verify TritonGPU ops +LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp new file mode 100644 index 000000000..ae34598ae --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -0,0 +1,489 @@ +#include + +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton::gpu { +namespace { + +// We use the following nomenclature in this file. +// +// - ctaLayout: A layout for one block, i.e. input dims (register, lane, warp). +// - cgaLayout: Arrangement of multiple blocks, i.e. input dims (block). +// +// Note that this is inconsistent with the type name CTALayoutAttr. That type +// is equivalent to our cgaLayout. +// +// IMO the type name is wrong. If we tried to be consistent anyway, then we'd +// have to rename ctaLayout to "warpLayout". I think that's more confusing than +// being inconsistent about "cgaLayout", especially when we have to consider the +// size of the warpLayout (surely that's not the "warpSize"). + +#define S(v) StringAttr::get(ctx, (v)) + +// Returns ["out0", "out1", ..., "out"]. +SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { + SmallVector ret; + for (int i = 0; i < rank; i++) { + ret.push_back(S("dim" + llvm::Twine(i))); + } + return ret; +} + +// Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of +// size product(shape) and then reshaping to permute(shape, order). +LinearLayout identityND(StringAttr inDimName, ArrayRef shape, + ArrayRef order, + ArrayRef outDimNames) { + assert(shape.size() == order.size()); + + MLIRContext *ctx = inDimName.getContext(); + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < shape.size(); i++) { + // Start with the most-minor dimension, which is order[0]. + int dim = order[i]; + ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]); + } + return ret; +} + +// Make a LinearLayout that maps a block-id to an N-dimensional index. +// +// The tensor is split up into CTAsPerCGA pieces, which are distributed among +// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups). +// +// See the nomenclature note at the top of the file for an explanation of why +// this is called makeCgaLayout when it accepts a CTALayoutAttr. +LinearLayout makeCgaLayout(CTALayoutAttr layout) { + MLIRContext *ctx = layout.getContext(); + StringAttr kBlock = S("block"); + + int rank = layout.getCTAOrder().size(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < rank; i++) { + // Start with the most minor dimension, which is order[0]. + int dim = layout.getCTAOrder()[i]; + int split = layout.getCTASplitNum()[dim]; + int ctas = layout.getCTAsPerCGA()[dim]; + assert(ctas % split == 0); + ret *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * + LinearLayout::zeros1D(ctas / split, kBlock, outDimNames[dim]); + } + + // Transpose to standard order (dim0, dim1, ...). + return ret.transposeOuts(outDimNames); +} + +// Shrinks the output set of a layout function while leaving the input set +// unchanged, by making high-order inputs in inDimName map to the same output. +// Attempts to shrink down to desiredSize, but this is not always possible just +// by modifying one the specified input dimension. +// +// We do this by making the most-major inputs to the layout map to 0. This +// effectively duplicates data along that input dimension. For example, this +// layout has out-dim size 32: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 16. +// +// If we shrink it to size 16 along the `lane` dimension, we set L(lane=2) to 0: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0. +// +// This means that lane=2 has the same data as lane=0. +// +// If we shrink to size 8 along the lane dimension, we set L(lane=1) = 0 as +// well. But when we do this, we have to remove bit 1 (the value of L(lane=1)) +// from all other bases: +// +// L(register=1) = 4 +// L(register=2) = 2 +// L(register=1) = 1 +// L(lane=1) = 0 +// L(lane=2) = 0. +// +// Note this only works because the bases are powers of two. I don't quite know +// what to do when they're not. +LinearLayout shrinkCodomain(const LinearLayout &layout, StringAttr inDimName, + StringAttr outDimName, int desiredSize) { + assert(llvm::isPowerOf2_32(desiredSize)); + int outDimIdx = layout.getOutDimIndex(outDimName); + int desiredZeros = + llvm::Log2_32(layout.getOutDimSize(outDimName) / desiredSize); + if (desiredZeros == 0) { + return layout; + } + + // Find the desiredZeros most-major basis vectors that are not already zero. + // These are the ones we will set to zero. + SmallVector basesToZero; + for (int i = layout.getInDimSizeLog2(inDimName) - 1; + i >= 0 && basesToZero.size() < desiredZeros; i--) { + int basis = layout.getBasis(inDimName, i, outDimName); + if (basis != 0) { + basesToZero.push_back(basis); + } + } + + // Bail if all the bases are already zero; nothing more we can do. + if (basesToZero.empty()) { + return layout; + } + + // The algorithm below only works because the bases are powers of two. I'm + // not sure what to do otherwise. + assert(llvm::all_of(basesToZero, + [&](int basis) { return llvm::isPowerOf2_32(basis); })); + + // We want to zero out the bases in `basesToZero`, and also "shift out" the + // corresponding bits from all other bases. For example if we remove the + // basis with value 8 = 0b100, then if another basis has value 26 = 0b11010, + // the 1 in its 3rd position gets removed and it becomes 10 = 0b1010. + // + // We could manually alter the bases in `layout` to achieve this, but it's + // perhaps simpler to use the linearity of LLs to our advantage. + // + // Consider the function O which is the identity map from out-dims to + // out-dims. We can easily calculate what happens when we remove the relevant + // bases from O. Call this new function O'. + // + // Because of linearity, removing the bases from L is equivalent to composing + // L with O'. So that's what we do below. + + // Construct the out-dims -> out-dims identity layout O. + LinearLayout outputIdentity = LinearLayout::empty(); + for (StringAttr dim : layout.getOutDimNames()) { + outputIdentity *= + LinearLayout::identity1D(layout.getOutDimSize(dim), dim, dim); + } + + // Modify O to remove the relevant bases. + // + // TODO(jlebar): I don't like manually modifying bases here. Perhaps this + // should be a function on LinearLayout. + LinearLayout::BasesT newBases = outputIdentity.getBases(); + llvm::sort(basesToZero); + for (int basis : basesToZero) { + int idx = llvm::Log2_32(basis); + for (int i = newBases[outDimName].size() - 1; i > idx; i--) { + newBases[outDimName][i][outDimIdx] = + newBases[outDimName][i - 1][outDimIdx]; + } + newBases[outDimName][idx][outDimIdx] = 0; + } + + // Construct O'. + LinearLayout transform(std::move(newBases), layout.getOutDimNames()); + + // Compose O' with L. + return layout.compose(transform); +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// larger than shape[d]. Do this without changing the size of the layout's +// inputs (i.e. leave its domain unchanged). +// +// This function is invariant to the order of the layout's input and output +// dimensions. +LinearLayout ensureLayoutNotLargerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + MLIRContext *ctx = shape.begin()->first.getContext(); + + // For the purposes of this function, "block" is the "most-minor" dimension. + // This is just a consequence of how legacy layouts work: We only put the same + // tensor element into two different blocks as a last resort, only after all + // the registers in all the lanes in all the warps in a block already have the + // same tensor element. + SmallVector inDimNames = { + S("block"), + S("register"), + S("lane"), + S("warp"), + }; + + LinearLayout ret = layout; + for (auto outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + if (actualSize <= desiredSize) { + continue; + } + assert(actualSize % desiredSize == 0); + // TODO: We claim this is invariant to the order of dims, so can we get rid + // of llvm::reverse? + for (StringAttr inDimName : llvm::reverse(inDimNames)) { + if (ret.hasInDim(inDimName)) { + ret = shrinkCodomain(ret, inDimName, outDimName, desiredSize); + } + } + assert(ret.getOutDimSize(outDimName) == desiredSize); + } + return ret; +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along the "register" dimension. +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + + MLIRContext *ctx = shape.begin()->first.getContext(); + StringAttr kRegister = S("register"); + + LinearLayout ret = layout; + for (StringAttr outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + assert(actualSize > desiredSize || desiredSize % actualSize == 0); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kRegister, + outDimName); + assert(ret.getOutDimSize(outDimName) >= desiredSize); + } + return ret; +} + +// Combines the layout of a CTA (input dims [register, lane, warp]) with the +// layout of a CGA (i.e. a block), and ensures that the resulting layout has the +// given shape. +// +// See the nomenclature note at the top of the file for why the variable with +// type CTALayoutAttr is called cgaLayoutAttr. +LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, + CTALayoutAttr cgaLayoutAttr, + ArrayRef shape) { + int rank = shape.size(); + assert(ctaLayout.getNumOutDims() == rank); + assert(cgaLayoutAttr.getCTAOrder().size() == rank); + MLIRContext *ctx = cgaLayoutAttr.getContext(); + + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + llvm::SmallDenseMap labeledShape; + for (auto [dim, size] : llvm::zip(outDimNames, shape)) { + labeledShape[dim] = size; + } + + LinearLayout cgaLayout = + ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape) + .transposeOuts(ctaLayout.getOutDimNames()); + + // Calculate the shape of the ctaLayout, which is `shape` divided by the + // cgaLayout's size. + llvm::SmallDenseMap ctaShape; + assert(ctaLayout.getOutDimNames() == cgaLayout.getOutDimNames()); + for (auto dim : ctaLayout.getOutDimNames()) { + ctaShape[dim] = + std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim)); + } + + ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape); + ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape); + + LinearLayout ret = (ctaLayout * cgaLayout).transposeOuts(outDimNames); + for (auto dim : ret.getOutDimNames()) { + assert(ret.getOutDimSize(dim) == labeledShape[dim]); + } + return ret; +} + +LinearLayout blockedToLinearLayout(ArrayRef shape, + BlockedEncodingAttr blocked) { + assert(shape.size() == blocked.getOrder().size()); + + int rank = shape.size(); + MLIRContext *ctx = blocked.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + const auto &order = blocked.getOrder(); + LinearLayout ctaLayout = + identityND(S("register"), blocked.getSizePerThread(), order, + outDimNames) * + identityND(S("lane"), blocked.getThreadsPerWarp(), order, outDimNames) * + identityND(S("warp"), blocked.getWarpsPerCTA(), order, outDimNames); + + return combineCtaCgaWithShape(ctaLayout, blocked.getCTALayout(), shape); +} + +LinearLayout ampereMmaToLinearLayout(ArrayRef shape, + NvidiaMmaEncodingAttr mma) { + int rank = shape.size(); + + assert(mma.isAmpere()); + assert(rank == 2 || rank == 3); + assert(mma.getInstrShape().size() == rank); + assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || + (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); + + MLIRContext *ctx = mma.getContext(); + SmallVector dimNames = standardOutDimNames(ctx, rank); + + LinearLayout ctaLayout( + {{S("register"), {{1, 0}, {0, 8}}}, + {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, + llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); + + ctaLayout *= identityND( + S("warp"), mma.getWarpsPerCTA(), + llvm::to_vector(llvm::reverse(llvm::seq(rank))), dimNames); + + return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); +} + +LinearLayout hopperMmaToLinearLayout(ArrayRef shape, + NvidiaMmaEncodingAttr mma) { + int rank = shape.size(); + assert(mma.isHopper()); + assert(rank == 2); + + // wgmma operates on groups of 4 warps. + assert(product(mma.getWarpsPerCTA()) % 4 == 0); + + // Check that it's a known MMA layout. + assert(mma.getInstrShape().size() == 3); + int m = mma.getInstrShape()[0]; + int n = mma.getInstrShape()[1]; + int k = mma.getInstrShape()[2]; + assert(m == 16); + assert(n == 16 || n == 32 || n == 64 || n == 128 || n == 256); + assert(k == 8 || k == 16 || k == 32); + + MLIRContext *ctx = mma.getContext(); + LinearLayout ctaLayout( + {{S("register"), {{1, 0}, {0, 8}}}, + {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, + {S("dim1"), S("dim0")}); + + // Expand the `register` dimension so the size of dim1 matches `n`. + ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")), + S("register"), S("dim1")); + + // Expand the `warp` dimension according to warpsPerCTA. + // + // It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but + // this really does seem to be correct. + ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}, + {S("dim0"), S("dim1")}) + .transposeOuts(ctaLayout.getOutDimNames()); + + return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); +} + +std::optional toLinearLayout(ArrayRef shape, + SliceEncodingAttr slice) { + MLIRContext *ctx = slice.getContext(); + + // First compute the linear layout for this layout's parent. + SmallVector parentShape(shape); + parentShape.insert(parentShape.begin() + slice.getDim(), 1); + std::optional parentLL = + triton::gpu::toLinearLayout(parentShape, slice.getParent()); + if (!parentLL) { + return std::nullopt; + } + + // Remove dimension slice.getDim() from the parent layout. + // + // 1. Construct a layout `transform` from parent-out-dims to slice-out-dims + // that removes the relevant out-dim. + // 2. Compute linearSlice = parent.compose(transform). Now linearSlice maps + // from parent in-dims to slice out-dims. + // 3. Fix up duplicate registers introduced by slicing. + auto outDimNames = standardOutDimNames(ctx, shape.size() + 1); + LinearLayout transform = LinearLayout::empty(); + for (auto [idx, outDim] : llvm::enumerate(parentLL->getOutDimNames())) { + if (idx == slice.getDim()) { + // Because we're multiplying by all zeros, we could replace outDimNames[0] + // with any other valid out-dim; the layout will be the same. + transform *= LinearLayout::zeros1D(parentLL->getOutDimSize(outDim), + outDim, outDimNames[0]); + } else { + transform *= LinearLayout::identity1D( + parentLL->getOutDimSize(outDim), outDim, + outDimNames[idx - (idx < slice.getDim() ? 0 : 1)]); + } + } + LinearLayout sliceLL = parentLL->compose(transform); + + // Step 3: Along the "register" dim, remove any all-zero bases. + auto bases = sliceLL.getBases(); + std::vector> newRegBases; + for (const auto &basis : bases[S("register")]) { + if (llvm::any_of(basis, [](int b) { return b != 0; })) { + newRegBases.push_back(basis); + } + } + bases[S("register")] = newRegBases; + + LinearLayout ret = LinearLayout(std::move(bases), sliceLL.getOutDimNames()); + + // Match a hack in the legacy code that ensures that the number of registers + // matches getTotalElemsPerThread. Yup: We just removed all the zeros, now + // we're (maybe) adding some back. :) + // + // TODO(jlebar): Once getTotalElemsPerThread uses LLs instead of the existing + // legacy code, I think we can remove this. + int expectedNumRegisters = getTotalElemsPerThread(RankedTensorType::get( + shape, IntegerType::get(ctx, 32) /*dummy type*/, slice)); + if (ret.getInDimSize(S("register")) != expectedNumRegisters) { + int extraZeros = expectedNumRegisters / ret.getInDimSize(S("register")); + // Our use of "dim0" here is arbitrary; because we're adding zeros, any + // output dimension would work. + ret *= LinearLayout::zeros1D(extraZeros, S("register"), S("dim0")); + } + return ret; +} + +} // anonymous namespace + +std::optional toLinearLayout(ArrayRef shape, + Attribute layout) { + if (auto blocked = dyn_cast(layout)) { + return blockedToLinearLayout(shape, blocked); + } + if (auto mma = dyn_cast(layout)) { + if (mma.isAmpere()) { + return ampereMmaToLinearLayout(shape, mma); + } + if (mma.isHopper()) { + return hopperMmaToLinearLayout(shape, mma); + } + } + if (auto slice = dyn_cast(layout)) { + return toLinearLayout(shape, slice); + } + + // TODO(jlebar): Other layouts + return std::nullopt; +} + +} // namespace mlir::triton::gpu diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/IR/Types.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/IR/Types.cpp new file mode 100644 index 000000000..77f673cc2 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/IR/Types.cpp @@ -0,0 +1,38 @@ +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::gpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + +Type TokenType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + int type = 1; + if (parser.parseInteger(type)) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return TokenType::get(parser.getContext(), type); +} + +void TokenType::print(AsmPrinter &printer) const { + printer << "<" << getType() << ">"; +} + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::gpu::TritonGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + >(); +} diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp new file mode 100644 index 000000000..df84c4e62 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -0,0 +1,405 @@ +#include + +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +// Get the highest version supported for the hardware and the dot. +static int getMMAVersionSafe(int computeCapability, DotOp op) { + int baseVersion = 0; + if (computeCapability < 75) { + baseVersion = 1; + } else if (computeCapability < 90) { + baseVersion = 2; + } else if (computeCapability < 100) { + baseVersion = 3; + } else { + assert(false && "computeCapability not supported"); + } + + for (; baseVersion >= 1; baseVersion--) { + if (supportMMA(op, baseVersion)) { + return baseVersion; + } + } + + return 0; +} + +SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, + int numWarps) { + auto rank = shape.size(); + // Early exit for batched matmul + if (rank == 3) + return {(unsigned)numWarps, 1, 1}; + + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion() && + !isa(op); + }; + auto slices = multiRootGetSlice(dotOp, {filter}, {filter}); + bool hasChainedDot = false; + for (Operation *op : slices) { + if (isa(op) && (op != dotOp)) { + auto chainedDot = cast(op); + auto resTy = chainedDot.getResult().getType(); + if (resTy.getRank() != rank) { + continue; + } + if (auto mmaEncoding = + dyn_cast(resTy.getEncoding())) { + return getWarpsPerCTA(mmaEncoding); + } + hasChainedDot = true; + } + } + if (hasChainedDot) { + if (shape[0] >= shape[1]) { + return {(unsigned)numWarps, 1}; + } else { + return {1, (unsigned)numWarps}; + } + } + + SmallVector ret(rank, 1); + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 1] = 8; + shapePerWarp[rank - 2] = 16; + // TODO (@daadaada): double-check. + // original logic in + // https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252 + // seems buggy for shape = [32, 16] ? + do { + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] / shapePerWarp[0] / ret[0] >= + shape[1] / (shapePerWarp[1] * 2) / ret[1]) { + if (ret[0] < shape[0] / shapePerWarp[0]) { + ret[0] *= 2; + } else + ret[1] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; +} + +SmallVector +warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, + const SmallVector &instrShape) { + SetVector slices; + mlir::getForwardSlice(dotOp.getResult(), &slices); + if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != + slices.end()) + return {(unsigned)numWarps, 1}; + + // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). + SmallVector ret = {4, 1}; + SmallVector shapePerWarp = {16, instrShape[1]}; + do { + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] > shapePerWarp[0] * ret[0]) { + ret[0] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; +} + +class BlockedToMMA : public mlir::RewritePattern { + int computeCapability; + mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding + mutable llvm::DenseMap dotOpInstNs; + + static bool bwdFilter(Operation *op) { + return op->getNumOperands() == 1 && + (isa(op) || + isPureUnaryInlineAsm(op) || + op->getDialect()->getTypeID() == + mlir::TypeID::get()); + } + + // Finds the first different bitwidth in the chain of shape-preserving + // unary ops that x depends on. + // There are two primary scenarios: + // (1) Upcasting: A sequence such as loading an fp16, followed by arithmetic + // operations, then bitcasting to fp32, and finally computing in fp32. + // (2) Downcasting: This might involve loading an fp32, performing arithmetic + // operations, bitcasting to fp16, and finally computing in fp16. + // In the upcasting scenario, element reordering converts the original + // elements distribution to the order of higher precision primitives. As a + // result, kwidth can be the bitwidth of the lower precision primitive. + // Conversely, in the downcasting scenario, no reordering is performed, + // making it directory use the lower precision primitive. + static int computeOrigBitWidth(Value x) { + int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); + int origBitWidth = finalBitWidth; + SetVector slice; + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = bwdFilter; + getBackwardSlice(x, &slice, opt); + for (auto op : slice) { + if (Value arg = op->getOperand(0)) + if (auto argTy = dyn_cast(arg.getType())) { + auto argBitWidth = argTy.getElementType().getIntOrFloatBitWidth(); + if (argBitWidth != origBitWidth) { + origBitWidth = std::min(origBitWidth, argBitWidth); + break; + } + } + } + return origBitWidth; + } + +public: + BlockedToMMA(mlir::MLIRContext *context, int computeCapability) + : mlir::RewritePattern(DotOp::getOperationName(), 2, context), + computeCapability(computeCapability) {} + + static SmallVector + getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, + int numWarps, const SmallVector &instrShape) { + switch (version) { + case 2: + return warpsPerTileV2(dotOp, shape, numWarps); + case 3: + return warpsPerTileV3(dotOp, shape, numWarps, instrShape); + default: + assert(false && "not supported version"); + return {0, 0}; + } + } + + static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter, + int opIdx) { + OpBuilder::InsertionGuard g(rewriter); + Value arg = v; + if (auto cvtOp = v.getDefiningOp()) + arg = cvtOp.getSrc(); + auto argType = cast(arg.getType()); + auto eltType = argType.getElementType(); + assert(argType.getEncoding() && "unexpected tensor type"); + auto newOrder = getOrder(argType.getEncoding()); + + // MMAv3 with transpose only supports f16 and bf16 data type + // fallback to MMAv3 without transpose for other data types + if (!eltType.isF16() && !eltType.isBF16()) { + if (opIdx == 1) { + newOrder = {0, 1}; + } else { + newOrder = {1, 0}; + } + } + + auto CTALayout = getCTALayout(argType.getEncoding()); + auto newLayout = + SharedEncodingAttr::get(argType.getContext(), argType.getShape(), + newOrder, CTALayout, argType.getElementType()); + auto newType = MemDescType::get(argType.getShape(), + argType.getElementType(), newLayout); + rewriter.setInsertionPointAfterValue(arg); + return rewriter.create(arg.getLoc(), newType, arg); + } + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + if (computeCapability < 70) + return failure(); + auto dotOp = cast(op); + auto ctx = op->getContext(); + // TODO: Check data-types and SM compatibility + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + mlir::isa(oldRetType.getEncoding())) + return failure(); + + // get MMA encoding for the given number of warps + auto retShapePerCTA = getShapePerCTA(oldRetType); + auto mod = op->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + + int versionMajor = getMMAVersionSafe(computeCapability, dotOp); + if (!versionMajor) + return failure(); + + auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, + dotOp.getA().getType(), numWarps); + // operands + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = dotOp.getA().getType(); + auto oldBType = dotOp.getB().getType(); + + NvidiaMmaEncodingAttr mmaEnc; + if (versionMajor == 1) { + SetVector aBwdSlices, bBwdSlices; + auto isCvt = [](Operation *op) { return isa(op); }; + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = isCvt; + getBackwardSlice(a, &aBwdSlices, opt); + getBackwardSlice(b, &bBwdSlices, opt); + // get the source of the first conversion found in slices + auto getCvtArgOrder = [](Operation *op) { + return mlir::cast( + cast(op).getSrc().getType().getEncoding()) + .getOrder(); + }; + bool isARow = true; + bool isBRow = true; + Operation *aOp = a.getDefiningOp(); + Operation *bOp = b.getDefiningOp(); + if (!aBwdSlices.empty()) + aOp = aBwdSlices[0]; + if (!bBwdSlices.empty()) + bOp = bBwdSlices[0]; + if (aOp) + isARow = getCvtArgOrder(aOp)[0] == 1; + if (bOp) + isBRow = getCvtArgOrder(bOp)[0] == 1; + + mmaEnc = NvidiaMmaEncodingAttr::get( + oldRetType.getContext(), versionMajor, numWarps, CTALayout, + instrShape, oldAType.getShape(), oldBType.getShape(), retShapePerCTA, + isARow, isBRow, mmaV1Counter++); + } else if (versionMajor == 2 || versionMajor == 3) { + int versionMinor = computeCapability == 75 ? 1 : 0; + auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); + mmaEnc = NvidiaMmaEncodingAttr::get(oldRetType.getContext(), versionMajor, + versionMinor, warpsPerTile, CTALayout, + instrShape); + } + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); + // convert accumulator + auto oldAcc = dotOp.getOperand(2); + auto newAcc = + rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); + + if (versionMajor == 3) { + a = getMMAv3Operand(a, rewriter, 0); + b = getMMAv3Operand(b, rewriter, 1); + } else { + + // convert operands + int minBitwidth = + std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + Type minType = IntegerType::get(ctx, minBitwidth); + // convert A operand + auto newAEncoding = DotOperandEncodingAttr::get( + oldAType.getContext(), 0, newRetType.getEncoding(), + minBitwidth > 0 ? minType : oldAType.getElementType()); + auto newAType = RankedTensorType::get( + oldAType.getShape(), oldAType.getElementType(), newAEncoding); + a = rewriter.create(a.getLoc(), newAType, a); + // convert B operand + auto newBEncoding = DotOperandEncodingAttr::get( + oldBType.getContext(), 1, newRetType.getEncoding(), + minBitwidth > 0 ? minType : oldBType.getElementType()); + auto newBType = RankedTensorType::get( + oldBType.getShape(), oldBType.getElementType(), newBEncoding); + b = rewriter.create(b.getLoc(), newBType, b); + } + // convert dot instruction + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); + + rewriter.replaceOpWithNewOp(op, oldRetType, + newDot.getResult()); + return success(); + } +}; +} // namespace + +static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, + Type promotedType) { + Type tensorPromotedType = cast(operand.getType()) + .cloneWith(std::nullopt, promotedType); + return builder.create(loc, tensorPromotedType, operand); +} + +// promote operands of dot op if the existing combination is not natively +// supported. +static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { + mod.walk([=](DotOp dotOp) -> void { + auto D = dotOp.getD(); + OpBuilder builder(dotOp); + Type AElType = dotOp.getA().getType().getElementType(); + Type promoteType; + NvidiaMmaEncodingAttr mmaLayout = + dyn_cast(D.getType().getEncoding()); + if (mmaLayout) { + bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); + // promote operands for sm < 89 since fp8 mma is not natively supported + // promote operands for sm >= 90 when mma is not v3 + if (!isNativeFP8 || + (isNativeFP8 && (computeCapability == 89 || mmaLayout.isHopper()))) + return; + promoteType = builder.getF16Type(); + } else { + // FMA case. + Type AElType = dotOp.getA().getType().getElementType(); + Type DElType = D.getType().getElementType(); + if (AElType == DElType) + return; + promoteType = DElType; + } + Location loc = dotOp.getLoc(); + Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType); + Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType); + dotOp.setOperand(0, promotedA); + dotOp.setOperand(1, promotedB); + }); +} + +#define GEN_PASS_DEF_TRITONGPUACCELERATEMATMUL +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUAccelerateMatmulPass + : public impl::TritonGPUAccelerateMatmulBase< + TritonGPUAccelerateMatmulPass> { +public: + using impl::TritonGPUAccelerateMatmulBase< + TritonGPUAccelerateMatmulPass>::TritonGPUAccelerateMatmulBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + auto computeCapability = getNVIDIAComputeCapability(m); + + mlir::RewritePatternSet patterns(context); + patterns.add(context, computeCapability); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + // Now that we have picked the mma type, decompose dot that are not natively + // supported. + decomposeMixedModeDotOp(m, computeCapability); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..11f6ed22b --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,31 @@ +add_triton_library(TritonGPUTransforms + AccelerateMatmul.cpp + Coalesce.cpp + F32DotTC.cpp + CombineTensorSelectAndIf.cpp + ReduceDataDuplication.cpp + OptimizeDotOperands.cpp + OptimizeThreadLocality.cpp + Pipeliner/MatmulLoopPipeline.cpp + Pipeliner/OuterLoopPipeline.cpp + Pipeliner/PipelineExpander.cpp + Pipeliner/SoftwarePipeliner.cpp + #Pipeliner/TMAStoresPipeline.cpp + Pipeliner/PipeliningUtility.cpp + Prefetch.cpp + RemoveLayoutConversions.cpp + ReorderInstructions.cpp + Utility.cpp + + DEPENDS + TritonGPUTransformsIncGen + + LINK_LIBS PUBLIC + MLIRTransforms + MLIRTransformUtils + TritonAnalysis + TritonIR + TritonGPUIR + #TritonNvidiaGPUIR + MLIRTransformUtils +) diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp new file mode 100644 index 000000000..06a7d963d --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -0,0 +1,198 @@ +#include +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-coalesce" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct CoalescePass : public impl::TritonGPUCoalesceBase { + void + setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, + int numWarps, int threadsPerWarp, + llvm::MapVector &layoutMap) { + Value ptr = getMemAccessPtr(op); + auto refTensorType = cast(ptr.getType()); + + LDBG("Considering op: " << *op); + LLVM_DEBUG({ + DBGS() << "axis info of pointer: "; + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); + SmallVector order = argSort(contiguity); + LDBG("order=[" << triton::join(order, ", ") << "]"); + + auto matchesShape = [&refTensorType](const Value &val) { + auto rttType = dyn_cast(val.getType()); + return rttType && rttType.getShape() == refTensorType.getShape(); + }; + + // The desired divisibility is the maximum divisibility among all dependent + // pointers which have the same shape and order as `ptr`. + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); + if (ptr.getDefiningOp()) { + for (Operation *use : mlir::multiRootGetSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use)) + continue; + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); + memAccessesSameOrder.insert(use); + } + } + } + + auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType); + LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); + + int numElems = product(shapePerCTA); + int numThreads = numWarps * threadsPerWarp; + + unsigned perThread = getNumElementsPerThread(op, order, axisInfoAnalysis); + LDBG("perThread for op: " << perThread); + + for (Operation *opSameOrder : memAccessesSameOrder) { + if (opSameOrder == op) + continue; + unsigned currPerThread = + getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis); + LDBG("perThread for opSameOrder: " << currPerThread); + perThread = std::max(perThread, currPerThread); + } + + perThread = std::min(perThread, std::max(numElems / numThreads, 1)); + LDBG("perThread: " << perThread); + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + unsigned elemNumBits = getElementBitWidth(refTensorType); + perThread = std::min( + perThread, getNumElementsPerThread(op, order, axisInfoAnalysis)); + } + SmallVector sizePerThread(refTensorType.getRank(), 1); + sizePerThread[order[0]] = perThread; + + auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); + layoutMap[op] = triton::gpu::BlockedEncodingAttr::get( + &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, + threadsPerWarp, CTALayout); + } + + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + } + + void coalesceOp(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + } + + void runOnOperation() override { + // Run axis info analysis + ModuleOp moduleOp = getOperation(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // For each i/o operation, we determine what layout + // the pointers should have for best memory coalescing + llvm::MapVector layoutMap; + moduleOp.walk([&](Operation *curr) { + Value ptr = getMemAccessPtr(curr); + if (!ptr) + return; + // We only convert `tensor>` load/store + bool isPtrTensor = false; + if (auto tensorType = dyn_cast(ptr.getType())) + isPtrTensor = isa(tensorType.getElementType()); + if (!isPtrTensor) + return; + auto mod = curr->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, + layoutMap); + }); + + // For each memory op that has a layout L1: + // 1. Create a coalesced memory layout L2 of the pointer operands + // 2. Convert all operands from layout L1 to layout L2 + // 3. Create a new memory op that consumes these operands and + // produces a tensor with layout L2 + // 4. Convert the output of this new memory op back to L1 + // 5. Replace all the uses of the original memory op by the new one + for (auto &kv : layoutMap) { + coalesceOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp new file mode 100644 index 000000000..16183b1af --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp @@ -0,0 +1,124 @@ +#include "mlir/IR/Dominance.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Return true if the select could be merged into the If without breaking SSA +// rules. +static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp, + DominanceInfo &dom) { + // If needs to be dominated by the select. + if (!dom.dominates(selectOp.getOperation(), ifOp.getOperation())) { + return false; + } + // If needs to dominate all the select's users. + for (auto user : selectOp.getResult().getUsers()) { + if (!dom.dominates(ifOp, user)) { + return false; + } + } + return true; +} + +class CombineTensorSelectAndIfPass + : public impl::TritonGPUCombineTensorSelectAndIfBase< + CombineTensorSelectAndIfPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + DominanceInfo dom(m); + + // Go over the arith.select ops, look if there is an if + // with the same condition. + llvm::MapVector> selectToIf; + m.walk([&](arith::SelectOp selectOp) { + // Look if there is an if in the same block, with the same condition. + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + SetVector conditionUsers(condition.getUsers().begin(), + condition.getUsers().end()); + // sort the users in topological order. + conditionUsers = multiRootTopologicalSort(conditionUsers); + // Get condition's users + for (Operation *user : conditionUsers) { + auto ifOp = dyn_cast(user); + if (!ifOp || ifOp->getBlock() != parentBlock) + continue; + if (canMergeIntoIf(selectOp, ifOp, dom)) { + selectToIf[ifOp].push_back(selectOp); + break; + } + } + }); + + for (auto [ifOp, selectOps] : selectToIf) { + // Add new return value to the if (and create else block if necessary), + // then yield the select value in the then block and the else block. + OpBuilder builder(ifOp); + auto loc = ifOp.getLoc(); + // Create an scf::IfOp with extra return value. + SmallVector newResultTypes = {ifOp.getResultTypes().begin(), + ifOp.getResultTypes().end()}; + for (arith::SelectOp selectOp : selectOps) { + newResultTypes.push_back(selectOp.getResult().getType()); + } + auto newIfOp = builder.create( + loc, newResultTypes, ifOp.getCondition(), /*hasElse*/ true); + // Move the existing blocks to the new if. + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + + if (ifOp.elseBlock()) { + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + } else { + // Create an empty yield + auto yieldOp = newIfOp.getElseBodyBuilder().create(loc); + } + + SmallVector ifYieldOperands = newIfOp.thenYield().getOperands(); + SmallVector elseYieldOperands = newIfOp.elseYield().getOperands(); + for (arith::SelectOp selectOp : selectOps) { + Value thenValue = selectOp.getTrueValue(); + Value elseValue = selectOp.getFalseValue(); + ifYieldOperands.push_back(thenValue); + elseYieldOperands.push_back(elseValue); + } + // Update yields + auto updateYield = [&](scf::YieldOp yield, SmallVector &operands) { + builder.setInsertionPoint(yield); + builder.create(loc, operands); + yield.erase(); + }; + updateYield(newIfOp.thenYield(), ifYieldOperands); + updateYield(newIfOp.elseYield(), elseYieldOperands); + + int resultIdx = 0; + // Replace old if with the new one. + for (auto result : ifOp.getResults()) { + result.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + } + // Replace the select with the new return value. + for (arith::SelectOp selectOp : selectOps) { + selectOp.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + selectOp.erase(); + } + + ifOp.erase(); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp new file mode 100644 index 000000000..f701634d4 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -0,0 +1,90 @@ +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUF32DOTTC +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers +// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385 +// For a, b f32 +// dot(a, b, inputPrecision="tf32x3") -> +// let aBig = f32ToTF32(a), aSmall = a - aBig; +// let bBig = f32ToTF32(b), bSmall = b - bBig; +// dot(aSmall, bBig, inputPrecision="tf32") + +// dot(aBig, bSmall, inputPrecision="tf32") + +// dot(aBig, bBig, inputPrecision="tf32") +class TF32x3 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + + auto isF32 = [](Value operand) { + return cast(operand.getType()).getElementType().isF32(); + }; + + if (!(dotOp.getInputPrecision() == InputPrecision::TF32x3 && + isF32(dotOp.getA()) && isF32(dotOp.getB()))) { + return failure(); + } + + // Aux functions + auto f32ToTF32 = [&](Value value) -> Value { + return rewriter + .create(dotOp.getLoc(), value.getType(), + "cvt.rna.tf32.f32 $0, $1;", "=r,r", + /*isPure=*/true, /*pack=*/1, + ArrayRef{value}) + .getResult()[0]; + }; + auto sub = [&](Value a, Value b) -> Value { + return rewriter.create(dotOp.getLoc(), a, b); + }; + auto dot = [&](Value a, Value b, Value c) -> Value { + return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, + InputPrecision::TF32, + dotOp.getMaxNumImpreciseAcc()); + }; + + auto aBig = f32ToTF32(dotOp.getA()); + auto aSmall = sub(dotOp.getA(), aBig); + + auto bBig = f32ToTF32(dotOp.getB()); + auto bSmall = sub(dotOp.getB(), bBig); + + auto dot1 = dot(aSmall, bBig, dotOp.getC()); + auto dot2 = dot(aBig, bSmall, dot1); + auto dot3 = dot(aBig, bBig, dot2); + + rewriter.replaceOp(dotOp, dot3); + return success(); + } +}; + +} // anonymous namespace + +struct F32DotTCPass : public impl::TritonGPUF32DotTCBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); + if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) + .failed()) { + signalPassFailure(); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp new file mode 100644 index 000000000..0c72a528b --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -0,0 +1,344 @@ +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +// Given +// convert(trans(src)) #dot_operand -> +// convert(local_load(trans(alloc(src)))) +// change the encoding of the inner convert to a special, swizzled shared +// encoding. +class SwizzleShmemConvert : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp, + PatternRewriter &rewriter) const override { + // Match outerCvt(trans(innerCvt(x))). + auto trans = cvtOp.getSrc().getDefiningOp(); + if (!trans || trans.getOrder() != ArrayRef{1, 0}) + return failure(); + + auto srcTy = dyn_cast(trans.getSrc().getType()); + + if (auto srcCvt = trans.getSrc().getDefiningOp()) { + srcTy = srcCvt.getSrc().getType(); + } + auto sharedLoadTy = cast(cvtOp.getType()); + auto cvtEncoding = + dyn_cast(sharedLoadTy.getEncoding()); + if (!cvtEncoding) + return failure(); + + // TODO(Qingyi): need to check whether the CTALayout of innerCvtEnc should + // be used here. For tests where numCTAs = 1, this is not a problem since + // all CTALayouts are the same. + // + // Set needTrans to true here. newInnerCvtEnc is computed based on + // argEncoding which is before the transpose. Without needTrans we will + // compute vec and maxPhase based on incorrect m, n and k size of mma. The + // type inference of TransOp simply swap the order but doesn't fix the vec + // and maxPhase for the YType, hence it would causing incorrect swizzling + // code. + auto newInnerCvtEnc = + SharedEncodingAttr::get(getContext(), cvtEncoding, srcTy.getShape(), + /*order=*/getOrder(srcTy.getEncoding()), + triton::gpu::getCTALayout(srcTy.getEncoding()), + srcTy.getElementType(), /*needTrans=*/true); + if (newInnerCvtEnc == cvtEncoding) + return failure(); + + rewriter.setInsertionPoint(trans); + auto alloc = rewriter.create( + trans.getLoc(), + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), + newInnerCvtEnc), + trans.getSrc()); + auto newTrans = rewriter.create(trans.getLoc(), alloc, + ArrayRef({1, 0})); + rewriter.replaceOpWithNewOp(trans, sharedLoadTy, newTrans); + return success(); + } +}; + +// Move convert-to-dot-operand "up" past elementwise ops: +// +// convert(elementwise(x)) #dot_operand -> +// elementwise(convert(x, #dot_operand)). +// +// The goal is to put the convert right next to the originating load. If we can +// accomplish this, then we can save a shmem round-trip: +// +// Before: +// +// - Load from global into shmem using an async copy. +// - Load from shmem into a #blocked layout. +// - Do elementwise ops over #blocked layout. +// - Convert to #dot_operand (round-trip through shmem). +// - Do dot. +// +// After: +// +// - Load from global into shmem using an async copy (same as before). +// - Load from shmem into a #dot_operand layout. +// - Do elementwise ops over #dot_operand layout. +// - Do dot. +// +// Eliminating the shmem round-trip is such a big win, we're willing to do it +// even if this duplicates work because some of the elementwise ops have uses +// that don't flow into the dot. On the other hand, we only want to do this if +// we can in fact reduce shmem round-trips: For example, simply moving a convert +// up above e.g. an `add` now means we have *two* converts. That's worse, +// unless we can continue moving the converts upwards and eventually merge them. +// So we try to check that this will be beneficial before making any changes. +class HoistLayoutConversion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConvertLayoutOp cvt, + PatternRewriter &rewriter) const override { + // Only consider conversions to dot operand. + auto cvtTy = cast(cvt.getType()); + if (!isa(cvtTy.getEncoding())) + return failure(); + + auto src = cvt.getSrc().getDefiningOp(); + if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1) + return failure(); + + auto srcTy = dyn_cast(src->getResult(0).getType()); + if (!srcTy) + return failure(); + + if (!all_of(src->getOperandTypes(), + [](Type ty) { return isa(ty); })) + return failure(); + + // Only consider custom conversions or arith ops. + // TODO(jlebar): Is this too restrictive? + if (!isa(src) && !isPureUnaryInlineAsm(src) && + src->getDialect()->getTypeID() != TypeID::get()) + return failure(); + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(src)) + return failure(); + + // Check that the conversion is transitively dependent on a load, and all + // operations between the load and the conversion are layout preserving. + // + // TODO(jlebar): This is accidentally quadratic; we iterate over the whole + // slice but then at the end we only modify one op! + SetVector slice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + // TODO(jlebar): Is this filter redundant with omitBlockArguments == true? + // That is, is it possible to get into a different region without going + // through a block argument? + opt.filter = [&](Operation *op) { + return op->getParentRegion() == cvt->getParentRegion(); + }; + getBackwardSlice(cvt.getOperation(), &slice, opt); + + // TODO(jlebar): This is too conservative when there are multiple loads in + // the chain (e.g. cvt(load(x) + load(y))). The intent is to check that all + // of the ops between the loads and the convert are elementwise. But + // actually we set foundLoad = true once we see the first load, and so we + // will reject the chain if the *second* load we encounter uses a + // non-elementwise op to calculate its pointers. + bool foundLoad = false; + for (Operation *currOp : slice) { + if (isa(currOp)) { + foundLoad = true; + } else if (foundLoad) { + // Bail out if there exists an op after Load that is not FpToFp, + // Bitcast, or Arith. + if (!isa(currOp) && + !isPureUnaryInlineAsm(currOp) && + currOp->getDialect()->getTypeID() != + TypeID::get()) + return failure(); + } + } + if (!foundLoad) + return failure(); + + SmallVector newOperands; + for (auto operand : src->getOperands()) { + // We checked earlier that all operands are ranked tensors. + auto operandTy = cast(operand.getType()); + Type newCvtTy = RankedTensorType::get( + srcTy.getShape(), operandTy.getElementType(), cvtTy.getEncoding()); + newOperands.push_back( + rewriter.create(cvt.getLoc(), newCvtTy, operand)); + } + auto newRet = rewriter.clone(*src); + for (int i = 0; i < newOperands.size(); i++) + newRet->setOperand(i, newOperands[i]); + newRet->getResult(0).setType(RankedTensorType::get( + srcTy.getShape(), srcTy.getElementType(), cvtTy.getEncoding())); + + rewriter.replaceOp(cvt, newRet->getResults()); + return success(); + } +}; + +#ifndef __ILUVATAR__ +// Rewrite +// +// dot(alloc(trans() #shared1) -> +// dot(trans(alloc() #shared2)) +// +// if dot is an MMAv3 (because MMAv3 allows us to fold transposes). +class FuseTransHopper : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LocalAllocOp allocOp, + PatternRewriter &rewriter) const override { + if (!allocOp->hasOneUse() || + !isa(*allocOp->getUsers().begin())) + return failure(); + + auto dot = *allocOp->getUsers().begin(); + auto dotEnc = dyn_cast( + cast(dot->getResult(0).getType()).getEncoding()); + if (!dotEnc || dotEnc.getVersionMajor() != 3) + return failure(); + + if (!allocOp.getSrc()) + return failure(); + + // Match outerCvt(trans(innerCvt(x))). + auto trans = allocOp.getSrc().getDefiningOp(); + if (!trans || trans.getOrder() != ArrayRef({1, 0})) + return failure(); + + MemDescType allocType = allocOp.getType(); + auto allocEncoding = cast(allocType.getEncoding()); + TensorOrMemDesc srcTy = trans.getSrc().getType(); + + // MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3 + // without transpose for other data types.) + auto newInnerCvtOrder = getOrder(srcTy.getEncoding()); + if (auto cvt = trans.getSrc().getDefiningOp()) { + newInnerCvtOrder = getOrder(cvt.getSrc().getType().getEncoding()); + } + auto srcElemTy = allocType.getElementType(); + if (!srcElemTy.isF16() && !srcElemTy.isBF16()) { + if (allocOp.getResult() == dot->getOperand(0)) { + newInnerCvtOrder = {0, 1}; + } else if (allocOp.getResult() == dot->getOperand(1)) { + newInnerCvtOrder = {1, 0}; + } + } + + // TODO(Qingyi): need to check whether the CTALayout of innerCvtEnc should + // be used here. For tests where numCTAs = 1, this is not a problem since + // all CTALayouts are the same. + auto newInnerEnc = SharedEncodingAttr::get( + getContext(), srcTy.getShape(), newInnerCvtOrder, + allocEncoding.getCTALayout(), srcTy.getElementType()); + + MemDescType innerTy = + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc); + auto newAlloc = rewriter.create(allocOp.getLoc(), innerTy, + trans.getSrc()); + rewriter.replaceOpWithNewOp(allocOp, newAlloc, + ArrayRef({1, 0})); + return success(); + } +}; +#endif + +// Rewrite +// dot(convert(lhs #mma) #shared, rhs) #mma -> +// dot(convert(lhs #mma) #dot_operand, rhs) #mma, +// for fp16 or bf16 MMAv3 dots. +struct MMAV3UseRegOperand : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + auto alloc = dotOp.getOperand(0).getDefiningOp(); + if (!alloc || !alloc.getSrc()) + return failure(); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + if (!isa(getEncoding(dotOp.getOperand(0)))) + return failure(); + auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); + auto dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!srcEnc || srcEnc.getVersionMajor() != 3 || !dstEnc || + dstEnc.getVersionMajor() != 3) + return failure(); + auto srcTy = cast(alloc.getSrc().getType()); + auto dotOperandEnc = DotOperandEncodingAttr::get( + dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); + auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), + dotOperandEnc); + if (!isMmaToDotShortcut(srcTy, newTy)) + return failure(); + + Value newOperand = + rewriter.create(dotOp.getLoc(), newTy, alloc.getSrc()); + rewriter.modifyOpInPlace(dotOp, [&]() { dotOp.setOperand(0, newOperand); }); + return success(); + } +}; + +} // namespace + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUOptimizeDotOperandsPass + : public impl::TritonGPUOptimizeDotOperandsBase< + TritonGPUOptimizeDotOperandsPass> { +public: + using impl::TritonGPUOptimizeDotOperandsBase< + TritonGPUOptimizeDotOperandsPass>::TritonGPUOptimizeDotOperandsBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::PassManager pm(m.getContext()); + pm.addPass(mlir::createCanonicalizerPass()); + auto ret = pm.run(m); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + if (this->hoistLayoutConversion.getValue()) + patterns.add(context); +#ifndef __ILUVATAR__ + patterns.add(context); +#endif + patterns.add(context); + ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp new file mode 100644 index 000000000..394d5b78d --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -0,0 +1,437 @@ +#include +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZETHREADLOCALITY +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +// Change the destination layout of reshape ops allowing reorder when used by a +// reduction in order to minimize the amount of cross thread communication for +// the reduction. +struct OptimizeReshapeLayoutPattern + : public mlir::OpRewritePattern { + OptimizeReshapeLayoutPattern(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp viewOp, + mlir::PatternRewriter &rewriter) const override { + if (!viewOp.getAllowReorder()) + return failure(); + std::optional reductionAxis; + for (Operation *user : viewOp.getResult().getUsers()) { + if (auto reduceOp = dyn_cast(user)) { + if (reductionAxis) { + if (reductionAxis != reduceOp.getAxis()) + return failure(); + } else { + reductionAxis = reduceOp.getAxis(); + } + } + } + if (!reductionAxis) + return failure(); + RankedTensorType tensorType = viewOp.getType(); + if (auto blocked = mlir::dyn_cast( + tensorType.getEncoding())) { + // If the layout already has all the elements along the reduction + // dimension in the same thread we can skip. + if (blocked.getThreadsPerWarp()[*reductionAxis] == 1 && + blocked.getWarpsPerCTA()[*reductionAxis] == 1 && + blocked.getCTAsPerCGA()[*reductionAxis] == 1) + return failure(); + } + ArrayRef shape = tensorType.getShape(); + llvm::SmallVector order; + for (int i : triton::gpu::getOrder(tensorType.getEncoding())) { + if (i != *reductionAxis) + order.push_back(i); + } + // Make the reduction axis last so that elements won't be distributed + // amongst threads along this dimension. + order.push_back(*reductionAxis); + llvm::SmallVector sizePerThread(shape.size(), 1); + auto mod = viewOp->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + triton::gpu::BlockedEncodingAttr encoding = + triton::gpu::BlockedEncodingAttr::get(viewOp.getContext(), shape, + sizePerThread, order, numWarps, + threadsPerWarp, numCTAs); + if (encoding == tensorType.getEncoding()) + return failure(); + RankedTensorType newType = + RankedTensorType::get(shape, tensorType.getElementType(), encoding); + if (triton::gpu::isExpensiveView(viewOp.getSrc().getType(), newType)) + return failure(); + rewriter.setInsertionPointAfter(viewOp); + rewriter.modifyOpInPlace(viewOp, [&]() { + viewOp.getResult().setType(newType); + viewOp.setEfficientLayout(true); + }); + auto cvt = rewriter.create( + viewOp.getLoc(), tensorType, viewOp.getResult()); + rewriter.replaceAllUsesExcept(viewOp.getResult(), cvt.getResult(), cvt); + return mlir::success(); + } +}; + +} // namespace + +class TritonGPUOptimizeThreadLocalityPass + : public impl::TritonGPUOptimizeThreadLocalityBase< + TritonGPUOptimizeThreadLocalityPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // First try to optimize the layout of existing views. + mlir::RewritePatternSet viewLayoutPatterns(&getContext()); + viewLayoutPatterns.add(&getContext()); + if (mlir::applyPatternsAndFoldGreedily(mod, std::move(viewLayoutPatterns)) + .failed()) { + signalPassFailure(); + } + + DenseSet reduceOps; + mod.walk([&](triton::ReduceOp reduce) -> void { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto reductionOp = getReductionOp(reduce); + if (!reductionOp || + !isa( + reductionOp.value())) + return; + // TODO: relax this restriction + if (!(isa(srcEncoding) && rank > 1)) + return; + for (auto operand : reduce->getOperands()) { + auto def = operand.getDefiningOp(); + if (!isa(def)) + return; + } + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + // Not worth applying this optimization if there is only one element per + // thread on the reduction axis + if (elemsPerThread == 1) + return; + if (!reduce->hasOneUse()) + return; + Operation *user = *(reduce->getUsers().begin()); + if (!user->hasOneUse()) + return; + OpOperand &yieldOpOperand = *(user->getUses().begin()); + auto yieldOp = dyn_cast(yieldOpOperand.getOwner()); + if (!yieldOp) + return; + auto operandNumber = yieldOpOperand.getOperandNumber(); + Block *block = reduce->getBlock(); + Operation *parentOp = block->getParentOp(); + auto forOp = dyn_cast(parentOp); + if (!forOp) + return; + auto argNum = yieldOpOperand.getOperandNumber(); + auto oldAccum = forOp.getInitArgs()[argNum]; + auto cstOp = dyn_cast(oldAccum.getDefiningOp()); + if (!cstOp) + return; + reduceOps.insert(reduce); + }); + + IRRewriter builder(&getContext()); + for (auto reduce : reduceOps) { + builder.setInsertionPoint(reduce); + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + assert(isa(srcEncoding) && + "Thread locality optimization only supports blocked encoding"); + auto blocked = dyn_cast(srcEncoding); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto rank = srcShape.size(); + // create new layouts + auto blocked3d = getThreadLocalityOptimizedEncoding(reduce); + auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce); + auto viewOpTensorType = RankedTensorType::get( + viewOpTensorShape, srcType.getElementType(), blocked3d); + auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank, + blocked3d, false); + // Get forOp + assert(reduce->hasOneUse()); + OpOperand &use = *(reduce->getUses().begin()); + auto operandNumber = use.getOperandNumber(); + auto oldUpdate = use.getOwner(); + assert(oldUpdate->getNumOperands() == 2); + auto accumOperandNumber = (operandNumber == 0) ? 1 : 0; + auto accumOperand = oldUpdate->getOperand(accumOperandNumber); + assert(isa(accumOperand)); + auto blockArg = dyn_cast(accumOperand); + auto blockArgNum = blockArg.getArgNumber(); + auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); + // get oldAccum + auto oldAccum = + forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()]; + // get old loop user + Value loopResult = + forOp.getResult(blockArgNum - forOp.getNumInductionVars()); + assert(loopResult.hasOneUse()); + OpOperand &loopUse = *(loopResult.getUses().begin()); + Operation *loopUser = loopUse.getOwner(); + // get old loop yield + auto oldYield = cast(forOp.getBody()->getTerminator()); + // create newAccum initialization + auto newAccum = + createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d); + // create new loop by copying the old for op signature and appending + // newAccum to the block arguments + auto newLoop = replaceForOpWithNewSignature( + builder, forOp, ValueRange{newAccum->getResult(0)}); + // create thread local reduction (also adds viewOps) + auto newReduce = createReduce(builder, reduce, viewOpTensorType); + + // create new accum update + auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate); + // create new yield + auto newYield = createYield(builder, newLoop, oldYield, + newUpdate->getResult(0), blockArgNum); + // create post loop reduction on the original reduce axis + auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce); + // add convert_layout to get back to original layout, the result layout + // should now match the layout of the old accumulator (%cst) + Type destType = loopResult.getType(); + auto cvtLayout = createConvertLayout(builder, destType, newReduce2); + // incorporate the original accumulator value into the final result + auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate, + cvtLayout, oldAccum); + // Replace the old loop user with the final result + loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0)); + + // cleanup + oldYield.erase(); + forOp.erase(); + } + }; + +private: + std::optional getReductionOp(triton::ReduceOp reduce) const { + auto numRegions = reduce->getNumRegions(); + if (numRegions != 1) + return std::nullopt; + Region ®ion = reduce->getRegion(0); + auto numBlocks = region.getBlocks().size(); + if (numBlocks != 1) + return std::nullopt; + Block &block = region.front(); + auto blockWithoutTerminator = block.without_terminator(); + auto blockSizeWithoutTerminator = std::distance( + blockWithoutTerminator.begin(), blockWithoutTerminator.end()); + if (blockSizeWithoutTerminator != 1) + return std::nullopt; + Operation *op = &block.front(); + return std::optional(op); + } + Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder, + Operation *oldUpdate, + Operation *cvtLayout, + Value oldAccum) const { + builder.setInsertionPointAfter(cvtLayout); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), oldAccum); + mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0)); + auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping); + return finalOp; + } + Operation *createConvertLayout(OpBuilder &builder, Type destType, + Operation *newReduce) const { + builder.setInsertionPointAfter(newReduce); + auto newCvt = builder.create( + newReduce->getLoc(), destType, newReduce->getResult(0)); + return newCvt; + } + + Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop, + triton::ReduceOp &reduce) const { + auto resultIndex = + loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars(); + auto newLoopResult = loop.getResult(resultIndex); + builder.setInsertionPointAfter(loop); + IRMapping mapping; + mapping.map(*(reduce.getOperands().begin()), newLoopResult); + auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping); + return newReduce2; + } + + Operation *createYield(OpBuilder &builder, scf::ForOp &loop, + scf::YieldOp &oldYield, Value newUpdate, + int oldAccumBlockArgNum) const { + builder.setInsertionPoint(oldYield); + SmallVector yieldValues = llvm::to_vector(oldYield.getOperands()); + yieldValues[oldAccumBlockArgNum - 1] = + loop.getBody()->getArgument(oldAccumBlockArgNum); + yieldValues.push_back(newUpdate); + auto newYield = + builder.create(oldYield.getLoc(), yieldValues); + return newYield; + } + + Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop, + Operation *newReduce, Operation *oldUpdate) const { + auto blockArgNum = loop.getBody()->getNumArguments() - 1; + auto newArg = loop.getBody()->getArgument(blockArgNum); + builder.setInsertionPointAfter(newReduce); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), newArg); + mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0)); + auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping); + return newUpdate; + } + + Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce, + Type viewOpTensorType) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + builder.setInsertionPointAfter(reduce); + IRMapping mapping; + for (auto operand : reduce.getOperands()) { + auto viewOp = builder.create( + reduce.getLoc(), viewOpTensorType, operand, /*allowReorder=*/true); + viewOp.setEfficientLayout(true); + mapping.map(operand, viewOp); + } + + auto newReduce = cloneWithInferType(builder, &(*reduce), mapping); + newReduce->setAttr("axis", builder.getI32IntegerAttr(rank)); + auto typeInfer = dyn_cast(newReduce); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newReduce->getContext(), newReduce->getLoc(), + newReduce->getOperands(), newReduce->getAttrDictionary(), + newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newReduce->getResult(i).setType(newTypes[i]); + } + } + return newReduce; + } + + // Work around the lack of support for MaxNumFOp and MinNumFOp in + // arith::getNeutralElement. + std::optional getNeutralElement(Operation *op) const { + if (isa(op)) { + OpBuilder builder(op->getContext()); + + Type resultType = op->getResult(0).getType(); + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/true)); + } + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/false)); + } + } else { + return mlir::arith::getNeutralElement(op); + } + llvm_unreachable("Unhandled reduction op"); + return std::nullopt; + } + + Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce, + Value &oldAccum, SmallVector &shape, + Attribute &slice2d) const { + // Drop the last dimension (thread locality dimension) + SmallVector accumShape(shape.begin(), shape.end() - 1); + auto elemType = cast(oldAccum.getType()).getElementType(); + // Create tensor type for the new accumulator + auto accumType = RankedTensorType::get(accumShape, elemType, slice2d); + // Create new accumulator + builder.setInsertionPointAfter(oldAccum.getDefiningOp()); + auto reductionOp = getReductionOp(reduce); + assert(reductionOp && "Processing a reduce that is not supported!"); + auto neutralVal = getNeutralElement(reductionOp.value()); + assert(neutralVal && "Could not find neutral value for reduction op!"); + auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value()); + auto newAccum = builder.create(oldAccum.getLoc(), + accumType, denseAttr); + return newAccum; + } + + SmallVector + getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto rank = srcShape.size(); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto viewOpTensorShape = insertValue(srcShape, rank, 1); + viewOpTensorShape[reduce.getAxis()] /= elemsPerThread; + viewOpTensorShape[rank] = elemsPerThread; + return viewOpTensorShape; + } + + Attribute getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto blocked = dyn_cast(srcEncoding); + auto sizePerThread3d = + insertValue(blocked.getSizePerThread(), rank, + blocked.getSizePerThread()[reduce.getAxis()]); + sizePerThread3d[reduce.getAxis()] = 1; + auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1); + auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1); + auto order3d = insertValue(blocked.getOrder(), 0, rank); + auto ctasPerCGA3d = + insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1); + auto ctasSplitNum3d = + insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1); + auto ctaOrder3d = + insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank); + auto ctaLayout3d = triton::gpu::CTALayoutAttr::get( + reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d); + SmallVector smeCTA(rank); + auto blocked3d = triton::gpu::BlockedEncodingAttr::get( + reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d, + order3d, ctaLayout3d, false, smeCTA); + return blocked3d; + } + + template + SmallVector insertValue(ArrayRef vec, unsigned index, int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } + template + SmallVector insertValue(const SmallVector &vec, unsigned index, + int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp new file mode 100644 index 000000000..7aafac67b --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -0,0 +1,1785 @@ +#include "PipelineExpander.h" +#include "PipeliningUtility.h" +#include "Schedule.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#include + +#define DEBUG_TYPE "triton-matmul-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define int_attr(num) builder.getI64IntegerAttr(num) + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +// namespace ttng = mlir::triton::nvidia_gpu; + +// TODO: We can extra some helpers into common utilities once we add more +// schedules. + +namespace { + +struct LoadInfo { + // Layout of the data in the shared memory. + ttg::SharedEncodingAttr sharedEncoding = nullptr; + // Blocked encoding is used for loads not used by the dot. + ttg::BlockedEncodingAttr blockedEncoding = nullptr; + bool loadIsMMAV3 = false; + int distToUse = 0; + bool usedByDot = false; +}; + +} // namespace + +class CoarseSchedule { +public: + class ClusterList { + std::list orderClusters; + + public: + using iterator = decltype(orderClusters)::iterator; + ClusterList() = default; + iterator begin() { return orderClusters.begin(); } + iterator end() { return orderClusters.end(); } + size_t size() { return orderClusters.size(); } + iterator newAtBack() { + orderClusters.push_back(orderClusters.size()); + return std::prev(orderClusters.end()); + } + iterator newAtFront() { + orderClusters.push_front(-1); + for (auto &clusterId : orderClusters) { + clusterId++; + } + return orderClusters.begin(); + } + iterator newBefore(iterator cluster) { + auto ret = orderClusters.insert(cluster, *cluster); + for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) { + clusterId++; + } + return ret; + } + }; + + CoarseSchedule(int numStages) : numStages(numStages) {} + int numStages; + ClusterList clusters; + using Cluster = decltype(clusters)::iterator; + + DenseMap> opToStageAndCluster; + + void insert(Operation *op, int stage, Cluster cluster) { + opToStageAndCluster[op] = {stage, cluster}; + } + + bool insertIfAbsent(Operation *op, int stage, Cluster cluster) { + if (opToStageAndCluster.count(op)) + return false; + insert(op, stage, cluster); + return true; + } + + void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, + bool includeArg) { + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + if (insertIfAbsent(defOp, stage, cluster)) { + insertDepsOfOp(defOp, stage, cluster, includeArg); + } + } + } + } + + void erase(Operation *op) { opToStageAndCluster.erase(op); } + + int count(Operation *op) { return opToStageAndCluster.count(op); } + + std::pair operator[](Operation *op) { + return opToStageAndCluster[op]; + } + + SmallVector> + getOpsInOrder(scf::ForOp forOp) { + SmallVector>, 8> + orderClusters(clusters.size()); + for (auto &op : forOp.getBody()->without_terminator()) { + if (opToStageAndCluster.count(&op) == 0) { + continue; + } + assert(opToStageAndCluster[&op].first < numStages && + "Op with invalid stage!"); + int clusterId = *opToStageAndCluster[&op].second; + assert(clusterId == std::distance(clusters.begin(), + opToStageAndCluster[&op].second) && + "Cluster ID mismatch!"); + orderClusters[clusterId].push_back( + make_tuple(&op, opToStageAndCluster[&op].first, + opToStageAndCluster[&op].second)); + } + SmallVector> opsInOrder; + for (int i = 0; i < orderClusters.size(); i++) { + for (auto [op, stage, cluster] : orderClusters[i]) { + opsInOrder.push_back({op, stage, cluster}); + } + } + + return opsInOrder; + } + + std::vector> + createFinalSchedule(scf::ForOp forOp) { + SmallVector> opsInOrder = + getOpsInOrder(forOp); + std::vector> schedule; + for (auto [op, stage, cluster] : opsInOrder) { + LDBG("Adding op to schedule at stage " << stage << " cluster " << *cluster + << ":" << *op); + schedule.push_back({op, stage}); + } + return schedule; + } + + void dump() { + for (int i = 0; i < numStages; i++) { + LDBG("- Ops in stage " << i); + for (auto &[op, stageAndCluster] : opToStageAndCluster) { + if (i == stageAndCluster.first) { + llvm::outs() << " cluster: " << *stageAndCluster.second << " "; + op->dump(); + } + } + } + } +}; + +static bool isMMAv3Dot(Operation *op) { + auto dot = dyn_cast(op); + if (!dot) + return false; + auto enc = + mlir::dyn_cast(dot.getType().getEncoding()); + return enc && enc.isHopper(); +} + +// Replace the ForOp's yield with a new one with the given operands appended. +static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { + // Fix up the yield op. + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); + + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, + CoarseSchedule &schedule, + CoarseSchedule::Cluster prefetchCluster, + llvm::MapVector &loadToInfo, + int numStages) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + if (!isExpensiveLoadOrStore(loadOp) && loadToInfo[loadOp].blockedEncoding) { + // For inexpensive loads that do not directly feed into dot ops + // we want to use optimal layout for the data. + ttg::BlockedEncodingAttr encoding = loadToInfo[loadOp].blockedEncoding; + auto convertBlockLayout = [&](Value src) { + auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), encoding); + auto cvt = + builder.create(loadOp->getLoc(), newTy, src); + return cvt.getResult(); + }; + src = convertBlockLayout(src); + if (mask) + mask = convertBlockLayout(mask); + if (other) + other = convertBlockLayout(other); + } + + tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + tt::MemDescType subviewTy = tt::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), /*mutableMemory=*/true); + auto view = + builder.create(loc, subviewTy, alloc, copyOffsets); + Operation *copy = builder.create( + loc, src, view, mask, other, loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile(), loadOp.getInputStride(), loadOp.getInputStride(), + loadOp.getInputStride()); + Operation *commmit = + builder.create(loc, copy->getResult(0)); + Operation *wait = + builder.create(loc, commmit->getResult(0), 0); + + bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + schedule.insert(commmit, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + if (isMMV3Load) { + auto alloc = cast((*loadOp->getUsers().begin())); + alloc.replaceAllUsesWith(viewLoad.getResult()); + alloc.erase(); + } else { + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + alloc.replaceAllUsesWith(viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) { + alloc.erase(); + } + + auto sharedLoad = builder.create( + loc, loadOp.getType(), viewLoad, wait->getResult(0)); + auto result = sharedLoad->getResults(); + + // Create a select for non-zero other values as they are not handled by + // AsyncCopyGlobalToLocalOp for now. + Value other = loadOp.getOther(); + if (other && !isZeroConst(other)) { + auto select = builder.create( + loc, loadOp.getType(), mask, sharedLoad.getResult(), other); + result = select->getResults(); + } + + loadOp->replaceAllUsesWith(result); + + // Prefetch load if is not MMAV3 and is used by the dot. + if (loadToInfo[loadOp].usedByDot) { + schedule.insert(wait, numStages - 2, prefetchCluster); + schedule.insert(viewLoad, numStages - 2, prefetchCluster); + } + } + loadOp.erase(); +} + +#ifndef __ILUVATAR__ +static void createTMAAsyncCopy( + scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, + Value phase, CoarseSchedule &schedule, + llvm::MapVector &loadToInfo, int numStages) { + assert(phase && "Phase value is required for TMA async copy."); + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + tt::MemDescType subviewTy = tt::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), /*mutableMemory=*/true); + auto view = + builder.create(loc, subviewTy, alloc, copyOffsets); + + Value pred = builder.create(loc, 1, 1); + Operation *copy = builder.create( + loc, loadOp.getDescPtr(), loadOp.getIndices(), barrier, view, pred); + + bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + builder.setInsertionPointAfter(waitOp); + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + if (isMMV3Load) { + auto alloc = cast((*loadOp->getUsers().begin())); + alloc.replaceAllUsesWith(viewLoad.getResult()); + alloc.erase(); + } else { + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + alloc.replaceAllUsesWith(viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) { + alloc.erase(); + } + + auto sharedLoad = builder.create( + loc, loadOp.getType(), viewLoad /*,wait->getResult(0)*/); + auto result = sharedLoad->getResults(); + loadOp->replaceAllUsesWith(result); + } + loadOp.erase(); +} +#endif + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreDotEnc(Value val) { + ttg::SharedEncodingAttr attr; + for (Operation *user : val.getUsers()) { + ttg::SharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()).getEncoding()); + if (!dotOpEnc) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), + ttg::getOrder(srcTy.getEncoding()), + ttg::getCTALayout(srcTy.getEncoding()), + srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); + } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; + } + return attr; +} + +static ttg::BlockedEncodingAttr +getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { + Value src = loadOp.getPtr(); + auto ty = cast(src.getType()); + auto mod = loadOp->getParentOfType(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + tt::AxisInfo::DimVectorT contiguity = + axisInfo.getAxisInfo(src)->getContiguity(); + SmallVector order = argSort(contiguity); + unsigned currPerThread = getNumElementsPerThread(loadOp, order, axisInfo); + SmallVector sizePerThread(order.size(), 1); + sizePerThread[order[0]] = currPerThread; + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(ty.getEncoding()); + return ttg::BlockedEncodingAttr::get(loadOp->getContext(), ty.getShape(), + sizePerThread, order, numWarps, + threadsPerWarp, ctaLayout); +} + +static std::optional +getSharedEncoding(Operation *loadOp, bool isMMAV3) { + auto ty = cast(loadOp->getResultTypes()[0]); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + auto blockedOrder = ttg::getOrder(ty.getEncoding()); + SmallVector order; + if (blockedOrder.size() == 3) { + for (unsigned i = 0; i < blockedOrder.size(); ++i) { + if (blockedOrder[i] == 0) + continue; + order.push_back(blockedOrder[i]); + } + order.push_back(0); + } else { + order = blockedOrder; + } + if (isMMAV3) { + return ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), order, + ctaLayout, ty.getElementType()); + } + + // If the load is used by a LocalAllocOp, use the same encoding as the allocs. + // If the allocs don't all have the same encoding, bail. + if (llvm::any_of(loadOp->getUsers(), [&](Operation *user) { + return isa(user); + })) { + ttg::SharedEncodingAttr localAllocEnc; + for (auto user : loadOp->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) + return std::nullopt; + } + return localAllocEnc; + } + + // Use non-swizzled layout for loads that do not feed into dot ops. + // TODO: This won't be optimal for 2D tensors. + return ttg::SharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, + ctaLayout); +} + +// Create a map from load ops to their indirection level and the +// final use of the load op (another load op, or a dot op). +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +static llvm::SmallVector> +loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + llvm::SmallVector> + loadOpToIndLevelAndUse; + DenseSet seen; + + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + if (!seen.insert(op).second) + return; + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.push_back(std::make_tuple(op, distance, use)); + use = op; + distance++; + } + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } + + return loadOpToIndLevelAndUse; +} + +static bool loadIsMMAv3(Operation *loadOp) { + if (!loadOp->hasOneUse()) + return false; + auto alloc = dyn_cast(*loadOp->getUsers().begin()); + if (!alloc) + return false; + auto sharedEnc = cast(alloc.getType().getEncoding()); + if (!sharedEnc.getHasLeadingOffset()) + return false; + + // MMA V3 case. + auto newOrder = sharedEnc.getOrder(); + auto ty = cast(loadOp->getResultTypes()[0]); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + + // The operand of MMAv3 is in SharedEncoding and its order should not + // be changed after FuseTranspositions Pass. So we only pipeline the + // load if the order of the loaded BlockedEncoding is the same as the + // order of the SharedEncoding it is converted to. + return oldOrder == newOrder; +} + +static llvm::MapVector +assignMemoryLayouts(llvm::SmallVector> + &loadOpToIndLevelAndUse, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + llvm::MapVector loadToInfo; + + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(op)) + // TODO pawel: err, we'd need to verify that the distance is the same + continue; + LoadInfo loadInfo; + + if (auto loadOp = dyn_cast(op)) { + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + continue; + auto ty = + cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + LDBG("Load " << *loadOp << " has width " << width); + if (width < 32) + continue; + } + + if (auto nestedForOp = dyn_cast(op)) + return loadToInfo; + + if (auto dot = dyn_cast(use)) { + loadInfo.usedByDot = true; + if (loadIsMMAv3(op)) { + loadInfo.loadIsMMAV3 = true; + loadInfo.sharedEncoding = + getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); + } else if (isa(op)) { + loadInfo.sharedEncoding = + getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); + } else { + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + + // HACK: Triton LLVM codegen has a bug where local_loads from #shared to + // #mma layout can lead to invalid code if the loaded shape is smaller + // than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with + // tile {16,8} is bad because 1 < 8). To work around this, don't + // pipeline such loads. + // + // The codegen bug is caught by an assertion, so if you think you've + // fixed it, feel free to delete this code and see if the assert still + // fails. :) + if (!loadInfo.sharedEncoding) { + if (auto dotEnc = dyn_cast( + dot.getResult().getType().getEncoding())) { + auto loadTy = cast(op->getResultTypes()[0]); + auto mmaInstrShape = dotEnc.getInstrShape(); + if (loadTy.getRank() < mmaInstrShape.size()) + continue; + bool ok = true; + for (int i = 0; i < mmaInstrShape.size(); i++) { + if (loadTy.getShape()[loadTy.getRank() - mmaInstrShape.size() + + i] < mmaInstrShape[i]) { + ok = false; + break; + } + } + // If this load might trigger the bug, don't do the fallback logic + // below, which might allow the load to be pipelined. + if (!ok) + continue; + } + } + } + } else if (auto loadOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadsToPipeline already, it means that the use is not valid for + // pipelining for some reason. We should skip this loadOp, too. Note that + // we have an assumption that distAndUse.second (i.e. the use of this + // loadOp) has already be processed in a previous loop iteration. This + // assumption is held by how loadOpsToIndirectionLevelAndUse recursively + // collects loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(loadOp) == 0) { + continue; + } + } + + // If we still don't have a shared encoding, try a "generic" shared + // encoding. + if (!loadInfo.sharedEncoding && !isMMAv3Dot(use)) { + loadInfo.sharedEncoding = + getSharedEncoding(op, /*isMMAV3=*/loadInfo.loadIsMMAV3) + .value_or(nullptr); + if (auto loadOp = dyn_cast(op)) { + loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis); + } + } + + // If that still didn't work, bail on pipelining this load. + if (!loadInfo.sharedEncoding) { + continue; + } + loadToInfo[op] = loadInfo; + } + + return loadToInfo; +} + +static llvm::MapVector +scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return {}; + + // Check which loads are good for pipelining, and assign them + // memory layouts. + llvm::MapVector loadToInfo = + assignMemoryLayouts(loadOpToIndLevelAndUse, axisInfoAnalysis); + + if (loadToInfo.empty()) + return {}; + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + } + unsigned stagesBetweenLoads = + ceil(numStages - 2, maxIndirectionLevel + 1); + + CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + // Put the root uses of the loads in the last stage. + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + // Non-LoadOp(s) are the root uses of all LoadOp(s) and should be + // always present in the opInfo + if (!isa(use)) { + schedule.insert(use, numStages - 1, rootUsersCluster); + rootUsers.insert(use); + } + } + + SmallVector loadsClusters; + for (int i = 0; i < maxIndirectionLevel + 1; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); + } + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadsClusters[indLevel]); + } + + // Distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + return loadToInfo; +} + +// Schedule the prologue and epilogue `if` ops in the loop, pushing them as +// close to the loop boundaries as possible. Return the cluster after the +// prologue (or the beginning of the loop if there is no prologue). +static CoarseSchedule::Cluster +schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice((Operation *)op, &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } + } + } + } + CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insert(ifOp, stage, prologueCluster); + } + + // Look for the IfOp that is in the forward slice of the root users and put it + // at the end of the loop. + CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto rootUser : rootUsers) { + SetVector forwardSlice; + getForwardSlice(rootUser, &forwardSlice); + + int stage = schedule[rootUser].first; + for (auto op : forwardSlice) { + scf::IfOp ifOp = dyn_cast(op); + if (ifOp == nullptr) { + // check if the op is in the body of an if op that's part of the loop + auto parentOp = op->getParentOp(); + if (parentOp != nullptr && + parentOp->getParentOp() == forOp.getOperation()) { + ifOp = dyn_cast(parentOp); + } + } + if (ifOp) { + schedule.insertIfAbsent(ifOp, stage, + epilogueCluster); // after prefetch extracts + } + } + } + return afterPrologue; +} + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +static void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule, + int numStages) { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); + } + } +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +static void scheduleDistanceOneDependencies(scf::ForOp forOp, + CoarseSchedule &schedule, + int numStages) { + auto getNestedOperands = [](Operation *op) -> SmallVector { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + if (auto arg = dyn_cast(operand)) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op.getBlock()) { + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && schedule.count(defOp) == 0) { + if (isa(defOp)) { + // Exception: Schedule loads with a distance of 1 together + // with the current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], + true); + } + } + } + } + } + } +} + +static void scheduleRemainingToLastStage(scf::ForOp forOp, + CoarseSchedule &schedule, + CoarseSchedule::Cluster afterPrologue, + int numStages) { + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + CoarseSchedule::Cluster userCluster = opToCluster[user]; + CoarseSchedule::Cluster opCluster; + if (schedule.count(op)) + opCluster = schedule[op].second; + else + opCluster = opToCluster[op]; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); + } +} + +// Create an allocation that can hold distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, + ttg::SharedEncodingAttr sharedEnc, unsigned distance) { + OpBuilder builder(forOp); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = mlir::triton::MemDescType::get( + bufferShape, ty.getElementType(), sharedEnc, /*mutableMemory*/ true); + Value alloc = builder.create( + loadOp->getLoc(), memdescType, Value()); + return alloc; +} + +#ifndef __ILUVATAR__ +// Create an allocation to hold the mbarriers. +static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + auto context = forOp.getContext(); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = + tt::MemDescType::get({distance}, builder.getI64Type(), barrierEncoding, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = tt::MemDescType::get( + {1}, builder.getI64Type(), barrierEncoding, /*mutableMemory=*/true); + Value barrierAlloc = builder.create( + loc, barrierMemDescType, Value()); + for (unsigned i = 0; i < distance; i++) { + Value idx = builder.create(loc, i, 32); + Value barrierView = builder.create( + loc, singleBarrierMemDescType, barrierAlloc, idx); + builder.create(forOp->getLoc(), barrierView, 1); + } + return barrierAlloc; +} +#endif + +struct AsyncLoad { + AsyncLoad(Operation *loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {} + Operation *loadOp; + Value alloc; + Value barrier; + Operation *waitOp = nullptr; + bool isTMALoad = false; +}; + +#ifndef __ILUVATAR__ +// Create barriers and wait ops for the async loads. Barriers may be shared by +// multiple loads is the schedule allows it. +static void createTMABarrierAndWait( + scf::ForOp &forOp, SmallVector &asyncLoads, Value insertIdx, + Value extractIdx, Value phase, int numBuffers, CoarseSchedule &schedule, + SmallVector &barriers, + const llvm::MapVector &loadToInfo) { + llvm::SmallDenseMap loadToAsyncLoad; + for (AsyncLoad &asyncLoad : asyncLoads) { + loadToAsyncLoad[asyncLoad.loadOp] = &asyncLoad; + } + SmallVector> loadGroups; + llvm::SmallDenseSet visited; + // Find groups of loads that can share the same barrier. We look consecutive + // loads and check that there are uses in between. + for (AsyncLoad &asyncLoad : asyncLoads) { + if (!asyncLoad.isTMALoad || visited.count(asyncLoad.loadOp)) + continue; + llvm::SmallDenseSet users; + SmallVector group; + Block *loadBlock = asyncLoad.loadOp->getBlock(); + auto addToGroup = [&](AsyncLoad *loadInfo) { + group.push_back(loadInfo); + visited.insert(loadInfo->loadOp); + for (Operation *user : loadInfo->loadOp->getUsers()) { + auto it = loadToInfo.find(loadInfo->loadOp); + if (it != loadToInfo.end()) { + // Special case for MMAv3 loads, we can ignore the alloc and only + // consider uses of the alloc op since it will be removed. + if (it->second.loadIsMMAV3) { + auto alloc = cast( + (*loadInfo->loadOp->getUsers().begin())); + if (alloc->getBlock() == loadBlock) { + users.insert(alloc->getUsers().begin(), alloc->getUsers().end()); + continue; + } + } + } + Operation *userInBlock = loadBlock->findAncestorOpInBlock(*user); + if (userInBlock) + users.insert(userInBlock); + } + }; + addToGroup(&asyncLoad); + Operation *nextOp = asyncLoad.loadOp->getNextNode(); + while (nextOp) { + if (users.count(nextOp) || visited.count(nextOp)) + break; + if (isa(nextOp)) { + auto it = loadToAsyncLoad.find(nextOp); + if (it != loadToAsyncLoad.end() && it->second->isTMALoad) { + addToGroup(it->second); + } + } + nextOp = nextOp->getNextNode(); + } + loadGroups.push_back(group); + } + + // For each group calculate the size and insert the barrier after the last + // load. + for (SmallVector &group : loadGroups) { + int sizeInBytes = 0; + for (AsyncLoad *asyncLoad : group) { + auto tensorTy = + cast(asyncLoad->loadOp->getResult(0).getType()); + int loadSize = product(tensorTy.getShape()); + sizeInBytes += + loadSize * tensorTy.getElementType().getIntOrFloatBitWidth() / 8; + } + + Value barrierAlloc = createBarrierAlloc(forOp, numBuffers); + barriers.push_back(barrierAlloc); + Location loc = forOp.getLoc(); + OpBuilder builder(forOp); + tt::MemDescType barrierTy = tt::MemDescType::get( + {1}, builder.getI64Type(), + cast(barrierAlloc.getType()).getEncoding(), + /*mutableMemory=*/true); + builder.setInsertionPoint(group[0]->loadOp); + Value barrier = builder.create( + loc, barrierTy, barrierAlloc, ArrayRef({insertIdx})); + Value pred = builder.create(loc, 1, 1); + Operation *expect = builder.create( + forOp.getLoc(), barrier, sizeInBytes, pred); + auto [stage, cluster] = schedule[asyncLoads[0].loadOp]; + schedule.insert(expect, stage, cluster); + + builder.setInsertionPointAfter(group.back()->loadOp); + Value barrierViewWait = builder.create( + loc, barrierTy, barrierAlloc, ArrayRef({extractIdx})); + Operation *wait = + builder.create(loc, barrierViewWait, phase); + // Update the async loads info. + for (AsyncLoad *asyncLoad : group) { + asyncLoad->barrier = barrier; + asyncLoad->waitOp = wait; + } + } +} +#endif + +#ifndef __ILUVATAR__ +// Convert load ops into their asyn version and apply multi-buffering based on +// the required number of buffers. +static SmallVector +createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, + llvm::MapVector &loadToInfo, + SmallVector &barriers, int numStages) { + // Calculate the number of buffers needed for each load. + // TODO pawel: we could do more fine-grained allocation here and + // allocate only the number of buffers that specific loads need. + // Instead, we allocate the maximum number of buffers needed by any load. + int numBuffers = + // llvm::max_element(llvm::make_second_range(loadToInfo), [](auto &lhs, + std::max_element( + llvm::make_second_range(loadToInfo).begin(), + llvm::make_second_range(loadToInfo).end(), + [](auto &lhs, auto &rhs) { return lhs.distToUse < rhs.distToUse; }) + ->distToUse; + bool hasMMAV3 = + llvm::any_of(loadToInfo, [](auto &kv) { return kv.second.loadIsMMAV3; }); + if (hasMMAV3) { + // For MMAv3, we need an extra buffer as this is assumed in the wgmma + // pipelining post-processing. + numBuffers++; + }; + + SmallVector asyncLoads; + SmallVector allocs; + bool hasTMALoad = false; + for (auto &[loadOp, info] : loadToInfo) { + assert(info.sharedEncoding && "LoadOp shared encoding not defined."); + Value alloc = createAlloc(forOp, loadOp, info.sharedEncoding, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + allocs.push_back(alloc); + asyncLoads.emplace_back(loadOp, alloc); + if (isa(loadOp)) { + hasTMALoad = true; + asyncLoads.back().isTMALoad = true; + } + } + + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); + + Location loc = forOp.getLoc(); + // Create two new counters to index into the allocs. + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value insertIdx = minusOne; + Value extractIdx = minusOne; + Value phase = Value(); + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + SmallVector newOperands; + newOperands.push_back(insertIdx); + newOperands.push_back(extractIdx); + if (hasTMALoad) { + phase = builder.create(loc, 0, 32); + newOperands.push_back(phase); + } + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + insertIdx = newForOp.getBody()->getArgument(newOperandIndex); + extractIdx = newForOp.getBody()->getArgument(newOperandIndex + 1); + if (phase) { + phase = newForOp.getBody()->getArgument(newOperandIndex + 2); + } + + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + insertIdx = builder.create(loc, insertIdx, one); + Value cndIns = builder.create(loc, arith::CmpIPredicate::slt, + insertIdx, numBuffersVal); + insertIdx = builder.create(loc, cndIns, insertIdx, zero); + + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + if (phase) { + Value nextPhase = builder.create(loc, phase, one); + phase = builder.create(loc, cndExt, phase, nextPhase); + } + createTMABarrierAndWait(forOp, asyncLoads, insertIdx, extractIdx, phase, + numBuffers, schedule, barriers, loadToInfo); + + // Create a cluster for the prefetches. It may end up being empty, but this + // is OK. + CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + + for (AsyncLoad &asyncLoad : asyncLoads) { + if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { + createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, + schedule, prefetchCluster, loadToInfo, numStages); + } else { + auto descLoad = cast(asyncLoad.loadOp); + createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx, + extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase, + schedule, loadToInfo, numStages); + } + } + SmallVector newYieldOperands = {insertIdx, extractIdx}; + if (phase) + newYieldOperands.push_back(phase); + // Patch the yield with the updated counters. + appendToYield(forOp, newYieldOperands); + + return allocs; +} +#endif + +#ifndef __ILUVATAR__ +static void invalidateBarriers(OpBuilder &builder, + SmallVector &barriers) { + for (Value barrier : barriers) { + int numBarriers = barrier.getType().cast().getShape()[0]; + for (int i = 0; i < numBarriers; i++) { + Value idx = builder.create(barrier.getLoc(), i, 32); + tt::MemDescType barrierTy = tt::MemDescType::get( + {1}, builder.getI64Type(), + cast(barrier.getType()).getEncoding(), + /*mutableMemory=*/true); + Value barrierView = builder.create( + barrier.getLoc(), barrierTy, barrier, idx); + builder.create(barrier.getLoc(), barrierView); + } + } +} +#endif + +#ifndef __ILUVATAR__ +bool mlir::triton::preProcessLoopAndGetSchedule( + scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + CoarseSchedule coarseSchedule(numStages); + llvm::MapVector loadToInfo = + scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + if (loadToInfo.empty()) + return false; + + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + coarseSchedule.dump(); + }); + + SmallVector barriers; + // Convert the loads into async loads and create the allocs. + SmallVector allocs = + createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages); + + LLVM_DEBUG({ + LDBG("Coarse schedule with async loads:"); + coarseSchedule.dump(); + }); + + CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with prologue and epilogue:"); + coarseSchedule.dump(); + }); + + scheduleDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + coarseSchedule.dump(); + }); + + scheduleDistanceOneDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + coarseSchedule.dump(); + }); + + scheduleRemainingToLastStage(forOp, coarseSchedule, afterPrologue, numStages); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + coarseSchedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + coarseSchedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = tt::predicateOp; + options.supportDynamicLoops = true; + options.annotateFn = [](Operation *op, + mlir::triton::PipeliningOption::PipelinerPart part, + unsigned iteration) {}; + // Insert a wait 0 after the loop + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + builder.create(forOp.getLoc(), ValueRange({}), 0); + // Invalidate any mbarrier create + invalidateBarriers(builder, barriers); + // Explicitly deallocate allocated tensors after the wait op + for (auto alloc : allocs) + builder.create(forOp.getLoc(), alloc); + return true; +} +#endif + +/// Find the minimum number of async_commit_group ops between the wait +/// and the associated async_commit_group. This can be safely used as the wait +/// number. +static int minNumInterleavedCommitOps(Operation *waitOp) { + auto countCommitsBetween = [](Operation *op1, Operation *op2) { + int count = 0; + for (auto op = op1; op != op2; op = op->getNextNode()) { + if (isa(op)) + count++; + // Intentionally skip block ops' children. This will give us + // convervatively low number of insert ops. + } + return count; + }; + + int minCommitNumber = INT_MAX; + + // DFS the def chain of the extract op to find the insert op. On each path + // we calculate the number of async_commit. Then we select the minimum number + // of async_commit ops among all the paths. + std::function minOverHistories = + [&](Value val, Operation *sinkOp, int thisHistorySum) -> int { + if (Operation *defOp = val.getDefiningOp()) { + thisHistorySum += countCommitsBetween(defOp->getNextNode(), sinkOp); + minCommitNumber = std::min(minCommitNumber, thisHistorySum); + return minCommitNumber; + } + if (auto arg = mlir::dyn_cast(val)) { + Block *block = arg.getOwner(); + auto forOp = dyn_cast(block->getParentOp()); + + // Failed to track, return 0 conservatively. + if (!forOp) + return 0; + + Operation *firstForInst = &*forOp.getBody()->begin(); + int insertsBetween = countCommitsBetween(firstForInst, sinkOp); + thisHistorySum += insertsBetween; + if (thisHistorySum >= minCommitNumber) + return minCommitNumber; + + // get the value value assigned to the argument coming from outside the + // loop + Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1]; + int min1 = minOverHistories(incomingVal, forOp, thisHistorySum); + + // get the value value assigned to the argument coming from the previous + // iteration + Operation *yieldOp = block->getTerminator(); + Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1); + int min2 = minOverHistories(prevVal, yieldOp, thisHistorySum); + return std::min(std::min(min1, min2), minCommitNumber); + } + // Failed to track, return 0 conservatively. + return 0; + }; + + if (waitOp->getNumOperands() != 1) + return 0; + int minCommits = minOverHistories(waitOp->getOperand(0), waitOp, 0); + return minCommits; +} + +// Look for consecutive wait ops and combine them into a single wait op. +static void +combineRedundantWaitOps(llvm::SmallSetVector &waitOps) { + llvm::MapVector toDelete; + for (auto waitOp : waitOps) { + if (toDelete.count(waitOp)) + continue; + SmallVector waitGroup = {waitOp}; + SmallVector depTokens; + unsigned minWaitNumber = waitOp.getNum(); + Operation *next = waitOp->getNextNode(); + while (next && isa(next)) { + if (auto nextWait = dyn_cast(next)) { + waitGroup.push_back(nextWait); + minWaitNumber = std::min(minWaitNumber, nextWait.getNum()); + depTokens.append(nextWait.getOperands().begin(), + nextWait.getOperands().end()); + } + next = next->getNextNode(); + } + if (waitGroup.size() == 1) + continue; + OpBuilder builder(waitGroup.back()); + auto newWaitOp = builder.create(waitOp.getLoc(), + depTokens, minWaitNumber); + for (auto waitOp : waitGroup) { + toDelete[waitOp] = newWaitOp; + } + } + for (auto waitOp : toDelete) { + waitOp.first->replaceAllUsesWith(waitOp.second); + waitOp.first->erase(); + } +} + +/// Update wait op number by analyzing the number of async_commit_group ops +/// along all paths. +void mlir::triton::updateWaits(ModuleOp module) { + llvm::SmallSetVector waitOps; + module.walk([&](ttg::AsyncWaitOp waitOp) { + int minNumCommits = minNumInterleavedCommitOps(waitOp); + waitOp.setNum(minNumCommits); + waitOps.insert(waitOp); + }); + combineRedundantWaitOps(waitOps); +} + +#ifndef __ILUVATAR__ +// Add the given values as operands of the given wait, and replace all uses of +// the values with the wait. Also adds related MemDesc's to the wait. +// +// Threading %a through the wait transforms +// +// %a = <...> +// (%x', %y') = ttng.async_wait %x, %y +// %b = fn(%a) +// +// into +// +// %a = <...> +// (%x', %y', %a') = ttng.async_wait %x, %y, %a +// %b = fn(%a') +// +// The wait must dominate all uses of the elements of `values`. +// +// In addition to adding each value from `values` to the wait, this function +// also adds some MemDesc's to the wait. The idea is that if you have +// +// %alloc = ttg.local_alloc ... +// %a = ttng.dot_async %alloc +// %a1 = ttng.dot_wait %a +// +// then we want the wait to depend on %alloc as well as %a. This extends the +// live range of %alloc, so that it won't be destroyed until after the dot is +// waited on. +// +// Specifically, this function finds all dot_async ops that elements of `values` +// depend on. Then it adds the MemDesc operands of those dots to the wait. +static void threadValuesThroughWait(ttng::DotWaitOp wait, + MutableArrayRef values) { + IRRewriter builder(wait.getContext()); + builder.setInsertionPoint(wait); + + // Operands are only added to the wait through this function, so we can have + // the invariant that the wait has no duplicates. This makes things a bit + // easier below. + size_t origNumOperands = wait.getNumOperands(); + SetVector newOperands(wait.getOperands().begin(), + wait.getOperands().end()); + assert(newOperands.size() == origNumOperands && + "Wait op has duplicate operands."); + + newOperands.insert(values.begin(), values.end()); + + // Find memdefs depended on by `values` through async dot ops. + SmallVector asyncDots; + for (Value v : values) { + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.filter = [&](Operation *op) { + if (auto dot = dyn_cast(op)) { + asyncDots.push_back(dot); + return false; + } + return op->getBlock() == wait->getBlock(); + }; + SetVector slice; + getBackwardSlice(v, &slice, options); + } + + for (ttng::DotAsyncOp dot : asyncDots) { + for (Value operand : dot.getOperands()) { + if (isa(operand.getType())) { + newOperands.insert(operand); + } + } + } + + // We can't use replaceWithNewOp because we're changing the number of return + // values in the operation. + auto newWait = builder.create( + wait.getLoc(), llvm::to_vector(newOperands), wait.getPendings()); + + auto dominatedByNewWait = [&](OpOperand &operand) { + auto opInThisBlock = + newWait->getBlock()->findAncestorOpInBlock(*operand.getOwner()); + return opInThisBlock && newWait->isBeforeInBlock(opInThisBlock); + }; + for (int i = 0; i < origNumOperands; i++) { + Value operand = wait.getResult(i); + if (!isa(operand.getType())) + operand.replaceAllUsesWith(newWait.getResult(i)); + } + for (int i = origNumOperands; i < newOperands.size(); i++) { + Value operand = newWait.getOperand(i); + if (!isa(operand.getType())) + operand.replaceUsesWithIf(newWait.getResult(i), dominatedByNewWait); + } + wait->erase(); +} + +// Determines whether a given MMAv3 dot op, represented as ttng.dot_async, needs +// a wait immediately after it. +// +// In PTX, MMAv3 exists only as an asynchronous op. In Triton, we can represent +// MMAv3 ops as either tt.dot (synchronous) or ttng.dot_async. But even if we +// use ttng.dot_async, the conservative thing is to make a dot "effectively +// synchronous" by inserting a `ttng.dot_wait {pendings=0}` right after it. +// +// We can omit the wait and create a "properly async" dot if all of the +// following are true. +// +// 1. All operands that touch shared memory are multi-buffered, i.e. can't read +// an incomplete value while it's being written asynchronously by a load. +// +// 2. If the dot is used by any op in the loop, it must be used under an `if`, +// and will be synced with a `wait 0` at the beginning of the `if` block. +// +// 3. During iteration i, between the start of the loop up until the first +// `ttng.dot_wait {pendings=0}` op, the result of the dot from iteration i-1 +// is consumed only by other MMAv3 dots as the `c` operand. +// +// This is safe because the following pseudo-PTX is valid: +// +// %accum = dot_async %a1, %b1, %c1 +// %accum = dot_async %a2, %b2, %accum +// +// That is, the second async dot can use the result of the first one without +// an intervening wait. However, the only operation that can legally read +// %accum before the wait is another dot_async, and this only works for the +// `c` operand, not `a` or `b`. See +// https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence +// (ttng::DotAsyncOp corresponds to wgmma.fence followed by one or more +// wgmma.async ops, so our understanding is that the two ttng::DotAsyncOps +// don't have to correspond to wgmma.async ops with the same shapes as +// specified in the docs, because there's an intervening fence.) +// +// If the op can be properly async, this function returns the index of the dot +// in the loop's iter_args. (Rule (2) above ensures this is well-defined.) +// +static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, + scf::ForOp forOp) { + LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp); + + // Rule 1: All shmem operands are multi-buffered. + auto checkOperand = [&](Value operand) { + if (!isa( + cast(operand.getType()).getEncoding())) { + return true; + } + + // If it's a shmem operand, it must either be defined outside the loop, or + // come from an MemDescSubview op. Only ConvertLayout and Trans ops are + // allowed in between. + Value transitiveOperand = operand; + while (isa_and_nonnull( + transitiveOperand.getDefiningOp())) { + transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0); + } + return forOp.isDefinedOutsideOfLoop(transitiveOperand) || + isa(transitiveOperand.getDefiningOp()); + }; + + // We don't have to call checkOperand on getC() because it's always in + // registers, never in shmem. + assert(isa(dotOp.getC().getType().getEncoding())); + if (!checkOperand(dotOp.getA()) || !checkOperand(dotOp.getB())) { + LDBG("Can't make dot async because shmem operands aren't multi-buffered"); + return std::nullopt; + } + + // Rule 2: The dot cannot be unconditionally used by any op in the loop. + // Uses under `if` are allowed, as can be explicitly synced with a `wait 0`. + int iterArgIdx = -1; + Value iterArg = nullptr; + SmallVector> queue; + for (auto &use : dotOp->getUses()) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + while (!queue.empty()) { + auto [user, argIdx] = queue.pop_back_val(); + if (user->getParentOp() == forOp) { + if (isa(user)) { + if (iterArg) { + // The dot is used by the loop's yield, but we can't have any other + // uses. + LDBG("Can't make dot async because dot is used by multiple ops in " + "the loop."); + return std::nullopt; + } + iterArgIdx = argIdx; + iterArg = forOp.getRegionIterArg(argIdx); + continue; + } + LDBG("Can't make dot async because dot is unconditionally used in the " + "loop."); + return std::nullopt; + } + if (auto ifOp = dyn_cast(user->getParentOp())) { + if (isa(user)) { + // The result is returned by the if, follow it further. + auto uses = ifOp.getResult(argIdx).getUses(); + for (auto &use : uses) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + } + } else { + return std::nullopt; + } + } + + // Rule 3a: Are the only users of the dot's result from iteration i-1 other + // MMAv3 dots? If so, we're done, this dot can be properly async. + if (llvm::all_of(iterArg.getUses(), [&](OpOperand &use) { + return isa(use.getOwner()) && + use.getOperandNumber() == 2; + })) { + return iterArgIdx; + } + + // Rule 3b: Are all users of the dot's result from iteration i-1 after the + // first `dot_wait {pendings=0}` op? If so, the dot can be properly async, + // but we have to thread its result from iteration i-1 through the wait. + auto waitOps = forOp.getBody()->getOps(); + auto firstWaitOpIter = llvm::find_if( + waitOps, [&](auto waitOp) { return waitOp.getPendings() == 0; }); + if (firstWaitOpIter != waitOps.end() && + llvm::all_of(iterArg.getUsers(), [&](Operation *user) { + assert(forOp->isAncestor(user)); + while (user->getParentOp() != forOp) { + user = user->getParentOp(); + } + return (*firstWaitOpIter)->isBeforeInBlock(user); + })) { + LDBG("MMAv3 dot can be properly async because it follows a dot_wait " + "{pendings=0}.\n" + << " wait: " << *firstWaitOpIter << "\n" + << " dot: " << dotOp); + threadValuesThroughWait(*firstWaitOpIter, {iterArg}); + return iterArgIdx; + } + + LDBG("Can't make dot async because its result from i-1 is used by " + "something other than another MMAv3 dot as the `c` operand."); + return std::nullopt; +} + +// If necessary, insert a dot-wait inside the loop, waiting for the results of +// the properly-async dots from iteration i-1 to complete. (We pipeline to +// depth 2, so there are at most 2 copies of each dot_async in flight at a +// time.) +// +// We can skip inserting the wait if we have a `dot_wait {pendings=0}` somewhere +// in the loop. To see why, consider: +// +// dot_async +// dot_async; wait 0 // synchronous dot +// dot_async +// dot_async +// +// In this example, there are three properly-async dots, so we'd normally put +// `wait 3` at the end of the loop, meaning "wait until there are 3 or fewer +// pending async dots". But note that when this iteration of the loop +// completes, there are only *two* pending async dots from this iteration, so +// this wait would do nothing. This is true in general, no matter where the +// `wait 0` appears. +static void insertAsyncDotWaitInLoop( + scf::ForOp forOp, + const llvm::MapVector &properlyAsyncDots) { + if (properlyAsyncDots.empty()) + return; + + if (llvm::any_of(forOp.getBody()->getOps(), + [](auto wait) { return wait.getPendings() == 0; })) { + return; + } + + // Insert waits before the users of the properly async dots other than loop + // yield. + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + SmallVector uses; + for (auto &use : asyncDot->getUses()) { + if (auto yieldOp = dyn_cast(use.getOwner())) { + continue; + } + uses.push_back(&use); + } + + DenseMap> blockToUsers; + for (auto use : uses) { + auto block = use->getOwner()->getBlock(); + blockToUsers[block].push_back(use->get()); + } + + for (auto [block, users] : blockToUsers) { + OpBuilder builder(block, block->begin()); + auto newWait = builder.create(asyncDot->getLoc(), + ArrayRef{}, 0); + + threadValuesThroughWait(newWait, users); + } + } + + // Add the wait right after the last properly-async dot. This only needs to + // wait for all properly-async dots from the i-1'th iteration to complete, IOW + // we wait until there are most `asyncDots.size()` dots in flight. + // + // (You might want to put the wait at the end of the loop instead of right + // after the last dot, but there could be a load into shmem between the last + // async dot and the end of the loop, and that could clobber memory being used + // by a dot.) + IRRewriter builder(forOp.getContext()); + auto lastAsyncDot = properlyAsyncDots.back().first; + builder.setInsertionPointAfter(lastAsyncDot); + auto wait = builder.create(lastAsyncDot->getLoc(), + /*inputs=*/ArrayRef{}, + properlyAsyncDots.size()); + + // Thread the results of the async dots through the wait. + SmallVector addlWaitOperands; + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + addlWaitOperands.push_back(asyncDot->getResult(0)); + } + threadValuesThroughWait(wait, addlWaitOperands); +} + +// Convert MMAv3 tt::DotOps (i.e. Hopper wgmma) into ttng::DotAsyncOps and +// insert ttng::DotWaitOps as necessary. +// +// We assume we have space for each dot to be pipelined to depth 2, i.e. each +// dot op in the loop can have at most 2 dot_async ops in flight at once. (Each +// dot_async op usually corresponds to a series of wgmma.async ops.) +void triton::asyncLaunchDots(scf::ForOp forOp) { + LDBG("Original loop:\n" << *forOp); + + // First, change every MMAv3 tt.dot into ttng.dot_async. The rest of this + // function is concerned with inserting ttng.dot_wait ops in the appropriate + // places. + // + // It's not strictly necessary to convert every dot into dot_async: + // Synchronous MMAv3 dots can be represented equally well as `tt.dot` or + // `ttng.dot_async; wait 0`. But this makes things easier elsewhere. + // + // We call those dots that don't need to be followed immediately by a `wait 0` + // "properly async", or sometimes just "async". + IRRewriter builder(forOp.getContext()); + for (auto dotOp : llvm::to_vector(forOp.getBody()->getOps())) { + if (isMMAv3Dot(dotOp)) { + builder.setInsertionPoint(dotOp); + builder.replaceOpWithNewOp( + dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), + dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); + } + } + + // For each dot, determine whether it can be properly async, or if it needs a + // sync immediately after. If it can be properly async, we know its only use + // is in the loop's `yield` statement; asyncDots maps the op to its index in + // the yield op. + llvm::MapVector properlyAsyncDots; + for (auto dotOp : forOp.getBody()->getOps()) { + if (auto iterArgIdx = dotCanBeProperlyAsync(dotOp, forOp)) { + properlyAsyncDots[dotOp] = *iterArgIdx; + } else { + builder.setInsertionPointAfter(dotOp); + auto wait = + builder.create(dotOp.getLoc(), ArrayRef{}, + /*pendings=*/0); + SmallVector waitOperands = {dotOp.getResult()}; + threadValuesThroughWait(wait, waitOperands); + } + } + + if (properlyAsyncDots.empty()) { + LDBG("No properly async dots."); + return; + } + + // Next, insert a wait inside the loop. We pipeline to depth 2, so the third + // iteration's set of asynchronous dots (and their corresponding async copies + // from global to shmem) can't start until the first iteration's set has + // completed. + insertAsyncDotWaitInLoop(forOp, properlyAsyncDots); + + // Finally, insert a wait after the loop, waiting for dots from the final + // iteration of the loop. + SmallVector waitOperands; + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + waitOperands.push_back(forOp.getResult(iterArgIdx)); + } + // Wait until there are 0 outstanding async dot ops. + builder.setInsertionPointAfter(forOp); + auto dotWaitAfterLoop = + builder.create(forOp.getLoc(), ArrayRef{}, 0); + threadValuesThroughWait(dotWaitAfterLoop, waitOperands); +} +#endif diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp new file mode 100644 index 000000000..8b3f55bb8 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp @@ -0,0 +1,131 @@ +#include "PipelineExpander.h" +#include "PipeliningUtility.h" +#include "Schedule.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +// create the schedule for a matmul loop. This is ad hoc based on how we know +// matmul loops should be pipelined and is not a generic scheduler. +static std::vector> +createSchedule(scf::ForOp forOp, int numStages) { + SmallVector insertOps; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + insertOps.emplace_back(&op); + } + DenseSet insertAndDeps; + for (Operation *op : insertOps) { + tt::addDep(op, insertAndDeps, true); + } + + DenseSet epilogue; + bool foundLoop = false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (insertAndDeps.count(&op)) + continue; + if (isa(op)) + foundLoop = true; + if (isa(op)) + continue; + if (foundLoop) + epilogue.insert(&op); + } + + std::vector> schedule; + // Schedule stage 1 first. + tt::addOps(forOp, 1, schedule, [&](Operation *op) { + return insertAndDeps.count(op) == 0 && epilogue.count(op) == 0; + }); + + // Then Schedule stage 0. + tt::addOps(forOp, 0, schedule, + [&](Operation *op) { return insertAndDeps.count(op); }); + + // Then schedule the epilogue in stage 1 + tt::addOps(forOp, 1, schedule, + [&](Operation *op) { return epilogue.count(op); }); + return schedule; +} + +// pre-process the loop by hosting allocations/deallocation out of the +// loop. +static void hoistAllocAndConst(scf::ForOp forOp) { + SmallVector toHoist; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (auto allocOp = dyn_cast(op)) { + // We hoist the allocOp only if it is created by the inner loop + // pipelining. + if (!allocOp.getSrc()) + toHoist.push_back(&op); + } else if (isa(op)) { + toHoist.push_back(&op); + } + } + for (Operation *op : toHoist) { + op->moveBefore(forOp); + auto allocOp = dyn_cast(op); + if (!allocOp) + continue; + for (Operation *user : allocOp->getUsers()) { + if (auto dealloc = dyn_cast(user)) { + dealloc->moveAfter(forOp); + } + } + } +} + +static bool preCondition(scf::ForOp forOp) { + // Check if there is a dependency from the loop to the async copy op. In this + // case we cannot pipeline the async copy. + SmallVector insertOps; + int numForOps = 0; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + insertOps.emplace_back(&op); + if (isa(op)) + numForOps++; + } + if (insertOps.empty() || numForOps != 1) + return false; + DenseSet insertAndDeps; + for (Operation *op : insertOps) { + tt::addDep(op, insertAndDeps, true); + } + // If there is a recurrence containing both the async and the for op we cannot + // pipeline. + for (Operation *op : insertAndDeps) { + if (isa(op)) + return false; + } + return true; +} + +bool mlir::triton::getOuterLoopSchedule( + scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { + assert(numStages == 2 && "only support 2 stage pipelining for now"); + // 1. Check precondition, we cannot have a recurrence involving async cp ops + if (!preCondition(forOp)) + return false; + + // 2. pre-process the loop by hosting allocations. + hoistAllocAndConst(forOp); + + // 3. Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + createSchedule(forOp, numStages); + + // 4. Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = mlir::triton::predicateOp; + options.supportDynamicLoops = true; + return true; +} diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp new file mode 100644 index 000000000..6dfd0e344 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -0,0 +1,776 @@ +//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop software pipelining +// +//===----------------------------------------------------------------------===// + +// Fork of upstream pipeliner. This will be merged upstream once things are +// stable. Modifications so far are: +// -Bug fix for def with a distance of 1 scheduled in stage 0. +// -Support dynamic loops and predicate operations in the prologue. +// -Support for non-index type for induction variable. +// -Support source with distance of 1 used multiple stages later. +// -Fix bug when a value yield is used outside the loop and the value def is not +// in the last stage. If we are not peeling the epilgue we need to remap the +// output correctly. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" + +#include "PipelineExpander.h" + +#define DEBUG_TYPE "triton-loop-pipelining" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::scf; +using namespace mlir::triton; + +namespace { + +/// Helper to keep internal information during pipelining transformation. +struct LoopPipelinerInternal { + /// Coarse liverange information for ops used across stages. + struct LiverangeInfo { + unsigned lastUseStage = 0; + unsigned defStage = 0; + }; + +protected: + ForOp forOp; + unsigned maxStage = 0; + DenseMap stages; + std::vector opOrder; + Value ub; + Value lb; + Value step; + bool dynamicLoop; + triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; + + // When peeling the kernel we generate several version of each value for + // different stage of the prologue. This map tracks the mapping between + // original Values in the loop and the different versions + // peeled from the loop. + DenseMap> valueMapping; + + /// Assign a value to `valueMapping`, this means `val` represents the version + /// `idx` of `key` in the epilogue. + void setValueMapping(Value key, Value el, int64_t idx); + + /// Return the defining op of the given value, if the Value is an argument of + /// the loop return the associated defining op in the loop and its distance to + /// the Value. + std::pair getDefiningOpAndDistance(Value value); + + /// Return true if the schedule is possible and return false otherwise. A + /// schedule is correct if all definitions are scheduled before uses. + bool verifySchedule(); + +public: + /// Initialize the information for the given `op`, return true if it + /// satisfies the pre-condition to apply pipelining. + bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options); + /// Emits the prologue, this creates `maxStage - 1` part which will contain + /// operations from stages [0; i], where i is the part index. + void emitPrologue(RewriterBase &rewriter); + /// Gather liverange information for Values that are used in a different stage + /// than its definition. + llvm::MapVector analyzeCrossStageValues(); + scf::ForOp createKernelLoop( + const llvm::MapVector &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap); + /// Emits the pipelined kernel. This clones loop operations following user + /// order and remaps operands defined in a different stage as their use. + LogicalResult createKernel( + scf::ForOp newForOp, + const llvm::MapVector &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter); + /// Emits the epilogue, this creates `maxStage - 1` part which will contain + /// operations from stages [i; maxStage], where i is the part index. + void emitEpilogue(RewriterBase &rewriter, + llvm::SmallVector &returnValues); +}; + +bool LoopPipelinerInternal::initializeLoopInfo( + ForOp op, const triton::PipeliningOption &options) { + LDBG("Start initializeLoopInfo"); + forOp = op; + ub = forOp.getUpperBound(); + lb = forOp.getLowerBound(); + step = forOp.getStep(); + + dynamicLoop = true; + auto upperBoundCst = ub.getDefiningOp(); + auto lowerBoundCst = lb.getDefiningOp(); + auto stepCst = step.getDefiningOp(); + if (!upperBoundCst || !lowerBoundCst || !stepCst) { + if (!options.supportDynamicLoops) { + LDBG("--dynamic loop not supported -> BAIL"); + return false; + } + } else { + int64_t ubImm = upperBoundCst.value(); + int64_t lbImm = lowerBoundCst.value(); + int64_t stepImm = stepCst.value(); + int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm); + if (numIteration > maxStage) { + dynamicLoop = false; + } else if (!options.supportDynamicLoops) { + LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + return false; + } + } + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { + LDBG("--no epilogue or predicate set -> BAIL"); + return false; + } + if (dynamicLoop && peelEpilogue) { + LDBG("--dynamic loop doesn't support epilogue yet -> BAIL"); + return false; + } + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) { + LDBG("--empty schedule -> BAIL"); + return false; + } + + opOrder.reserve(schedule.size()); + for (auto &opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + + // All operations need to have a stage. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!stages.contains(&op)) { + op.emitOpError("not assigned a pipeline stage"); + LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + return false; + } + } + + if (!verifySchedule()) { + LDBG("--invalid schedule: " << op << " -> BAIL"); + return false; + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + LDBG("--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"); + return false; + } + } + + // Support only loop-carried dependencies with a distance of one iteration or + // those defined outside of the loop. This means that any dependency within a + // loop should either be on the immediately preceding iteration, the current + // iteration, or on variables whose values are set before entering the loop. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [this](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def || + (!stages.contains(def) && forOp->isAncestor(def)); + })) { + LDBG("--only support loop carried dependency with a distance of 1 or " + "defined outside of the loop -> BAIL"); + return false; + } + annotateFn = options.annotateFn; + return true; +} + +/// Find operands of all the nested operations within `op`. +static SetVector getNestedOperands(Operation *op) { + SetVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + operands.insert(operand); + } + }); + return operands; +} + +/// Compute unrolled cycles of each op (consumer) and verify that each op is +/// scheduled after its operands (producers) while adjusting for the distance +/// between producer and consumer. +bool LoopPipelinerInternal::verifySchedule() { + int64_t numCylesPerIter = opOrder.size(); + // Pre-compute the unrolled cycle of each op. + DenseMap unrolledCyles; + for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) { + Operation *def = opOrder[cycle]; + auto it = stages.find(def); + assert(it != stages.end()); + int64_t stage = it->second; + unrolledCyles[def] = cycle + stage * numCylesPerIter; + } + for (Operation *consumer : opOrder) { + int64_t consumerCycle = unrolledCyles[consumer]; + for (Value operand : getNestedOperands(consumer)) { + auto [producer, distance] = getDefiningOpAndDistance(operand); + if (!producer) + continue; + auto it = unrolledCyles.find(producer); + // Skip producer coming from outside the loop. + if (it == unrolledCyles.end()) + continue; + int64_t producerCycle = it->second; + if (consumerCycle < producerCycle - numCylesPerIter * distance) { + consumer->emitError("operation scheduled before its operands"); + return false; + } + } + } + return true; +} + +/// Clone `op` and call `callback` on the cloned op's operands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation * +cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + function_ref callback) { + Operation *clone = rewriter.clone(*op); + clone->walk([&](Operation *nested) { + // 'clone' itself will be visited first. + for (OpOperand &operand : nested->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || isa(operand.get())) + callback(&operand); + } + }); + return clone; +} + +void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (auto [arg, operand] : + llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { + setValueMapping(arg, operand.get(), 0); + } + auto yield = cast(forOp.getBody()->getTerminator()); + Location loc = forOp.getLoc(); + SmallVector predicates(maxStage); + for (int64_t i = 0; i < maxStage; i++) { + if (dynamicLoop) { + Type t = ub.getType(); + // pred = ub > lb + (i * step) + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, i)))); + predicates[i] = rewriter.create( + loc, arith::CmpIPredicate::slt, iv, ub); + } + + // special handling for induction variable as the increment is implicit. + // iv = lb + i * step + Type t = lb.getType(); + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, + rewriter.getIntegerAttr(t, i)))); + setValueMapping(forOp.getInductionVar(), iv, i); + for (Operation *op : opOrder) { + if (stages[op] > i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); + int predicateIdx = i - stages[op]; + if (predicates[predicateIdx]) { + newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); + assert(newOp && "failed to predicate op."); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); + // If the value is a loop carried dependency update the loop argument + // mapping. + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), i - stages[op] + 1); + } + } + } + } +} + +llvm::MapVector +LoopPipelinerInternal::analyzeCrossStageValues() { + llvm::MapVector crossStageValues; + for (Operation *op : opOrder) { + unsigned stage = stages[op]; + + auto analyzeOperand = [&](OpOperand &operand) { + auto [def, distance] = getDefiningOpAndDistance(operand.get()); + if (!def) + return; + auto defStage = stages.find(def); + if (defStage == stages.end() || defStage->second == stage || + defStage->second == stage + distance) + return; + assert(stage > defStage->second); + LiverangeInfo &info = crossStageValues[operand.get()]; + info.defStage = defStage->second; + info.lastUseStage = std::max(info.lastUseStage, stage); + }; + + for (OpOperand &operand : op->getOpOperands()) + analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + analyzeOperand(*operand); + }); + } + return crossStageValues; +} + +std::pair +LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { + int64_t distance = 0; + if (auto arg = dyn_cast(value)) { + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + distance++; + value = + forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + Operation *def = value.getDefiningOp(); + if (!def) + return {nullptr, 0}; + return {def, distance}; +} + +scf::ForOp LoopPipelinerInternal::createKernelLoop( + const llvm::MapVector + &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap) { + // Creates the list of initial values associated to values used across + // stages. The initial values come from the prologue created above. + // Keep track of the kernel argument associated to each version of the + // values passed to the kernel. + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance of 1 or " + "outside the loop"); + auto defStage = stages.find(def); + if (defStage != stages.end()) { + Value valueVersion = + valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage->second]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } else + newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]); + } + for (auto escape : crossStageValues) { + LiverangeInfo &info = escape.second; + Value value = escape.first; + for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; + stageIdx++) { + Value valueVersion = + valueMapping[value][maxStage - info.lastUseStage + stageIdx]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - + stageIdx)] = newLoopArg.size() - 1; + } + } + + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) { + Type t = ub.getType(); + Location loc = forOp.getLoc(); + // newUb = ub - maxStage * step + Value maxStageValue = rewriter.create( + loc, rewriter.getIntegerAttr(t, maxStage)); + Value maxStageByStep = + rewriter.create(loc, step, maxStageValue); + newUb = rewriter.create(loc, ub, maxStageByStep); + } + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); + return newForOp; +} + +LogicalResult LoopPipelinerInternal::createKernel( + scf::ForOp newForOp, + const llvm::MapVector + &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter) { + valueMapping.clear(); + + // Create the kernel, we clone instruction based on the order given by + // user and remap operands coming from a previous stages. + rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + Location loc = newForOp.getLoc(); + Type t = ub.getType(); + for (unsigned i = 0; i < maxStage; i++) { + // c = ub - (maxStage - i) * step + Value c = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); + + Value pred = rewriter.create( + newForOp.getLoc(), arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); + predicates[i] = pred; + } + } + for (Operation *op : opOrder) { + int64_t useStage = stages[op]; + auto *newOp = rewriter.clone(*op, mapping); + SmallVector operands; + // Collect all the operands for the cloned op and its nested ops. + op->walk([&operands](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + operands.push_back(&operand); + } + }); + for (OpOperand *operand : operands) { + Operation *nestedNewOp = mapping.lookup(operand->getOwner()); + // Special case for the induction variable uses. We replace it with a + // version incremented based on the stage where it is used. + if (operand->get() == forOp.getInductionVar()) { + rewriter.setInsertionPoint(newOp); + + // offset = (maxStage - stages[op]) * step + Type t = step.getType(); + Value offset = rewriter.create( + forOp.getLoc(), step, + rewriter.create( + forOp.getLoc(), + rewriter.getIntegerAttr(t, maxStage - stages[op]))); + Value iv = rewriter.create( + forOp.getLoc(), newForOp.getInductionVar(), offset); + nestedNewOp->setOperand(operand->getOperandNumber(), iv); + rewriter.setInsertionPointAfter(newOp); + continue; + } + Value source = operand->get(); + auto arg = dyn_cast(source); + if (arg && arg.getOwner() == forOp.getBody()) { + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + if (stageDep->second == useStage + 1) { + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + source = ret; + } + // For operands defined in a previous stage we need to remap it to use + // the correct region argument. We look for the right version of the + // Value based on the stage where it is used. + Operation *def = source.getDefiningOp(); + if (!def) + continue; + auto stageDef = stages.find(def); + if (stageDef == stages.end() || stageDef->second == useStage) + continue; + auto remap = loopArgMap.find( + std::make_pair(operand->get(), useStage - stageDef->second)); + assert(remap != loopArgMap.end()); + nestedNewOp->setOperand(operand->getOperandNumber(), + newForOp.getRegionIterArgs()[remap->second]); + } + + if (predicates[useStage]) { + newOp = predicateFn(rewriter, newOp, predicates[useStage]); + if (!newOp) + return failure(); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0); + } + + // Collect the Values that need to be returned by the forOp. For each + // value we need to have `LastUseStage - DefStage` number of versions + // returned. + // We create a mapping between original values and the associated loop + // returned values that will be needed by the epilogue. + llvm::SmallVector yieldOperands; + for (OpOperand &yieldOperand : + forOp.getBody()->getTerminator()->getOpOperands()) { + Value source = mapping.lookupOrDefault(yieldOperand.get()); + // When we don't peel the epilogue and the yield value is used outside the + // loop we need to make sure we return the version from numStages - + // defStage. + if (!peelEpilogue && + !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { + Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; + if (def) { + auto defStage = stages.find(def); + if (defStage != stages.end() && defStage->second < maxStage) { + Value pred = predicates[defStage->second]; + source = rewriter.create( + pred.getLoc(), pred, source, + newForOp.getBody() + ->getArguments()[yieldOperand.getOperandNumber() + 1]); + } + } + } + yieldOperands.push_back(source); + } + + for (auto &it : crossStageValues) { + int64_t version = maxStage - it.second.lastUseStage + 1; + unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; + // add the original version to yield ops. + // If there is a live range spanning across more than 2 stages we need to + // add extra arg. + for (unsigned i = 1; i < numVersionReturned; i++) { + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back( + newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]); + } + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + } + // Map the yield operand to the forOp returned value. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance of 1 or " + "defined outside the loop"); + auto defStage = stages.find(def); + if (defStage == stages.end()) { + for (unsigned int stage = 1; stage <= maxStage; stage++) + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + retVal.value(), stage); + } else if (defStage->second > 0) { + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage->second + 1); + } + } + rewriter.create(forOp.getLoc(), yieldOperands); + return success(); +} + +void LoopPipelinerInternal::emitEpilogue( + RewriterBase &rewriter, llvm::SmallVector &returnValues) { + // Emit different versions of the induction variable. They will be + // removed by dead code if not used. + for (int64_t i = 0; i < maxStage; i++) { + Location loc = forOp.getLoc(); + Type t = lb.getType(); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); + // number of iterations = ((ub - 1) - lb) / step + Value totalNumIteration = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, ub, minusOne), lb), + step); + // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) + Value minusI = + rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); + Value newlastIter = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, totalNumIteration, minusI))); + setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); + } + // Emit `maxStage - 1` epilogue part that includes operations from stages + // [i; maxStage]. + for (int64_t i = 1; i <= maxStage; i++) { + for (Operation *op : opOrder) { + if (stages[op] < i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[maxStage - stages[op] + i]; + newOperand->set(replacement); + } + }); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue, + i - 1); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + maxStage - stages[op] + i); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand &operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + unsigned version = maxStage - stages[op] + i + 1; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + if (version > maxStage) { + returnValues[operand.getOperandNumber()] = newOp->getResult(destId); + continue; + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), version); + } + } + } + } +} + +void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { + auto it = valueMapping.find(key); + // If the value is not in the map yet add a vector big enough to store all + // versions. + if (it == valueMapping.end()) + it = + valueMapping + .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) + .first; + it->second[idx] = el; +} + +} // namespace + +FailureOr +mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const triton::PipeliningOption &options, + bool *modifiedIR) { + if (modifiedIR) + *modifiedIR = false; + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); + + if (modifiedIR) + *modifiedIR = true; + + // 1. Emit prologue. + pipeliner.emitPrologue(rewriter); + + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); + + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, + rewriter))) + return failure(); + + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + pipeliner.emitEpilogue(rewriter, returnValues); + } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); + + return newForOp; +} diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h new file mode 100644 index 000000000..0a3d736c6 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h @@ -0,0 +1,101 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ + +// This is a fork of upstream pipeline transformation. This will be merged back +// upstream once we have a stable solution. + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +class RewriterBase; +class Operation; +class Value; + +namespace scf { +class ForOp; +} + +namespace triton { + +/// Options to dictate how loops should be pipelined. +struct PipeliningOption { + /// Lambda returning all the operation in the forOp, with their stage, in the + /// order picked for the pipelined loop. + using GetScheduleFnType = std::function> &)>; + GetScheduleFnType getScheduleFn = nullptr; + enum class PipelinerPart { + Prologue, + Kernel, + Epilogue, + }; + /// Lambda called by the pipeliner to allow the user to annotate the IR while + /// it is generated. + /// The callback passes the operation created along with the part of the + /// pipeline and the iteration index. The iteration index is always 0 for the + /// kernel. For the prologue and epilogue, it corresponds to the iteration + /// peeled out of the loop in the range [0, maxStage[. + using AnnotationlFnType = + std::function; + AnnotationlFnType annotateFn = nullptr; + + /// Control whether the epilogue should be peeled out of the loop or + /// operations should be predicated to skip the early stages in the last loop + /// iterations. If the epilogue is predicated; the user needs to provide a + /// lambda to generate the predicated version of operations. + bool peelEpilogue = true; + + /// Control whether the transformation checks that the number of iterations is + /// greater or equal to the number of stages and skip the transformation if + /// this is not the case. If the loop is dynamic and this is set to true the + /// pipeliner will have to predicate operations in the the prologue/epilogue. + bool supportDynamicLoops = false; + + // Callback to predicate operations when the prologue or epilogue are not + // peeled. This takes the original operation, an i1 predicate value and the + // pattern rewriter. It is expected to replace the given operation with + // the predicated equivalent and return it, or return nullptr if the + // predication is impossible. In the latter case, pipelining will fail and + // may leave IR in a partially transformed state. + using PredicateOpFnType = + std::function; + PredicateOpFnType predicateFn = nullptr; + + // TODO: add option to decide if the prologue should be peeled. +}; + +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +/// +/// If `modifiedIR` is provided, it will be set to a value that indicates +/// whether pipelining modified the IR before failing, signaling to the caller +/// whether they can proceed with different transformations. +FailureOr pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp, + const PipeliningOption &options, + bool *modifiedIR = nullptr); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp new file mode 100644 index 000000000..a4819f5c3 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -0,0 +1,124 @@ +#include "PipeliningUtility.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +// Combine the current mask with the given predicate. +static Value getPredMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (isa(maskType)) { + mask = rewriter.create(loc, maskType, pred); + } + if (currentMask) { + mask = rewriter.create(loc, mask, currentMask); + } + return mask; +} + +// Function to mask operations during scheduling. +Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (auto ifOp = dyn_cast(op)) { + rewriter.setInsertionPoint(op); + Value cnd = getPredMask(rewriter, ifOp.getCondition().getType(), + ifOp.getCondition(), pred); + ifOp.getConditionMutable().assign(cnd); + return op; + } + if (auto asyncCopyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(asyncCopyOp); + Value mask = getPredMask(rewriter, asyncCopyOp.getSrc().getType(), + asyncCopyOp.getMask(), pred); + asyncCopyOp.getMaskMutable().assign(mask); + return op; + } + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + return op; + } +#ifndef __ILUVATAR__ + if (auto copyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(copyOp); + Value mask = getPredMask(rewriter, copyOp.getPred().getType(), + copyOp.getPred(), pred); + copyOp.getPredMutable().assign(mask); + return op; + } + if (auto expectOp = dyn_cast(op)) { + rewriter.setInsertionPoint(expectOp); + Value mask = getPredMask(rewriter, expectOp.getPred().getType(), + expectOp.getPred(), pred); + expectOp.getPredMutable().assign(mask); + return op; + } +#endif + + assert("don't know how to predicate this op" && false); + return op; +} + +/// Helper to recursively add dependencies to the same stage. +void mlir::triton::addDep(Operation *op, DenseSet &deps, + bool includeArg, DenseSet *filter) { + if (filter && filter->count(op)) + return; + if (!deps.insert(op).second) + return; + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = mlir::dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + addDep(defOp, deps, includeArg, filter); + } + } +} + +// Add operations to the schedule with the given stage based on the filter +// function. +void mlir::triton::addOps( + scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!filter(&op)) + continue; + schedule.emplace_back(&op, stage); + } +} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.h b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.h similarity index 100% rename from lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.h rename to third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.h diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h new file mode 100644 index 000000000..729b76c05 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h @@ -0,0 +1,43 @@ +#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ +#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ + +#include "PipelineExpander.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include + +namespace mlir { +namespace triton { + +#ifndef __ILUVATAR__ +/// This fill out the pipelining options including schedule and annotations +/// for wait ops. This also does pre-processing by converting some of the +/// loads into async loads so that the IR is ready to be pipelined. +bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); +#endif + +/// Fills out pipelining options for an outer loop pipelining case. This +/// schedules async copies to overlap with the epilogue of a loop. +bool getOuterLoopSchedule(scf::ForOp &forOp, int numStages, + mlir::triton::PipeliningOption &options); + +#ifndef __ILUVATAR__ +/// Pipeline the TMA stores in the loop. +bool pipelineTMAStores(scf::ForOp forOp); + +/// This does post-processing on the pipelined loop to try to pipeline wgmma +/// ops. +// TODO: this should be included as part of the pipeline but currently the wgmma +// wait modeling is problematic. +void asyncLaunchDots(scf::ForOp forOp); +#endif + +/// Post process the pipelined loop by updating the wait ops with the right +/// number of groups in flight. +void updateWaits(ModuleOp module); + +} // namespace triton +} // namespace mlir +#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_ diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp new file mode 100644 index 000000000..37a735451 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -0,0 +1,169 @@ +#include "PipelineExpander.h" +#include "PipeliningUtility.h" +#include "Schedule.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUPIPELINE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Return true if the preconditions for pipelining the loop are met. +static bool preCondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + })) + return false; + // Don't pipeline outer loops. + if (forOp + ->walk([&](Operation *op) { + if (forOp.getOperation() == op) + return WalkResult::advance(); + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + return true; +} + +static void tryAndPipelineOuterLoop(scf::ForOp forOp) { + mlir::triton::PipeliningOption options; + bool foundSchedule = false; + // Limit 2 stages to not require extra shared memory. + foundSchedule = getOuterLoopSchedule(forOp, /*numStage=*/2, options); + if (!foundSchedule) + return; + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); +} + +static bool pipelineLoop(scf::ForOp forOp, int numStages) { + mlir::triton::PipeliningOption options; + if (!preCondition(forOp)) + return false; + + bool foundSchedule = false; +#ifndef __ILUVATAR__ + foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options); +#endif + + // TODO: add more pipelines strategy. + if (!foundSchedule) + return false; + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); + + if (failed(newForOp)) + return false; +#ifndef __ILUVATAR__ + mlir::triton::asyncLaunchDots(newForOp.value()); +#endif + return true; +} + +struct PipelinePass : public impl::TritonGPUPipelineBase { + + using impl::TritonGPUPipelineBase::TritonGPUPipelineBase; + + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return numStages; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + } + + void runOnOperation() override { + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + if (loops.empty()) + return; + + llvm::SmallSetVector outerLoops; + for (scf::ForOp forOp : loops) { + auto outerLoop = dyn_cast(forOp->getParentOp()); + int loopNumStages = getNumStagesOrDefault(forOp); + bool pipelined = pipelineLoop(forOp, loopNumStages); + if (pipelined && outerLoop && getNumStagesOrDefault(outerLoop) > 1) + outerLoops.insert(outerLoop); + } + + // schedule the waits + mlir::triton::updateWaits(getOperation()); + + // Clean up arithmetic before applying the next level of pipelining to + // simplify the IR. + auto arithDialect = + getOperation().getContext()->getLoadedDialect(); + RewritePatternSet patterns(getOperation().getContext()); + arithDialect->getCanonicalizationPatterns(patterns); + if (applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)) + .failed()) + return signalPassFailure(); + + // Try to pipeline the outer loop to overlap the prologue and epilogue of + // the inner loop. + for (scf::ForOp outerLoop : outerLoops) + tryAndPipelineOuterLoop(outerLoop); + + // Re-collect loop ops + loops.clear(); + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + +#ifndef __ILUVATAR__ + for (scf::ForOp forOp : loops) { + mlir::triton::pipelineTMAStores(forOp); + } +#endif + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp new file mode 100644 index 000000000..7f921d9ea --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -0,0 +1,92 @@ +#include "Schedule.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +static SmallVector +getTMAStores(scf::ForOp forOp) { + SmallVector tmaStores; + + // Do not use walk, as we don't want to walk into nested loops. + std::function collectTMAStores = [&](Operation *op) { + if (auto storeOp = dyn_cast(op)) { + tmaStores.push_back(storeOp); + } + for (Region ®ion : op->getRegions()) { + for (Operation &op : region.getOps()) { + if (!isa(op)) + collectTMAStores(&op); + } + } + }; + collectTMAStores(forOp); + return tmaStores; +} + +static Value createAlloc(scf::ForOp &forOp, + tt::ExperimentalDescriptorStoreOp storeOp) { + OpBuilder builder(forOp); + auto ty = cast(storeOp.getSrc().getType()); + auto order = ttg::getOrder(ty.getEncoding()); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + Attribute encoding = + ttg::SharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, ctaLayout); + if (ty.getRank() > 1) { + encoding = ttg::SharedEncodingAttr::get( + ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType()); + } + + Type memdescType = tt::MemDescType::get(ty.getShape(), ty.getElementType(), + encoding, /*mutableMemory*/ true); + Value alloc = builder.create(storeOp->getLoc(), + memdescType, Value()); + return alloc; +} + +static void createTMAAsyncCopy(scf::ForOp &forOp, + tt::ExperimentalDescriptorStoreOp storeOp, + Value alloc) { + OpBuilder builder(storeOp); + auto loc = storeOp.getLoc(); + auto ty = cast(storeOp.getSrc().getType()); + auto order = ttg::getOrder(ty.getEncoding()); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + + // Put wait before the local_store make the store truly async. We know + // that we are the only user of the CopyLocalToGlobal. + builder.create(loc, 0); + builder.create(loc, storeOp.getSrc(), alloc); + builder.create(loc, false); + builder.create( + loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc); + + storeOp->erase(); +} + +bool mlir::triton::pipelineTMAStores(scf::ForOp forOp) { + SmallVector tmaStores = + getTMAStores(forOp); + if (tmaStores.empty()) + return false; + + DenseMap storeToAlloc; + for (tt::ExperimentalDescriptorStoreOp op : tmaStores) { + storeToAlloc[op] = createAlloc(forOp, op); + } + + for (tt::ExperimentalDescriptorStoreOp op : tmaStores) { + createTMAAsyncCopy(forOp, op, storeToAlloc[op]); + } + + // Deallocate shared memory buffers. + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + builder.create(forOp->getLoc(), 0); + for (auto it : storeToAlloc) { + builder.create(forOp->getLoc(), it.second); + } + return true; +} diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp new file mode 100644 index 000000000..31679c066 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -0,0 +1,428 @@ +//===----------------------------------------------------------------------===// +// +// This pass tries to prefetch operands (a and b) of tt.dot. +// Those ConvertLayoutOps will be lowered to shared memory loads. +// +// For example: +// %a: tensor<128x32xf16, #enc> +// scf.for %iv = ... iter_args(%a_arg = %a, ...) { +// %d = tt.dot %a_arg, %b, %c +// ... +// scf.yield %a_next, ... +// } +// +// will be translated to +// +// %a: tensor<128x32xf16, #enc> +// %a_tmp = tensor.subview %a[0, 0] [128, 16] +// %a_prefetch = triton_gpu.local_load %a_tmp +// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch) +// { +// %x = tt.dot %a_prefetch_arg, %b, %c +// %a_tmp_rem = tensor.subview %a_buf[0, 16] [128, 16] +// %a_prefetch_next = triton_gpu.local_load %a_tmp_rem +// ... +// scf.yield %next_a, ..., %a_prefetch_next +// } +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUPREFETCH +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +class Prefetcher { + /// cache the ForOp we are working on + scf::ForOp forOp; + /// cache the YieldOp of this ForOp + scf::YieldOp yieldOp; + /// + // TODO: add a hook to infer prefetchWidth + unsigned prefetchWidth = 32; + + /// dots to be prefetched + SetVector dots; + /// dot => dot operand + DenseMap dot2aLoopArg; + DenseMap dot2aHeaderDef; + DenseMap dot2bLoopArg; + DenseMap dot2bHeaderDef; + DenseMap dot2aYield; + DenseMap dot2bYield; + DenseMap> dot2aVals; + DenseMap> dot2bVals; + /// operand => defining + DenseMap operand2headPrefetch; + + LogicalResult isForOpOperand(Value v); + + Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + Attribute dotOperandEncoding, + std::optional offsetK = std::nullopt, + std::optional shapeK = std::nullopt); + + void cloneElementwiseOps(Value &bRem, const SmallVector &vals, + OpBuilder &builder); + +public: + Prefetcher() = delete; + + Prefetcher(scf::ForOp forOp) : forOp(forOp) { + yieldOp = cast(forOp.getBody()->getTerminator()); + } + + LogicalResult initialize(); + + void emitPrologue(); + + scf::ForOp createNewForOp(); +}; + +void Prefetcher::cloneElementwiseOps(Value &ret, const SmallVector &vals, + OpBuilder &builder) { + IRMapping mapping; + mapping.map(vals[1], ret); + for (int i = 2; i < vals.size(); i++) { + Value v = vals[i]; + Value curr = builder.clone(*v.getDefiningOp(), mapping)->getResult(0); + if (isa(curr.getType())) { + auto retType = RankedTensorType::get( + cast(ret.getType()).getShape(), + cast(curr.getType()).getElementType(), + cast(curr.getDefiningOp()->getOperand(0).getType()) + .getEncoding()); + curr.setType(retType); + } + mapping.map(v, curr); + } + if (vals.size() > 1) + ret = mapping.lookup(vals.back()); +} + +Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + Attribute dotOperandEncoding, + std::optional offsetK, + std::optional shapeK) { + // opIdx: 0 => a, 1 => b + auto type = cast(v.getType()); + SmallVector shape{type.getShape().begin(), type.getShape().end()}; + SmallVector offset{0, 0}; + Type elementType = type.getElementType(); + + // k => (prefetchWidth, k - prefetchWidth) + int64_t kIdx = opIdx == 0 ? 1 : 0; + + offset[kIdx] = isPrologue ? 0 : prefetchWidth; + shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth); + + if (shapeK) + shape[kIdx] = *shapeK; + if (offsetK) + offset[kIdx] = *offsetK; + + SmallVector offsetsVal; + for (int64_t off : offset) + offsetsVal.push_back( + builder.create(v.getLoc(), off, 32)); + Value newSmem = builder.create( + v.getLoc(), + triton::MemDescType::get(shape, elementType, type.getEncoding()), v, + offsetsVal); + + auto encoding = + dyn_cast(dotOperandEncoding); + assert(encoding && "dotEncoding need be DotOperandEncodingAttr"); + auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( + builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8, + encoding.getUseSme()); + Value prefetchSlice = builder.create( + v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), + newSmem); + + return prefetchSlice; +} + +LogicalResult Prefetcher::initialize() { + Block *loop = forOp.getBody(); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + SmallVector dotsInFor; + for (Operation &op : *loop) + if (auto dotOp = dyn_cast(op)) { + // bail out if there exist non v2 dots. + auto dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!dstEnc || dstEnc.getVersionMajor() != 2) + return failure(); + dotsInFor.push_back(dotOp); + } + + if (dotsInFor.empty()) + return failure(); + + // TODO: segfault (original for still has uses) + // when used in flash attention that has 2 dots in the loop + if (dotsInFor.size() > 1) + return failure(); + + // returns source of cvt + + // returns source of cvt + auto getPrefetchSrc = [](Value v) -> SmallVector { + // walk back to conversion + Operation *op = v.getDefiningOp(); + bool foundConvertFromShared = false; + SmallVector rets; + rets.push_back(op->getResult(0)); + while (op) { + if (op->getNumOperands() != 1) + break; + if (!op->getResult(0).hasOneUse()) + break; + rets.push_back(op->getOperand(0)); + if (auto cvt = dyn_cast(op)) { + foundConvertFromShared = true; + break; + } + op = op->getOperand(0).getDefiningOp(); + } + std::reverse(rets.begin(), rets.end()); + + if (foundConvertFromShared) + return rets; + return {}; + }; + + auto getIncomingOp = [this](Value v) -> Value { + if (auto arg = mlir::dyn_cast(v)) + if (arg.getOwner()->getParentOp() == forOp.getOperation()) + return forOp.getTiedLoopInit(arg)->get(); + return Value(); + }; + + auto getYieldOp = [this](Value v) -> Value { + auto arg = mlir::cast(v); + unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars(); + return yieldOp.getOperand(yieldIdx); + }; + + for (triton::DotOp dot : dotsInFor) { + auto aType = dot.getA().getType(); + auto bType = dot.getB().getType(); + auto aEnc = + mlir::cast(aType.getEncoding()); + auto bEnc = + mlir::cast(bType.getEncoding()); + int aKWidth = aEnc.getKWidth(); + int bKWidth = bEnc.getKWidth(); + assert(aKWidth == bKWidth); + + auto kSize = aType.getShape()[1]; + + // works better with nvidia tensor cores + unsigned elementWidth = aType.getElementTypeBitWidth(); + if (aKWidth == 0) + prefetchWidth = 256 / elementWidth; + else + prefetchWidth = 8 * aKWidth; + +#ifdef __ILUVATAR__ + if (prefetchWidth < 16) + prefetchWidth = 16; +#endif + + // Skip prefetching if kSize is less than prefetchWidth + if (kSize < prefetchWidth) + continue; + auto aVals = getPrefetchSrc(dot.getA()); + auto bVals = getPrefetchSrc(dot.getB()); + + if (aVals.size() && bVals.size()) { + Value aSmem = aVals.front(); + Value bSmem = bVals.front(); + Value aHeaderDef = getIncomingOp(aSmem); + Value bHeaderDef = getIncomingOp(bSmem); + // Only prefetch loop arg + if (aHeaderDef && bHeaderDef) { + dots.insert(dot); + dot2aVals[dot] = aVals; + dot2bVals[dot] = bVals; + dot2aHeaderDef[dot] = aHeaderDef; + dot2bHeaderDef[dot] = bHeaderDef; + dot2aLoopArg[dot] = aSmem; + dot2bLoopArg[dot] = bSmem; + dot2aYield[dot] = getYieldOp(aSmem); + dot2bYield[dot] = getYieldOp(bSmem); + } + } + } + + return success(); +} + +void Prefetcher::emitPrologue() { + OpBuilder builder(forOp); + + for (triton::DotOp dot : dots) { + Attribute dotEncoding = dot.getType().getEncoding(); + Attribute dotOperandEncodingA = dot.getA().getType().getEncoding(); + Value aPrefetched = + generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder, + dotOperandEncodingA); + cloneElementwiseOps(aPrefetched, dot2aVals[dot], builder); + Attribute dotOperandEncodingB = dot.getB().getType().getEncoding(); + Value bPrefetched = + generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder, + dotOperandEncodingB); + cloneElementwiseOps(bPrefetched, dot2bVals[dot], builder); + + operand2headPrefetch[dot.getA()] = aPrefetched; + operand2headPrefetch[dot.getB()] = bPrefetched; + } +} + +scf::ForOp Prefetcher::createNewForOp() { + OpBuilder builder(forOp); + + SmallVector loopArgs; + for (auto v : forOp.getInitArgs()) + loopArgs.push_back(v); + for (triton::DotOp dot : dots) { + loopArgs.push_back(operand2headPrefetch[dot.getA()]); + loopArgs.push_back(operand2headPrefetch[dot.getB()]); + } + + auto newForOp = builder.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), loopArgs); + + builder.setInsertionPointToStart(newForOp.getBody()); + IRMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + for (Operation &op : forOp.getBody()->without_terminator()) { + Operation *newOp = builder.clone(op, mapping); + auto dot = dyn_cast(&op); + if (dot && dots.contains(dot)) { + Attribute dotEncoding = dot.getType().getEncoding(); + // prefetched dot + Operation *firstDot = builder.clone(*dot, mapping); + if (Value a = operand2headPrefetch.lookup(dot.getA())) + firstDot->setOperand( + 0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin())); + if (Value b = operand2headPrefetch.lookup(dot.getB())) + firstDot->setOperand( + 1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin())); + + // remaining part + int64_t kOff = prefetchWidth; + int64_t kRem = dot.getA().getType().getShape()[1] - prefetchWidth; + Operation *prevDot = firstDot; + while (kRem != 0) { + // int64_t kShape = largestPow2(kRem); + int64_t kShape = prefetchWidth; + auto insertionPoint = builder.saveInsertionPoint(); + builder.setInsertionPoint(prevDot); + Attribute dotOperandEncodingA = dot.getA().getType().getEncoding(); + Value aRem = generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, + false, dotEncoding, builder, + dotOperandEncodingA, kOff, kShape); + cloneElementwiseOps(aRem, dot2aVals[dot], builder); + Attribute dotOperandEncodingB = dot.getB().getType().getEncoding(); + Value bRem = generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, + false, dotEncoding, builder, + dotOperandEncodingB, kOff, kShape); + cloneElementwiseOps(bRem, dot2bVals[dot], builder); + builder.restoreInsertionPoint(insertionPoint); + newOp = builder.clone(*dot, mapping); + newOp->setOperand(0, aRem); + newOp->setOperand(1, bRem); + newOp->setOperand(2, prevDot->getResult(0)); + prevDot = newOp; + kOff += kShape; + kRem -= kShape; + } + } + // update mapping of results + for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) + mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); + } + + // prefetch next iteration + SmallVector yieldValues; + for (Value v : forOp.getBody()->getTerminator()->getOperands()) + yieldValues.push_back(mapping.lookupOrDefault(v)); + for (triton::DotOp dot : dots) { + Attribute dotEncoding = dot.getType().getEncoding(); + Attribute dotOperandEncodingA = dot.getA().getType().getEncoding(); + Value aToYield = + generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, true, dotEncoding, + builder, dotOperandEncodingA); + cloneElementwiseOps(aToYield, dot2aVals[dot], builder); + yieldValues.push_back(aToYield); + // bToYield + Attribute dotOperandEncodingB = dot.getB().getType().getEncoding(); + Value bToYield = + generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, true, dotEncoding, + builder, dotOperandEncodingB); + cloneElementwiseOps(bToYield, dot2bVals[dot], builder); + yieldValues.push_back(bToYield); + } + // Update ops of yield + if (!yieldValues.empty()) + builder.create(yieldOp.getLoc(), yieldValues); + return newForOp; +} + +} // anonymous namespace + +struct PrefetchPass : public impl::TritonGPUPrefetchBase { + void runOnOperation() override { + + // Canonicalize convert ops to make the pattern matching easier. + RewritePatternSet cleanUpPatterns(&getContext()); + triton::gpu::ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, + &getContext()); + if (mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(cleanUpPatterns)) + .failed()) { + signalPassFailure(); + } + getOperation()->walk([&](scf::ForOp forOp) { + Prefetcher prefetcher(forOp); + + if (prefetcher.initialize().failed()) + return; + + prefetcher.emitPrologue(); + + scf::ForOp newForOp = prefetcher.createNewForOp(); + + // replace the original loop + for (unsigned i = 0; i < forOp->getNumResults(); ++i) + forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); + forOp->erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp new file mode 100644 index 000000000..e071ef104 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -0,0 +1,100 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREDUCEDATADUPLICATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUReduceDataDuplicationPass + : public impl::TritonGPUReduceDataDuplicationBase< + TritonGPUReduceDataDuplicationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcEncoding = srcType.getEncoding(); + if (isa(srcEncoding)) + return; + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (!dstDotOp) + return; + if (auto srcMmaEncoding = + dyn_cast(srcEncoding)) { + + if (srcMmaEncoding.getVersionMajor() == 1 || + srcMmaEncoding.getVersionMajor() == 2 || + (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && + dstDotOp.getParent() == srcMmaEncoding)) + return; + } + if (auto srcMmaEncoding = + dyn_cast(srcEncoding)) { + + if (srcMmaEncoding.getVersionMajor() != 2 || + (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && + dstDotOp.getParent() == srcMmaEncoding)) + return; + } + if (auto srcMfmaEncoding = + dyn_cast(srcEncoding)) { + + if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 && + srcMfmaEncoding.getIsTransposed() && + dstDotOp.getParent() == srcMfmaEncoding) + return; + } + auto srcOrder = triton::gpu::getOrder(srcEncoding); + auto rank = srcOrder.size(); + SmallVector sharedOrder; + if (rank == 3) { + // add all elements except the element that is zero + for (unsigned i = 0; i < rank; ++i) + if (srcOrder[i] != 0) + sharedOrder.emplace_back(srcOrder[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = srcOrder; + } + auto tmpType = triton::MemDescType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get( + mod.getContext(), dstDotOp, srcType.getShape(), sharedOrder, + triton::gpu::getCTALayout(srcEncoding), + srcType.getElementType())); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + auto newConvert = builder.create(cvtOp.getLoc(), + dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp new file mode 100644 index 000000000..9b653a3b7 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -0,0 +1,1463 @@ +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREMOVELAYOUTCONVERSIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-remove-layout-conversions" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0)) +class ConvertDotConvert : public RewritePattern { +public: + ConvertDotConvert(MLIRContext *context) + : RewritePattern(ConvertLayoutOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto dstOp = cast(op); + auto dotOp = dstOp.getSrc().getDefiningOp(); + if (!dotOp) + return failure(); + if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || + std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) + return failure(); + auto cvtOp = dotOp.getOperand(2).getDefiningOp(); + if (!cvtOp) + return failure(); + if (!cvtOp.getSrc().getDefiningOp()) + return failure(); + RankedTensorType dstTy = dstOp.getType(); + RankedTensorType srcTy = cvtOp.getSrc().getType(); + if (dstTy != srcTy) + return failure(); + + auto _0f = rewriter.create( + op->getLoc(), dstTy.getElementType(), + rewriter.getZeroAttr(dstTy.getElementType())); + auto _0 = rewriter.create(op->getLoc(), dotOp.getType(), _0f); + auto newDot = rewriter.create( + op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1), + _0, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); + auto newCvt = rewriter.create(op->getLoc(), dstTy, + newDot.getResult()); + rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getSrc()); + return success(); + } +}; + +// convert(slice<{parent=#mma, noWarpReduce=true}>, blocked) -> +// convert(slice<{parent=#mma, noWarpReduce=true}>, slice<{parent=#mma, +// noWarpReduce=false}>) + convert(slice<{parent=#mma, noWarpReduce=false}>, +// blocked) this is a heuristic to accommodate some pattern seen in fused +// attention kernels. +// TODO: replace this by something more generic, i.e. layout-aware CSE +class SliceMMAConvert : public mlir::RewritePattern { +public: + SliceMMAConvert(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 1, context) {} + + LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto dstOp = cast(op); + Value src = dstOp.getSrc(); + Value dst = dstOp.getResult(); + auto srcTy = src.getType().cast(); + auto dstTy = dst.getType().cast(); + auto srcLayout = + srcTy.getEncoding().dyn_cast(); + auto dstLayout = + dstTy.getEncoding().dyn_cast(); + if (!srcLayout || !dstLayout) + return mlir::failure(); + auto srcMmaLayout = + srcLayout.getParent().dyn_cast(); + if (!srcMmaLayout) + return mlir::failure(); + auto srcNoWarpReduce = srcLayout.getNoWarpReduce(); + if (!srcNoWarpReduce) + return mlir::failure(); + + Attribute reduceEncoding = triton::gpu::SliceEncodingAttr::get( + dstOp.getContext(), srcLayout.getDim(), srcMmaLayout, false); + auto newSrc = rewriter.create( + loc, + RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), + reduceEncoding), + src); + rewriter.replaceOpWithNewOp(op, dstTy, + newSrc); + return mlir::success(); + } +}; + +// The current algorithm works by analyzing the IR and doing a one-shot rewrite +// based on the analysis. The algorithm is as follows. +// +// 1. Find all the anchor ops. These are ops that have a layout we want to +// preserve. +// +// 2. For each anchor, propagate its layout to all its descendants. +// An op can have multiple ancestors that are anchors, so at this stage an op +// may have multiple layouts associated with it. +// +// 3. Resolve conflicts by deciding which of the multiple layouts the op should +// keep, inserting convert-layout ops to resolve conflicts. After this +// stage, each value has only one layout associated with it. +// +// 4. Rewrite the IR by walking the function in dominance order. Since we +// assume the IR is structured we just need to process the regions in the +// correct order. For each op, rewrite it using the layout decided by the +// analysis phase. +class LayoutPropagation { +public: + // Structure to keep track of the layout associated to a value. + struct LayoutInfo { + LayoutInfo(Attribute encoding) { encodings.insert(encoding); } + LayoutInfo() {} + llvm::SmallSetVector encodings; + }; + LayoutPropagation(FuncOp F) : funcOp(F) {} + // Find the anchor ops and set their layout in the data structure. + void initAnchorLayout(); + // Recursively Propagate the layout to all the users of the anchor ops until + // we reach a fix point. + void propagateLayout(); + // Add layouts given in `Info` to the uses of `value`. + SmallVector propagateToUsers(Value value, LayoutInfo &info); + // Set the encoding to all the values and fill out the values with new layout + // in `changed`. + void setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, Operation *op); + // Resolve cases where a value has multiple layouts associated to it. + void resolveConflicts(); + // Rewrite the IR for the full module. + void rewrite(); + // Rewrite the IR for a region. + void rewriteRegion(Region &R); + // Rewrite an op based on the layout picked by the analysis. + Operation *rewriteOp(Operation *op); + // Rewrite a for op based on the layout picked by the analysis. + Operation *rewriteForOp(scf::ForOp forOp); + Operation *rewriteWhileOp(scf::WhileOp whileOp); + Operation *rewriteIfOp(scf::IfOp ifOp); + void rewriteYieldOp(scf::YieldOp yieldOp); + void rewriteConditionOp(scf::ConditionOp conditionOp); + void rewriteReduceToScalar(Operation *reduceOp); + void rewriteAssertOp(AssertOp assertOp); + bool rewriteStoreOp(Operation *storeOp); + Operation *cloneElementwise(OpBuilder &rewriter, Operation *op, + Attribute encoding); + // Map the original value to the rewritten one. + void map(Value old, Value newV); + // Return the mapped value in the given encoding. This will insert a convert + // if the encoding is different than the encoding decided at resolve time. + Value getValueAs(Value value, Attribute encoding); + // Dump the current stage of layout information. + void dump(); + +private: + // map from value to layout information. + llvm::MapVector layouts; + // map of the values rewrite based on their encoding. + DenseMap, Value> rewriteMapping; + SetVector opToDelete; + FuncOp funcOp; +}; + +class LayoutRematerialization { +public: + LayoutRematerialization(FuncOp F) : funcOp(F) {} + // Map the original value to the remat'ed one. + void addRematValue(Value old, Attribute encoding, Value newV); + bool hasRematValue(Value value, Attribute encoding) { + return rematMapping.contains({value, encoding}); + } + // Return the remat'ed value in the given encoding. + Value getRematValue(Value value, Attribute encoding) { + auto it = rematMapping.find({value, encoding}); + assert(it != rematMapping.end()); + return it->second; + } + void cleanup(); + void backwardRematerialization(); + void backwardRematerialization(ConvertLayoutOp convertOp); + void hoistConvertOnTopOfExtOrBroadcast(); + void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp, IRMapping &mapping); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp); + +private: + void updateRematMapping(SmallVector> &values); + // Existing tuples of (value, layout) that needs to be updated when recreating + // scf ops. This prevents keeping track of Values that have been delete when + // rewriting slices. + DenseMap mappedValues; + // map of the values remat based on encoding. + DenseMap, Value> rematMapping; + // DenseMap, Operation*> + SetVector opToDelete; + FuncOp funcOp; +}; + +void LayoutRematerialization::addRematValue(Value old, Attribute encoding, + Value newV) { + LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); + rematMapping[{old, encoding}] = newV; + mappedValues[old] = encoding; +} + +// Remove unneeded values now that we are done with the rematMapping. +void LayoutRematerialization::cleanup() { + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +// Look ahead to at the transitive uses and see if there is a convert to mma +// operations. +bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { + SmallVector queue = {op->getResult(0)}; + SetVector forwardSlice; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + Value currentValue = queue.back(); + queue.pop_back(); + getForwardSlice(currentValue, &forwardSlice); + for (Operation *op : forwardSlice) { + // HACK: Stop propagation if the ReduceOp is using mma layout but is + // producing tensor smaller than the layout we would like to propagate. + // This is to avoid stepping into the known bug. + if (isa(op)) { + auto tensorType = + dyn_cast(op->getOperand(0).getType()); + if (tensorType && + isa(tensorType.getEncoding())) { + auto mmaInstrShape = + cast(encoding).getInstrShape(); + if (tensorType.getShape()[tensorType.getRank() - 2] < + mmaInstrShape[0] || + tensorType.getShape()[tensorType.getRank() - 1] < + mmaInstrShape[1]) { + return false; + } + } + } + + if (auto convertOp = dyn_cast(op)) { + Attribute dstEncoding = convertOp.getType().getEncoding(); + if (auto mmaLayout = dyn_cast(dstEncoding)) + return (mmaLayout.getVersionMajor() > 1) ? true + : mmaLayout == encoding; + if (isa(dstEncoding)) + return true; + if (isa(dstEncoding)) + return true; + if (isa(dstEncoding)) { + if (auto mmaLayout = dyn_cast(encoding)) { + return mmaLayout.getVersionMajor() > 1; + } else if (isa(encoding)) { + return true; + } else { + assert((mlir::isa(encoding))); + return true; + } + } + } + bool isMMAV3 = + isa(encoding) && + cast(encoding).getVersionMajor() == 3; + if (isMMAV3 && (isa(op) || isa(op))) + return true; + auto yield = dyn_cast(op); + if (!yield) + continue; + if (auto ifOp = dyn_cast(yield->getParentOp())) { + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && + (forwardSlice.count(def) || operand.get() == currentValue) && + (seen.insert(operand.get()).second == true)) + queue.push_back(ifOp.getResult(operand.getOperandNumber())); + } + } + auto forOp = dyn_cast(yield.getOperation()->getParentOp()); + if (!forOp) + continue; + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && (forwardSlice.count(def) || operand.get() == currentValue) && + (seen.insert(operand.get()).second == true)) + queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); + } + } + } + return false; +} + +// Return true if the op is an op with a layout we don't want to change. We will +// propagate the layout starting from anchor ops. +bool isLayoutAnchor(Operation *op) { + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return true; + + // Heuristic: Mark permuting reshape as a layout anchor. Its dst can be + // anything, so it stops forward-propagation of layouts. We rely on the + // backwards pass to fix it up if necessary. (If we didn't do this, then + // anything following the reshape won't be covered by the forward pass at + // all.) + if (auto reshape = dyn_cast(op)) + return reshape.getAllowReorder(); + +#ifdef __ILUVATAR__ + if (auto reduceOp = dyn_cast(op)) { + if (reduceOp.getNoWarpReduce() == true) + return true; + } + return isMmaConvertLayout(op); +#else + return false; +#endif +} + +void LayoutPropagation::initAnchorLayout() { + auto maybeAddAnchor = [&](Value v) { + if (auto tensorType = dyn_cast(v.getType())) { +#ifndef __ILUVATAR__ + // Workaround, don't popagate MMA layout unless there is a convert + // back to mma further down to avoid generating reduction with MMA + // layout that may have lower performance. + // This can be improved with more aggressive backward propagation. + if (isa(tensorType.getEncoding()) && + v.getDefiningOp() && + !hasConvertToMMATransisitiveUse(v.getDefiningOp(), + tensorType.getEncoding())) { + return; + } +#endif + layouts.insert({v, LayoutInfo(tensorType.getEncoding())}); + } + }; + + // Consider function args as anchors. This makes it easier to write tests -- + // you can pass a tensor with an encoding as an arg, instead of explicitly + // calling tt.load. + for (auto arg : funcOp.getArguments()) { + maybeAddAnchor(arg); + } + + funcOp.walk([&](Operation *op) { + if (isLayoutAnchor(op)) { + for (auto result : op->getResults()) { + maybeAddAnchor(result); + } + } + }); +} + +void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, + Operation *op) { + for (Value value : values) { + if (!isa(value.getType())) + continue; + bool hasChanged = false; + for (auto encoding : info.encodings) { + std::optional dstEncoding; + if (isa(op)) { + // Try to remove the convert by making the dst encoding match the source + // encoding. + dstEncoding = encoding; + } else { + dstEncoding = inferDstEncoding(op, encoding); + } + if (dstEncoding) + hasChanged |= layouts[value].encodings.insert(*dstEncoding); + } + if (hasChanged) + changed.push_back(value); + } +} + +SmallVector LayoutPropagation::propagateToUsers(Value value, + LayoutInfo &info) { + SmallVector changed; + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); +#ifdef __ILUVATAR__ + assert(info.encodings.size() >= 1 && + "we should have at least one encoding."); + bool skip = false; + for (auto encoding : info.encodings) { + auto dstEncoding = inferDstEncoding(user, encoding); + if (auto sliceLayout = dyn_cast(encoding)) { + if (sliceLayout.getNoWarpReduce() && encoding == dstEncoding) { + Operation *defOp = value.getDefiningOp(); + if (defOp && dyn_cast(defOp)) { + skip = true; + break; + } + } + } + } + if (skip) + continue; +#endif + if (auto forOp = dyn_cast(user)) { + Value arg = forOp.getTiedLoopRegionIterArg(&use); + Value result = forOp.getTiedLoopResult(&use); + setEncoding({arg, result}, info, changed, user); + continue; + } + if (auto whileOp = dyn_cast(user)) { + Value arg = whileOp.getBeforeArguments()[use.getOperandNumber()]; + setEncoding({arg}, info, changed, user); + continue; + } + if (auto yieldOp = dyn_cast(user)) { + auto parent = yieldOp->getParentOp(); + SmallVector valuesToPropagate; + if (isa(parent)) + valuesToPropagate.push_back(parent->getResult(use.getOperandNumber())); + if (auto forOp = dyn_cast(parent)) + valuesToPropagate.push_back( + forOp.getRegionIterArg(use.getOperandNumber())); + if (auto whileOp = dyn_cast(parent)) { + valuesToPropagate.push_back( + whileOp.getBeforeArguments()[use.getOperandNumber()]); + valuesToPropagate.push_back( + whileOp->getOperand(use.getOperandNumber())); + } + if (isa(parent)) + setEncoding(valuesToPropagate, info, changed, user); + continue; + } + if (auto conditionOp = dyn_cast(user)) { + auto whileOp = cast(conditionOp->getParentOp()); + // Skip arg 0 as it is the condition. + unsigned argIndex = use.getOperandNumber() - 1; + Value afterArg = whileOp.getAfterArguments()[argIndex]; + Value result = whileOp->getResult(argIndex); + setEncoding({afterArg, result}, info, changed, user); + continue; + } + if (user->hasTrait() || + user->hasTrait() || + isa(user)) { + setEncoding(user->getResults(), info, changed, user); + continue; + } + } + return changed; +} + +void LayoutPropagation::propagateLayout() { + SmallVector queue; + for (auto it : layouts) { + queue.push_back(it.first); + } + while (!queue.empty()) { + Value currentValue = queue.back(); + LayoutInfo info = layouts[currentValue]; + queue.pop_back(); + SmallVector changed = propagateToUsers(currentValue, info); + + LLVM_DEBUG({ + DBGS() << "propagateLayout considering " << currentValue << ", which has " + << info.encodings.size() << " candidate encoding(s):\n"; + for (Attribute encoding : info.encodings) + DBGS() << " " << encoding << "\n"; + }); + + queue.insert(queue.end(), changed.begin(), changed.end()); + } +} + +void LayoutPropagation::resolveConflicts() { + for (auto &it : layouts) { + Operation *op = it.first.getDefiningOp(); + LayoutInfo &info = it.second; + if (info.encodings.size() <= 1) + continue; + // Hacky resolve, prefer block encoding. + // TODO: add a proper heuristic. + Attribute encoding = *info.encodings.begin(); + bool isLoadOrStore = + op && isa(op); + for (Attribute e : info.encodings) { + if ((isLoadOrStore && isa(e)) || + (!isLoadOrStore && isa(e))) { + encoding = e; + break; + } + } + info.encodings.clear(); + info.encodings.insert(encoding); + } +} + +void LayoutPropagation::dump() { + for (auto it : layouts) { + llvm::errs() << "Value: "; + OpPrintingFlags flags; + flags.skipRegions(); + it.first.print(llvm::errs(), flags); + llvm::errs() << " \n encoding:\n"; + for (auto encoding : it.second.encodings) { + encoding.print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "--\n"; + } +} + +void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } + +bool reduceToScalar(Operation *op) { + // For reductions returning a scalar we can change the src encoding without + // affecting the output. + return isa(op) && !isa(op->getResultTypes()[0]); +} + +void LayoutPropagation::rewriteRegion(Region ®ion) { + SmallVector queue = {®ion}; + while (!queue.empty()) { + Region *currentRegion = queue.back(); + queue.pop_back(); + for (Operation &op : currentRegion->getOps()) { + bool needRewrite = false; + SmallVector results = op.getResults(); + for (Value result : results) { + auto it = layouts.find(result); + // If we haven't mapped this value skip. + if (it == layouts.end()) + continue; + LayoutInfo &info = it->second; + assert(info.encodings.size() == 1 && + "we should have resolved to a single encoding"); + auto encoding = cast(result.getType()).getEncoding(); + // If the encoding is already what we want skip. + if (encoding == *info.encodings.begin()) + continue; + needRewrite = true; + } + if (needRewrite) { + Operation *newOp = rewriteOp(&op); + for (Region &R : newOp->getRegions()) + queue.push_back(&R); + } else if (auto yieldOp = dyn_cast(&op)) { + rewriteYieldOp(yieldOp); + } else if (auto conditionOp = dyn_cast(&op)) { + rewriteConditionOp(conditionOp); + } else if (reduceToScalar(&op)) { + rewriteReduceToScalar(&op); + } else if (auto assertOp = dyn_cast(&op)) { + rewriteAssertOp(assertOp); + } else { + bool changed = false; + if (auto storeOp = dyn_cast(&op)) + changed = rewriteStoreOp(storeOp); + if (changed) + continue; + // If we don't need to rewrite the op we still need to remap the + // operands. + for (OpOperand &operand : op.getOpOperands()) { + auto it = layouts.find(operand.get()); + if (it == layouts.end()) + continue; + Attribute encoding = + cast(operand.get().getType()).getEncoding(); + Value newOperand = getValueAs(operand.get(), encoding); + op.setOperand(operand.getOperandNumber(), newOperand); + } + for (Region &R : op.getRegions()) + queue.push_back(&R); + } + } + } + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +void LayoutPropagation::map(Value old, Value newV) { + rewriteMapping[{old, cast(newV.getType()).getEncoding()}] = + newV; +} + +Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { + if (auto tensorType = dyn_cast(value.getType())) { + Value rewrittenValue; + auto layoutIt = layouts.find(value); + if (layoutIt == layouts.end()) { + rewrittenValue = value; + } else { + assert(layoutIt->second.encodings.size() == 1 && + "we should have resolved to a single encoding"); + Attribute encodingPicked = *(layoutIt->second.encodings.begin()); + if (encodingPicked == tensorType.getEncoding()) + rewrittenValue = value; + else + rewrittenValue = rewriteMapping[{value, encodingPicked}]; + } + assert(rewrittenValue); + if (cast(rewrittenValue.getType()).getEncoding() == + encoding) + return rewrittenValue; + OpBuilder rewriter(value.getContext()); + rewriter.setInsertionPointAfterValue(rewrittenValue); + auto tmpType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + Value converted = rewriter.create(value.getLoc(), tmpType, + rewrittenValue); + // TODO: we could cache the conversion. + return converted; + } + return value; +} + +Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, + Operation *op, + Attribute encoding) { + Operation *newOp = rewriter.clone(*op); + + std::optional operandEnc; + if (op->getNumOperands() > 0) { + operandEnc = inferSrcEncoding(op, encoding); + assert(operandEnc.has_value()); + } + + for (OpOperand &operand : op->getOpOperands()) { + newOp->setOperand(operand.getOperandNumber(), + getValueAs(operand.get(), *operandEnc)); + } + + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { + auto origType = dyn_cast(op->getResult(i).getType()); + if (!origType) + continue; + auto newType = RankedTensorType::get(origType.getShape(), + origType.getElementType(), encoding); + newOp->getResult(i).setType(newType); + } + return newOp; +} + +Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) { + SmallVector operands; + OpBuilder rewriter(forOp); + for (auto [operand, result] : + llvm::zip(forOp.getInitArgs(), forOp.getResults())) { + Value convertedOperand = operand; + if (layouts.count(result)) + convertedOperand = + getValueAs(operand, *layouts[result].encodings.begin()); + operands.push_back(convertedOperand); + } + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), operands); + newForOp->setAttrs(forOp->getAttrs()); + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + forOp.getBody()->getOperations()); + + for (auto [oldResult, newResult] : + llvm::zip(forOp.getResults(), newForOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + + for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(), + newForOp.getBody()->getArguments())) { + if (oldArg.getType() == newArg.getType()) { + oldArg.replaceAllUsesWith(newArg); + continue; + } + map(oldArg, newArg); + } + return newForOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) { + SmallVector operands; + SmallVector returnTypes; + OpBuilder rewriter(whileOp); + for (auto [operand, arg] : + llvm::zip(whileOp->getOperands(), whileOp.getBeforeArguments())) { + Value convertedOperand = operand; + if (layouts.count(arg)) + convertedOperand = getValueAs(operand, *layouts[arg].encodings.begin()); + operands.push_back(convertedOperand); + } + for (Value ret : whileOp.getResults()) { + auto it = layouts.find(ret); + if (it == layouts.end()) { + returnTypes.push_back(ret.getType()); + continue; + } + auto origType = dyn_cast(ret.getType()); + auto newType = + RankedTensorType::get(origType.getShape(), origType.getElementType(), + it->second.encodings[0]); + returnTypes.push_back(newType); + } + + auto newWhileOp = + rewriter.create(whileOp.getLoc(), returnTypes, operands); + SmallVector argsTypesBefore; + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + SmallVector bbArgLocsBefore(argsTypesBefore.size(), + whileOp.getLoc()); + SmallVector bbArgLocsAfter(returnTypes.size(), whileOp.getLoc()); + rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter); + + for (int i = 0; i < whileOp.getNumRegions(); ++i) { + newWhileOp->getRegion(i).front().getOperations().splice( + newWhileOp->getRegion(i).front().getOperations().begin(), + whileOp->getRegion(i).front().getOperations()); + } + + auto remapArg = [&](Value oldVal, Value newVal) { + if (oldVal.getType() == newVal.getType()) + oldVal.replaceAllUsesWith(newVal); + else + map(oldVal, newVal); + }; + for (auto [oldResult, newResult] : + llvm::zip(whileOp.getResults(), newWhileOp.getResults())) + remapArg(oldResult, newResult); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getBeforeArguments(), newWhileOp.getBeforeArguments())) + remapArg(oldArg, newArg); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getAfterArguments(), newWhileOp.getAfterArguments())) + remapArg(oldArg, newArg); + return newWhileOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { + SmallVector operands; + OpBuilder rewriter(ifOp); + SmallVector newResultTypes(ifOp->getResultTypes()); + for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) { + auto it = layouts.find(ifOp->getResult(i)); + if (it == layouts.end()) + continue; + auto origType = cast(ifOp->getResult(i).getType()); + Attribute encoding = *(it->second.encodings.begin()); + newResultTypes[i] = RankedTensorType::get( + origType.getShape(), origType.getElementType(), encoding); + } + auto newIfOp = rewriter.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), true, true); + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + for (auto [oldResult, newResult] : + llvm::zip(ifOp.getResults(), newIfOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newIfOp.getOperation(); +} + +void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { + Operation *parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + Type yieldType = operand.get().getType(); + if (isa(parentOp)) + yieldType = parentOp->getResult(operand.getOperandNumber()).getType(); + if (auto whileOp = dyn_cast(parentOp)) + yieldType = + whileOp.getBeforeArguments()[operand.getOperandNumber()].getType(); + auto tensorType = dyn_cast(yieldType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + yieldOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) { + scf::WhileOp whileOp = cast(conditionOp->getParentOp()); + for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) { + OpOperand &operand = conditionOp->getOpOperand(i); + Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType(); + auto tensorType = dyn_cast(argType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + conditionOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) { + OpBuilder rewriter(reduceOp); + Attribute srcEncoding; + // Since all the operands need to have the same encoding pick the first one + // and use it for all the operands. + for (Value operand : reduceOp->getOperands()) { + auto it = layouts.find(operand); + if (it != layouts.end()) { + srcEncoding = it->second.encodings[0]; + break; + } + } + if (!srcEncoding) + return; + for (OpOperand &operand : reduceOp->getOpOperands()) { + Value newOperand = getValueAs(operand.get(), srcEncoding); + reduceOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { + Attribute srcEncoding; + // Only need to deal with the first operand which is the condition tensor. + Value operand = assertOp->getOperand(0); + auto it = layouts.find(operand); + if (it == layouts.end()) + return; + srcEncoding = it->second.encodings[0]; + Value newOperand = getValueAs(operand, srcEncoding); + assertOp->setOperand(0, newOperand); +} + +bool LayoutPropagation::rewriteStoreOp(Operation *storeOp) { + IRMapping mapping; + Value value = storeOp->getOperand(1); + if (!value.getDefiningOp()) + return false; + Attribute srcEncoding; + auto cvt = dyn_cast(value.getDefiningOp()); + if (!cvt) + return false; + auto it = layouts.find(value); + if (it == layouts.end()) + return false; + srcEncoding = it->second.encodings[0]; + auto tensorType = cvt->getResult(0).getType().cast(); + unsigned bitWidth = tensorType.getElementType().getIntOrFloatBitWidth(); + if (!srcEncoding || !isa(srcEncoding) || + (dyn_cast(srcEncoding).getVersionMinor() == 0 && + bitWidth == 16)) + return false; + + OpBuilder rewriter(storeOp); + for (Value arg : storeOp->getOperands()) { + if (arg.getDefiningOp() == cvt) { + // mapping.map(arg, cvt.getOperand()); + auto src = getValueAs(arg, srcEncoding); + mapping.map(arg, src); + } else { + auto oldType = arg.getType().cast(); + auto newType = RankedTensorType::get( + oldType.getShape(), oldType.getElementType(), srcEncoding); + auto cvtI = rewriter.create(arg.getLoc(), newType, arg); + if (Operation *argOp = arg.getDefiningOp()) + cvtI->moveAfter(argOp); + mapping.map(arg, cvtI); + } + } + rewriter.setInsertionPoint(storeOp); + Operation *newOp = rewriter.clone(*storeOp, mapping); + opToDelete.insert(storeOp); + return true; +} + +Operation *LayoutPropagation::rewriteOp(Operation *op) { + opToDelete.insert(op); + if (auto forOp = dyn_cast(op)) + return rewriteForOp(forOp); + if (auto whileOp = dyn_cast(op)) + return rewriteWhileOp(whileOp); + if (auto ifOp = dyn_cast(op)) + return rewriteIfOp(ifOp); + OpBuilder rewriter(op); + Attribute encoding = *layouts[op->getResult(0)].encodings.begin(); + if (auto convertOp = dyn_cast(op)) { + Attribute srcEncoding = convertOp.getSrc().getType().getEncoding(); + auto it = layouts.find(convertOp.getSrc()); + if (it != layouts.end()) + srcEncoding = *(it->second.encodings.begin()); + Value src = getValueAs(convertOp.getSrc(), srcEncoding); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, src); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (canFoldIntoConversion(op, encoding)) { + Operation *newOp = rewriter.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, + newOp->getResult(0)); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (op->hasTrait() || + op->hasTrait() || + isa(op)) { + Operation *newOp = cloneElementwise(rewriter, op, encoding); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newOp; + } + llvm::report_fatal_error("unexpected op in rewrite"); + return nullptr; +} + +bool canBeRemat(Operation *op) { + if (isa(op)) + return !isExpensiveLoadOrStore(op); + if (isa(op)) + return false; + if (isa(op)) + return false; + + return true; +} + +void LayoutRematerialization::updateRematMapping( + SmallVector> &values) { + for (auto [old, newV] : values) { + auto it = mappedValues.find(old); + if (it != mappedValues.end()) { + Attribute encoding = it->second; + auto rematIt = rematMapping.find({old, it->second}); + assert(rematIt != rematMapping.end()); + Value replacedValue = rematIt->second; + rematMapping.erase(rematIt); + mappedValues.erase(it); + // Loop through the replacement value to find the new version of remat + // value. This should be okay as the number of values should be small. + for (auto [before, after] : values) { + if (before == replacedValue) { + replacedValue = after; + break; + } + } + rematMapping[{newV, encoding}] = replacedValue; + mappedValues[newV] = encoding; + } + } +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp, + IRMapping &mapping) { + SetVector opsToRewrite; + // Keep track of yield operands that need to be duplicated. + DenseMap> yieldOperandsMap; + // Keep these around to remove them from the slice after our collection pass + // This ensures we don't duplicate them during an for rewrite or causing the + // for/yield to fall out of sync + SetVector valuesWithExistingRemat; + for (Value v : slice) { + auto layoutIt = layout.find(v); + assert(layoutIt != layout.end()); + // If we already have a remat value for this value, use it. + if (hasRematValue(v, layoutIt->second)) { + mapping.map(v, getRematValue(v, layoutIt->second)); + valuesWithExistingRemat.insert(v); + continue; + } + if (v.getDefiningOp()) { + opsToRewrite.insert(v.getDefiningOp()); + if (auto ifOp = v.getDefiningOp()) { + unsigned operandIdx = cast(v).getResultNumber(); + opsToRewrite.insert(ifOp.thenYield().getOperation()); + yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx); + opsToRewrite.insert(ifOp.elseYield().getOperation()); + yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx); + } + } else { + BlockArgument blockArg = cast(v); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (auto loopOp = cast(parentOp)) { + opsToRewrite.insert(loopOp.getOperation()); + OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg); + auto yieldOp = blockArg.getOwner()->getTerminator(); + yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber()); + opsToRewrite.insert(yieldOp); + } + } + } + slice.set_subtract(valuesWithExistingRemat); + opsToRewrite = multiRootTopologicalSort(opsToRewrite); + + // replaceAllUsesWith calls delayed until after initial rewrite. + // This is required for slice.count(value) to work mid rewrite. + SmallVector> replacements; + + SmallVector deadOps; + IRRewriter builder(slice.begin()->getContext()); + for (Operation *op : opsToRewrite) { + if (auto forOp = dyn_cast(op)) { + // Keep a mapping of the operands index to the new operands index. + SmallVector> argMapping; + SmallVector newOperands; + for (auto arg : forOp.getRegionIterArgs()) { + if (slice.count(arg)) { + OpOperand &initVal = *forOp.getTiedLoopInit(arg); + argMapping.push_back(std::make_pair( + forOp.getTiedLoopResult(&initVal).getResultNumber(), + forOp.getInitArgs().size() + newOperands.size())); + newOperands.push_back(mapping.lookup(initVal.get())); + } + } + // Create a new for loop with the new operands. + scf::ForOp newForOp = replaceForOpWithNewSignature( + builder, forOp, newOperands, replacements); + deadOps.push_back(forOp.getOperation()); + Block &loopBody = *newForOp.getBody(); + for (auto m : argMapping) { + mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second)); + int numIndVars = newForOp.getNumInductionVars(); + mapping.map(loopBody.getArgument(m.first + numIndVars), + loopBody.getArgument(m.second + numIndVars)); + LLVM_DEBUG({ + DBGS() << "mapping forOp " + << loopBody.getArgument(m.first + numIndVars) << " to " + << loopBody.getArgument(m.second + numIndVars) << '\n'; + }); + // The result is not in the layout/slice, the argument is. + Value oldArg = loopBody.getArgument(m.first + numIndVars); + addRematValue(newForOp.getResult(m.first), layout[oldArg], + newForOp.getResult(m.second)); + addRematValue(oldArg, layout[oldArg], + loopBody.getArgument(m.second + numIndVars)); + } + continue; + } + if (auto ifOp = dyn_cast(op)) { + SmallVector newTypes; + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + auto it = layout.find(res); + assert(it != layout.end()); + + auto oldType = cast(res.getType()); + auto newType = RankedTensorType::get( + oldType.getShape(), oldType.getElementType(), it->second); + newTypes.push_back(newType); + } + } + scf::IfOp newIfOp = + replaceIfOpWithNewSignature(builder, ifOp, newTypes, replacements); + unsigned oldIdx = 0; + unsigned newIdx = ifOp.getNumResults(); + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + // Why can't we use res instead of ifOp.getResult(oldIdx)? + mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx)); + addRematValue(ifOp.getResult(oldIdx), layout[res], + newIfOp.getResult(newIdx)); + ++newIdx; + } + ++oldIdx; + } + deadOps.push_back(ifOp.getOperation()); + continue; + } + builder.setInsertionPoint(op); + if (auto yieldOp = dyn_cast(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + SmallVector operandsToRewrite = yieldOperandsMap[op]; + // Sort so that operands are added in the same order as the new scf + // results/arguments. + std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); + for (int operandIdx : operandsToRewrite) { + yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); + } + builder.create(op->getLoc(), yieldOperands); + op->erase(); + continue; + } + if (isa(op)) { + Operation *newOp = builder.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), + layout[op->getResult(0)]); + auto cvt = builder.create(op->getLoc(), newType, + newOp->getResult(0)); + mapping.map(op->getResult(0), cvt.getResult()); + addRematValue(op->getResult(0), layout[op->getResult(0)], + cvt.getResult()); + continue; + } + Operation *newOp = builder.clone(*op, mapping); + for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { + auto it = layout.find(old); + if (it == layout.end()) + continue; + auto newType = RankedTensorType::get( + cast(old.getType()).getShape(), + cast(old.getType()).getElementType(), it->second); + newV.setType(newType); + addRematValue(old, it->second, newV); + } + } + // Check mapping and see if there are existing convertOps on the old Argument + convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc())); + opToDelete.insert(convertOp); + + updateRematMapping(replacements); + for (auto &kv : replacements) { + builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + + for (Operation *op : deadOps) + opToDelete.insert(op); +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp) { + IRMapping mapping; + rewriteSlice(slice, layout, convertOp, mapping); +} + +LogicalResult getRematerializableSlice( + Value root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation = nullptr) { + LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding, + layout, stopPropagation); + if (result.failed() || slice.empty()) + return failure(); + + // Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return failure(); + } + } + return success(); +} + +void LayoutRematerialization::backwardRematerialization() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + backwardRematerialization(convertOp); + } +} + +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertOnTopOfExtOrBroadcast(convertOp); + } +} + +void LayoutRematerialization::backwardRematerialization( + ConvertLayoutOp convertOp) { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristic to accommodate fused attention + RankedTensorType targetType = convertOp.getType(); + if (isa(targetType.getEncoding())) + return; + Value oldV = convertOp->getOperand(0); + LDBG("check backward remat with source " << oldV << " encoding " + << targetType.getEncoding()); + // Check to see if there are existing remat'ed values for the pair of oldValue + // and encoding. + if (hasRematValue(oldV, targetType.getEncoding())) { + // Replace it with the remat'ed value. + Value newV = getRematValue(oldV, targetType.getEncoding()); + convertOp.replaceAllUsesWith(newV); + opToDelete.insert(convertOp); + LDBG("found remat'ed value" << newV); + return; + } + + // 1. Take a backward slice of all the tensor dependencies that can be + // rematerialized. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getSrc(), targetType.getEncoding(), slice, layout); + if (result.failed()) { + LDBG(" getRematerializableSlice failed"); + return; + } + + LLVM_DEBUG({ + DBGS() << " remat convert op " << convertOp << '\n'; + for (Value v : slice) + DBGS() << " " << v << '\n'; + }); + // 2. Rewrite the slice. + rewriteSlice(slice, layout, convertOp); +} + +// For convert left we try to hoist them above type extension to reduce the cost +// of the convert. +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( + ConvertLayoutOp convertOp) { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + RankedTensorType targetType = convertOp.getType(); + if (mlir::isa(targetType.getEncoding())) + return; + +#ifdef __ILUVATAR__ + auto srcType = convertOp.getOperand().getType().cast(); + auto srcSliceLayout = srcType.getEncoding().dyn_cast(); + auto tgtSliceLayout = targetType.getEncoding().dyn_cast(); + if (srcSliceLayout && tgtSliceLayout) { + if (srcSliceLayout.getNoWarpReduce() || tgtSliceLayout.getNoWarpReduce()) + return; + } +#endif + + auto isExtOrBroadcastOp = [](Operation *op) { + if (isa(op)) { + return true; + } + if (auto fpToFpOp = dyn_cast(op)) { + auto srcType = cast(fpToFpOp.getOperand().getType()); + return getElementBitWidth(srcType) < + getElementBitWidth(fpToFpOp.getType()); + } + return false; + }; + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + LogicalResult result = + getRematerializableSlice(convertOp.getSrc(), targetType.getEncoding(), + slice, layout, isExtOrBroadcastOp); + if (result.failed()) + return; + + Operation *extOrBroadcatOp = nullptr; + unsigned sliceSize = slice.size(); + for (unsigned i = 0; i < sliceSize; i++) { + Value v = slice[i]; + Operation *op = v.getDefiningOp(); + if (!op) + continue; + if (isExtOrBroadcastOp(op)) { + SetVector tempSlice; + DenseMap tempLayout; + std::optional srcEncoding = inferSrcEncoding(op, layout[v]); + if (!srcEncoding) + return; + LogicalResult result = getRematerializableSlice( + op->getOperand(0), *srcEncoding, tempSlice, tempLayout); + // If we can rematerialize the rest of the ext slice we can ignore this + // ext as it won't need a convert. + if (result.succeeded()) { + slice.insert(tempSlice.begin(), tempSlice.end()); + layout.insert(tempLayout.begin(), tempLayout.end()); + continue; + } + // Only apply it if there is a single ext op otherwise we would have to + // duplicate the convert. + if (extOrBroadcatOp != nullptr) + return; + extOrBroadcatOp = op; + } + } + + if (extOrBroadcatOp == nullptr) + return; + Attribute dstEncoding = layout[extOrBroadcatOp->getResult(0)]; + std::optional srcEncoding = + inferSrcEncoding(extOrBroadcatOp, dstEncoding); + if (!srcEncoding) + return; + // Move the convert before the ext op and rewrite the slice. + OpBuilder builder(extOrBroadcatOp); + auto tensorType = + cast(extOrBroadcatOp->getOperand(0).getType()); + auto newType = RankedTensorType::get( + tensorType.getShape(), tensorType.getElementType(), *srcEncoding); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); + Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp); + newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); + auto oldExtOrBroadcastType = + cast(extOrBroadcatOp->getResult(0).getType()); + Type newExtOrBroadcasrType = RankedTensorType::get( + oldExtOrBroadcastType.getShape(), oldExtOrBroadcastType.getElementType(), + dstEncoding); + newExtOrBroadcast->getResult(0).setType(newExtOrBroadcasrType); + IRMapping mapping; + mapping.map(extOrBroadcatOp->getResult(0), newExtOrBroadcast->getResult(0)); + slice.remove(extOrBroadcatOp->getResult(0)); + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp, mapping); +} + +void backwardRematerialization(ModuleOp module) { + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.backwardRematerialization(); + layoutRemat.cleanup(); + }); +} + +void hoistConvert(ModuleOp module) { + SmallVector convertOps; + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.hoistConvertOnTopOfExtOrBroadcast(); + layoutRemat.cleanup(); + }); +} +} // namespace + +class TritonGPURemoveLayoutConversionsPass + : public impl::TritonGPURemoveLayoutConversionsBase< + TritonGPURemoveLayoutConversionsPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + // 1. Propagate layout forward starting from "anchor" ops. + m.walk([](FuncOp funcOp) { + LayoutPropagation layoutPropagation(funcOp); + layoutPropagation.initAnchorLayout(); + layoutPropagation.propagateLayout(); + layoutPropagation.resolveConflicts(); + layoutPropagation.rewrite(); + }); + + LLVM_DEBUG({ + DBGS() << "Module after propagating layouts forward:\n"; + m.dump(); + }); + + RewritePatternSet cleanUpPatterns(context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context); + if (applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns)).failed()) { + signalPassFailure(); + } + + LLVM_DEBUG({ + DBGS() << "Module after canonicalizing:\n"; + m.dump(); + }); + + // 2. For remaining convert ops, try to rematerialize the slice of producer + // operation to avoid having to convert. + backwardRematerialization(m); + LLVM_DEBUG({ + DBGS() << "Module after backward remat:\n"; + m.dump(); + }); + + // 3. For remaining converts, try to hoist them above cast generating larger + // size types in order to reduce the cost of the convert op. + hoistConvert(m); + LLVM_DEBUG({ + DBGS() << "Module after hoisting converts:\n"; + m.dump(); + }); + + RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); +#ifdef __ILUVATAR__ + decomposePatterns.add(context); +#endif + if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) + .failed()) { + signalPassFailure(); + } + LLVM_DEBUG({ + DBGS() << "Module after decomposing dot-converts:\n"; + m.dump(); + }); + + // 4. Apply clean up patterns to remove remove dead convert and dead code + // generated by the previous transformations. + RewritePatternSet cleanUpPatterns2(context); + populateForOpDeadArgumentElimination(cleanUpPatterns2); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + if (applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns2)).failed()) { + signalPassFailure(); + } + LLVM_DEBUG({ + DBGS() << "Module after final cleanups:\n"; + m.dump(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp new file mode 100644 index 000000000..bff277c59 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp @@ -0,0 +1,140 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREORDERINSTRUCTIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static bool willIncreaseRegisterPressure(Operation *op) { + if (isa(op)) + return true; + auto cvt = dyn_cast(op); + if (!cvt) + return false; + if (mlir::isa( + cvt.getType().getEncoding())) + return true; + return false; +} + +class TritonGPUReorderInstructionsPass + : public impl::TritonGPUReorderInstructionsBase< + TritonGPUReorderInstructionsPass> { +public: + TritonGPUReorderInstructionsPass() = default; + + Operation *getFirstUse(Operation *op) { + std::vector users; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + users.push_back(ancestor); + } + auto minOpIt = std::min_element(users.begin(), users.end(), + [](mlir::Operation *a, mlir::Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != users.end() ? *minOpIt : nullptr; + } + + void runOnOperation() override { + ModuleOp m = getOperation(); + mlir::DominanceInfo dom(m); + // sink conversion after the last dealloc + // before the first use ancestor in its block + m.walk([&](triton::gpu::ConvertLayoutOp op) { + auto curr = mlir::Block::iterator(op); + for (; &*curr != getFirstUse(op); curr++) + if (isa(&*curr)) + op->moveAfter(&*curr); + }); + // Sink conversions into loops when they will increase + // register pressure + DenseMap opToMove; + auto moveAfter = [](Operation *lhs, Operation *rhs) { + lhs->moveAfter(rhs); + }; + m.walk([&](Operation *op) { + if (!willIncreaseRegisterPressure(op)) + return; + auto user_begin = op->user_begin(); + auto user_end = op->user_end(); + if (std::distance(user_begin, user_end) != 1) + return; + if (user_begin->getParentOfType() == + op->getParentOfType()) + return; + opToMove.insert({op, *user_begin}); + }); + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); + // Move alloc(load) immediately after dependent load + m.walk([&](triton::gpu::LocalAllocOp op) { + if (!op.getSrc()) + return; + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + moveAfter(op, argOp); + }); + // Move transpositions just after their definition + opToMove.clear(); + m.walk([&](triton::TransOp op) { + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + moveAfter(op, argOp); + }); + // Move `dot` operand so that conversions to opIdx=1 happens after + // conversions to opIdx=0 + m.walk([&](triton::gpu::LocalLoadOp op) { + auto dstEncoding = mlir::dyn_cast( + op.getType().getEncoding()); + if (!dstEncoding) + return; + int opIdx = dstEncoding.getOpIdx(); + if (opIdx != 1) + return; + if (!op->hasOneUse()) + return; + auto dotUser = dyn_cast(*op->user_begin()); + if (!dotUser) + return; + auto AOp = + dotUser.getOperand(0).getDefiningOp(); + if (!AOp) + return; + // Check that the conversion to OpIdx=1 happens before and can be moved + // after the conversion to OpIdx=0. + if (!dom.dominates(op.getOperation(), AOp.getOperation())) + return; + moveAfter(op, AOp); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Utility.cpp new file mode 100644 index 000000000..a5c3d73f5 --- /dev/null +++ b/third_party/iluvatar/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -0,0 +1,984 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "ttg-utility" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { + +using namespace triton; + +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + TensorOrMemDesc type, + int numWarps) { + if (version == 1) + return {16, 16}; + else if (version == 2) { + auto rank = shape.size(); + SmallVector ret(rank, 1); + ret[rank - 1] = 8; + ret[rank - 2] = 16; + return ret; + } else if (version == 3) { + unsigned k = 256 / type.getElementTypeBitWidth(); + if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { + assert(false && "type not supported"); + return {0, 0, 0}; + } + auto eltType = type.getElementType(); + SmallVector validN; + + // MMAv3 with larger instruction shape is preferred. + if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, + 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, + 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); + } + + if (eltType.isInteger(8)) { + validN.assign({224, 208, 192, 176, 160, 144, 128, 112, 96, 80, 64, 48, 32, + 24, 16, 8}); + } + + unsigned m = 16; + unsigned mWarps = std::max(shape[0] / m, 1); + unsigned nWarps = std::max(numWarps / mWarps, 1); + unsigned maxN = std::max(shape[1] / nWarps, 8); + for (auto n : validN) { + if (shape[1] % n == 0 && n <= maxN) { + return {m, n, k}; + } + } + + assert(false && "type not supported"); + return {0, 0, 0}; + } else { + assert(false && "version not supported"); + return {0, 0}; + } +} + +bool isLoadFromTensorPtr(triton::LoadOp op) { + return mlir::triton::isTensorPointerType(op.getPtr().getType()); +} + +SmallVector argSort(const SmallVector &arr) { + SmallVector ret(arr.size()); + std::iota(ret.begin(), ret.end(), 0); + std::stable_sort(ret.begin(), ret.end(), + [&](unsigned x, unsigned y) { return arr[x] > arr[y]; }); + return ret; +} + +Value getMemAccessPtr(Operation *op) { + if (auto ld = dyn_cast(op)) + return ld.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto copy = dyn_cast(op)) + return copy.getSrc(); + if (auto store = dyn_cast(op)) + return store.getPtr(); + return nullptr; +} + +unsigned getElementBitWidth(RankedTensorType type) { + auto typeForMem = + isa(type.getElementType()) + ? cast(type.getElementType()).getPointeeType() + : type.getElementType(); + return typeForMem.getIntOrFloatBitWidth(); +} + +unsigned getNumElementsPerThread(Operation *op, SmallVector order, + ModuleAxisInfoAnalysis &axisInfoAnalysis) { + Value val = getMemAccessPtr(op); + auto ty = cast(val.getType()); + auto shapePerCTA = triton::gpu::getShapePerCTA(ty); + AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + unsigned elemNumBits = getElementBitWidth(ty); + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); + unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); + unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); + unsigned maxContig = + std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); + unsigned alignment = std::min(maxMultiple, maxContig); + // For int64, we have to use this + unsigned currPerThread = std::min(alignment, 128 / elemNumBits); +#ifdef __ILUVATAR__ + if (elemNumBits <= 32) + currPerThread = std::min(alignment, 32 / elemNumBits); +#endif + LDBG("elemNumBytes: " << elemNumBytes + << ", divisibility: " << maxMultipleBytes + << ", contig: " << valInfo.getContiguity(order[0]) + << ", alignment: " << alignment); + return currPerThread; +} + +//===----------------------------------------------------------------------===// +// GraphDumper +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphDumper::onValue(Value value) const { + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +GraphDumper::NodeInfo GraphDumper::onOperation(Operation *op) const { + return {{"shape", "ellipse"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +std::string GraphDumper::dump(triton::FuncOp func) const { + llvm::SetVector values; + llvm::SetVector operations; + + func.walk([&](Operation *op) { + operations.insert(op); + for (Value operand : op->getOperands()) + values.insert(operand); + for (Value result : op->getResults()) + values.insert(result); + }); + + std::ostringstream oss; + oss << "// Generated by Triton GraphDumper\n" + << "\n" + << "digraph {\n"; + + oss << " // Value Nodes\n"; + for (Value value : values) + oss << " " << emitValueNode(value) << "\n"; + oss << "\n"; + + oss << " // Operation Nodes\n"; + for (Operation *op : operations) + oss << " " << emitOperationNode(op) << "\n"; + oss << "\n"; + + oss << " // Edges\n"; + for (Operation *op : operations) { + for (Value operand : op->getOperands()) + oss << " " << emitEdge(getUniqueId(operand), getUniqueId(op)) << "\n"; + for (Value result : op->getResults()) + oss << " " << emitEdge(getUniqueId(op), getUniqueId(result)) << "\n"; + } + + oss << "}\n"; + return oss.str(); +} + +void GraphDumper::dumpToFile(triton::FuncOp func, + const std::string &filename) const { + std::ofstream ofs(filename); + ofs << dump(func); +} + +std::string GraphDumper::getShapeStr(const Type &type) const { + std::ostringstream oss; + oss << "["; + if (auto tensorTy = dyn_cast(type)) { + auto shape = tensorTy.getShape(); + for (unsigned i = 0; i < shape.size(); ++i) { + if (i > 0) + oss << ", "; + oss << shape[i]; + } + } + oss << "]"; + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Value value) const { + std::ostringstream oss; + oss << value.getImpl(); + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Operation *op) const { + std::ostringstream oss; + oss << op; + return oss.str(); +} + +std::string GraphDumper::emitNode(const std::string &id, + const GraphDumper::NodeInfo info) const { + std::ostringstream oss; + oss << "\"" << id << "\" ["; + for (auto it = info.begin(); it != info.end(); ++it) { + if (it != info.begin()) + oss << ", "; + oss << it->first << " = \"" << it->second << "\""; + } + oss << "];"; + return oss.str(); +} + +std::string GraphDumper::emitEdge(const std::string &srcId, + const std::string &destId) const { + std::ostringstream oss; + oss << "\"" << srcId << "\" -> \"" << destId << "\";"; + return oss.str(); +} + +std::string GraphDumper::emitValueNode(Value value) const { + NodeInfo info = onValue(value); + if (info.find("label") == info.end()) { + std::string shapeStr = getShapeStr(value.getType()); + if (auto arg = mlir::dyn_cast(value)) + info["label"] = + "BlockArg" + std::to_string(arg.getArgNumber()) + " " + shapeStr; + else + info["label"] = shapeStr; + } + return emitNode(getUniqueId(value), info); +} + +std::string GraphDumper::emitOperationNode(Operation *op) const { + NodeInfo info = onOperation(op); + if (info.find("label") == info.end()) + info["label"] = op->getName().getStringRef().str(); + return emitNode(getUniqueId(op), info); +} + +//===----------------------------------------------------------------------===// +// GraphLayoutMarker +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphLayoutMarker::onValue(Value value) const { + std::string color = getColor(value.getType()); + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", color}}; +} + +std::string GraphLayoutMarker::getColor(const Type &type) const { + if (auto tensorTy = dyn_cast(type)) { + auto layout = tensorTy.getEncoding(); + if (isa(layout)) + return "green"; + else if (isa(layout)) + return "yellow"; + else if (isa(layout)) + return "lightslateblue"; + else if (isa(layout)) + return "orange"; + else if (isa(layout)) + return "orangered"; + else { + llvm::report_fatal_error("Unrecognized layout"); + return "unknown"; + } + } else { + return "white"; + } +} +// -------------------------------------------------------------------------- // + +static std::optional inferDstEncoding(triton::ReduceOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding, op.getNoWarpReduce()); +} + +static std::optional inferDstEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return std::nullopt; + if (op.getAxis() != sliceEncoding.getDim()) + return std::nullopt; + return sliceEncoding.getParent(); +} + +static std::optional inferDstEncoding(JoinOp op, Attribute srcEnc) { + Attribute dstEnc; + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferJoinOpEncoding(srcEnc, dstEnc, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return std::nullopt; +} + +static std::optional inferDstEncoding(SplitOp op, Attribute srcEnc) { + Attribute dstEnc; + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(srcEnc, dstEnc, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return std::nullopt; +} + +static std::optional inferSrcEncoding(triton::ReduceOp op, + Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return std::nullopt; + if (op.getAxis() != sliceEncoding.getDim()) + return std::nullopt; + return sliceEncoding.getParent(); +} + +static std::optional inferSrcEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding, false); + // FIXME: Shall we support noWarpReduce filed for ExpandDimsOp? +} + +static std::optional inferSrcEncoding(JoinOp op, Attribute dstEnc) { + // Split is the inverse of join. + Attribute srcEnc; + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return std::nullopt; +} + +static std::optional inferSrcEncoding(SplitOp op, Attribute dstEnc) { + // Join is the inverse of split. + Attribute srcEnc; + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferJoinOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return std::nullopt; +} + +static std::optional +inferTransOpDstEncoding(Attribute srcEnc, ArrayRef order) { + // Simply forward to the existing inferTransOpEncoding function. + Attribute retEncoding; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferTransOpEncoding(srcEnc, order, retEncoding))) { + return retEncoding; + } + return std::nullopt; +} + +static std::optional inferDstEncoding(triton::TransOp op, + Attribute encoding) { + return inferTransOpDstEncoding(encoding, op.getOrder()); +} + +static std::optional inferSrcEncoding(triton::TransOp op, + Attribute encoding) { + // We want to solve for srcEnc in + // transpose(srcEnc, order) -> dstEnc. + // Given the identity + // transpose(transpose(x, order), inverse(order)) == x, + // we can see this is equivalent to + // transpose(dstEnc, inverse(order)) -> srcEnc. + return inferTransOpDstEncoding(encoding, + triton::inversePermutation(op.getOrder())); +} + +static std::optional +inferReshapeOpDstEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, bool allowReorder) { + // We don't do anything smart to allow-reorder reshapes here. They are + // handled in OptimizeThreadLocality. + if (allowReorder) + return std::nullopt; + + Attribute dstEnc; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpNoReorderEncoding( + srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) { + return dstEnc; + } + return std::nullopt; +} + +static std::optional inferDstEncoding(triton::ReshapeOp op, + Attribute encoding) { + return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding, + op.getType().getShape(), + op.getAllowReorder()); +} + +static std::optional inferSrcEncoding(triton::ReshapeOp op, + Attribute encoding) { + // The encoding of x given the encoding of y in `reshape(x) -> y` is the same + // as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an + // invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this + // way. + return inferReshapeOpDstEncoding(op.getType().getShape(), encoding, + op.getSrc().getType().getShape(), + op.getAllowReorder()); +} + +std::optional inferSrcEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + // Scan only supports blocked encoding at the moment. + if (!isa(encoding)) + return std::nullopt; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op)) { + return encoding; + } + + if (auto reduceOp = dyn_cast(op)) + return inferSrcEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferSrcEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferSrcEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferSrcEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferSrcEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferSrcEncoding(reshape, encoding); + + return std::nullopt; +} + +std::optional inferDstEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + if (!isa(encoding)) + return std::nullopt; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op)) + return encoding; + if (auto reduceOp = dyn_cast(op)) + return inferDstEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferDstEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferDstEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferDstEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferDstEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferDstEncoding(reshape, encoding); + + return std::nullopt; +} + +bool isSingleValue(Value value) { + // Don't consider load as expensive if it is loading a scalar. + if (auto tensorTy = dyn_cast(value.getType())) + return tensorTy.getNumElements() == 1; + // TODO: Handle other cases. + // For example, when ptr is a tensor of single value. + // It means that ptr is a resultant of broadcast or generated through + // a chain of broadcast and other operations. + // Rematerialize it without considering contiguous memory access pattern is + // fine. + return true; +} + +bool isExpensiveLoadOrStore(Operation *op) { + // Case 1: Pointer of tensor is always expensive + auto operandType = op->getOperand(0).getType(); + if (triton::isTensorPointerType(operandType)) + return true; + // Case 2a: A size 1 tensor is not expensive since all threads will load the + // same + if (isSingleValue(op->getOperand(0))) + return false; + // Case 2b: Tensor of pointers has more threads than elements + // we can presume a high hit-rate that makes it cheap to load + auto ptrType = cast(op->getOperand(0).getType()); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (ptrType.getNumElements() < numWarps * threadsPerWarp) + return false; + return true; +} + +bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { + if (!op) + return true; + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return triton::gpu::isExpensiveCat(cast(op), targetEncoding); + if (isa(op)) + return true; + if (isa( + op)) + return true; + return false; +} + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { + if (isa(op)) + return !triton::gpu::isExpensiveCat(cast(op), + targetEncoding); + if (auto convert = dyn_cast(op)) { + if (mlir::isa(targetEncoding)) { + auto srcEncoding = convert.getSrc().getType().getEncoding(); + if (targetEncoding != srcEncoding) + return false; + } + return true; + } + + if (auto reshape = dyn_cast(op)) { + auto reshapeDstType = reshape.getType(); + RankedTensorType newDstType = + RankedTensorType::get(reshapeDstType.getShape(), + reshapeDstType.getElementType(), targetEncoding); + return reshape.getAllowReorder() && + !reshape.getEfficientLayout().has_value() && + !triton::gpu::isExpensiveView(reshape.getSrc().getType(), + newDstType); + } + return isa(op); +} + +scf::ForOp replaceForOpWithNewSignature( + RewriterBase &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInitArgs()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands); + newLoop->setAttrs(loop->getAttrs()); + newLoop.getBody()->erase(); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); + for (Value operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + return newLoop; +} + +scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, + ValueRange newIterOperands) { + SmallVector> replacements; + auto newForOp = replaceForOpWithNewSignature(rewriter, loop, newIterOperands, + replacements); + for (auto &kv : replacements) { + rewriter.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + return newForOp; +} + +scf::IfOp replaceIfOpWithNewSignature( + RewriterBase &rewriter, scf::IfOp ifOp, TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(ifOp); + + // Create a new loop before the existing one, with the extra operands. + auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes()); + resultTypes.append(newResultTypes.begin(), newResultTypes.end()); + scf::IfOp newIf = rewriter.create( + ifOp.getLoc(), resultTypes, ifOp.getCondition(), /*withElse=*/true); + newIf->setAttrs(ifOp->getAttrs()); + + rewriter.inlineBlockBefore(ifOp.thenBlock(), newIf.thenBlock(), + newIf.thenBlock()->begin()); + rewriter.inlineBlockBefore(ifOp.elseBlock(), newIf.elseBlock(), + newIf.elseBlock()->begin()); + + for (auto it : llvm::zip(ifOp.getResults(), + newIf.getResults().take_front(ifOp.getNumResults()))) + replacements.push_back(it); + return newIf; +} + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping) { + Operation *newOp = rewriter.clone(*op, mapping); + // if input types haven't changed, we're done + bool preserveTypes = + std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) { + return !mapping.contains(v) || + v.getType() == mapping.lookup(v).getType(); + }); + if (preserveTypes) + return newOp; + + if (newOp->getNumResults() == 0) + return newOp; + auto origType = dyn_cast(op->getResult(0).getType()); + auto argType = dyn_cast(newOp->getOperand(0).getType()); + if (!origType || !argType) + return newOp; + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), argType.getEncoding()); + newOp->getResult(0).setType(newType); + auto typeInfer = dyn_cast(newOp); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newOp->getContext(), newOp->getLoc(), newOp->getOperands(), + newOp->getAttrDictionary(), newOp->getPropertiesStorage(), + newOp->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newOp->getResult(i).setType(newTypes[i]); + } + } + return newOp; +} + +// Check if the convert will be a no-op in codegen. +static bool isFreeConvert(Operation *op) { + auto convertOp = dyn_cast(op); + if (!convertOp) + return false; + return isMmaToMmaShortcut(convertOp.getSrc().getType(), convertOp.getType()); +} + +LogicalResult +getConvertBackwardSlice(Value root, SetVector &slice, + Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation) { + DenseSet> seen; + SmallVector> queue; + + auto enqueue = [&](Value operand, Attribute encoding) { + auto x = std::make_pair(operand, encoding); + if (!seen.insert(x).second) { + return; // Already enqueued, skip + } + queue.push_back(x); + }; + enqueue(root, rootEncoding); + + while (!queue.empty()) { + auto [currentValue, encoding] = queue.back(); + queue.pop_back(); + if (!isa(currentValue.getType())) + continue; + // Skip propagating through for op results for now. + // TODO: enable this based on needs. + if (currentValue.getDefiningOp()) + return failure(); + slice.insert(currentValue); + if (layout.find(currentValue) != layout.end()) { + if (layout[currentValue] != encoding) + return failure(); + } + layout[currentValue] = encoding; + + if (auto ifOp = currentValue.getDefiningOp()) { + auto results = ifOp.getResults(); + unsigned argIdx = mlir::cast(currentValue).getResultNumber(); + + auto thenValue = ifOp.thenYield().getOperand(argIdx); + auto elseValue = ifOp.elseYield().getOperand(argIdx); + + enqueue(thenValue, encoding); + enqueue(elseValue, encoding); + + continue; + } + if (auto *definingOp = currentValue.getDefiningOp()) { + // If the op has multiple results we need to update all results layout. + for (Value result : definingOp->getResults()) { + if (result == currentValue || !isa(result.getType())) + continue; + enqueue(result, encoding); + } + if (!isFreeConvert(definingOp) && + canFoldIntoConversion(definingOp, encoding)) + continue; + if (stopPropagation && stopPropagation(definingOp)) + continue; + if (isa(definingOp)) + return failure(); + for (Value operand : definingOp->getOperands()) { + auto srcEncoding = inferSrcEncoding(definingOp, encoding); + if (!srcEncoding) + return failure(); + enqueue(operand, *srcEncoding); + } + continue; + } + auto blockArg = cast(currentValue); + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); + Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( + blockArg.getArgNumber() - forOp.getNumInductionVars()); + enqueue(initOperand->get(), encoding); + enqueue(yieldOperand, encoding); + continue; + } + // TODO: add support for WhileOp and other region types. + return failure(); + } + return success(); +} + +// TODO(thomas): this is duplicated with what is in GPUToLLVM +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = triton::applyPermutation(shape, order); + auto reorderedMultiDim = delinearize(b, loc, linear, reordered); + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + if (rank == 1) { + multiDim[0] = linear; + } else { + Value remained = linear; + for (auto &&en : llvm::enumerate(shape.drop_back())) { + auto dimSize = b.create(loc, en.value(), 32); + multiDim[en.index()] = b.create(loc, remained, dimSize); + remained = b.create(loc, remained, dimSize); + } + multiDim[rank - 1] = remained; + } + return multiDim; +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(b, loc, triton::applyPermutation(multiDim, order), + triton::applyPermutation(shape, order)); +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = b.create(loc, 0, 32); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = b.create(loc, dimShape, 32); + linear = b.create( + loc, b.create(loc, linear, dimSize), dim); + } + } + return linear; +} + +bool isPureUnaryInlineAsm(Operation *op) { + auto inlineAsmOp = dyn_cast(op); + if (!inlineAsmOp) + return false; + return op->getNumOperands() == 1 && op->getNumResults() == 1 && + inlineAsmOp.getPure(); +} + +int getNVIDIAComputeCapability(Operation *module) { + assert(module->hasAttr(triton::AttrTargetName) && + "Expected a target attribute on the module operation"); + + StringAttr targetAttr = + cast(module->getAttr(triton::AttrTargetName)); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("cuda:") && + "expected target attribute to be prefixed with \"cuda:\""); + + StringRef capabilityStr = ref.drop_front(5); // drop the "cuda:" + int computeCapability; + bool parseError = capabilityStr.getAsInteger(10, computeCapability); + assert(!parseError && + "invalid compute capability string in target attribute"); + + return computeCapability; +} + +namespace { + +/// Detect dead arguments in scf.for op by assuming all the values are dead and +/// propagate liveness property. +struct ForOpDeadArgElimination : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + Block &block = *forOp.getBody(); + auto yieldOp = cast(block.getTerminator()); + // Assume that nothing is live at the beginning and mark values as live + // based on uses. + DenseSet aliveValues; + SmallVector queue; + // Helper to mark values as live and add them to the queue of value to + // propagate if it is the first time we detect the value as live. + auto markLive = [&](Value val) { + if (!forOp->isAncestor(val.getParentRegion()->getParentOp())) + return; + if (aliveValues.insert(val).second) + queue.push_back(val); + }; + // Mark all yield operands as live if the associated forOp result has any + // use. + for (auto result : llvm::enumerate(forOp.getResults())) { + if (!result.value().use_empty()) + markLive(yieldOp.getOperand(result.index())); + } + if (aliveValues.size() == forOp.getNumResults()) + return failure(); + // Operations with side-effects are always live. Mark all theirs operands as + // live. + block.walk([&](Operation *op) { + if (!isa(op) && !wouldOpBeTriviallyDead(op)) { + for (Value operand : op->getOperands()) + markLive(operand); + } + }); + // Propagate live property until reaching a fixed point. + while (!queue.empty()) { + Value value = queue.pop_back_val(); + if (auto nestedFor = value.getDefiningOp()) { + auto result = mlir::cast(value); + OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); + markLive(forOperand.get()); + auto nestedYieldOp = + cast(nestedFor.getBody()->getTerminator()); + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + continue; + } + if (auto nestedIf = value.getDefiningOp()) { + auto result = mlir::cast(value); + for (scf::YieldOp nestedYieldOp : + {nestedIf.thenYield(), nestedIf.elseYield()}) { + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + } + continue; + } + if (Operation *def = value.getDefiningOp()) { + // TODO: support while ops. + if (isa(def)) + return failure(); + for (Value operand : def->getOperands()) + markLive(operand); + continue; + } + // If an argument block is live then the associated yield operand and + // forOp operand are live. + auto arg = mlir::cast(value); + if (auto forOwner = dyn_cast(arg.getOwner()->getParentOp())) { + if (arg.getArgNumber() < forOwner.getNumInductionVars()) + continue; + unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars(); + Value yieldOperand = + forOwner.getBody()->getTerminator()->getOperand(iterIdx); + markLive(yieldOperand); + markLive(forOwner.getInitArgs()[iterIdx]); + } + } + SmallVector deadArg; + for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) { + if (aliveValues.contains(yieldOperand.value())) + continue; + if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1)) + continue; + + // The yield operand might live outside the loop, e.g. + // %init = ... + // %x = ... + // %y = for iter_args(%unused = %init) { + // yield %x + // } + // + // In this case, the loop returns %x if it runs 1 or more times, and + // otherwise it returns %init. We cowardly refuse to remove this operand + // from the yield. (We could, but we'd need to prove that the loop runs 0 + // or >=1 times.) + // + // As a special case, if it doesn't matter whether the loop runs 0 or >=1 + // times (because the loop returns the same value in both cases) then we + // can still mark the operand as dead. This occurs in the above example + // when %init is the same as %x. + if (!forOp->isAncestor( + yieldOperand.value().getParentRegion()->getParentOp()) && + yieldOperand.value() != forOp.getInitArgs()[yieldOperand.index()]) + continue; + + deadArg.push_back(yieldOperand.index()); + } + if (deadArg.empty()) + return failure(); + rewriter.modifyOpInPlace(forOp, [&]() { + // For simplicity we just change the dead yield operand to use the + // associated argument and leave the operations and argument removal to + // dead code elimination. + for (unsigned deadArgIdx : deadArg) { + BlockArgument arg = block.getArgument(deadArgIdx + 1); + yieldOp.setOperand(deadArgIdx, arg); + } + }); + return success(); + } +}; + +} // namespace + +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace mlir diff --git a/third_party/iluvatar/lib/Target/CMakeLists.txt b/third_party/iluvatar/lib/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/iluvatar/lib/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/iluvatar/lib/Target/LLVMIR/CMakeLists.txt b/third_party/iluvatar/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..f2f9adf8f --- /dev/null +++ b/third_party/iluvatar/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,28 @@ +add_triton_library(TritonLLVMIR + LLVMDIScope.cpp + LLVMIRBreakPhiStruct.cpp + + DEPENDS + LLVMIRIncGen + + LINK_LIBS + ${CMAKE_DL_LIBS} + PUBLIC + MLIRArithToLLVM + MLIRBuiltinToLLVMIRTranslation + MLIRIndexToLLVM + MLIRIR + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRNVVMToLLVMIRTranslation + MLIRROCDLToLLVMIRTranslation + MLIRSCFToControlFlow + MLIRSupport + MLIRTargetLLVMIRExport + TritonGPUToLLVM + ) + +set_source_files_properties( + LLVMIRTranslation.cpp + PROPERTIES + COMPILE_FLAGS "-D__BUILD_DIR__=\\\"${CMAKE_BINARY_DIR}\\\"") diff --git a/third_party/iluvatar/lib/Target/LLVMIR/LLVMDIScope.cpp b/third_party/iluvatar/lib/Target/LLVMIR/LLVMDIScope.cpp new file mode 100644 index 000000000..bcd56b684 --- /dev/null +++ b/third_party/iluvatar/lib/Target/LLVMIR/LLVMDIScope.cpp @@ -0,0 +1,161 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Target/LLVMIR/Passes.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Path.h" + +//===----------------------------------------------------------------------===// +// This file implements a pass to add debug info scope to LLVM operations, and +// is inspired by the DIScopeForLLVMFuncOpPass in LLVM/MLIR. Different from the +// DIScopeForLLVMFuncOpPass, this pass also handles inlined functions. +//===----------------------------------------------------------------------===// + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Target/LLVMIR/Passes.h.inc" + +namespace { + +/// Attempt to extract a filename for the given loc. +FileLineColLoc extractFileLoc(Location loc) { + if (auto fileLoc = dyn_cast(loc)) + return fileLoc; + if (auto nameLoc = dyn_cast(loc)) + return extractFileLoc(nameLoc.getChildLoc()); + if (auto opaqueLoc = dyn_cast(loc)) + return extractFileLoc(opaqueLoc.getFallbackLocation()); + if (auto fusedLoc = dyn_cast(loc)) + return extractFileLoc(fusedLoc.getLocations().front()); + if (auto callerLoc = dyn_cast(loc)) + return extractFileLoc(callerLoc.getCaller()); + StringAttr unknownFile = mlir::StringAttr::get(loc.getContext(), ""); + return mlir::FileLineColLoc::get(unknownFile, 0, 0); +} + +/// Add a debug info scope to LLVMFuncOp that are missing it. +struct LLVMDIScopePass : public LLVMDIScopeBase { + LLVMDIScopePass() = default; + + void setSubprogramAttr(LLVM::LLVMFuncOp funcOp) { + Location loc = funcOp.getLoc(); + if (loc->findInstanceOf>()) + return; + + MLIRContext *context = &getContext(); + + // To find a DICompileUnitAttr attached to a parent (the module for + // example), otherwise create a default one. + LLVM::DICompileUnitAttr compileUnitAttr; + if (ModuleOp module = funcOp->getParentOfType()) { + auto fusedCompileUnitAttr = + module->getLoc() + ->findInstanceOf>(); + if (fusedCompileUnitAttr) + compileUnitAttr = fusedCompileUnitAttr.getMetadata(); + } + + // Filename, line and colmun to associate to the function. + LLVM::DIFileAttr fileAttr; + int64_t line = 1, col = 1; + FileLineColLoc fileLoc = extractFileLoc(loc); + if (!fileLoc && compileUnitAttr) { + fileAttr = compileUnitAttr.getFile(); + } else if (!fileLoc) { + fileAttr = LLVM::DIFileAttr::get(context, "", ""); + } else { + line = fileLoc.getLine(); + col = fileLoc.getColumn(); + StringRef inputFilePath = fileLoc.getFilename().getValue(); + fileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(inputFilePath), + llvm::sys::path::parent_path(inputFilePath)); + } + auto subroutineTypeAttr = + LLVM::DISubroutineTypeAttr::get(context, llvm::dwarf::DW_CC_normal, {}); + + // Figure out debug information (`subprogramFlags` and `compileUnitAttr`) to + // attach to the function definition / declaration. External functions are + // declarations only, and are defined in a different compile unit, so mark + // them appropriately in `subprogramFlags`, and set an empty + // `compileUnitAttr`. + DistinctAttr distinctId; + auto subprogramFlags = LLVM::DISubprogramFlags::Optimized; + if (!funcOp.isExternal()) { + distinctId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + if (!compileUnitAttr) { + compileUnitAttr = LLVM::DICompileUnitAttr::get( + context, distinctId, llvm::dwarf::DW_LANG_C, fileAttr, + StringAttr::get(context, "triton"), + /*isOptimized=*/true, LLVM::DIEmissionKind::LineTablesOnly); + } + subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition; + } else { + compileUnitAttr = {}; + } + + StringAttr funcNameAttr = funcOp.getNameAttr(); + // Note that scopeline is set differently from LLVM's + // DIScopeForLLVMFuncOpPass. I don't find reasons why scopeline should be + // the column offset + auto subprogramAttr = LLVM::DISubprogramAttr::get( + context, distinctId, compileUnitAttr, fileAttr, funcNameAttr, + funcNameAttr, fileAttr, + /*line=*/line, + /*scopeline=*/line, subprogramFlags, subroutineTypeAttr); + funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); + } + + // Get a nested loc for inlined functions + Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr, + Location calleeLoc) { + auto calleeFileName = extractFileLoc(calleeLoc).getFilename(); + auto context = op->getContext(); + LLVM::DIFileAttr calleeFileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(calleeFileName), + llvm::sys::path::parent_path(calleeFileName)); + auto lexicalBlockFileAttr = LLVM::DILexicalBlockFileAttr::get( + context, scopeAttr, calleeFileAttr, /*discriminator=*/0); + Location loc = calleeLoc; + if (mlir::isa(calleeLoc)) { + auto nestedLoc = mlir::cast(calleeLoc).getCallee(); + loc = getNestedLoc(op, lexicalBlockFileAttr, nestedLoc); + } + return FusedLoc::get(context, {loc}, lexicalBlockFileAttr); + } + + void setLexicalBlockFileAttr(Operation *op) { + auto opLoc = op->getLoc(); + if (auto callSiteLoc = dyn_cast(opLoc)) { + auto callerLoc = callSiteLoc.getCaller(); + auto calleeLoc = callSiteLoc.getCallee(); + LLVM::DIScopeAttr scopeAttr; + // We assemble the full inline stack so the parent of this loc must be a + // function + auto funcOp = op->getParentOfType(); + auto funcOpLoc = mlir::cast(funcOp.getLoc()); + scopeAttr = mlir::cast(funcOpLoc.getMetadata()); + auto loc = + CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc); + op->setLoc(loc); + } + } + + void runOnOperation() override { + getOperation()->walk([&](Operation *op) -> void { + if (isa(op)) + setSubprogramAttr(cast(op)); + else + setLexicalBlockFileAttr(op); + }); + } +}; + +} // end anonymous namespace + +std::unique_ptr mlir::createLLVMDIScopePass() { + return std::make_unique(); +} diff --git a/third_party/iluvatar/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp b/third_party/iluvatar/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp new file mode 100644 index 000000000..44afcfd21 --- /dev/null +++ b/third_party/iluvatar/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +/// Implements a trivial pass breaking up 1 level deep structure in phi nodes. +/// This handles the common case generated by Triton and allow better +/// optimizations down the compiler pipeline. +//===----------------------------------------------------------------------===// +#include "LLVMPasses.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiStruct(PHINode *phiNode) { + StructType *STy = dyn_cast(phiNode->getType()); + if (!STy) + return false; + IRBuilder<> builder(phiNode); + unsigned numOperands = phiNode->getNumIncomingValues(); + unsigned numScalarEl = STy->getNumElements(); + Value *newStruct = UndefValue::get(STy); + builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHI()); + llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP(); + for (unsigned i = 0; i < numScalarEl; i++) { + builder.SetInsertPoint(phiNode); + PHINode *newPhiNode = + builder.CreatePHI(STy->getElementType(i), numOperands); + for (unsigned j = 0; j < numOperands; ++j) { + Value *operand = phiNode->getIncomingValue(j); + builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator()); + newPhiNode->addIncoming(builder.CreateExtractValue(operand, i), + phiNode->getIncomingBlock(j)); + } + builder.restoreIP(insertInsertPt); + newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i); + insertInsertPt = builder.saveIP(); + } + phiNode->replaceAllUsesWith(newStruct); + return true; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + SmallVector PhiNodes; + for (BasicBlock &BB : F) { + for (Instruction &inst : BB) { + if (PHINode *phiNode = dyn_cast(&inst)) { + Changed |= processPhiStruct(phiNode); + continue; + } + break; + } + } + return Changed; +} + +PreservedAnalyses BreakStructPhiNodesPass::run(Function &F, + FunctionAnalysisManager &AM) { + + bool b = runOnFunction(F); + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/third_party/iluvatar/lib/Target/LLVMIR/LLVMPasses.h b/third_party/iluvatar/lib/Target/LLVMIR/LLVMPasses.h new file mode 100644 index 000000000..1dcdb2992 --- /dev/null +++ b/third_party/iluvatar/lib/Target/LLVMIR/LLVMPasses.h @@ -0,0 +1,16 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CodeGen.h" + +namespace llvm { + +// Pass to pre-process LLVM IR before optimization and break up phi of struct. +// Breaking up those phis into elementary types allows better optimizations +// downstream. +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; + +} // namespace llvm diff --git a/third_party/iluvatar/lib/Tools/CMakeLists.txt b/third_party/iluvatar/lib/Tools/CMakeLists.txt new file mode 100644 index 000000000..4b021da33 --- /dev/null +++ b/third_party/iluvatar/lib/Tools/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(TritonTools + LinearLayout.cpp + + DEPENDS + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + f2reduce +) diff --git a/third_party/iluvatar/lib/Tools/LinearLayout.cpp b/third_party/iluvatar/lib/Tools/LinearLayout.cpp new file mode 100644 index 000000000..75e530db5 --- /dev/null +++ b/third_party/iluvatar/lib/Tools/LinearLayout.cpp @@ -0,0 +1,427 @@ +#include "triton/Tools/LinearLayout.h" + +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "third_party/f2reduce/f2reduce.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton { + +namespace { +using BasesT = LinearLayout::BasesT; +using llvm::Twine; + +BasesT makeBasesMap( + ArrayRef>>> bases) { + BasesT ret; + for (const auto &[inDim, inDimBases] : bases) { + ret[inDim] = inDimBases; + } + return ret; +} + +std::string stringifyBases(const BasesT &bases, + ArrayRef outDimNames) { + std::string ret; + + if (bases.empty()) + return "(empty layout)\n"; + + // TODO: Add spaces for alignment. + for (const auto &[inDim, inDimBases] : bases) { + if (inDimBases.empty()) { + ret += " - " + inDim.str() + " is a size 1 dimension\n"; + continue; + } + + ret += " - " + + join(llvm::seq(inDimBases.size()), "\n ", + [&, &inDim = inDim, &inDimBases = inDimBases](int i) { + return inDim.str() + "=" + std::to_string(1 << i) + " -> (" + + join(inDimBases[i], ", ") + ")"; + }) + + "\n"; + } + ret += "where out dims are: [" + + join(outDimNames, ", ", [](StringAttr s) { return s.str(); }) + "]\n"; + return ret; +} + +BasesT validateBases(BasesT bases, ArrayRef outDimNames) { + if (bases.empty()) + return bases; + + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t b) { return b < 0; })) { + llvm::report_fatal_error( + "Invalid bases passed to LinearLayout. Expected all basis " + "values to be non-negative, but found a negative value for " + "in dimension '" + + Twine(inDim) + "'. Full list of bases:\n" + + stringifyBases(bases, outDimNames)); + } + } + } + + // Check that the bases all have length equal to outDimNames.size(). + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (basis.size() != outDimNames.size()) { + llvm::report_fatal_error( + "Invalid bases passed to LinearLayout. Expect all bases to have " + "the same size, equal to outDimNames.size() (" + + Twine(outDimNames.size()) + + "). But this failed for in dimension '" + Twine(inDim) + + "'. Full list of bases:\n" + stringifyBases(bases, outDimNames)); + } + } + } + + return bases; +} + +// Compute the rank of the matrix formed by taking the bases for the given +// outDim as columns. In other words, finds the number of linearly-independent +// bases for this output dimension. +int getMatrixRank(const LinearLayout &layout, StringAttr outDim) { + // Suppose we have a layout specified by the following key values. + // + // L(0,1) = 0b01 + // L(0,2) = 0b10 + // L(1,0) = 0b10 + // L(2,0) = 0b11 + // + // We will create one column per key value. The max bit width of these values + // is 2, so our matrix will have 2 rows. The final matrix will be + // + // | ↑ ↑ ↑ ↑ | | 0b0111 | + // | L(0,1) L(0,2) L(1,0) L(2,0) | = | 0b1001 | + // | ↓ ↓ ↓ ↓ | + int numRows = layout.getOutDimSizeLog2(outDim); + + int numCols = 0; + for (StringAttr inDim : layout.getInDimNames()) { + numCols += layout.getInDimSizeLog2(inDim); + } + + if (numCols == 0 || numRows == 0) + return 0; + + // Don't handle giant LLs. This makes some things easier; for example, each + // row can be a single uint64_t. + assert(numCols <= 64 && "LinearLayout too large"); + assert(numRows <= 64 && "LinearLayout too large"); + + // Note that `new int[n]()` is zero-initialized, whereas `new int[n]` is not. + std::unique_ptr m(new uint64_t[numRows]()); + + // Fill in the matrix. + int c = 0; + for (StringAttr inDim : layout.getInDimNames()) { + for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) { + uint64_t basis = layout.getBasis(inDim, i, outDim); + for (int j = 0; j < numRows; j++) { + m[j] |= ((basis >> j) & 1) << c; + } + c++; + } + } + + // stride is specified in number of 64-bit words per row. + f2reduce::inplace_rref_strided(m.get(), numRows, numCols, /*stride=*/1); + + // The rank of the reduced matrix is simply the number of nonzero rows. + int rank = 0; + for (int i = 0; i < numRows; i++) { + if (m[i] != 0) + rank++; + } + return rank; +} + +// Check that the given layout is surjective, i.e. that every `out` coordinate +// can be reached by some `in` coordinate. +// +// It's sufficient to check each output dimension indepedently. Still, +// it's prohibitively slow to calculate this naively. +// +// Thankfully, this is equivalent to checking that the number of +// linearly-independent bases for outDim d is equal to getOutDimSizeLog2(d). +// This can be computed by finding the rank of the matrix whose columns are +// those bases. We can compute the rank of our matrix using Gaussian +// elimination, which runs in O(n^3) for an n x n matrix. Our matrix size is +// log(product(inDimSize)) x log(outDimSize), and we do this numOutDims times, +// so this should be plenty fast overall. +void validateSurjectivity(const LinearLayout &layout) { + for (const auto &outDim : layout.getOutDimNames()) { + unsigned rank = getMatrixRank(layout, outDim); + unsigned expectedRank = layout.getOutDimSizeLog2(outDim); + if (rank != expectedRank) { + llvm::report_fatal_error( + "Invalid bases passed to LinearLayout. Expected bases to be " + "surjective, i.e. all possible output coordinates can be reached " + "by some input coordinates. But this failed for output dimension " + + Twine(outDim) + ", where we got rank " + Twine(rank) + + " instead of expected rank " + Twine(expectedRank) + + ". Full list of bases:\n" + + Twine(stringifyBases(layout.getBases(), layout.getOutDimNames()))); + } + } +} + +template +void assertDimsEqualIgnoringOrder(T &&a, U &&b) { + llvm::DenseSet as(a.begin(), a.end()); + llvm::DenseSet bs(b.begin(), b.end()); + if (as != bs) { + llvm::report_fatal_error("Dimensions must match, ignoring order, but they " + "don't. Got dims: [" + + Twine(triton::join(a, ", ")) + "] and [" + + triton::join(b, ", ") + "]"); + } +} + +} // anonymous namespace + +LinearLayout::LinearLayout(BasesT bases, ArrayRef outDimNames) + : bases(validateBases(std::move(bases), outDimNames)), + outDimNames(outDimNames.begin(), outDimNames.end()) { + validateSurjectivity(*this); +} + +LinearLayout::LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames) + : LinearLayout(makeBasesMap(bases), outDimNames) {} + +/*static*/ LinearLayout LinearLayout::identity1D(int32_t size, + StringAttr inDimName, + StringAttr outDimName) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> powersOf2; + for (int32_t i = 1; i < size; i *= 2) { + powersOf2.emplace_back().push_back(i); + } + return LinearLayout({{inDimName, std::move(powersOf2)}}, {outDimName}); +} + +/*static*/ LinearLayout LinearLayout::zeros1D(int32_t size, + StringAttr inDimName, + StringAttr outDimName) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> zeros; + for (int i = 0; i < llvm::Log2_32(size); i++) { + zeros.emplace_back().push_back(0); + } + return LinearLayout({{inDimName, zeros}}, {outDimName}); +} + +int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const { + // Sadly SetVector doesn't provide an O(1) way to do this. + for (int i = 0; i < outDimNames.size(); ++i) { + if (outDimNames[i] == outDim) { + return i; + } + } + llvm::report_fatal_error("outDim " + Twine(outDim) + " is not in layout\n" + + toString()); +} + +int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + return it->second.size(); +} + +int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const { + // TODO(jlebar): Cache this? + int32_t outDimIdx = getOutDimIndex(outDim); + int32_t max = 0; + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + max = std::max(max, basis[outDimIdx]); + } + } + return max == 0 ? 0 : llvm::Log2_32(max) + 1; +} + +LinearLayout LinearLayout::transposeIns(ArrayRef newInDims) const { + assertDimsEqualIgnoringOrder(newInDims, getInDimNames()); + + BasesT newBases; + for (const auto &inDim : newInDims) { + newBases[inDim] = bases.find(inDim)->second; + } + return LinearLayout(std::move(newBases), outDimNames.getArrayRef()); +} + +LinearLayout +LinearLayout::transposeOuts(ArrayRef newOutDims) const { + assertDimsEqualIgnoringOrder(newOutDims, getOutDimNames()); + + std::vector permutation; + for (const auto &outDim : newOutDims) { + permutation.push_back(getOutDimIndex(outDim)); + } + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + std::vector newBasis; + for (int32_t i : permutation) { + newBasis.push_back(basis[i]); + } + newInDimBases.push_back(std::move(newBasis)); + } + } + return LinearLayout(std::move(newBases), newOutDims); +} + +LinearLayout operator*(LinearLayout inner, LinearLayout outer) { + // Check that elements common to both outerDimsRange and innerDimsRange appear + // in the same relative order. + auto checkCommonDims = [&](auto outerDimsRange, auto innerDimsRange) { + llvm::DenseSet outerDims(outerDimsRange.begin(), + outerDimsRange.end()); + llvm::DenseSet innerDims(innerDimsRange.begin(), + innerDimsRange.end()); + + std::vector outerCommonDims; + for (StringAttr dim : outerDimsRange) { + if (innerDims.contains(dim)) { + outerCommonDims.push_back(dim); + } + } + + std::vector innerCommonDims; + for (StringAttr dim : innerDimsRange) { + if (outerDims.contains(dim)) { + innerCommonDims.push_back(dim); + } + } + + if (outerCommonDims != innerCommonDims) { + llvm::report_fatal_error( + "Cannot multiply layouts. All in/out dimensions common to both " + "layouts must appear in the same relative order, but they " + "don't.\nOuter:\n" + + Twine(outer.toString()) + "\nInner:\n" + inner.toString()); + } + }; + + // Check that dims common to outer and inner have the same relative order. + checkCommonDims(outer.getInDimNames(), inner.getInDimNames()); + checkCommonDims(outer.getOutDimNames(), inner.getOutDimNames()); + + // Get the sizeLog2 of all input and output dimensions we're going to + // consider, in order. `inner` is more minor, so its dimensions come first. + llvm::MapVector inDimSizes; + llvm::SetVector outDimNames; + for (const auto &layout : {inner, outer}) { + for (StringAttr inDim : layout.getInDimNames()) { + inDimSizes[inDim] += layout.getInDimSizeLog2(inDim); + } + for (StringAttr outDim : layout.getOutDimNames()) { + outDimNames.insert(outDim); + } + } + BasesT allBases; + for (auto [inDimName, inDimSize] : inDimSizes) { + std::vector> &inDimBases = allBases[inDimName]; + + // Fill with zeros. + inDimBases = std::vector>( + inDimSize, std::vector(outDimNames.size(), 0)); + + for (auto [outDimIdx, outDimName] : llvm::enumerate(outDimNames)) { + if (inner.hasInDim(inDimName) && inner.hasOutDim(outDimName)) { + for (int i = 0; i < inner.getInDimSizeLog2(inDimName); i++) { + inDimBases[i][outDimIdx] = inner.getBasis(inDimName, i, outDimName); + } + } + if (outer.hasInDim(inDimName) && outer.hasOutDim(outDimName)) { + int offset = + inner.hasInDim(inDimName) ? inner.getInDimSizeLog2(inDimName) : 0; + int shift = inner.hasOutDim(outDimName) + ? inner.getOutDimSizeLog2(outDimName) + : 0; + for (int i = 0; i < outer.getInDimSizeLog2(inDimName); i++) { + inDimBases[offset + i][outDimIdx] = + outer.getBasis(inDimName, i, outDimName) << shift; + } + } + } + } + + return LinearLayout(std::move(allBases), outDimNames.getArrayRef()); +} + +SmallVector> +LinearLayout::apply(ArrayRef> ins) const { + assertDimsEqualIgnoringOrder(llvm::make_first_range(ins), getInDimNames()); + + SmallVector> ret; + for (StringAttr outDim : getOutDimNames()) { + int32_t outVal = 0; + for (auto &[inDim, val] : ins) { + for (int i = 0; i < getInDimSizeLog2(inDim); i++) { + if (val & (1 << i)) + outVal ^= getBasis(inDim, i, outDim); + } + } + ret.push_back({outDim, outVal}); + } + return ret; +} + +LinearLayout LinearLayout::compose(const LinearLayout &outer) const { + assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getInDimNames()); + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + SmallVector> bases; + for (auto [outDim, b] : llvm::zip(getOutDimNames(), basis)) { + bases.push_back({outDim, b}); + } + auto newBases = outer.apply(bases); + auto newBasesRange = llvm::make_second_range(newBases); + newInDimBases.push_back( + std::vector(newBasesRange.begin(), newBasesRange.end())); + } + } + return LinearLayout(std::move(newBases), outer.getOutDimNames()); +} + +bool operator==(LinearLayout lhs, LinearLayout rhs) { + // llvm::MapVector doesn't have an operator== :(. + if (lhs.getOutDimNames() != rhs.getOutDimNames()) + return false; + if (lhs.bases.size() != rhs.bases.size()) + return false; + for (auto it1 = lhs.bases.begin(), it2 = rhs.bases.begin(); + it1 != lhs.bases.end(); ++it1, ++it2) { + if (*it1 != *it2) + return false; + } + return true; +} + +std::string LinearLayout::toString() const { + return stringifyBases(bases, getOutDimNames()); +} + +} // namespace mlir::triton diff --git a/third_party/iluvatar/python/src/interpreter.cc b/third_party/iluvatar/python/src/interpreter.cc new file mode 100644 index 000000000..6ab7c6c75 --- /dev/null +++ b/third_party/iluvatar/python/src/interpreter.cc @@ -0,0 +1,435 @@ +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace { + +enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; + +enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; + +std::map mem_semantic_map = { + {MemSemantic::ACQUIRE_RELEASE, __ATOMIC_ACQ_REL}, + {MemSemantic::ACQUIRE, __ATOMIC_ACQUIRE}, + {MemSemantic::RELEASE, __ATOMIC_RELEASE}, + {MemSemantic::RELAXED, __ATOMIC_RELAXED}, +}; + +// Use compiler builtin atomics instead of std::atomic which requires +// each variable to be declared as atomic. +// Currently work for clang and gcc. +template T atomic_cmp(T *ptr, T val, int order) { + auto cmp = [](T old, T val) { + if constexpr (is_min) { + return old > val; + } else { + return old < val; + } + }; + // First load + T old_val = __atomic_load_n(ptr, order); + while (cmp(old_val, val)) { + if (__atomic_compare_exchange(ptr, &old_val, &val, false, order, order)) { + break; + } + } + return old_val; +} + +template T atomic_fadd(T *ptr, T val, int order) { + T old_val; + T new_val; + // First load + // Load ptr as if uint32_t or uint64_t and then memcpy to T + if constexpr (sizeof(T) == 4) { + uint32_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); + std::memcpy(&old_val, &tmp, sizeof(T)); + } else if constexpr (sizeof(T) == 8) { + uint64_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); + std::memcpy(&old_val, &tmp, sizeof(T)); + } else { + throw std::invalid_argument("Unsupported data type"); + } + while (true) { + new_val = old_val + val; + if (__atomic_compare_exchange(ptr, &old_val, &new_val, false, order, + order)) { + break; + } + } + return old_val; +} + +class AtomicOp { +public: + AtomicOp(const uint64_t *ptr, size_t numel, int order) + : ptr(ptr), numel(numel), order(order) {} + + void apply() { + for (size_t i = 0; i < numel; ++i) { + applyAt(reinterpret_cast(ptr[i]), i); + } + } + + virtual ~AtomicOp() = default; + +protected: + virtual void applyAt(void *, size_t i) = 0; + + const uint64_t *ptr; + size_t numel; + int order; +}; + +template class AtomicRMWOpBase : public AtomicOp { +public: + AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret, + const bool *mask, size_t numel, int order) + : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {} + +protected: + void applyAt(void *loc, size_t i) override final { + if (mask[i]) { + *(static_cast(ret) + i) = + applyAtMasked(static_cast(loc), + *(static_cast(val) + i), order); + } + } + + virtual DType applyAtMasked(DType *loc, const DType value, int order) = 0; + + const void *val; + void *ret; + const bool *mask; +}; + +template +class AtomicRMWOp : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_add(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_fadd(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_and(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_or(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_xor(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_exchange_n(loc, value, order); + } +}; + +class AtomicCASOp : public AtomicOp { +public: + AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, + size_t itemsize, size_t numel, int order) + : AtomicOp(ptr, numel, order), expected(expected), desired(desired), + itemsize(itemsize) {} + +protected: + void applyAt(void *loc, size_t i) override { + // Atomic operations perform bitwise comparison, so it's safe to + // use number of bytes (itemsize) to determine the type of pointers + if (itemsize == 1) { + uint8_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 2) { + uint16_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 4) { + uint32_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 8) { + uint64_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else { + // The ‘__atomic’ builtins can be used with any integral scalar or pointer + // type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are + // also allowed if ‘__int128’ (see 128-bit Integers) is supported by the + // architecture. + // https://gcc.gnu.org/onlinedocs/gcc/_005f_005fatomic-Builtins.html + throw std::invalid_argument("Invalid byte size"); + } + } + +private: + void *expected; + const void *desired; + size_t itemsize; +}; + +// This is a workaround because explicit template parameter list for lambdas is +// a C++20 extension: +// auto try_make_op = [&]() { +// if (dtype.is(pybind11::dtype::of())) { +// atomic_op = std::make_unique>(ptr, val, ret, mask, +// numel, order); +// } +// }; +template struct OpCreator { + pybind11::dtype dtype; + const uint64_t *ptr; + const void *val; + void *ret; + const bool *mask; + size_t numel; + int order; + std::unique_ptr &atomic_op; + + template void create() { + if (!atomic_op && dtype.is(pybind11::dtype::of())) { + atomic_op = std::make_unique>(ptr, val, ret, mask, + numel, order); + } + } +}; + +template +std::unique_ptr +makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, + void *ret, const bool *mask, size_t numel, int order) { + // Iterate over all supported data types, make one that matches, and return + std::unique_ptr atomic_op; + OpCreator try_make_op{dtype, ptr, val, ret, + mask, numel, order, atomic_op}; + + (try_make_op.template create(), ...); + if (!atomic_op) { + throw std::invalid_argument("Unsupported data type"); + } + // Make it a unique_ptr + return atomic_op; +} + +} // namespace + +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "RMW_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX) + .export_values(); + + m.def("load", + [](py::array_t ptr, py::array_t mask, py::array other, + py::dtype ret_dtype) -> py::array { + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptr.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } + return ret.reshape(shape); + }); + + m.def("store", + [](py::array_t ptr, py::array value, py::array_t mask) { + int numel = ptr.size(); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_value = value.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) { + memcpy(reinterpret_cast(reshaped_ptr.mutable_at(i)), + reshaped_value.data(i), value.dtype().itemsize()); + } + } + }); + + m.def("atomic_rmw", + [](RMWOp rmw_op, py::array_t ptr, py::array val, + py::array_t mask, MemSemantic sem) -> py::array { + int order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = val.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto *ptr_data = reshaped_ptr.data(); + auto *mask_data = reshaped_mask.data(); + auto *val_data = static_cast(reshaped_val.data()); + auto *ret_data = static_cast(ret.mutable_data()); + + std::unique_ptr atomic_op; + +#define MAKE_ATOMIC_RMW_OP(OP_NAME, ...) \ + case OP_NAME: \ + atomic_op = makeAtomicRMWOp( \ + ret_dtype, ptr_data, val_data, ret_data, mask_data, numel, order); \ + break; + + switch (rmw_op) { + MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, float, double) + MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MAX, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMAX, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MIN, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMIN, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XCHG, int32_t, uint32_t, int64_t, + uint64_t) + default: + throw std::invalid_argument("Unsupported RMW operation"); + } + +#undef MAKE_ATOMIC_RMW_OP + + atomic_op->apply(); + return ret.reshape(shape); + }); + + m.def("atomic_cas", + [](py::array_t ptr, py::array &cmp, py::array &val, + MemSemantic sem) -> py::array { + int order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = cmp.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array reshaped_cmp = cmp.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto itemsize = cmp.itemsize(); + memcpy(static_cast(ret.mutable_data()), + static_cast(reshaped_cmp.data()), + itemsize * numel); + AtomicCASOp(reshaped_ptr.data(), ret.mutable_data(), + static_cast(reshaped_val.data()), itemsize, + numel, order) + .apply(); + return ret.reshape(shape); + }); +} diff --git a/third_party/iluvatar/python/src/ir.cc b/third_party/iluvatar/python/src/ir.cc new file mode 100644 index 000000000..cbe0c81b6 --- /dev/null +++ b/third_party/iluvatar/python/src/ir.cc @@ -0,0 +1,1648 @@ +#include +#include +#include + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/LocationSnapshot.h" +#include "mlir/Transforms/Passes.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace { + +namespace py = pybind11; +using namespace mlir; +using namespace triton; + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + OpBuilder &getBuilder() { return *builder; } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(FileLineColLoc::get(context, fileName, line, column)); + } + + Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +}; + +std::string locationToString(Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +void outputWarning(Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), + /*stack_level=*/2); +} + +} // anonymous namespace + +/*****************************************************************************/ +/* Python bindings for ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::enum_(m, "PADDING_OPTION", py::module_local()) + .value("PAD_ZERO", PaddingOption::PAD_ZERO) + .value("PAD_NAN", PaddingOption::PAD_NAN) + .export_values(); + + py::enum_(m, "CACHE_MODIFIER", py::module_local()) + .value("NONE", CacheModifier::NONE) + .value("CA", CacheModifier::CA) + .value("CG", CacheModifier::CG) + .value("WB", CacheModifier::WB) + .value("CS", CacheModifier::CS) + .value("WT", CacheModifier::WT) + .export_values(); + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", MemSyncScope::GPU) + .value("CTA", MemSyncScope::CTA) + .value("SYSTEM", MemSyncScope::SYSTEM) + .export_values(); + + py::enum_(m, "EVICTION_POLICY", py::module_local()) + .value("NORMAL", EvictionPolicy::NORMAL) + .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) + .export_values(); + + py::enum_(m, "ATOMIC_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX); + + py::enum_(m, "ROUNDING_MODE", py::module_local()) + .value("RTZ", RoundingMode::RTZ) + .value("RTNE", RoundingMode::RTNE); + + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", PropagateNan::NONE) + .value("ALL", PropagateNan::ALL); + + py::enum_(m, "INPUT_PRECISION", py::module_local()) + .value("TF32", InputPrecision::TF32) + .value("TF32x3", InputPrecision::TF32x3) + .value("IEEE", InputPrecision::IEEE) + .export_values(); + + py::class_(m, "context", py::module_local()).def(py::init<>()); + + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "type", py::module_local()) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) + .def("is_fp16", &Type::isF16) + .def("__str__", [](Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type", py::module_local()) + .def("param_types", [](FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); + + py::class_(m, "location", py::module_local()) + .def("__str__", [](Location &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "value", py::module_local()) + .def("set_attr", + [](Value &self, std::string &name, Attribute &attr) -> void { + if (Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + auto arg = mlir::cast(self); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } + }) + .def("get_context", &Value::getContext) + .def("replace_all_uses_with", + [](Value &self, Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + .def("get_type", &Value::getType) + .def("id", [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }); + + py::class_(m, "op_result", py::module_local()); + + py::class_(m, "block_argument", py::module_local()); + + py::class_(m, "region", py::module_local()) + .def("get_parent_region", &Region::getParentRegion, ret::reference) + .def("size", [](Region &self) { return self.getBlocks().size(); }) + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }); + + py::class_(m, "block", py::module_local()) + .def("arg", + [](Block &self, int index) -> BlockArgument { + if (index >= self.getNumArguments()) + throw pybind11::index_error("Block argument index out of range"); + return self.getArgument(index); + }) + .def("add_argument", + [](Block &self, Type ty) { + auto loc = UnknownLoc::get(ty.getContext()); + self.addArgument(ty, loc); + }) + .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) + .def("dump", &Block::dump) + .def("move_before", + [](Block &self, Block &dst) { self.moveBefore(&dst); }) + .def("insert_before", &Block::insertBefore) + .def("get_parent", &Block::getParent, ret::reference) + .def("merge_block_before", + [](Block &self, Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error( + "This block has arguments, don't merge"); + dst.getOperations().splice(dst.begin(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", + [](Block &self, Value &v, Value &newVal) { + v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = + currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; + }); + }) + .def("__str__", + [](Block &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return str; + }) + .def("has_terminator", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("has_return", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); + + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", py::module_local()); + py::class_(m, "bool_attr", py::module_local()); + + // Ops + py::class_(m, "OpState", py::module_local()) + .def("set_attr", + [](OpState &self, std::string &name, Attribute &attr) -> void { + self->setAttr(name, attr); + }) + .def("get_num_results", + [](OpState &self) -> unsigned { return self->getNumResults(); }) + .def("get_result", + [](OpState &self, unsigned idx) -> Value { + if (idx >= self->getNumResults()) + throw pybind11::index_error("Op result index out of range"); + return self->getResult(idx); + }) + .def( + "get_region", + [](OpState &self, unsigned idx) -> Region & { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self->getRegion(idx); + }, + ret::reference) + .def( + "get_body", + [](scf::ForOp &self, unsigned idx) -> Block * { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self.getBody(idx); + }, + ret::reference) + .def("dump", [](OpState &self) { self->dump(); }) + .def("__str__", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self->print(os, printingFlags); + return str; + }) + .def("append_operand", + [](OpState &self, Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) + .def("verify", [](OpState &self) -> bool { + return succeeded(verify(self.getOperation())); + }); + // scf Ops + py::class_(m, "ForOp", py::module_local()) + .def("get_induction_var", &scf::ForOp::getInductionVar); + + py::class_(m, "IfOp", py::module_local()) + .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &scf::IfOp::thenYield) + .def("get_else_yield", &scf::IfOp::elseYield); + py::class_(m, "YieldOp", py::module_local()); + py::class_(m, "WhileOp", py::module_local()) + .def("get_before", &scf::WhileOp::getBefore, ret::reference) + .def("get_after", &scf::WhileOp::getAfter, ret::reference); + py::class_(m, "ConditionOp", py::module_local()); + + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + + // dynamic_attr is used to transfer ownership of the MLIR context to the + // module + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) + .def("dump", &ModuleOp::dump) + .def("str", + [](ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self.print(os, printingFlags); + return str; + }) + .def("push_back", + [](ModuleOp &self, FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("has_function", + [](ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", + [](ModuleOp &self, std::string &funcName) -> FuncOp { + return self.lookupSymbol(funcName); + }) + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("create_location_snapshot", + [](ModuleOp &self, const std::string &fileName) -> void { + generateLocationsFromIR(/*raw_ostream=*/llvm::nulls(), + /*fileName=*/fileName, + /*op=*/self, /*flags=*/{}); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); + + m.def("make_attr", [](const std::vector &values, MLIRContext &context) { + return mlir::cast(DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + IntegerType::get(&context, 32)), + values)); + }); + + m.def( + "parse_mlir_module", + [](const std::string &inputFilename, MLIRContext &context) { + // parse module + OwningOpRef module = + parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); + return module->clone(); + }, + ret::take_ownership); + + py::class_(m, "function", py::module_local()) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", + [](FuncOp &self, unsigned idx) -> BlockArgument { + if (idx >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + return self.getArgument(idx); + }) + .def( + "add_entry_block", + [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, + ret::reference) + .def( + "set_arg_attr", + [](FuncOp &self, int arg_no, const std::string &name, int val) { + // set arg attributes "name" to value "val" + auto attrTy = IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); + }, + ret::reference) + // .def("has_attr", &::FuncOp::hasAttr) + .def("finalize", + [](FuncOp &self) -> void { + // Remove dead code + // 1. Unreachable code after return + self.walk([&](Block *block) { + Operation *retOp = nullptr; + // It's better to not use walk here because we only want to + // check operations in the current block + for (auto &op : block->getOperations()) { + if (isa(op)) + if (retOp == nullptr) { + retOp = &op; + break; + } + } + if (retOp && retOp != &block->back()) { + auto pos = retOp->getIterator(); + pos++; + auto *newBlock = block->splitBlock(pos); + newBlock->erase(); + } + }); + // 2. Check if the result of tl.advance is used + self.walk([&](Operation *op) { + if (isa(op) && op->getResult(0).use_empty()) + outputWarning(op->getLoc(), "The result of tl.advance is not " + "being used. Note that tl.advance " + "does not have any side effects. " + "To move the block pointer, you " + "need to assign the result of " + "tl.advance to a variable."); + }); + }) + .def_property_readonly("type", &FuncOp::getFunctionType) + .def("reset_type", &FuncOp::setType); + + py::class_(m, "InsertPoint", py::module_local()); + + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + // getters + .def("create_module", + [](TritonOpBuilder &self) -> ModuleOp { + return self.create(); + }) + // insertion block/point + .def("set_insertion_point_to_start", + [](TritonOpBuilder &self, Block &block) -> void { + self.setInsertionPointToStart(block); + }) + .def("set_insertion_point_to_end", + [](TritonOpBuilder &self, Block &block) { + self.setInsertionPointToEnd(block); + }) + .def("set_insertion_point_after", + [](TritonOpBuilder &self, Operation &op) { + self.setInsertionPointAfter(op); + }) + .def( + "get_insertion_block", + [](TritonOpBuilder &self) -> Block * { + return self.getBuilder().getInsertionBlock(); + }, + ret::reference) + .def("get_insertion_point", + [](TritonOpBuilder &self) { + return self.getBuilder().saveInsertionPoint(); + }) + .def("restore_insertion_point", + [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { + self.restoreInsertionPoint(pt); + }) + // Attr + .def("get_bool_attr", + [](TritonOpBuilder &self, bool value) { + return self.getBuilder().getBoolAttr(value); + }) + .def("get_int32_attr", + [](TritonOpBuilder &self, int32_t value) { + return self.getBuilder().getI32IntegerAttr(value); + }) + // Use arith.ConstantOp to create constants + // Constants + .def("get_int1", + [](TritonOpBuilder &self, bool v) -> Value { + return Value(self.create( + v, self.getBuilder().getI1Type())); + }) + .def("get_int8", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_int16", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_int32", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_int64", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_uint8", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_uint16", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_uint32", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_uint64", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_bf16", + [](TritonOpBuilder &self, float v) -> Value { + auto type = self.getBuilder().getBF16Type(); + return self.create( + APFloat(type.getFloatSemantics(), std::to_string(v)), type); + }) + .def("get_fp16", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF16FloatAttr(v)); + }) + .def("get_fp32", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF32FloatAttr(v)); + }) + .def("get_fp64", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + self.getBuilder().getF64FloatAttr(v)); + }) + .def("get_null_value", + [](TritonOpBuilder &self, Type type) -> Value { + if (auto floatTy = dyn_cast(type)) + return self.create( + APFloat(floatTy.getFloatSemantics(), 0), floatTy); + else if (auto intTy = dyn_cast(type)) + return self.create(0, intTy); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](TritonOpBuilder &self, Type type) -> Value { + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = dyn_cast(type)) + return self.create(val, intTy); + else + throw std::runtime_error("Not implemented"); + }) + + // Types + .def("get_void_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getNoneType(); + }) + .def("get_int1_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI1Type(); + }) // or ret::copy? + .def("get_int8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_int16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(16); + }) + .def("get_int32_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI32Type(); + }) + .def("get_int64_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI64Type(); + }) + .def("get_fp8e4nv_ty", + // TODO: fp8e4nv is using Float8E4M3FNUZType, which + // does not seem right. It should use FloatE4M3FNType + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b15_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_fp8e5_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_half_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF16Type(); + }) + .def("get_bf16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getBF16Type(); + }) + .def("get_float_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF32Type(); + }) + .def("get_double_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF64Type(); + }) + .def("get_ptr_ty", + [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { + return PointerType::get(type, addrSpace); + }) + .def("get_block_ty", + [](TritonOpBuilder &self, Type &elementType, + std::vector &shape) -> Type { + return RankedTensorType::get(shape, elementType); + }) + .def("get_function_ty", + [](TritonOpBuilder &self, std::vector inTypes, + std::vector outTypes) -> Type { + return self.getBuilder().getFunctionType(inTypes, outTypes); + }) + // locs + .def("set_loc", + [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) + .def("set_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) { self.setLastLoc(fileName, line, column); }) + .def("get_loc", + [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) + + // Ops + .def("get_or_insert_function", + [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, + Type &funcType, std::string &visibility, + bool noinline) -> FuncOp { + if (Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + if (auto funcTy = dyn_cast(funcType)) { + llvm::SmallVector attrs = { + NamedAttribute( + self.getBuilder().getStringAttr("sym_visibility"), + self.getBuilder().getStringAttr(visibility)), + NamedAttribute(self.getBuilder().getStringAttr("noinline"), + self.getBuilder().getBoolAttr(noinline))}; + return self.create(funcName, funcTy, attrs); + } + throw std::invalid_argument("invalid function type"); + }) + .def( + "create_block", + [](TritonOpBuilder &self) -> Block * { + Region *parent = self.getBuilder().getBlock()->getParent(); + return self.getBuilder().createBlock(parent); + }, + ret::reference) + .def( + "create_block_with_parent", + [](TritonOpBuilder &self, Region &parent, + std::vector &argTypes) -> Block * { + // TODO: update arg loc + auto loc = self.getBuilder().getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), loc); + return self.getBuilder().createBlock(&parent, {}, argTypes, + argLocs); + }, + ret::reference) + .def( + "new_block", + [](TritonOpBuilder &self) -> Block * { return new Block(); }, + ret::reference) + // Function + .def("ret", + [](TritonOpBuilder &self, std::vector &vals) -> OpState { + return self.create(vals); + }) + .def("call", + [](TritonOpBuilder &self, FuncOp &func, std::vector &args) + -> OpState { return self.create(func, args); }) + // Unstructured control flow + .def("create_cond_branch", + [](TritonOpBuilder &self, Value condition, Block *trueDest, + Block *falseDest) -> OpState { + return self.create(condition, trueDest, + falseDest); + }) + .def("create_branch", + [](TritonOpBuilder &self, Block *dest, std::vector &args) + -> OpState { return self.create(dest, args); }) + // Structured control flow + .def("create_for_op", + [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, + std::vector &initArgs) -> scf::ForOp { + return self.create(lb, ub, step, initArgs); + }) + .def("create_if_op", + [](TritonOpBuilder &self, std::vector &retTypes, + Value &condition, bool withElse) -> scf::IfOp { + return self.create(retTypes, condition, withElse); + }) + .def("create_yield_op", + [](TritonOpBuilder &self, std::vector &yields) + -> scf::YieldOp { return self.create(yields); }) + .def("create_while_op", + [](TritonOpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> scf::WhileOp { + return self.create(retTypes, initArgs); + }) + .def("create_condition_op", + [](TritonOpBuilder &self, Value &cond, + std::vector &args) -> scf::ConditionOp { + return self.create(cond, args); + }) + + // miscellaneous + .def("create_make_range", + [](TritonOpBuilder &self, int start, int end) -> Value { + auto retType = RankedTensorType::get( + {end - start}, self.getBuilder().getI32Type()); + return self.create(retType, start, end); + }) + + // Cast instructions + // Conversions for custom FP types (FP8 and non-standard rounding modes) + .def("create_fp_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType, + std::optional roundingMode) -> Value { + if (roundingMode.has_value()) + return self.create( + dstType, src, + RoundingModeAttr::get(self.getBuilder().getContext(), + roundingMode.value())); + else + return self.create(dstType, src); + }) + // Conversions for standard LLVM builtin types + .def("create_bitcast", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_si_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_ui_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_si", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_ui", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_ext", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_trunc", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_int_cast", + [](TritonOpBuilder &self, Value &src, Type &dstType, + bool isSigned) -> Value { + // get element type if necessary + Type srcType = src.getType(); + auto srcTensorType = dyn_cast(srcType); + auto dstTensorType = dyn_cast(dstType); + Type srcEltType = srcType; + Type dstEltType = dstType; + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); + } + unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); + unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); + if (srcWidth == dstWidth) + return self.create(dstType, src); + else if (srcWidth > dstWidth) + return self.create(dstType, src); + else if (isSigned) + return self.create(dstType, src); + else + return self.create(dstType, src); + }) + .def("create_to_index", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getIndexType(), input); + }) + .def("create_index_to_si", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getI64Type(), input); + }) + .def("create_fmul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_frem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fadd", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fsub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_umulhi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_udiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_srem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_urem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_fma", + [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { + return Value(self.create(a, b, c)); + }) + .def("create_shl", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_lshr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_ashr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minimumf follows the torch.minimum convention and returns NaN if either + // operand is NaN + .def("create_minimumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minnumf follows the torch.fmin convention and returns the non-NaN + // operand + .def("create_minnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maximumf follows the torch.maximum convention and returns NaN if either + // operand is NaN + .def("create_maximumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maxnumf follows the torch.fmax convention and returns the non-NaN + // operand + .def("create_maxnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_clampf", + [](TritonOpBuilder &self, Value &input, Value &min, Value &max, + PropagateNan propagateNan) -> Value { + return Value(self.create(input, min, max, propagateNan)); + }) + .def("create_precise_sqrt", + [](TritonOpBuilder &self, Value &input) -> Value { + return Value(self.create(input)); + }) + .def("create_precise_divf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // AddPtr (similar to GEP) + .def("create_addptr", + [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { + return self.create(ptr.getType(), ptr, offset); + }) + // Comparison (int) + .def("create_icmpSLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sle, lhs, + rhs); + }) + .def("create_icmpSLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::slt, lhs, + rhs); + }) + .def("create_icmpSGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sge, lhs, + rhs); + }) + .def("create_icmpSGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sgt, lhs, + rhs); + }) + .def("create_icmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ule, lhs, + rhs); + }) + .def("create_icmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ult, lhs, + rhs); + }) + .def("create_icmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::uge, lhs, + rhs); + }) + .def("create_icmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ugt, lhs, + rhs); + }) + .def("create_icmpEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::eq, lhs, + rhs); + }) + .def("create_icmpNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ne, lhs, + rhs); + }) + // Comparison (float) + .def("create_fcmpOLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLT, lhs, + rhs); + }) + .def("create_fcmpOGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGT, lhs, + rhs); + }) + .def("create_fcmpOLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLE, lhs, + rhs); + }) + .def("create_fcmpOGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGE, lhs, + rhs); + }) + .def("create_fcmpOEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OEQ, lhs, + rhs); + }) + .def("create_fcmpONE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ONE, lhs, + rhs); + }) + .def("create_fcmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULT, lhs, + rhs); + }) + .def("create_fcmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGT, lhs, + rhs); + }) + .def("create_fcmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULE, lhs, + rhs); + }) + .def("create_fcmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGE, lhs, + rhs); + }) + .def("create_fcmpUEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UEQ, lhs, + rhs); + }) + .def("create_fcmpUNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UNE, lhs, + rhs); + }) + // // Logical + .def("create_and", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_xor", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_or", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + // Input/Output + .def("create_load", + [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_store", + [](TritonOpBuilder &self, Value &ptrs, Value &value, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, value, cacheModifier, evictionPolicy); + }) + .def("create_tensor_pointer_load", + [](TritonOpBuilder &self, Value &ptr, + std::vector &boundaryCheck, + std::optional paddingOption, + CacheModifier cacheModifier, EvictionPolicy evictionPolicy, + bool isVolatile) -> Value { + return self.create(ptr, boundaryCheck, paddingOption, + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_tensor_pointer_store", + [](TritonOpBuilder &self, Value &ptr, Value &val, + std::vector &boundaryCheck, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptr, val, boundaryCheck, cacheModifier, + evictionPolicy); + }) + .def("create_masked_load", + [](TritonOpBuilder &self, Value &ptrs, Value &mask, + std::optional &other, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_masked_store", + [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, val, mask, cacheModifier, + evictionPolicy); + }) + .def("create_descriptor_load", + [](TritonOpBuilder &self, Value &desc_ptr, + std::vector &indices, Type type, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> Value { + return self.create( + type, desc_ptr, indices, cacheModifier, evictionPolicy); + }) + .def("create_descriptor_store", + [](TritonOpBuilder &self, Value &desc_ptr, Value value, + std::vector &indices) -> void { + self.create(desc_ptr, value, + indices); + }) + .def("create_reshape", + [](TritonOpBuilder &self, Value &arg, std::vector &shape, + bool allowReorder) -> Value { + auto argType = + cast(arg.getType()).getElementType(); + return self.create( + RankedTensorType::get(shape, argType), arg, allowReorder); + }) + .def("create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + std::vector retShape = argType.getShape(); + retShape.insert(retShape.begin() + axis, 1); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, axis); + }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create( + RankedTensorType::get(shape, lhsType.getElementType()), lhs, + rhs); + }) + .def("create_join", + [](TritonOpBuilder &self, Value &a, Value &b) -> Value { + return self.create(a, b); + }) + .def("create_split", + [](TritonOpBuilder &self, Value &a) -> std::vector { + auto op = self.create(a); + return std::vector(op->result_begin(), op->result_end()); + }) + // Implements tl.trans and tl.permute. + .def("create_trans", + [](TritonOpBuilder &self, Value &arg, + std::vector &order) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + auto retShape = applyPermutation(argType.getShape(), order); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, order); + }) + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + if (auto argType = dyn_cast(arg.getType())) + return self.createOrFold( + RankedTensorType::get(shape, argType.getElementType()), arg); + throw std::invalid_argument( + "arg is not of RankedTensorType, use create_splat"); + }) + .def("create_splat", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + auto argType = arg.getType(); + auto ret = self.createOrFold( + RankedTensorType::get(shape, argType), arg); + return ret; + }) + // // atomic + .def("create_atomic_cas", + [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, + MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, ptr, cmp, val, sem, + scope); + }) + .def("create_atomic_rmw", + [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, + Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, rmwOp, ptr, val, mask, + sem, scope); + }) + // External + .def("create_extern_elementwise", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, Type retType, bool isPure) -> Value { + return self.create(retType, argList, libName, + libPath, symbol, isPure); + }) + // Built-in instruction + .def("create_get_program_id", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create( + self.getBuilder().getI32Type(), + ProgramIDDimAttr::get(self.getBuilder().getContext(), + ProgramIDDim(axis))); + }) + .def("create_get_num_programs", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create( + self.getBuilder().getI32Type(), + ProgramIDDimAttr::get(self.getBuilder().getContext(), + ProgramIDDim(axis))); + }) + .def("create_dot", + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, InputPrecision inputPrecision, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, inputPrecision, + maxNumImpreciseAcc); + }) + .def("create_floor", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_ceil", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_cos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_erf", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_rsqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_fabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_iabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_reduce", + [](TritonOpBuilder &self, std::vector operands, + int axis) -> OpState { + return self.create(operands, axis, false); + }) + .def("create_reduce_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_scan", + [](TritonOpBuilder &self, std::vector operands, int axis, + bool reverse) -> OpState { + return self.create(operands, axis, reverse); + }) + .def("create_scan_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_ptr_to_int", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_int_to_ptr", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_select", + [](TritonOpBuilder &self, Value &condition, Value &trueValue, + Value &falseValue) -> Value { + return self.create(condition, trueValue, + falseValue); + }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, const std::vector &values, + const std::vector &types, bool isPure, + int pack) -> OpState { + return self.create( + types, inlineAsm, constraints, isPure, pack, values); + }) + .def("create_print", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values) -> void { + self.create( + StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)), + hex, values); + }) + .def("create_assert", + [](TritonOpBuilder &self, Value &condition, + const std::string &message, const std::string &fileName, + const std::string &funcName, unsigned lineNo) -> void { + auto messageAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(message)); + auto fileNameAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(fileName)); + auto funcNameAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(funcName)); + auto lineNoAttr = self.getBuilder().getI32IntegerAttr(lineNo); + self.create(condition, messageAttr, fileNameAttr, + funcNameAttr, lineNoAttr); + }) + // Undef + .def("create_undef", + [](TritonOpBuilder &self, Type &type) -> Value { + return self.create(type); + }) + .def("create_histogram", + [](TritonOpBuilder &self, Value operand, int numBins) -> Value { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand); + }) + // Force GPU barrier + .def("create_barrier", + [](TritonOpBuilder &self) { self.create(); }) + // Make a block pointer (tensor pointer in Triton IR) + .def("create_make_block_ptr", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &offsets, + std::vector &tensorShape, + std::vector &order) -> Value { + return self.create(base, shape, strides, offsets, + tensorShape, order); + }) + // Advance a block pointer + .def("create_advance", + [](TritonOpBuilder &self, Value &ptr, + std::vector &offsets) -> Value { + return self.create(ptr.getType(), ptr, offsets); + }); + + py::class_(m, "pass_manager", py::module_local()) + .def(py::init()) + .def("enable_debug", + [](PassManager &self) { + auto *context = self.getContext(); + bool haveDiagnostics = + ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); + bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + if (haveDiagnostics || haveDump) { + context->disableMultithreading(); + } + if (haveDiagnostics) { + context->printOpOnDiagnostic(true); + context->printStackTraceOnDiagnostic(true); + context->getDiagEngine().registerHandler([](Diagnostic &diag) { + llvm::outs() << diag << "\n"; + return success(); + }); + } + if (haveDump) { + auto printingFlags = OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + printingFlags.enableDebugInfo(); + auto printAlways = [](Pass *, Operation *) { return true; }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/printAlways, + /*shouldPrintAfterPass=*/printAlways, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure*/ true, llvm::dbgs(), + printingFlags); + } + }) + .def("run", [](PassManager &self, ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + makeReproducer(anchorName, passes, op, reproducerPath); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector split; + llvm::SmallVector storage; + llvm::SmallVector debugTypes; + + StringRef(debugOnly.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(debugTypes), + [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + + ::llvm::DebugFlag = true; + ::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }); +} + +void init_triton_env_vars(py::module &m) { + m.def("get_cache_invalidating_env_vars", + []() -> std::map { + std::map ret; + for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { + auto strVal = triton::tools::getStrEnv(envVar); + if (strVal.empty()) + continue; + auto boolV = triton::tools::isEnvValueBool(strVal); + if (boolV.has_value()) + ret[envVar] = boolV.value() ? "true" : "false"; + else + ret[envVar] = strVal; + } + return ret; + }); +} diff --git a/third_party/iluvatar/python/src/llvm.cc b/third_party/iluvatar/python/src/llvm.cc new file mode 100644 index 000000000..b03438827 --- /dev/null +++ b/third_party/iluvatar/python/src/llvm.cc @@ -0,0 +1,406 @@ +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Pass.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include +#include +#include + +namespace py = pybind11; + +namespace llvm { +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; +} // namespace llvm + +using namespace llvm; + +std::string translateLLVMIRToASM(llvm::Module &module, + const std::string &triple, + const std::string &proc, + const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, bool isObject) { + using namespace mlir; + // options + auto options = llvm::cl::getRegisteredOptions(); + for (std::string flag : flags) { + auto *shortPtr = static_cast *>(options[flag]); + assert(shortPtr); + shortPtr->setValue(true); + } + if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optIt = options.find("print-after-all"); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + bool disableLLVMOpt = triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (!disableLLVMOpt) { + // Check to see if we are passing a list of flags to disable optimizations. + auto flagList = triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + } + + // inline everything + for (llvm::Function &f : module.functions()) + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); + pm.add(llvm::createVerifierPass()); + + const bool enabledTiming = triton::tools::getBoolEnv("LLVM_ENABLE_TIMING"); + if (enabledTiming) { + llvm::TimePassesIsEnabled = true; + llvm::TimePassesPerRun = true; + } + + pm.run(module); + + SmallString<0> timePassesStr; + raw_svector_ostream reportStream(timePassesStr); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + // module->print(llvm::outs(), nullptr); + + // create machine + module.setTargetTriple(triple); + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error); + llvm::TargetOptions opt; + if (enable_fp_fusion) + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + std::unique_ptr machine{target->createTargetMachine( + module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + std::nullopt, + disableLLVMOpt ? llvm::CodeGenOptLevel::None + : llvm::CodeGenOptLevel::Aggressive)}; + // set data layout + module.setDataLayout(machine->createDataLayout()); + // emit machine code + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + for (llvm::Function &f : module.functions()) + f.addFnAttr(llvm::Attribute::AlwaysInline); + llvm::legacy::PassManager pass; + // emit + auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile + : llvm::CodeGenFileType::AssemblyFile; + machine->addPassesToEmitFile(pass, pstream, nullptr, fileType); + pass.run(module); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + } + return result; +} + +using ret = py::return_value_policy; + +void init_triton_llvm(py::module &&m) { + + py::class_(m, "context", py::module_local()) + .def(py::init<>()); + + py::class_(m, "function_list") + .def( + "__iter__", + [](llvm::Module::FunctionListType &s) { + return py::make_iterator(s.begin(), s.end()); + }, + py::keep_alive<0, 1>()); + + // Module Flag behavior. See + // https://llvm.org/doxygen/classllvm_1_1Module.html#a0a5c55e12c97b80021330fe82b642293 + // for details. + py::class_(m, "module_flag_behavior", + py::module_local()); + m.attr("MODULE_FLAG_BEHAVIOR_ERROR") = llvm::Module::Error; + m.attr("MODULE_FLAG_BEHAVIOR_WARNING") = llvm::Module::Warning; + m.attr("MODULE_FLAG_BEHAVIOR_REQUIRE") = llvm::Module::Require; + m.attr("MODULE_FLAG_BEHAVIOR_OVERRIDE") = llvm::Module::Override; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND") = llvm::Module::Append; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND_UNIQUE") = llvm::Module::AppendUnique; + m.attr("MODULE_FLAG_BEHAVIOR_MAX") = llvm::Module::Max; + m.attr("MODULE_FLAG_BEHAVIOR_MIN") = llvm::Module::Min; + + py::class_(m, "module", py::module_local()) + .def( + "__str__", + [](llvm::Module *self) { + std::string str; + llvm::raw_string_ostream os(str); + os << *self; + return os.str(); + }, + ret::take_ownership) + .def( + "get_functions", + [](llvm::Module *mod) -> llvm::Module::FunctionListType & { + // Note: Backends assume that we are compiling exactly one kernel + // (i.e. one function that's that's called by the CPU) and that it's + // the first function in this list. + return mod->getFunctionList(); + }, + ret::reference_internal) + .def("add_flag", + [](llvm::Module *mod, llvm::Module::ModFlagBehavior behavior, + std::string &key, uint32_t value) { + return mod->addModuleFlag(behavior, key, value); + }); + + py::class_(m, "function", py::module_local()) + .def_property_readonly( + "name", [](llvm::Function *fn) { return fn->getName().str(); }) + .def("set_calling_conv", &llvm::Function::setCallingConv) + .def("add_fn_attr", [](llvm::Function *fn, std::string &name, + std::string &val) { fn->addFnAttr(name, val); }) + + // Sets the nvvm.maxreg property on the given function. + .def("set_nvvm_maxnreg", + [](llvm::Function *fn, int maxnreg) { + auto op = MDNode::get( + fn->getContext(), + { + ValueAsMetadata::get(fn), + MDString::get(fn->getContext(), "maxnreg"), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(fn->getContext()), maxnreg)), + }); + fn->getParent() + ->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(op); + }) + // External functions that are definitions (i.e. not declarations) are + // kernel functions. + .def("is_declaration", &llvm::Function::isDeclaration) + .def("is_external_linkage", [](llvm::Function *fn) { + return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage; + }); + + // optimization levels + py::class_(m, "optimization_level", + py::module_local()); + m.attr("OPTIMIZE_O0") = llvm::OptimizationLevel::O0; + m.attr("OPTIMIZE_O1") = llvm::OptimizationLevel::O1; + m.attr("OPTIMIZE_O2") = llvm::OptimizationLevel::O2; + m.attr("OPTIMIZE_O3") = llvm::OptimizationLevel::O3; + m.attr("OPTIMIZE_Os") = llvm::OptimizationLevel::Os; + m.attr("OPTIMIZE_Oz") = llvm::OptimizationLevel::Oz; + + m.def( + "to_module", + [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { + return mlir::translateModuleToLLVMIR(mod, ctx); + }, + py::keep_alive<0, 2>()); + + m.def( + "optimize_module", + [](llvm::Module *mod, const llvm::OptimizationLevel &opt, + const std::string triple) { + if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT")) + return; + // Check to see if we are passing a list of flags to disable + // optimizations. + auto flagList = mlir::triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + auto options = llvm::cl::getRegisteredOptions(); + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + using namespace llvm; + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + PassInstrumentationCallbacks *instrCbPtr = nullptr; + PassInstrumentationCallbacks passInstrCb; + StandardInstrumentations standardInstr(mod->getContext(), + /*DebugLogging*/ true); + if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optMap = llvm::cl::getRegisteredOptions(); + auto optIt = optMap.find("print-after-all"); + if (optIt != optMap.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + standardInstr.registerCallbacks(passInstrCb, &mam); + instrCbPtr = &passInstrCb; + } + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = true; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. + // This cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also + // applies some scheduling that helps performance in some cases. We + // should work on using NVPTX target instead and address the performance + // regressions with some scheduling solution. + // tuningOptions.SLPVectorization = true; + tuningOptions.SLPVectorization = true; + + if (!triple.empty()) + mod->setTargetTriple(triple.c_str()); + + PassBuilder pb(nullptr /*targetMachine*/, tuningOptions, std::nullopt, + instrCbPtr); + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + pb.registerVectorizerStartEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make + // sure all the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + fpm.addPass(InstCombinePass()); + }); + mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); + mpm.run(*mod, mam); + }, + py::arg("mod"), py::arg("opt"), py::arg("triple") = ""); + + m.def( + "translate_to_asm", + [](std::string llvmIR, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, bool isObject) -> py::object { + std::string obj; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + obj = translateLLVMIRToASM(*module, triple, proc, features, flags, + enable_fp_fusion, isObject); + } + if (isObject) + return py::bytes(obj); + else + return py::str(obj); + }, + ret::take_ownership); + + m.def("init_targets", []() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + }); + }); + + m.def("link_extern_libs", [](llvm::Module *dstMod, + const std::vector &paths) { + if (paths.empty()) + return; + + LLVMContext &ctx = dstMod->getContext(); + llvm::Linker linker(*dstMod); + for (const std::string &path : paths) { + llvm::SMDiagnostic err; + std::unique_ptr libMod = llvm::parseIRFile(path, err, ctx); + if (!libMod) { + std::string message = "Failed to parse library at " + path; + throw std::invalid_argument(message); + } + libMod->setTargetTriple(dstMod->getTargetTriple()); + libMod->setDataLayout(dstMod->getDataLayout()); + + std::unordered_set externalFns; + for (llvm::Function &fn : libMod->functions()) { + if (!fn.isDeclaration()) + externalFns.insert(fn.getName().str()); + } + + if (linker.linkInModule(std::move(libMod), + llvm::Linker::Flags::LinkOnlyNeeded)) { + std::string message = "Failed to link library at " + path; + throw std::invalid_argument(message); + } + + // Mark linked-in functions as internal because backends use external + // linkage as a signifier of kernel functions. + for (llvm::Function &fn : dstMod->functions()) { + if (externalFns.count(fn.getName().str())) { + fn.setLinkage(llvm::GlobalValue::InternalLinkage); + } + } + } + }); +} diff --git a/third_party/iluvatar/python/src/main.cc b/third_party/iluvatar/python/src/main.cc new file mode 100644 index 000000000..867e558ec --- /dev/null +++ b/third_party/iluvatar/python/src/main.cc @@ -0,0 +1,70 @@ +#include +namespace py = pybind11; + +#include "python/src/plugin.h" + +using BackendInitFunc = void (*)(pybind11::module &&); + +BackendInitFunc load_backend_init_func(const char *backend_name) { + const std::string func_name = std::string("init_triton_") + backend_name; + void *symbol = load_backend_symbol(backend_name, func_name.c_str()); + return reinterpret_cast(symbol); +} + +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N +#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +// #define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); +#define INIT_BACKEND(name) \ + do { \ + try { \ + static auto func = load_backend_init_func(#name); \ + func(m.def_submodule(#name)); \ + } catch (const std::exception &e) { \ + std::cerr << "Failed to load backend " #name ": " << e.what() \ + << std::endl; \ + } \ + } while (0); + +void init_triton_env_vars(pybind11::module &m); +void init_triton_ir(pybind11::module &&m); +void init_triton_llvm(pybind11::module &&m); +void init_triton_interpreter(pybind11::module &&m); +void init_triton_passes(pybind11::module &&m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) + +PYBIND11_MODULE(libtriton, m) { + m.doc() = "Python bindings to the C++ Triton API"; + init_triton_env_vars(m); + init_triton_ir(m.def_submodule("ir")); + init_triton_passes(m.def_submodule("passes")); + init_triton_interpreter(m.def_submodule("interpreter")); + init_triton_llvm(m.def_submodule("llvm")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) +} diff --git a/third_party/iluvatar/python/src/passes.cc b/third_party/iluvatar/python/src/passes.cc new file mode 100644 index 000000000..513e811d2 --- /dev/null +++ b/third_party/iluvatar/python/src/passes.cc @@ -0,0 +1,90 @@ +#include "mlir/Transforms/Passes.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" +#include +#include + +namespace py = pybind11; + +void init_triton_analysis(py::module &&m) { + py::class_(m, "allocation", py::module_local()) + .def(py::init()); + py::class_(m, "membar", py::module_local()) + .def(py::init()) + .def("run", &mlir::ModuleMembarAnalysis::run); +} + +void init_triton_passes_common(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass); + ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass); + ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass); + ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); + ADD_PASS_WRAPPER_0("add_cse", createCSEPass); + ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); +} + +void init_triton_passes_ttir(py::module &&m) { + using namespace mlir::triton; + ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass); + ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass); + ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", + createRewriteTensorPointerPass); + ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", + createConvertTritonToTritonGPUPass, const std::string &, + int, int, int); +} + +void init_triton_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton::gpu; + ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce); + ADD_PASS_WRAPPER_0("add_optimize_thread_locality", + createTritonGPUOptimizeThreadLocality); + ADD_PASS_OPTION_WRAPPER_1("add_pipeline", createTritonGPUPipeline, int); + ADD_PASS_WRAPPER_0("add_prefetch", createTritonGPUPrefetch); + ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); + ADD_PASS_WRAPPER_0("add_reorder_instructions", + createTritonGPUReorderInstructions); + ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC); + ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands", + createTritonGPUOptimizeDotOperands, bool); + ADD_PASS_WRAPPER_0("add_remove_layout_conversions", + createTritonGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_0("add_reduce_data_duplication", + createTritonGPUReduceDataDuplication); + ADD_PASS_WRAPPER_0("add_allocate_shared_memory", + createAllocateSharedMemoryPass); + ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", + createTritonGPUCombineTensorSelectAndIf); +} + +void init_triton_passes_convert(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_scf_to_cf", createConvertSCFToCFPass); + ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); + ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); + ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); +} + +void init_triton_passes_llvmir(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_di_scope", createLLVMDIScopePass); +} + +void init_triton_passes(py::module &&m) { + init_triton_analysis(m.def_submodule("analysis")); + init_triton_passes_common(m.def_submodule("common")); + init_triton_passes_convert(m.def_submodule("convert")); + init_triton_passes_ttir(m.def_submodule("ttir")); + init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); + init_triton_passes_llvmir(m.def_submodule("llvmir")); +} diff --git a/third_party/iluvatar/python/src/passes.h b/third_party/iluvatar/python/src/passes.h new file mode 100644 index 000000000..46801d802 --- /dev/null +++ b/third_party/iluvatar/python/src/passes.h @@ -0,0 +1,40 @@ +#define ADD_PASS_WRAPPER_0(name, builder) \ + m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); }) + +#define ADD_PASS_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); }) + +#define ADD_PASS_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder(val0, val1)); \ + }) + +#define ADD_PASS_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder(val0, val1, val2)); \ + }) + +#define ADD_PASS_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); }) + +#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) + +#define ADD_PASS_OPTION_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder({val0, val1})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder({val0, val1, val2})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3) { \ + pm.addPass(builder({val0, val1, val2, val3})); \ + }) diff --git a/third_party/iluvatar/python/test/unit/conftest.py b/third_party/iluvatar/python/test/unit/conftest.py new file mode 100644 index 000000000..7a02d322b --- /dev/null +++ b/third_party/iluvatar/python/test/unit/conftest.py @@ -0,0 +1,12 @@ +# content of conftest.py + +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default='cuda') + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/third_party/iluvatar/python/test/unit/hopper/__init__.py b/third_party/iluvatar/python/test/unit/hopper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/iluvatar/python/test/unit/hopper/test_experimental_tma.py b/third_party/iluvatar/python/test/unit/hopper/test_experimental_tma.py new file mode 100644 index 000000000..b20f75bc5 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/hopper/test_experimental_tma.py @@ -0,0 +1,130 @@ +import numpy as np +import pytest +import torch +import tempfile + +import triton +import triton.language as tl + + +def test_descriptor_load_ttgir(): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: + pytest.skip("Test requires Hopper target.") + return + device = "cuda" + SIZE = 128 + + x = torch.randn(SIZE, dtype=torch.float32, device=device) + desc = np.empty(SIZE, dtype=np.int8) + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size(), desc) + size_in_bytes = SIZE * x.element_size() + + ir = f""" + #blocked = #triton_gpu.blocked<{{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}}> + #shared = #triton_gpu.shared<{{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}}> + module attributes {{"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %c0_i32 = arith.constant 0 : i32 + %0 = tt.make_range {{end = {SIZE} : i32, start = 0 : i32}} : tensor<{SIZE}xi32, #blocked> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, mutable> + %2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, mutable> + triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, mutable> + %true = arith.constant 1 : i1 + triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, mutable> + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : , <1xi64, #shared, mutable> -> <{SIZE}xf32, #shared, mutable> + triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, mutable> + %3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, mutable> -> tensor<{SIZE}xf32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<{SIZE}x!tt.ptr, #blocked> + %5 = tt.addptr %4, %0 : tensor<{SIZE}x!tt.ptr, #blocked>, tensor<{SIZE}xi32, #blocked> + tt.store %5, %3 : tensor<{SIZE}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + desc = torch.tensor(desc, device=device) + z_tri = torch.empty_like(x) + kernel[(1, 1, 1)](z_tri, desc) + assert torch.equal(x, z_tri) + + +def test_experimetal_descriptor_load(): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: + pytest.skip("Test requires Hopper target.") + return + device = "cuda" + SIZE = 128 + + @triton.jit + def kernel(Z, desc, SIZE: tl.constexpr): + off_desc = 0 + off = tl.arange(0, SIZE) + x = tl._experimental_descriptor_load(desc, [off_desc], [SIZE], Z.dtype.element_ty) + tl.store(Z + off, x) + + x = torch.randn(SIZE, dtype=torch.float32, device=device) + desc = np.empty(SIZE, dtype=np.int8) + triton.runtime.driver.active.utils.fill_1d_tma_descriptor(x.data_ptr(), SIZE, SIZE, x.element_size(), desc) + desc = torch.tensor(desc, device=device) + z_tri = torch.empty_like(x) + kernel[(1, )](z_tri, desc, SIZE=SIZE, num_warps=4) + assert torch.equal(x, z_tri) + + +@triton.jit +def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # + M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], tl.float16) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16) + accumulator = tl.dot(a, b, acc=accumulator) + offs_k += BLOCK_SIZE_K + accumulator = accumulator.to(tl.float16) + tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) + + +@pytest.mark.parametrize("num_stages", [1, 4]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) +def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K): + if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 9: + pytest.skip("Test requires Hopper target.") + return + device = "cuda" + M, N, K = 8192, 8192, 1024 + torch.manual_seed(42) + A = torch.randn((M, K), dtype=torch.float16, device=device) + B = torch.randn((K, N), dtype=torch.float16, device=device) + C = torch.empty((M, N), dtype=torch.float16, device=device) + TMA_SIZE = 128 + desc_a = np.empty(TMA_SIZE, dtype=np.int8) + desc_b = np.empty(TMA_SIZE, dtype=np.int8) + desc_c = np.empty(TMA_SIZE, dtype=np.int8) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size(), + desc_a) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size(), + desc_b) + triton.runtime.driver.active.utils.fill_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size(), + desc_c) + + desc_a = torch.tensor(desc_a, device=device) + desc_b = torch.tensor(desc_b, device=device) + desc_c = torch.tensor(desc_c, device=device) + kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, + 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps=8, + num_stages=num_stages) + ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + if BLOCK_M >= 64 and BLOCK_N >= 64: + assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] diff --git a/third_party/iluvatar/python/test/unit/hopper/test_flashattention.py b/third_party/iluvatar/python/test/unit/hopper/test_flashattention.py new file mode 100644 index 000000000..fc8db664c --- /dev/null +++ b/third_party/iluvatar/python/test/unit/hopper/test_flashattention.py @@ -0,0 +1,467 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) +""" + +# import numpy as np +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel(Q, K, V, sm_scale, # + L, M, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, D0, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + + # TODO: may replace with TMA store without range offset + # initialize offsets for store + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_prev = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + stride_qh_2d = stride_qh // stride_qm // stride_qk + + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + out_tile_ptr = tl.make_block_ptr( + base=Out, + shape=(D0, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # load q: it will stay in SRAM throughout + q = tl.load(q_tile_ptr) + + # loop over k, v and update accumulators + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + # -- compute qk ---- + k = tl.load(k_tile_ptr, boundary_check=(0, 1)) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + # compute new m + m_curr = tl.maximum(tl.max(qk, 1), m_prev) + # correct old l + l_prev *= tl.exp(m_prev - m_curr) + # attention weights + p = tl.exp(qk - m_curr[:, None]) + l_curr = tl.sum(p, 1) + l_prev + # rescale operands of matmuls + l_rcp = 1. / l_curr + p *= l_rcp[:, None] + acc *= (l_prev * l_rcp)[:, None] + # update acc + p = p.to(tl.float16) + v = tl.load(v_tile_ptr, boundary_check=(0, 1)) + acc += tl.dot(p, v) + # update m_i and l_i + l_prev = l_curr + m_prev = m_curr + # update pointers + k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_N, 0]) + v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_N, 0]) + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_prev) + tl.store(m_ptrs, m_prev) + + acc = acc.to(tl.float16) + tl.store(out_tile_ptr, acc, boundary_check=(0, 1)) + + +@triton.jit +def _bwd_preprocess(Out, DO, L, # + NewDO, Delta, # + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel(Q, K, V, sm_scale, Out, DO, # + DQ, DK, DV, # + L, M, # + D, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + Z, H, N_CTX, D0, # + num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # init tile_ptr + stride_qz_2d = stride_qz // stride_qm // stride_qk + stride_qh_2d = stride_qh // stride_qm // stride_qk + + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + do_tile_ptr = tl.make_block_ptr( + base=DO, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dq_tile_ptr = tl.make_block_ptr( + base=DQ, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dk_tile_ptr = tl.make_block_ptr( + base=DK, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dv_tile_ptr = tl.make_block_ptr( + base=DV, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # offset pointers for batch/head + DQ += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_tile_ptr, boundary_check=(0, 1)) + v = tl.load(v_tile_ptr, boundary_check=(0, 1)) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_tile_ptr, boundary_check=(0, 1)) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, tl.trans(k)) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_tile_ptr, boundary_check=(0, 1)) + dv += tl.dot(tl.trans(p.to(tl.float16)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(tl.float16)), q) + # compute dq + dq = tl.load(dq_tile_ptr) + dq += tl.dot(ds.to(tl.float16), k) + tl.store(dq_tile_ptr, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_tile_ptr = tl.advance(q_tile_ptr, [BLOCK_M, 0]) + do_tile_ptr = tl.advance(do_tile_ptr, [BLOCK_M, 0]) + dq_tile_ptr = tl.advance(dq_tile_ptr, [BLOCK_M, 0]) + q_tile_ptr = tl.advance(q_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0]) + do_tile_ptr = tl.advance(do_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0]) + dq_tile_ptr = tl.advance(dq_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0]) + # increment tile pointers + k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_M, 0]) + v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_M, 0]) + # write-back + tl.store(dv_tile_ptr, dv.to(tl.float16), boundary_check=(0, 1)) + tl.store(dk_tile_ptr, dk.to(tl.float16), boundary_check=(0, 1)) + dv_tile_ptr = tl.advance(dv_tile_ptr, [BLOCK_M, 0]) + dk_tile_ptr = tl.advance(dk_tile_ptr, [BLOCK_M, 0]) + + +empty = torch.empty(128, device="cuda") + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + D0 = q.shape[0] * q.shape[1] * q.shape[2] + _fwd_kernel[grid]( + q, k, v, sm_scale, # + L, m, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], D0, # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, # + num_warps=num_warps, num_stages=2) + + ctx.save_for_backward(q, k, v, o, L, m) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + BLOCK = 128 + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + D0 = q.shape[0] * q.shape[1] * q.shape[2] + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, # + do_scaled, delta, # + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL) + _bwd_kernel[(ctx.grid[1], )]( + q, k, v, ctx.sm_scale, # + o, do_scaled, # + dq, dk, dv, # + l, m, # + delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], D0, # + ctx.grid[0], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + num_warps=8, num_stages=1) + return dq, dk, dv, None + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 128, 64), + (4, 48, 256, 64), + (4, 48, 512, 64), + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + # (4, 48, 8192, 64), out of memory +]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") +def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() + sm_scale = 0.2 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + for z in range(Z): + for h in range(H): + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).half() + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # triton implementation + tri_out = attention(q, k, v, sm_scale) + # print(ref_out) + # print(tri_out) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=0) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 14)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + }, + ) for mode in ['fwd', 'bwd'] +] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# only works on post-Ampere GPUs right now +# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/third_party/iluvatar/python/test/unit/hopper/test_gemm.py b/third_party/iluvatar/python/test/unit/hopper/test_gemm.py new file mode 100644 index 000000000..88b39b57f --- /dev/null +++ b/third_party/iluvatar/python/test/unit/hopper/test_gemm.py @@ -0,0 +1,458 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import itertools +import os +import re + +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@triton.jit +def matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr # + ): + a_block_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + b_block_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), + order=(0, 1), + ) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + + c = tl.dot(a, b) + + if FLOAT16_OUTPUT: + c = c.to(tl.float16) + + if USE_TMA_EPILOGUE: + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + tl.store(c_block_ptr, c) + else: + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.parametrize( + 'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', + itertools.chain(*[[ + # numCTAs = 1, no TMA multicast: + [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], + [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], + [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], + [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], + [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + # static mask, cluster 4x1 + [256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], + [256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], + # dynamic mask, cluster 2x2 + [128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], + [128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], + # small M, N + [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], + ] for USE_TMA_EPILOGUE in [True, False]])) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): + if is_hip() and NUM_CTAS > 1: + pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") + + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + if OUTPUT_TYPE == "float16": + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + else: + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + + matmul_no_scf_kernel[(1, 1)]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # + num_warps=NUM_WARPS, # + num_ctas=NUM_CTAS, # + FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # + USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + golden = torch.matmul(a_f32, b_f32) + torch.set_printoptions(profile="full") + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) + + +@triton.jit +def matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_wm, stride_wn, # + stride_zm, stride_zn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, # + out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, # + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, # + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, # + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, # + W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, # + Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr # + ): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + block_offset_m = pid_m * BLOCK_M + block_offset_n = pid_n * BLOCK_N + + a_tile_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(A_ORDER_0, A_ORDER_1), + ) + b_tile_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), + block_shape=(BLOCK_K, BLOCK_N), + order=(B_ORDER_0, B_ORDER_1), + ) + # for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix + w_tile_ptr = tl.make_block_ptr( + base=w_ptr, + shape=(N, N), + strides=(stride_wm, stride_wn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_N), + order=(W_ORDER_0, W_ORDER_1), + ) + z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + offs_m = block_offset_m + tl.arange(0, BLOCK_M) + offs_n = block_offset_n + tl.arange(0, BLOCK_N) + z_ptrs = z_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn + bias_ptrs = bias_ptr + offs_m[:, None] * stride_zm + offs_n[None, :] * stride_zn + mask = (offs_m < M)[:, None] & (offs_n < N)[None, :] + + for k in range(0, K, BLOCK_K): + a = tl.load(a_tile_ptr, boundary_check=(0, 1)) + b = tl.load(b_tile_ptr, boundary_check=(0, 1)) + z += tl.dot(a, b) + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_K, 0]) + + z = z.to(out_dtype) + + if ADD_MATRIX: + z += tl.load(bias_ptrs, mask=mask) + if ADD_ROWS: + ZRs = bias_ptr + offs_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = bias_ptr + offs_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(w_tile_ptr) + z = tl.dot(z.to(w.dtype), w) + z = z.to(out_dtype) + + if USE_TMA_STORE: + z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn), + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), + order=(Z_ORDER_0, Z_ORDER_1)) + tl.store(z_block_ptr, z, boundary_check=(0, 1)) + else: + tl.store(z_ptrs, z, mask=mask) + + +@pytest.mark.parametrize( + 'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', + [ + # corner shapes + (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) + for shape_w_c in [ + [4096, 1, 1024, False, False, True], + [2048, 204, 1000, True, False, True], + [4096, 1, 1024, False, False, False], + [2048, 204, 1000, True, False, False], + ] + for out_dtype in ['float16', 'float32'] # + for use_tma_store in [False, True] # + ] + [ + # softmax epilogue + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ + [64, 64, 16, 4, 1, 64, 64, 64], + [128, 128, 64, 4, 1, None, None, None], + [16, 16, 64, 4, 1, 16, 16, 64], + [64, 64, 32, 8, 1, 64, 64, 64], + [128, 128, 64, 4, 1, 128, 128, 128], + ] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for + trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] + ] + [ + # loop over epilogues besides of softmax + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ + [64, 64, 16, 4, 1, 128, 128, 64], + *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], + # for chain-dot + [128, 128, 64, 4, 1, None, None, None], + [64, 64, 16, 4, 1, None, None, None], + # small BLOCK_M and BLOCK_K + [16, 16, 64, 4, 1, 128, 128, 64], + *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], + # repeat + [64, 64, 32, 8, 1, 128, 256, 64], + [64, 64, 16, 8, 2, 128, 128, 64], + # irregular shape + [128, 128, 64, 4, 1, 500, 200, 128], + [128, 128, 64, 4, 2, 513, 193, 192], + ] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' + ] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in + [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( + epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) + ] + [ + # loop over tile shapes and transpose combinations + (*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ + [64, 64, 32, 4, 1, 128, 256, 64], + [128, 128, 16, 4, 4, 512, 256, 64], + [128, 256, 32, 4, 8, 256, 256, 192], + [512, 256, 32, 4, 8, 1024, 256, 192], + # BLOCK_K >= 128 + [64, 128, 128, 4, 1, 512, 256, 256], + [128, 128, 128, 4, 1, 256, 256, 192], + [128, 128, 128, 4, 2, 256, 256, 192], + # small BLOCK_M and BLOCK_K + [16, 32, 32, 4, 1, 128, 256, 64], + [32, 32, 16, 4, 1, 256, 256, 192], + [16, 32, 64, 4, 4, 512, 256, 64], + ] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in + [False, True] for trans_output in [False, True] for num_stages in [3] + ] + [ + # loop over instr shapes & pipeline stages + (64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) + for n in [16, 32, 64, 128, 256] + for trans_output in [False] + for out_dtype in ['float32'] + for use_tma_store in [False] + for num_stages in [2, 4, 5, 7] + ] + [ + # irregular shapes + (*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ + [128, 128, 64, 4, 1], + [256, 128, 64, 4, 2], + [128, 128, 128, 4, 2], + ] for shape in [ + [512, 360, 1024], + [360, 4096, 512], + ] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in + [3, 4] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, + out_dtype, USE_TMA_STORE, NUM_STAGES): + if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ + '16-32-64-4-4-512-256-64-True-False', + '16-32-64-4-4-512-256-64-True-True', + '16-32-64-4-4-512-256-64-False-False', + '16-32-64-4-4-512-256-64-False-True', + ]: + pytest.skip('shapePerCTA[1] < 16 not supported') + + if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ + '16-32-64-4-1-256-256-256-False', + '16-32-64-4-2-256-256-256-False', + '16-32-64-4-2-256-256-256-True', + '16-32-64-8-2-256-256-256-False', + '16-32-64-8-2-256-256-256-True', + ]: + pytest.skip('Known legacy issue, ldmatrix can only support x4') + + if is_hip() and NUM_CTAS > 1: + pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") + + if epilogue == 'add-rows' and NUM_CTAS > 1: + pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') + + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K if K is None else K + + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + a_order = [0, 1] + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + a_order = [1, 0] + + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + b_order = [0, 1] + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + b_order = [1, 0] + + if out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + torch_out_dtype = torch.float16 + else: + out_dtype = tl.float32 + torch_out_dtype = torch.float32 + + # avoid out of memory + if epilogue in ['add-matrix', 'add-rows', 'add-cols']: + if (TRANS_OUTPUT): + bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T + else: + bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) + else: + bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) + + # for chain-dot only + w = torch.randn((N, N), device='cuda', dtype=torch.float16).T + w_order = [0, 1] + + if (TRANS_OUTPUT): + z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T + z_order = [0, 1] + else: + z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) + z_order = [1, 0] + + # torch result + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + dot = torch.matmul(a_f32, b_f32) + + def process_epilogue(d, bias, w, epilogue): + if epilogue == 'add-matrix': + ref = d + bias + elif epilogue == 'add-rows': + ref = d + bias[:, 0][:, None] + elif epilogue == 'add-cols': + ref = d + bias[0, :][None, :] + elif epilogue == 'softmax': + num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) + denom = torch.sum(num, dim=-1, keepdims=True) + ref = num / denom + # ref = torch.softmax(d, 1) + elif epilogue == 'chain-dot': + ref = torch.matmul(d, w.to(torch.float32)) + else: + ref = d + return ref + + golden = process_epilogue(dot, bias, w, epilogue) + + def grid(META): + return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) + + pgm = matmul_kernel[grid]( + a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_wm=w.stride(0), stride_wn=w.stride(1), # + stride_zm=z.stride(0), stride_zn=z.stride(1), # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # + out_dtype=out_dtype, # + USE_TMA_STORE=USE_TMA_STORE, # + ADD_MATRIX=epilogue == 'add-matrix', # + ADD_ROWS=epilogue == 'add-rows', # + ADD_COLS=epilogue == 'add-cols', # + DO_SOFTMAX=epilogue == 'softmax', # + CHAIN_DOT=epilogue == 'chain-dot', # + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # + W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # + Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) + + torch.set_printoptions(profile="full") + golden = torch.nn.functional.normalize(golden) + z = torch.nn.functional.normalize(z) + assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) + + # check is cuda backend specific + if is_hip(): + return + + disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() + if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: + ptx = pgm.asm['ptx'] + wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) + assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) diff --git a/third_party/iluvatar/python/test/unit/hopper/test_gemm_fusion.py b/third_party/iluvatar/python/test/unit/hopper/test_gemm_fusion.py new file mode 100644 index 000000000..3b07a1c8b --- /dev/null +++ b/third_party/iluvatar/python/test/unit/hopper/test_gemm_fusion.py @@ -0,0 +1,175 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def gemm_fusion_kernel(A, B, C, E, # + M, N, K, # + stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): + pid = tl.program_id(0) + + a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + + acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) + a = tl.load(a_tile_ptr) + for i in range(0, N, BLOCK_N): + b = tl.load(b_tile_ptr) + o_ab = tl.dot(a, tl.trans(b)) + c = tl.load(c_tile_ptr) + o_ab = o_ab.to(tl.float16) + acc_e += tl.dot(o_ab, c) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_N, 0]) + c_tile_ptr = tl.advance(c_tile_ptr, [BLOCK_N, 0]) + + acc_e = acc_e.to(tl.float16) + tl.store(e_tile_ptr, acc_e) + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="not passed on ampere") +def test_gemm_fusion(): + M, N, K = 4096, 4096, 64 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64 + A = torch.empty((M, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + B = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + C = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + E = torch.empty((M, K), dtype=torch.float16, device='cuda') + ref_out = torch.matmul(torch.matmul(A, B.T), C) + num_warps = 4 + grid = (triton.cdiv(M, BLOCK_M), 1) + gemm_fusion_kernel[grid]( + A, B, C, E, M, N, K, # + A.stride(0), A.stride(1), # + B.stride(0), B.stride(1), # + C.stride(0), C.stride(1), # + E.stride(0), E.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, # + num_warps=num_warps) + + torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) + + +@triton.jit +def batched_gemm_fusion(Q, K, V, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, NH, N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_qz, stride_qh, stride_qm, stride_qk), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_kz, stride_kh, stride_kn, stride_kk), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_vz, stride_vh, stride_vk, stride_vn), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + o_tile_ptr = tl.make_block_ptr( + base=Out, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_oz, stride_oh, stride_om, stride_on), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + + q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3)) + q = tl.reshape(q, (BLOCK_M, BLOCK_DMODEL), can_reorder=True) + for i in range(0, N_CTX, BLOCK_N): + k = tl.load(k_tile_ptr, boundary_check=(0, 1, 2, 3)) + k = tl.reshape(k, (BLOCK_N, BLOCK_DMODEL), can_reorder=True) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + + p = qk.to(tl.float16) + v = tl.load(v_tile_ptr, boundary_check=(0, 1, 2, 3)) + v = tl.reshape(v, (BLOCK_N, BLOCK_DMODEL), can_reorder=True) + acc += tl.dot(p, v) + + k_tile_ptr = tl.advance(k_tile_ptr, [0, 0, BLOCK_N, 0]) + v_tile_ptr = tl.advance(v_tile_ptr, [0, 0, BLOCK_N, 0]) + + acc = tl.reshape(acc, (1, 1, BLOCK_M, BLOCK_DMODEL), can_reorder=True) + acc = acc.to(tl.float16) + tl.store(o_tile_ptr, acc) + + +@pytest.mark.skip(reason="don't support 4d across stack, left for future") +def test_batched_gemm_fusion(): + Z = 4 + NH = 48 + H = 64 + N_CTX = 2048 + BLOCK_M, BLOCK_N, BLOCK_DMODEL = 128, 128, H + torch.manual_seed(20) + A = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + B = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + C = torch.empty((Z, NH, N_CTX, H), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + E = torch.empty_like(A) + BT = B.transpose(-1, -2) + ref_out = torch.matmul(torch.matmul(A, BT), C) + num_warps = 4 + grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH) + batched_gemm_fusion[grid]( + A, B, C, E, # + A.stride(0), A.stride(1), A.stride(2), A.stride(3), # + B.stride(0), B.stride(1), B.stride(2), B.stride(3), # + C.stride(0), C.stride(1), C.stride(2), C.stride(3), # + E.stride(0), E.stride(1), E.stride(2), E.stride(3), # + Z, NH, N_CTX, # + BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps) + + torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) diff --git a/third_party/iluvatar/python/test/unit/hopper/test_mixed_io.py b/third_party/iluvatar/python/test/unit/hopper/test_mixed_io.py new file mode 100644 index 000000000..68ee474a4 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/hopper/test_mixed_io.py @@ -0,0 +1,81 @@ +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + +dtype_mapping = { + 'float16': torch.float16, + 'float32': torch.float32, +} + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x_block_ptr = tl.make_block_ptr(base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero') + + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +@pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str', + [(98432, 1024, dtype_str) for dtype_str in ['float16', 'float32']]) +def test_add(SIZE, BLOCK_SIZE, dtype_str): + dtype = dtype_mapping[dtype_str] + output = torch.empty(SIZE, device='cuda', dtype=dtype) + x = torch.randn(SIZE, device='cuda', dtype=dtype) + y = torch.randn(SIZE, device='cuda', dtype=dtype) + + def grid(meta): + return (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + + add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE) + + output_torch = x + y + torch.set_printoptions(profile='full') + assert_close(output, output_torch, rtol=1e-2, atol=1e-3, check_dtype=False) + + +@triton.jit +def load_reduce_kernel( + x_ptr, + y_ptr, + stride_xm, + stride_xn, + stride_y, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x_ptr = tl.make_block_ptr(base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + x = tl.load(x_ptr) + y = tl.max(x, axis=1) + tl.store(y_ptr + tl.arange(0, BLOCK_M), y) + + +@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', [(128, 64, dtype_str) for dtype_str in ['float16']]) +def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str): + dtype = dtype_mapping[dtype_str] + x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype) + y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype) + + load_reduce_kernel[(1, )](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N) + + golden = x.max(dim=1)[0] + torch.set_printoptions(profile='full') + assert_close(y, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py b/third_party/iluvatar/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py similarity index 100% rename from python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py rename to third_party/iluvatar/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/third_party/iluvatar/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py similarity index 100% rename from python/test/unit/hopper/test_persistent_warp_specialized_gemm.py rename to third_party/iluvatar/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py diff --git a/third_party/iluvatar/python/test/unit/hopper/test_tma_store_gemm.py b/third_party/iluvatar/python/test/unit/hopper/test_tma_store_gemm.py new file mode 100644 index 000000000..b2fc3e874 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/hopper/test_tma_store_gemm.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, +# publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +@triton.jit +def matmul_tma_load_store( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + OUTPUT_F16: tl.constexpr # +): + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + a = tl.load(a_block_ptr) + b = tl.load(b_block_ptr) + + c = tl.dot(a, b) + if OUTPUT_F16: + c = c.to(tl.float16) + + tl.store(c_block_ptr, c) + + +@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_F16', [ + [64, 64, 16, 1, 4, False, True, False], + [64, 64, 16, 1, 4, False, True, True], + [128, 64, 32, 1, 4, False, True, False], + [128, 64, 32, 1, 4, False, True, True], + [64, 128, 32, 1, 4, False, True, False], + [64, 128, 32, 1, 4, False, True, True], + [128, 128, 64, 1, 4, False, True, False], + [128, 128, 64, 1, 4, False, True, True], +]) +def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F16): + if (TRANS_A): + a = torch.randn((K, M), device='cuda', dtype=torch.float16).T + else: + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + if (TRANS_B): + b = torch.randn((N, K), device='cuda', dtype=torch.float16).T + else: + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + if OUTPUT_F16: + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + matmul_tma_load_store[(1, 1)]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, # + OUTPUT_F16=OUTPUT_F16) + golden = torch.matmul(a, b) + torch.set_printoptions(profile="full") + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/language/assert_helper.py b/third_party/iluvatar/python/test/unit/language/assert_helper.py similarity index 100% rename from python/test/unit/language/assert_helper.py rename to third_party/iluvatar/python/test/unit/language/assert_helper.py diff --git a/python/test/unit/operators/conftest.py b/third_party/iluvatar/python/test/unit/language/conftest.py similarity index 100% rename from python/test/unit/operators/conftest.py rename to third_party/iluvatar/python/test/unit/language/conftest.py diff --git a/third_party/iluvatar/python/test/unit/language/print_helper.py b/third_party/iluvatar/python/test/unit/language/print_helper.py new file mode 100644 index 000000000..e032792f3 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/print_helper.py @@ -0,0 +1,125 @@ +import sys +import uuid + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +def get_current_target_warp_size(): + return triton.runtime.driver.active.get_current_target().warp_size + + +@triton.jit +def kernel_device_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_hex(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x, hex=True) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Triton should add a space after this prefix. + print("x:", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_large( + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32) + # Triton should change this prefix to "x: ". + tl.device_print("x ", x) + + +@triton.jit +def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + print("", x, y) + + +@triton.jit +def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + tl.device_print("", x, y) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr): + # This function takes an extra value as a tl.constexpr so this kernel is not + # cached. This way the static print is run every time. + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_print("", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_no_arg_print(): + print("", tl.program_id(0)) + + +@triton.jit +def kernel_print_no_arg(): + print("no arg") + + +@triton.jit +def kernel_print_pointer(X, Y, BLOCK: tl.constexpr): + tl.device_print("ptr ", X + tl.arange(0, BLOCK)) + + +def test_print(func: str, data_type: str): + N = 128 # This value should match with test_print in test_subprocess.py. + # TODO(antiagainst): Currently the warp count is chosen to make sure wedon't have multiple + # threads printing duplicated messages due to broadcasting. Improve print op lowering logic + # to filter out duplicated data range. + num_warps = N // get_current_target_warp_size() + + x = torch.arange(0, N, dtype=torch.int32, device='cuda').to(getattr(torch, data_type)) + y = torch.zeros((N, ), dtype=x.dtype, device="cuda") + if func == "device_print": + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "print": + kernel_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_large": + kernel_device_print_large[(1, 2)](BLOCK_M=64, num_warps=num_warps, BLOCK_N=N) + elif func == "print_multiple_args": + kernel_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_multiple_args": + kernel_device_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "static_print": + kernel_static_print[(1, )](x, y, num_warps=num_warps, BLOCK=N, PLACEHOLDER=uuid.uuid4()) + elif func == "no_arg_print": + kernel_no_arg_print[(1, )](num_warps=num_warps) + elif func == "print_no_arg": + kernel_print_no_arg[(1, )](num_warps=num_warps) + elif func == "device_print_hex": + kernel_device_print_hex[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_pointer": + kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N) + else: + assert f"Unknown kernel: {func}" + + if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ + func != "print_multiple_args" and func != "device_print_multiple_args" and \ + func != "device_print_pointer": + assert_close(y, x) + + +if __name__ == "__main__": + test_print(sys.argv[1], sys.argv[2]) diff --git a/third_party/iluvatar/python/test/unit/language/test_annotations.py b/third_party/iluvatar/python/test/unit/language/test_annotations.py new file mode 100644 index 000000000..0c1f065a1 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_annotations.py @@ -0,0 +1,49 @@ +from __future__ import annotations +import torch +import triton +import triton.language as tl +import pytest + + +def annotated_function(return_type=None, **arg_types): + """A decorator to add annotations to a function.""" + + def decorator(func): + func.__annotations__ = {**arg_types, 'return': return_type} + return func + + return decorator + + +# Test integer annotations +@pytest.mark.parametrize(("signed", "width"), [ + (signed, width) for signed in [False, True]\ + for width in [8, 16, 32, 64] +] + [(False, 1)] + ) +def test_int_annotation(signed, width, device): + + @triton.jit + @annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}") + def _kernel(X, v): + tl.store(X, v) + + h = _kernel[(1, )](torch.empty(1, device=device), 3) + pfx = 'si' if signed else 'ui' + assert f'%arg1: i{width}' in h.asm["ttir"] + assert f'arith.{pfx}tofp' in h.asm["ttir"] + + +# Test that unknown annotations do not emit an error +def test_unknown_annotation(device): + + @triton.jit + def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): + pass + + x = torch.empty(1, device=device) + _kernel[(1, )](x, x.shape[0], 32) + try: + _kernel[(1, )](x.shape[0], x.shape[0], 32) + except AttributeError: + pass diff --git a/third_party/iluvatar/python/test/unit/language/test_block_pointer.py b/third_party/iluvatar/python/test/unit/language/test_block_pointer.py new file mode 100644 index 000000000..c932131c9 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_block_pointer.py @@ -0,0 +1,100 @@ +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr): + pid = tl.program_id(0) + # We only copy half of the data to see if the padding works + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) + tl.store(b_block_ptr, a, boundary_check=(0, )) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtypes_str, n, padding_option", [ # + (dtypes_str, n, padding) + for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("float16", "float16"), ("int16", "float16")) + for n in (64, 128, 256, 512, 1024) + for padding in ("zero", "nan") # +]) +def test_block_copy(dtypes_str, n, padding_option, device): + src_dtype_str = dtypes_str[0] + dst_dtype_str = dtypes_str[0] + src_dtype = getattr(torch, src_dtype_str) + dst_dtype = getattr(torch, dst_dtype_str) + if src_dtype_str in ("bool", "int16"): + if padding_option == "nan": + pytest.skip("Padding with NaN is not supported for integer types") + a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype) + else: + a = torch.randn((n, ), device=device, dtype=src_dtype) + b = torch.zeros((n, ), device=device, dtype=dst_dtype) + + grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) + a.to(dst_dtype) + assert torch.all(a[0:n // 2] == b[0:n // 2]) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + else: + assert torch.all(torch.isnan(b[n // 2:n])) + + +@triton.jit +def matmul_no_scf_with_advance_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr # +): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) + # Below two lines are just for testing negative offsets for the `advance` API, which could be removed + a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K)) + a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K)) + a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero") + b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero") + + c = tl.dot(a, b) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, num_warps", [ # + (shape, num_warps) for shape in [ + [64, 64, 16], + [64, 64, 32], + [64, 64, 64], + ] for num_warps in [4, 8] +]) +def test_block_ptr_matmul_no_scf(shape, num_warps, device): + m, n, k = shape + a = torch.randn((m, k), device=device, dtype=torch.float16) + b = torch.randn((k, n), device=device, dtype=torch.float16) + c = torch.empty((m, n), device=device, dtype=torch.float32) + + grid = lambda META: (1, ) + matmul_no_scf_with_advance_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=m, N=n, K=k, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # + num_warps=num_warps) + golden = torch.matmul(a, b) + torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/third_party/iluvatar/python/test/unit/language/test_compile_errors.py b/third_party/iluvatar/python/test/unit/language/test_compile_errors.py new file mode 100644 index 000000000..0531f8ebc --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_compile_errors.py @@ -0,0 +1,304 @@ +import pytest + +import triton +import triton.language as tl +from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure +import traceback + + +def test_err_undefined_variable(): + + @triton.jit + def kernel(): + a += 1 # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "is not defined" in str(e.value), "error should mention the undefined variable" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_operator(): + + @triton.jit + def kernel(): + 0 + "a" + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the 0" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_static_assert(): + + @triton.jit + def kernel(): + tl.static_assert(isinstance(0, tl.tensor)) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert isinstance(e.value, CompileTimeAssertionFailure) + assert e.value.__cause__ is None + assert "at 2:4:" in str(e.value), "error should point to the static_assert call" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_unary_op(): + # Currently Triton can't evaluate `not` of a tuple at compile time. That's + # ok, but the error message needs to point to the correct spot. + @triton.jit + def kernel(): + not (0, 0) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert e.value.__cause__ is None + assert "at 2:4:" in str(e.value), "error should point to the `not`" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_op(): + + @triton.jit + def kernel(): + 1.0 << 1 + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the 1.0" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +# This has to be defined as a top-level function; jit'ed functions can't call +# nested functions. +@triton.jit +def nested_call(): + xyz # noqa + + +def test_err_in_nested_call(): + + @triton.jit + def kernel(): + # this is a comment to push nested_call() onto the next line + nested_call() + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + inner = e.value.__cause__ + outer = e.value + assert "at 2:4:" in str(inner), "error should point to xyz" + assert "" not in str(inner) + + assert "at 3:4" in str(outer), "error should point to the nested_call" + assert "" not in str(outer) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_builtin(): + + # The root error here comes from core.py. Make sure the stacktrace reflects + # this. + @triton.jit + def kernel(): + tl.expand_dims(None, -1) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + inner = e.value.__cause__ + outer = e.value + assert "/core.py" in '\n'.join(traceback.format_tb(inner.__traceback__)), "error should point inside core.py" + + assert "at 2:4:" in str(outer), "error should point to expand_dims call" + assert "" not in str(outer) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@triton.jit +def two_returns(): + return tl.arange(0, 4) + return tl.arange(0, 8) + + +def test_two_returns_no_err(): + # This program is valid; `a` has shape (10,). + @triton.jit + def kernel(): + a = two_returns() + a + tl.arange(0, 4) # only works if we took the first return + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +@triton.jit +def returns_branched_on_constexpr(N: tl.constexpr): + if N == 0: + return tl.arange(0, 4) + # Ideally this would work even without the `else`, but we're not that smart + # yet. + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_constexpr(): + + @triton.jit + def kernel1(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 4) + + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0})) + + @triton.jit + def kernel2(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 8) + + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1})) + + +@triton.jit +def returns_branched_on_non_constexpr(N: int): + if N == 0: + return tl.arange(0, 4) + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_non_constexpr(): + + @triton.jit + def kernel(N: int): + returns_branched_on_non_constexpr(N) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the function call" + assert "at 5:8:" in str(e.value.__cause__), "error should point to the second `return`" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_power_of_two_shapes(): + + @triton.jit + def kernel(): + tl.arange(2, 7) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert str(e.value.__cause__) == "arange's range must be a power of 2" + + +def test_power_of_two_shapes_2(): + + @triton.jit + def kernel(): + tl.full((33, ), 0, dtype=tl.int64) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" + + +def test_captured_var_access(): + + CAPTURED = 42 + + @triton.jit + def kernel(): + a = CAPTURED # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert "CAPTURED is not defined" in str(e.value) + + +GLOBAL = 42 + + +def test_global_var_access(): + + @triton.jit + def kernel(): + a = GLOBAL # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert "global variable" in str(e.value) + + +CONSTEXPR_ANNOTATED_GLOBAL: tl.constexpr = 42 + + +def test_constexpr_annotated_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_ANNOTATED_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_constexpr_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +TYPE_ALIAS = tl.pointer_type(tl.int32) + + +def test_global_type_alias_access(): + + @triton.jit + def kernel(): + a = TYPE_ALIAS # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +def test_global_access_in_fn_default_arg(): + + @triton.jit + def kernel(a=GLOBAL): + pass + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={0: "i32"}, constants={})) diff --git a/third_party/iluvatar/python/test/unit/language/test_conversions.py b/third_party/iluvatar/python/test/unit/language/test_conversions.py new file mode 100644 index 000000000..113469926 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_conversions.py @@ -0,0 +1,354 @@ +# fmt: off + + +import os +import numpy as np +import torch +import pytest +import triton +import triton.language as tl + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + +def is_cuda(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" + +def is_hip(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" + +def is_on_mi300(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') + +def matching_int(dtype): + if dtype.primitive_bitwidth == 8: + return torch.int8 + elif dtype.primitive_bitwidth == 16: + return torch.int16 + elif dtype.primitive_bitwidth == 32: + return torch.int32 + elif dtype.primitive_bitwidth == 64: + return torch.int64 + else: + raise ValueError('unsupported number of bits') + +@triton.jit +def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) + tl.store(dst + idxs, y) + + +def launch_type_convert_triton(src, src_dtype, dst_dtype, device, rounding=None, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE) + return dst + + +@triton.jit +def exhaustive_populate(dst, offset, BLOCK_SIZE : tl.constexpr, force_odd : tl.constexpr, output_bits : tl.constexpr, max_repr : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + vals = (idxs + offset).to(tl.uint32) + + # pseudorandom permutation: + multiplier = vals << 1 + multiplier += 3511 + vals *= multiplier + + if force_odd: + vals *= 2 + vals += 1 + + if (output_bits == 8): + vals &= 0xff + avals = vals & 0x7f + elif (output_bits == 16): + vals &= 0xffff + avals = vals & 0x7fff + elif (output_bits == 32): + avals = vals & 0x7fffffff + + vals = tl.where(avals <= max_repr, vals, 0) + + if (output_bits == 8): + vals = vals.to(tl.uint8) + elif (output_bits == 16): + vals = vals.to(tl.uint16) + + vals = vals.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, vals) + + +def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits, max_repr, device, BLOCK_SIZE=4096): + + assert(numel % BLOCK_SIZE == 0) + dst = torch.empty((numel,), dtype=matching_int(dst_dtype), device=device) + exhaustive_populate[(numel // BLOCK_SIZE,)](triton.reinterpret(dst, dst_dtype), offset, BLOCK_SIZE, force_odd, output_bits, max_repr) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. We don't need to have that + # as input to the conversion kernels. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(x.dtype == tl.float32, "input must be float32") + numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_dst == 8) or (numbits_dst == 16), "numbits_dst must be 8 or 16") + + x = x.to(tl.uint32, bitcast=True) + + mantissa = (x & 0x7fffff) + exponent = ((x >> 23) & 0xff).to(tl.int32) + mantissa = tl.where(exponent == 0, mantissa, mantissa + 0x800000).to(tl.int32) + exponent = tl.where(exponent == 0, exponent, exponent - 1) + + sign = (x >> 31) + + exponent = exponent + exponent_bias - 127 + adjustment : tl.constexpr = 0.5 ** (23 - mantissa_bits) + mantissa = mantissa.to(tl.float32) * adjustment + + # make exponent nonnegative: + mantissa = tl.where(exponent > -16, mantissa, 0.0) # destination has fewer than 16 mantissa bits, so safe + exponent = tl.where(exponent > -16, exponent, 0) + mantissa = tl.where(exponent > -8, mantissa, mantissa * 0.00390625) + exponent = tl.where(exponent > -8, exponent, exponent + 8) + mantissa = tl.where(exponent > -4, mantissa, mantissa * 0.0625) + exponent = tl.where(exponent > -4, exponent, exponent + 4) + mantissa = tl.where(exponent > -2, mantissa, mantissa * 0.25) + exponent = tl.where(exponent > -2, exponent, exponent + 2) + mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5) + exponent = tl.where(exponent > -1, exponent, exponent + 1) + + if rounding == 'rtne': + # Bring the value to the range [2 ** 23, 2 ** 24] + # where the representable floats map exactly to integers. + # Addition has RTNE semantics. + mantissa += 0x800000 + # Bring the value back to the original range. + mantissa -= 0x800000 + mantissa = mantissa.to(tl.int32) + elif rounding == 'rtz': + mantissa = mantissa.to(tl.int32) + else: + raise ValueError('unrecognized rounding mode') + + # Reassemble output floating-point representation: + exponent = exponent.to(tl.uint32) + y = (sign << (exponent_bits + mantissa_bits)) + (exponent << mantissa_bits) + mantissa + if numbits_dst == 8: + y = y.to(tl.uint8) + elif numbits_dst == 16: + y = y.to(tl.uint16) + return y + + +@triton.jit +def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias) + y = y.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, y) + + +def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + downcast_emulated[(src.shape[0] // BLOCK_SIZE,)]( + triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. downcast_emulated kernel will + # convert -0. in higher precision to 0x80 and thus need to fix the result to 0. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def upcast_emulated(src, dst, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + exponent_compensator : tl.constexpr = 2.0 ** (127 - exponent_bias) + + numbits_src : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_src == 8) or (numbits_src == 16), "numbits_src must be 8 or 16") + tl.static_assert(dst.dtype.element_ty == tl.float32, "dst dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + + if numbits_src == 8: + x = x.to(tl.uint8, bitcast=True) + elif numbits_src == 16: + x = x.to(tl.uint16, bitcast=True) + + x = x.to(tl.uint32) + + mantissa_mask : tl.constexpr = (1 << mantissa_bits) - 1 + exponent_mask : tl.constexpr = (1 << exponent_bits) - 1 + + mantissa = x & mantissa_mask + exponent = (x >> mantissa_bits) & exponent_mask + sign = (x >> (numbits_src - 1)) + + y = (sign << 31) | (exponent << 23) | (mantissa << (23 - mantissa_bits)) + y = y.to(tl.float32, bitcast=True) + y = y * exponent_compensator + + tl.store(dst + idxs, y) + + +def launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=torch.int32, device=device) + upcast_emulated[(src.shape[0] // BLOCK_SIZE,)](src, triton.reinterpret(dst, tl.float32), BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + return dst + + +def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, max_repr, offset, device): + + src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr, device) + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device, rounding=rounding) + src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device) + + dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device) + + dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias, device=device) + dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias, device=device) + + if not (torch.equal(dst, dst2)): + print('Error!!!') + + dst = dst.cpu().detach().numpy() + dst2 = dst2.cpu().detach().numpy() + src = src.cpu().detach().numpy() + + print(src[dst != dst2][0]) + print(dst[dst != dst2][0]) + print(dst2[dst != dst2][0]) + print(hex(src.view(np.uint32)[dst != dst2][0])) + print(hex(dst.view(np.uint32)[dst != dst2][0])) + print(hex(dst2.view(np.uint32)[dst != dst2][0])) + print('') + raise ValueError('%d elements mismatch' % (dst != dst2).sum()) + + +def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bias, max_repr, device): + + numbits_src = exponent_bits + mantissa_bits + 1 + + src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr, device=device) + + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device) + dst = launch_type_convert_triton(dst, dst_dtype, tl.float32, device=device) + + dst2 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device=device) + + assert(torch.equal(dst, dst2)) + + +@pytest.mark.parametrize("src_dtype, dst_dtype", [ + ('float16', 'float32'), + ('bfloat16', 'float32'), + + # ('float8e5', 'float16'), + # ('float8e5', 'bfloat16'), + # ('float8e5', 'float32'), + + # ('float8e4b15', 'float16'), + # ('float8e4b15', 'bfloat16'), # Unsupported conversion from f8E4M3B11FNUZ to bf16 + # ('float8e4b15', 'float32'), + + # ('float8e4nv', 'float16'), + # ('float8e4nv', 'bfloat16'), + # ('float8e4nv', 'float32'), + + # ('float8e4b8', 'float32'), + # ('float8e4b8', 'float16'), + + # ('float8e5b16', 'float32'), + # ('float8e5b16', 'float16'), +]) +def test_typeconvert_upcast(src_dtype, dst_dtype, device): + + if src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("float8e4nv upcast tests only supported on NVGPU with compute capability 9.0+") + + if src_dtype in ('float8e4nv', 'float8e4b15') and is_hip(): + pytest.skip(f"{src_dtype} upcast tests not supported on ROCm") + + if src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()): + pytest.skip("{src_dtype} upcast tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) + stuff = { + 'float8e4b15': (4, 3, 15, 0x7e), + 'float8e4nv': (4, 3, 7, 0x7e), + 'float8e5': (5, 2, 15, 0x7b), + 'float8e4b8': (4, 3, 8, 0x7f), + 'float8e5b16': (5, 2, 16, 0x7f), + 'float16': (5, 10, 15, 0x7bff), + 'bfloat16': (8, 7, 127, 0x7f7f), + }[src_dtype] + + upcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), *stuff, device=device) + +@pytest.mark.parametrize("src_dtype, dst_dtype, rounding, max_repr", [ + ('float32', 'float16', 'rtne', 0x477fe000), + ('float32', 'float16', 'rtz', 0x477fe000), + ('float32', 'bfloat16', 'rtne', 0x7f7f0000), + # ('float32', 'bfloat16', 'rtz', 0x7f7f0000), + # ('float32', 'float8e5', 'rtne', 0x47600000), + # ('float32', 'float8e5', 'rtz', 0x47600000), + # ('float32', 'float8e4nv', 'rtne', 0x43e00000), + # ('float32', 'float8e4b8', 'rtne', 0x43700000), + # ('float32', 'float8e5b16', 'rtne', 0x47600000), + # ('float32', 'float8e4b15', 'rtne', 0x3fe00000), # Skip, no HW rtne conversion from f32 to f8e4b15 + + # ('bfloat16', 'float8e5', 'rtne', 0x4760), + # ('bfloat16', 'float8e4nv', 'rtne', 0x43e0), + + # ('float16', 'float8e5', 'rtne', 0x7b00), + # ('float16', 'float8e4nv', 'rtne', 0x5f00), + + # ('bfloat16', 'float8e5b16', 'rtne', 0x4760), + # ('bfloat16', 'float8e4b8', 'rtne', 0x4370), + + # ('float16', 'float8e5b16', 'rtne', 0x7b00), + # ('float16', 'float8e4b8', 'rtne', 0x5b80), +]) +def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): + + if src_dtype != 'float32' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias) + stuff = { + 'float16': (5, 10, 15), + 'bfloat16': (8, 7, 127), + 'float8e5': (5, 2, 15), + 'float8e4b15': (4, 3, 15), + 'float8e4nv': (4, 3, 7), + 'float8e4b8': (4, 3, 8), + 'float8e5b16': (5, 2, 16), + }[dst_dtype] + + for i in range(256): + downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i, device=device) diff --git a/third_party/iluvatar/python/test/unit/language/test_core.py b/third_party/iluvatar/python/test/unit/language/test_core.py new file mode 100644 index 000000000..a011d5981 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_core.py @@ -0,0 +1,5455 @@ +# flake8: noqa: F821,F841 +import itertools +import re +from typing import Optional, Union +import math +import textwrap +import tempfile + +import numpy as np +import pytest +import torch +import os +import inspect +from numpy.random import RandomState + +import triton +import triton.language as tl +from triton.runtime.jit import TensorWrapper, reinterpret +from triton.runtime.build import is_corex + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cuda(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "hip" + + +int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] +if is_corex(): + float_dtypes = ['float16', 'float32'] +else: + float_dtypes = ['float16', 'float32', 'float64'] +dtypes = int_dtypes + uint_dtypes + float_dtypes +dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] +torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] + +# TODO: enable multiple cta cluster testing. +# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] +num_ctas_list = [1] + +GPU_DIALECT = "triton_gpu" +if is_interpreter(): + THREADS_PER_WARP = 1 +elif is_hip() or is_corex(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +def _bitwidth(dtype: str) -> int: + # ex.: "int64" -> 64 + return int(re.search(r'(\d+)$', dtype).group(1)) + + +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): + """ + Override `rs` if you're calling this function twice and don't want the same + result for both calls. + """ + if isinstance(shape, int): + shape = (shape, ) + if rs is None: + rs = RandomState(seed=17) + if dtype_str in int_dtypes + uint_dtypes: + iinfo = np.iinfo(getattr(np, dtype_str)) + low = iinfo.min if low is None else max(low, iinfo.min) + high = iinfo.max if high is None else min(high, iinfo.max) + dtype = getattr(np, dtype_str) + x = rs.randint(low, high, shape, dtype=dtype) + x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out. + return x + elif dtype_str and 'float8' in dtype_str: + x = rs.randint(20, 40, shape, dtype=np.int8) + return x + elif dtype_str in float_dtypes: + return rs.normal(0, 1, shape).astype(dtype_str) + elif dtype_str == 'bfloat16': + return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') + elif dtype_str in ['bool', 'int1', 'bool_']: + return rs.normal(0, 1, shape) > 0.0 + else: + raise RuntimeError(f'Unknown dtype {dtype_str}') + + +def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]: + ''' + Note: We need dst_type because the type of x can be different from dst_type. + For example: x is of type `float32`, dst_type is `bfloat16`. + If dst_type is None, we infer dst_type from x. + ''' + t = x.dtype.name + if t in uint_dtypes: + signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" + x_signed = x.astype(getattr(np, signed_type_name)) + return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) + else: + if dst_type and 'float8' in dst_type: + return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type)) + if t == 'float32' and dst_type == 'bfloat16': + return torch.tensor(x, device=device).bfloat16() + return torch.tensor(x, device=device) + + +def torch_dtype_name(dtype) -> str: + if isinstance(dtype, triton.language.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + # 'torch.int64' -> 'int64' + m = re.match(r'^torch\.(\w+)$', str(dtype)) + return m.group(1) + else: + raise TypeError(f'not a triton or torch dtype: {type(dtype)}') + + +def to_numpy(x): + if isinstance(x, TensorWrapper): + return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) + elif isinstance(x, torch.Tensor): + if x.dtype is torch.bfloat16: + return x.cpu().float().numpy() + return x.cpu().numpy() + else: + raise ValueError(f"Not a triton-compatible tensor: {x}") + + +def patch_kernel(template, to_replace): + if is_interpreter(): + local_namespace = {} + src = textwrap.dedent(inspect.getsource(template.fn)) + for k, v in to_replace.items(): + src = src.replace(k, v) + exec(src, globals(), local_namespace) + return local_namespace[template.fn.__name__] + else: + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel.src = kernel.src.replace(key, value) + return kernel + + +def check_cuda_or_hip(device): + # CUDA and HIP both use pytorch device 'cuda'. Other backends like Intel + # GPU do not. + if device not in ['cuda']: + pytest.skip("Only for cuda") + + +def check_type_supported(dtype, device): + ''' + skip test if dtype is not supported on the current device + ''' + if device in ['cuda']: + cc = torch.cuda.get_device_capability() + if not is_corex(): + if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}: + pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") + if is_interpreter(): + if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]: + pytest.skip("bfloat16 is not supported in the interpreter") + + +class MfmaLayout: + + def __init__(self, version, warps_per_cta, instr_shape, is_transposed): + self.version = version + self.warps_per_cta = warps_per_cta + self.instr_shape = instr_shape + self.is_transposed = is_transposed + + def __str__(self): + return f"#{GPU_DIALECT}.amd_mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, instrShape={self.instr_shape}, isTransposed = {str(self.is_transposed).lower()}}}>" + + +class WmmaLayout: + + def __init__(self, warps_per_cta): + self.warps_per_cta = warps_per_cta + + def __str__(self): + return f"#{GPU_DIALECT}.amd_wmma<{{warpsPerCTA = {self.warps_per_cta}}}>" + + +class MmaLayout: + + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): + self.version = version + self.warps_per_cta = warps_per_cta + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + self.instr_shape = instr_shape + + def __str__(self): + return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class SharedLayout: + + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): + self.vec = vec + self.per_phase = per_phase + self.max_phase = max_phase + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +def is_layout_applicable(layout) -> bool: + if isinstance(layout, (BlockedLayout, SharedLayout)): + return True + elif is_cuda(): + return isinstance(layout, MmaLayout) + elif is_hip(): + target_arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in target_arch: + # RDNA 3 + return isinstance(layout, WmmaLayout) + elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch): + # CDNA 1, 2, 3 + return isinstance(layout, MfmaLayout) + else: + return False + else: + return True + + +def filter_layouts(layouts): + return [l for l in layouts if is_layout_applicable(l)] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) +def test_empty_kernel(dtype_x, device): + SIZE = 128 + + @triton.jit + def kernel(X, SIZE: tl.constexpr): + pass + + check_type_supported(dtype_x, device) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) + kernel[(1, )](x, SIZE=SIZE, num_warps=4) + + +# generic test functions +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + # inputs + x = numpy_random(SIZE, dtype_str=dtype_x) + if 'log' in expr: + x = np.abs(x) + 0.01 + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_x) + kernel[(1, )](Z=z_tri, X=x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + # compare + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + + +def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: + """ + Given two dtype strings, returns the numpy dtype Triton thinks binary + operations on the two types should return. Returns None if the return value + matches numpy. This is generally needed because Triton and pytorch return + narrower floating point types than numpy in mixed operations, and because + Triton follows C/C++ semantics around mixed signed/unsigned operations, and + numpy/pytorch do not. + """ + overrides = { + ('float16', 'int16'): np.float16, + ('float16', 'int32'): np.float16, + ('float16', 'int64'): np.float16, + ('float16', 'uint16'): np.float16, + ('float16', 'uint32'): np.float16, + ('float16', 'uint64'): np.float16, + ('int8', 'uint8'): np.uint8, + ('int8', 'uint16'): np.uint16, + ('int8', 'uint32'): np.uint32, + ('int8', 'uint64'): np.uint64, + ('int16', 'uint16'): np.uint16, + ('int16', 'uint32'): np.uint32, + ('int16', 'uint64'): np.uint64, + ('int32', 'uint32'): np.uint32, + ('int32', 'uint64'): np.uint64, + ('int64', 'uint64'): np.uint64, + } + key = (a, b) if a < b else (b, a) + return overrides.get(key) + + +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, + y_low=None, y_high=None, test_broadcast=True): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + check_type_supported(dtype_y, device) + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_lhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + replacements = {'GENERATE_TEST_HERE': expr} + kernel = patch_kernel(kernel, replacements) + kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements) + kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements) + + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) + if mode_x == 'nan': + x[:] = float('nan') + if mode_y == 'nan': + y[:] = float('nan') + + def do_test(x, y, kernel_fn): + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) + if dtype_z is not None: + z_ref = z_ref.astype(dtype_z) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) + kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + err_msg = f"{expr}, {kernel_fn.__name__}" + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=1e-3, rtol=0.01) + + do_test(x, y, kernel) + if test_broadcast: + do_test(x[:1].reshape(()), y, kernel_broadcast_lhs) + do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) + + +def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: + # The result of x % y is ill-conditioned if x % y is much smaller than x. + # pytorch/CUDA has slightly different (probably better) rounding on + # remainders than stock LLVM. We currently don't expect to match it + # bit-for-bit. + return (dtype_x, dtype_y) in [ + ('int32', 'bfloat16'), + ('int32', 'float16'), + ('int32', 'float32'), + ('int64', 'bfloat16'), + ('int64', 'float16'), + ('int64', 'float32'), + ('int64', 'float64'), + ('uint16', 'bfloat16'), + ('uint16', 'float16'), + ('uint16', 'float32'), + ('uint32', 'bfloat16'), + ('uint32', 'float16'), + ('uint32', 'float32'), + ('uint64', 'bfloat16'), + ('uint64', 'float16'), + ('uint64', 'float32'), + ('uint64', 'float64'), + ] + + +def test_dtype_codegen(): + for dtype in dtypes_with_bfloat16: + full_name = f"triton.language.{dtype}" + assert repr(eval(full_name)) == full_name + + +# --------------- +# test binary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes_with_bfloat16 + for dtype_y in dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f' x {op} y' + if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = 'np.fmod(x, y)' + elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', + 'bfloat16'): + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' + elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): + with pytest.raises(AssertionError, match="Not equal to tolerance"): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + else: + _test_binary( + dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, + # fails with values where fmod(x, y) is roughly zero, but happens to + # pass with the random values chosen for non-broadcast tests + test_broadcast=(op != "%")) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) +def test_addptr(dtype, order, device): + check_type_supported(dtype, device) + + @triton.jit + def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): + offs = tl.arange(0, SIZE) + if ORDER == 0: + tl.store(y + offs, tl.load(x + offs)) + else: + tl.store(offs + y, tl.load(offs + x)) + + SIZE = 1024 + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + x_tri = to_triton(x, dst_type=dtype, device=device) + y_tri = to_triton(y, dst_type=dtype, device=device) + y = x + kernel[ + 1, + ](x_tri, y_tri, order, SIZE) + np.testing.assert_allclose(y, to_numpy(y_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y", [ # + (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes +] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_floordiv(dtype_x, dtype_y, num_ctas, device): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +def test_unsigned_name_mangling(device): + # Test that uint32 and int32 are mangled differently by the compiler + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(O1, O2, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + out1 = tl.abs(x) # uint32 -> nop + out2 = tl.abs(-y) # int32 -> should have an effect + tl.store(O1 + off, out1) + tl.store(O2 + off, out2) + + dtype_x = 'uint32' + dtype_y = 'int32' + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + # reference result + expect = (np.abs(x), np.abs(-y)) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect) + kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) + + # Bitwise op, so expect exact equality + assert (expect[0] == to_numpy(actual[0])).all() + assert (expect[1] == to_numpy(actual[1])).all() + + +# test bitwise ops +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if 'float' in dtype_x + dtype_y: + # The CompilationError must have been caused by a C++ exception with this text. + with pytest.raises(triton.TritonError, match='invalid operands of type'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device, num_ctas=num_ctas) + else: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['<<', '>>'] + for dtype_x in int_dtypes + uint_dtypes + for dtype_y in int_dtypes + uint_dtypes +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + if dtype_x.startswith('int'): + dtype_z = f'int{bw}' + else: + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, y_low=0, y_high=bw) + + +# --------------- +# test compare ops +# --------------- +ops = ['==', '!=', '>', '<', '>=', '<='] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "dtype_x, dtype_y, op, mode_x, mode_y", + # real + [(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes] + # NaNs + + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas) + + +# --------------- +# test broadcast +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) +def test_broadcast(dtype, device): + check_type_supported(dtype, device) + + @triton.jit + def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) + y = tl.load(y_ptr + offset2) + _, y_broadcasted = tl.broadcast(x, y) + tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + + M = 32 + N = 64 + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype, rs=rs) + y = numpy_random(N, dtype_str=dtype, rs=rs) + _, y_broadcasted_np = np.broadcast_arrays(x, y) + + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) + + broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + + +# ---------- +# test slice +# ---------- + + +@pytest.mark.interpreter +def test_slice(device): + + @triton.jit + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) + + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1, )](XBLOCK=32) + + +# ------------------ +# test invalid slice +# ------------------ + + +@pytest.mark.interpreter +def test_invalid_slice(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + dst[10:] + + with pytest.raises(triton.TritonError, match='unsupported tensor index'): + _kernel[(1, )](dst=dst) + + +# ---------------- +# test expand_dims +# ---------------- +@pytest.mark.interpreter +def test_expand_dims(device): + + @triton.jit + def expand_dims_kernel(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 0) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, 1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -2) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, (0, -1)) + tl.static_assert(t.shape == [1, N, 1]) + + t = tl.expand_dims(offset1, (0, 1, 3)) + tl.static_assert(t.shape == [1, 1, N, 1]) + + t = tl.expand_dims(offset1, (-4, 2, -1)) + tl.static_assert(t.shape == [1, N, 1, 1]) + + t = tl.expand_dims(offset1, (3, 1, 2)) + tl.static_assert(t.shape == [N, 1, 1, 1]) + + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + + # N is a scalar that's not even a tl.tensor -- this should work too. + t = tl.expand_dims(N, -1) + tl.static_assert(t.shape == [1]) + + N = 32 + dummy_tensor = torch.empty((), device=device) + expand_dims_kernel[(1, )](dummy_tensor, N) + + +@pytest.mark.interpreter +def test_expand_dims_error_cases(device): + + @triton.jit + def dim_out_of_range1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, -2) + t = tl.expand_dims(offset1, -3) + + @triton.jit + def dim_out_of_range2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 1) + t = tl.expand_dims(offset1, 2) + + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + + @triton.jit + def duplicate_dim1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, 0)) + + @triton.jit + def duplicate_dim2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, -3)) + + N = 32 + dummy_tensor = torch.empty((), device=device) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range1[(1, )](dummy_tensor, N) + assert "invalid axis -3" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range2[(1, )](dummy_tensor, N) + assert "invalid axis 2" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range3[(1, )](dummy_tensor, N) + assert "invalid axis 1" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim1[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim2[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + +# ---------------------------- +# test invalid program id axis +# ---------------------------- +@pytest.mark.interpreter +def test_invalid_pid_axis(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pid = tl.program_id(20) + + with pytest.raises(triton.TritonError) as exc_info: + _kernel[(1, )](dst) + assert re.search(r"program_id axis must be 0, 1, or 2 but got 20", str(exc_info.value.__cause__)) + + +# --------------- +# test where +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where(dtype, num_ctas, device): + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype, device) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_SCALAR_POINTERS: + ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr) + output = tl.load(ptr + offsets, mask=mask) + else: + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) + + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) + + cond_tri = to_triton(cond, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) + + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) + assert (z == to_numpy(z_tri)).all() + if select_ptrs: + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=True) + z = np.where(cond[0], x, y) + assert (z == to_numpy(z_tri)).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where_broadcast(num_ctas, device): + + @triton.jit + def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + + mask = tl.load(cond_ptr + yoffsets) + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + @triton.jit + def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + mask = 0 + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + SIZE = 32 + dtype = 'float32' + rs = RandomState(17) + x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) + mask = numpy_random(SIZE, 'bool', rs=rs) + z = np.where(mask, x, 0) + cond_tri = to_triton(mask, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) + where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE) + assert (z == to_numpy(z_tri)).all() + where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas) + z = np.where(0, x, 0) + assert (z == to_numpy(z_tri)).all() + + +# --------------- +# test unary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr", + [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') + for dtype_x in int_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_unary_op(dtype_x, expr, num_ctas, device): + _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) + + +# ---------------- +# test math ops +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr, x", + [(dtype_x, expr, x) + for dtype_x in (["float32", "float64"] if not is_corex() else ["float32"]) + for expr in ['exp', 'log', 'cos', 'sin', 'exp2', 'log2', 'sqrt', 'floor', 'ceil'] + for x in ['x', '3.0']]) +def test_math_op(dtype_x, expr, x, device): + _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in (["float32", "float64"] if not is_corex() else ["float32"])]) +def test_math_erf_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.math.erf(x) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = torch.erf(x) + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_fma_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, Y, W, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + w = tl.load(W + off) + z = tl.math.fma(x, y, w) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + y = torch.randn(SIZE, dtype=torch_dtype, device=device) + w = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = x * y + w + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, y, w, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_math_divide_op(expr, num_ctas, device): + numpy_expr = "x / y" + dtype = "float32" + _test_binary(dtype, dtype, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +# ------------- +# test precise math +# ------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("expr_prec, expr_ref", [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x)'), + ('tl.math.div_rn(x,y)', 'x / y')]) +#[('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), +# ('tl.math.div_rn(x,y)', '(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)')]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_precise_math(expr_prec, expr_ref, num_ctas, device): + + @triton.jit + def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + prec = PREC_CALC + ref = REF_CALC + tl.store(OUT + tl.arange(0, BLOCK), prec) + tl.store(OUT_REF + tl.arange(0, BLOCK), ref) + + shape = (128, ) + out = torch.zeros(shape, dtype=torch.float32, device=device) + out_ref = torch.zeros(shape, dtype=torch.float32, device=device) + + x = torch.randn(shape, dtype=torch.float32, device=device) + y = torch.randn(shape, dtype=torch.float32, device=device) + + if (expr_prec.count('sqrt') > 0): + x = torch.abs(x) + + if (expr_prec.count('div') > 0): + y += 1e-6 + + kernel = patch_kernel(kernel, {'PREC_CALC': expr_prec, 'REF_CALC': expr_ref}) + + kernel[(1, )](x, y, out, out_ref, BLOCK=shape[0], num_ctas=num_ctas) + assert torch.all(out == out_ref) # bitwise exact + + +# ---------------- +# test abs +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_abs(dtype_x, device): + _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) +def test_abs_fp8(in_dtype, device): + if is_hip(): + pytest.skip('test_abs_fp8 not supported on HIP.') + + @triton.jit + def abs_kernel(X, Z, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.abs(x) + tl.store(Z + off, z) + + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device=device) + # f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan + all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width + f8_tensor[all_exp_ones] = 0 + f8 = triton.reinterpret(f8_tensor, in_dtype) + n_elements = f8_tensor.numel() + out_f8 = torch.empty_like(f8_tensor) + abs_kernel[(1, )](f8, triton.reinterpret(out_f8, in_dtype), n_elements) + + f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) + expect = f32_tensor.abs() + actual_f8 = convert_float_to_float32(out_f8, in_dtype) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) + + +# ---------------- +# test passing shapes as individual params rather than tuples +# ---------------- + + +@pytest.mark.interpreter +def test_shapes_as_params(device): + + @triton.jit + def kernel(): + a = tl.arange(0, 32).expand_dims(-1).broadcast_to(32, 32) + tl.static_assert(a.shape == [tl.constexpr(32), tl.constexpr(32)]) + + a = tl.arange(0, 32).reshape(4, 8).permute(1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).reshape(32) + tl.static_assert(a.shape == [tl.constexpr(32)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.arange(0, 64).view(2, 4, 8) + tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) + + kernel[(1, )]() + + +# ---------------- +# test transpose +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_transpose(dtype_x, device): + check_type_supported(dtype_x, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + off2d = off[None, :] + (tl.arange(0, 2) * SIZE)[:, None] + x = tl.load(X + off2d) + z = x.T + tl.store(Z + off2d.T, z) + + x = numpy_random([SIZE, 2], dtype_str=dtype_x) + z_ref = x.T + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE) + np.testing.assert_allclose(z_ref, to_numpy(z_tri)) + + +# ---------------- +# test indexing +# ---------------- + + +def make_ptr_str(name, shape): + rank = len(shape) + offsets = [] + stride = 1 + for i in reversed(range(rank)): + idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) + offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] + stride *= shape[i] + return f"{name} + {' + '.join(offsets)}" + + +# TODO: handle `%4 = triton_gpu.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` +@pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16']]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_index1d(expr, dtype_str, num_ctas, device): + rank_x = expr.count(':') + rank_y = expr.count(',') + 1 + shape_x = [32 for _ in range(rank_x)] + shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] + + # Triton kernel + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + m = tl.arange(0, SIZE) + n = tl.arange(0, SIZE) + x = tl.load(X_PTR_EXPR) + z = GENERATE_TEST_HERE + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) + + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) + + # torch result + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) + z_ref = eval(expr) + y + # triton result + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x, device=device) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + # compare + assert (z_ref == to_numpy(z_tri)).all() + + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0], num_ctas=num_ctas) + except triton.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) + + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) + + +# --------------- +# test tuples +# --------------- + + +@triton.jit +def tuples_fn(a, b): + return a + b, \ + a - b, \ + a * b + + +@pytest.mark.interpreter +def test_tuples(device): + + @triton.jit + def with_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = tuples_fn(x, y) + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + @triton.jit + def without_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = x + y, x - y, x * y + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + x = torch.tensor([1.3], device=device, dtype=torch.float32) + y = torch.tensor([1.9], device=device, dtype=torch.float32) + a_tri = torch.tensor([0], device=device, dtype=torch.float32) + b_tri = torch.tensor([0], device=device, dtype=torch.float32) + c_tri = torch.tensor([0], device=device, dtype=torch.float32) + for kernel in [with_fn, without_fn]: + kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) + a_ref, b_ref, c_ref = x + y, x - y, x * y + assert a_tri == a_ref + assert b_tri == b_ref + assert c_tri == c_ref + + +@triton.jit(noinline=True) +def noinline_simple_fn(x, y, Z): + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_graph_fn1(x): + return x + 1 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn2(y): + return y + 2 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn(x, y, Z): + t0 = noinline_call_graph_fn1(x) + t1 = noinline_call_graph_fn2(y) + z = t0 + t1 + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_shared_fn(x, y, Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + x + y + tl.store(Z + offs, z) + + +@triton.jit(noinline=True) +def noinline_dynamic_fn(x, y, Z): + if x >= 1: + x = noinline_call_graph_fn1(x) + else: + x = noinline_call_graph_fn2(x) + if y >= 2: + y = noinline_call_graph_fn2(y) + else: + y = noinline_call_graph_fn1(y) + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_multi_values_fn(x, y): + return x + 1, y + 2 + + +@triton.jit(noinline=True) +def noinline_multi_values_fn(x, y, Z): + x, y = noinline_call_multi_values_fn(x, y) + z = x + y + tl.store(Z, z) + + +@pytest.mark.skip(reason="compiler do not support func call in llvmir until 2025 Q2") +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) +def test_noinline(mode, device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + GENERATE_TEST_HERE(x, y, Z) + + func_name = f'noinline_{mode}_fn' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name}) + x = torch.tensor([1.0], device=device, dtype=torch.float32) + y = torch.tensor([2.0], device=device, dtype=torch.float32) + if mode == "shared": + z = torch.ones((16, 16), device=device, dtype=torch.float32) + else: + z = torch.tensor([0.0], device=device, dtype=torch.float32) + kernel[(1, )](x, y, z, num_warps=1) + if mode == "simple": + assert torch.equal(z, x + y) + elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": + assert torch.equal(z, x + 1 + y + 2) + elif mode == "shared": + ref = torch.full((16, 16), 16, device=device, dtype=torch.float32) + assert torch.equal(z, ref + x + y) + + +# --------------- +# test atomics +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_x_str, mode, sem", + itertools.chain.from_iterable([[ + ('add', 'float16', mode, sem), + ('add', 'uint32', mode, sem), + ('add', 'int32', mode, sem), + ('add', 'float32', mode, sem), + #('add', 'uint64', mode, sem), + ('add', 'int64', mode, sem), + #('add', 'float64', mode, sem), + ('max', 'uint32', mode, sem), + ('max', 'int32', mode, sem), + ('max', 'float32', mode, sem), + #('max', 'uint64', mode, sem), + #('max', 'int64', mode, sem), + #('max', 'float64', mode, sem), + ('min', 'uint32', mode, sem), + ('min', 'int32', mode, sem), + ('min', 'float32', mode, sem), + #('min', 'uint64', mode, sem), + #('min', 'int64', mode, sem), + #('min', 'float64', mode, sem), + ] + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] + for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) +def test_atomic_rmw(op, dtype_x_str, mode, sem, device): + if is_interpreter(): + if dtype_x_str == 'float16': + pytest.skip("Only test atomic float16 ops on GPU") + + n_programs = 5 + + # triton kernel + @triton.jit + def kernel(X, Z): + pid = tl.program_id(0) + x = tl.load(X + pid) + old = GENERATE_TEST_HERE + tl.static_assert(old.dtype == x.dtype) + + sem_arg = sem if sem is None else f'"{sem}"' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'}) + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max + neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] + + # triton result + rs = RandomState(17) + x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str)) + if mode == 'all_neg': + x = -np.abs(x) + if mode == 'all_pos': + x = np.abs(x) + if mode == 'min_neg': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 + if mode == 'max_pos': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device) + + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) + h = kernel[(n_programs, )](x_tri, z_tri) + # torch result + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) + # compare + exact = op not in ['add'] + if exact: + assert z_ref.item() == to_numpy(z_tri).item() + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + sem_str = "acq_rel" if sem is None else sem + if not is_cuda() or is_corex(): + return + assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_rmw_predicate(num_ctas, device): + + @triton.jit + def kernel(X): + val = tl.program_id(0) + if val < 64: + tl.atomic_max(X, val) + + x = torch.zeros((1, ), device=device, dtype=torch.int32) + kernel[(4096, )](x, num_ctas=num_ctas) + assert x.item() == 63 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, device): + shape0, shape1 = shape + # triton kernel + + @triton.jit + def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + z = tl.sum(x, axis=AXIS) + if AXIS == 1: + tl.atomic_add(Z + off0, z) + else: + tl.atomic_add(Z + off1, z) + + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + # reference result + z_ref = np.sum(x, axis=axis, keepdims=False) + # triton result + x_tri = to_triton(x, device=device) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) + kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_rmw_block(num_ctas, device): + shape = (8, 8) + + @triton.jit + def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + offs = off0[:, None] * SHAPE1 + off1[None, :] + val = offs.to(tl.float32) + x = X + offs + tl.atomic_min(x, val) + + x = torch.ones((8, 8), device=device, dtype=torch.float32) + kernel[(2, )](x, shape[0], shape[1], num_ctas=num_ctas) + assert torch.min(x).item() == 0.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_cas(sem, num_ctas, device): + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock): + tl.atomic_cas(Lock, 0, 1) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + change_value[(1, )](Lock) + + assert (Lock[0] == 1) + + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock, SEM: tl.constexpr): + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, 0, 1, SEM) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # release lock + tl.atomic_xchg(Lock, 0) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + data = torch.zeros((128, ), device=device, dtype=torch.float32) + ref = torch.full((128, ), 2000.0) + h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas) + sem_str = "acq_rel" if sem is None else sem + np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) + if not is_cuda() or is_corex(): + return + assert f"atom.global.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_cas(sem, num_ctas, device): + + @triton.jit + def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr, USE_INT64: tl.constexpr = True): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + if USE_INT64: + dtype = tl.int64 + else: + dtype = tl.int32 + t1 = tl.full((BLOCK_SIZE, ), 0, dtype=dtype) + t2 = tl.full((BLOCK_SIZE, ), 2, dtype=dtype) + tl.atomic_cas(X + offsets, t1, t2, sem=sem) + + if is_corex(): + dtype = torch.int32 + else: + dtype = torch.int64 + X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=dtype) + Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=dtype) + + change_value[(2, )](X, 4, sem, not is_corex()) + assert (torch.equal(X, Y)) + + +# --------------- +# test cast +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", + [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'int1', False, 1024), + ('int8', 'bfloat16', False, 1024), + ] + [(f'uint{x}', f'int{x}', True, 1024) + for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) + for x in [8, 16, 32, 64]] + + (([(dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32]] # + + [(dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32]]) if torch.__version__ >= "2.1" and not is_corex() else [])) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): + # CUDA: bfloat16 on cc < 80 will not be tested + # Interpreter: Only bfloat16 <-> float32 is supported + if not is_interpreter() or \ + (is_interpreter() and not ((dtype_z == 'bfloat16' and dtype_x == 'float32') + or (dtype_z == 'float32' and dtype_x == 'bfloat16'))): + check_type_supported(dtype_x, device) + check_type_supported(dtype_z, device) + + if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + + torch.manual_seed(0) + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + if dtype_x.startswith('bfloat'): + x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) + elif dtype_x.startswith('float8'): + x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x)) + else: + x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10 + # Triton clamps negative values to zero, while numpy wraps around + # intmax, so avoid negatives for now. + # TODO: figure out which one should actually be happening, and test it + if dtype_z in uint_dtypes: + x = np.absolute(x) + x_tri = to_triton(x, device=device) + if 'float' in dtype_z and 'float' in dtype_x: + # make sure we use values that can be represented in both types + x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x)) + # triton kernel + + @triton.jit + def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): + x_ptr = X + tl.arange(0, SIZE) + z_ptr = Z + tl.arange(0, SIZE) + x = tl.load(x_ptr) + + # Depending on the value of ARG_HASH (a "random" number determined by + # the test parameters), spell the cast one of three different ways. + if ARG_HASH % 3 == 0: + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 3 == 1: + z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) + else: + z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) + + tl.store(z_ptr, z) + + # "Random" number used inside the kernel to determine how we spell the cast. + # This way we don't have to increase the number of tests. + arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) + + dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' + # triton result + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) + elif dtype_z.startswith('float8'): + z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) + else: + z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) + kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, num_ctas=num_ctas) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( + 'float8') or dtype_x.startswith('float8'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + if dtype_z.startswith('float8') and device not in ['cuda']: + t = z_ref.byte() ^ z_tri.byte() + torch.testing.assert_close(torch.zeros_like(t, dtype=torch.uint8), t) + else: + torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z_np)) + else: + z_ref = x.astype(getattr(np, dtype_z_np)) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, num_warps", + [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) +def test_cat(dtype_str, num_warps, device): + check_type_supported(dtype_str, device) + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.cat(x, y, can_reorder=True) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) + y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) + z_ref = torch.cat([x, y], dim=0).sum() + z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](x, y, z, N=128, num_warps=num_warps) + assert z.sum() == z_ref + # check if there's no duplicate value in z + assert z.unique().size(0) == z.size(0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", list(torch_dtypes)) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant(dtype_str, num_ctas, device): + check_type_supported(dtype_str, device) + """Tests that boolean True is stored as 1""" + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + output = GENERATE_TEST_HERE + tl.store(output_ptr + offsets, output, mask=mask) + + triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'}) + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + + assert torch.all(output == ref) + + +def test_load_store_same_ptr(device): + + @triton.jit() + def kernel(in_out_ptr): + pid = tl.program_id(axis=0) + x = tl.load(in_out_ptr + pid) + out = x * 2 + tl.store(in_out_ptr + pid, out) + + for _ in range(1000): + x = torch.ones((65536, ), device=device, dtype=torch.float32) + if is_hip(): + kernel[(65536, )](x, num_warps=16) # threads per Warp for ROCM is 64 + else: + kernel[(65536, )](x, num_warps=32) + assert torch.all(x == 2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['int32']) +def test_umulhi(dtype_str, device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.umulhi(x, y) + tl.store(Z + tl.arange(0, N), z) + + def umulhi32(a, b): + # Convert to 64-bit unsigned integers to prevent overflow + a_64 = a.astype(np.int64) + b_64 = b.astype(np.int64) + + # Perform the multiplication in 64-bit + product_64 = a_64 * b_64 + + # Shift right by 32 bits to get the high part of the product + result_high_32 = product_64 >> 32 + return result_high_32 + + rs = RandomState(17) + N = 128 + x = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + x_tri = to_triton(x, device=device) + y = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + y_tri = to_triton(y, device=device) + z_tri = torch.zeros_like(x_tri) + kernel[(1, )](x_tri, y_tri, z_tri, N=N) + + z_ref = umulhi32(x, y) + np.testing.assert_equal(z_ref, to_numpy(z_tri)) + + +@pytest.mark.interpreter +def test_join(device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.join(x, y) + tl.store(Z + tl.arange(0, N)[:, None] * 2 + tl.arange(0, 2)[None, :], z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(-128, 0, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, y, z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + z = tl.join(x, y) + tl.static_assert(z.shape == [2]) + tl.store(Z + tl.arange(0, 2), z) + + x = torch.full([1], 42, device=device).to(torch.int32) + y = torch.full([1], 100, device=device).to(torch.int32) + z = torch.zeros([2], device=device) + kernel[(1, )](x, y, z) + + np.testing.assert_equal([42, 100], to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_with_mma(device): + + @triton.jit + def kernel(X, Z): + x = tl.load(X + 16 * tl.arange(0, 32)[:, None] + tl.arange(0, 16)[None, :]) # (32,16) + x2 = tl.join(x, 2 * x) # (32,16,2) + x3 = tl.reshape(x2, (32, 32)) + z = tl.dot(x3, x3) # (32,32) + tl.store(Z + 32 * tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :], z) + + x = torch.arange(0, 32 * 16, device=device, dtype=torch.float32).reshape((32, 16)) + r = torch.stack([x, 2 * x], dim=-1).reshape((32, 32)) + z_ref = torch.matmul(r, r) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, z) + + torch.testing.assert_close(z, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("debug", [False, True]) +def test_interleave(device, debug): + + @triton.jit(debug=debug) + def kernel(Z, N: tl.constexpr): + z = tl.interleave(tl.arange(0, N), tl.arange(N, 2 * N)) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(128, 256, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1).reshape(256) + z = torch.zeros_like(z_ref) + kernel[(1, )](z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_interleave_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + z = tl.interleave(X, Y) + tl.static_assert(z.shape == [tl.constexpr(2)]) + tl.store(Z + tl.arange(0, 2), z) + + z = torch.zeros(2, device=device) + kernel[(1, )](10, 20, z) + + np.testing.assert_equal([10, 20], to_numpy(z)) + + +@pytest.mark.interpreter +def test_split(device): + + @triton.jit + def kernel(X, Z1, Z2, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + x1 = tl.reshape(x, (N // 2, 2)) + z1, z2 = tl.split(x1) + tl.store(Z1 + tl.arange(0, N // 2), z1) + tl.store(Z2 + tl.arange(0, N // 2), z2) + + x = torch.arange(0, 256, device=device).to(torch.int32).reshape((128, 2)) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2, N=256) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +@pytest.mark.interpreter +def test_split_to_scalar(device): + + @triton.jit + def kernel(X, Z1, Z2): + offs = tl.arange(0, 2) + x = tl.load(X + offs) + z1, z2 = tl.split(x) + tl.static_assert(isinstance(z1, tl.tensor)) + tl.static_assert(isinstance(z2, tl.tensor)) + tl.static_assert(z1.shape == []) + tl.static_assert(z2.shape == []) + tl.store(Z1, z1) + tl.store(Z2, z2) + + N = 2 + x = torch.arange(0, N, device=device).reshape(N // 2, 2) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +def convert_float_to_float32(fp: torch.tensor, dtype=None): + if not dtype: + dtype = getattr(tl, torch_dtype_name(fp.dtype)) + + fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}")) + exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1 + exp_bias = dtype.exponent_bias + sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int() + exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int() + frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int() + + output = torch.where( + exp == 0, + # subnormal + ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (frac / (2.0**dtype.fp_mantissa_width)), + # normal + ((-1.0)**sign) * (2.0**(exp - exp_bias)) * (1.0 + frac / (2.0**dtype.fp_mantissa_width))).float() + + extended_exp = ( + (1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width + # special cases, exp is 0b11..1 + if dtype in [tl.float8e4nv, tl.float8e4b15]: + # float8e4m3nv does not have infinities + output[fp == 0b01111111] = torch.nan + output[fp == 0b11111111] = torch.nan + else: + output = torch.where(exp == (1 << exp_width) - 1, + ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp + | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))) # + .view(torch.float32), output) + return output + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) +def test_convert_float16_to_float32(in_dtype, device): + """Tests that check convert_float_to_float32 function""" + check_type_supported(in_dtype, device) + + f16_input = torch.tensor(range(-int(2**(16 - 1)), int(2**(16 - 1))), dtype=torch.int16).view(in_dtype) + f32_output = convert_float_to_float32(f16_input) + + nan = f16_input.isnan() + assert torch.all(f32_output[nan].isnan()) + inf = f16_input.isinf() + assert torch.all(f32_output[inf].isinf()) + other = torch.logical_not(torch.logical_or(nan, inf)) + assert torch.all(f16_input[other] == f32_output[other]) + + +def serialize_fp8(np_data, in_dtype): + return np_data + + +# inverse of `serialize_fp8` + + +def deserialize_fp8(np_data, in_dtype): + return np_data + + +# --------------- +# test reduce +# --------------- + + +@pytest.mark.interpreter +def test_max_returns_zero(device): + # Simple test with a tl.max call that returns 0. The interpreter had a bug + # where it didn't handle this correctly. + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + z = tl.max(x) + tl.store(Z, z) + + BLOCK = 128 + x = torch.zeros((BLOCK, ), device=device) + z = torch.ones((1, ), device=device) + + kernel[(1, )](x, z, BLOCK=BLOCK) + assert z[0] == 0 + + +def get_reduced_dtype(dtype_str, op): + if op in ('argmin', 'argmax'): + return 'int32' + if dtype_str == 'bfloat16': + return 'float32' + return dtype_str + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ + 'min', + 'max', + 'min-with-indices', + 'max-with-indices', + 'argmin-tie-break-left', + 'argmax-tie-break-left', + 'sum', +] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce1d(op, dtype_str, shape, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + # triton kernel + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + GENERATE_TEST_HERE + tl.store(Z, z) + + if 'with-indices' in op: + patch = f'z, _ = tl.{op.split("-")[0]}(x, axis=0, return_indices=True)' + elif 'arg' in op: + tie_break_left = 'tie-break-left' in op + patch = f'z = tl.{op.split("-")[0]}(x, axis=0, tie_break_left={tie_break_left})' + else: + patch = f'z = tl.{op}(x, axis=0)' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs) + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'max-with-indices': np.max, + 'min-with-indices': np.min, + 'argmin-tie-break-fast': np.argmin, + 'argmin-tie-break-left': np.argmin, + 'argmax-tie-break-fast': np.argmax, + 'argmax-tie-break-left': np.argmax, + }[op] + if 'tie-break-left' in op: + x[3:10] = numpy_op(x) + x_tri = to_triton(x, device=device) + # numpy result + z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str + z_tri_dtype_str = z_dtype_str + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # triton result + z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) + z_tri = to_numpy(z_tri) + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# TODO: [Qingyi] Fix argmin / argmax +reduce_configs1 = [(op, dtype, (1, 1024), axis, False) + for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [1]] + +# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory +# exceeds the limit of 99KB +reduce2d_shapes = [(2, 32), (4, 32), (4, 128)] +# TODO: fix and uncomment +# , (32, 64), (64, 128)] +if is_corex() or is_cuda() and 'V100' in torch.cuda.get_device_name(0): + reduce2d_shapes += [(128, 256) and (32, 1024)] + +reduce_configs2 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce2d_shapes + for axis in [0, 1]] + [(op, 'float32', [16, 32], None, False) for op in ['min', 'max', 'sum']] + +reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] +reduce_configs3 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2]] +invalid_config = [('sum', 'float32', (32, 32), axis, False) for axis in [2, 3]] +negative_config = [('sum', 'float32', (32, 32), -1, False)] +keep_dims_2d_configs = [(op, 'float32', (32, 32), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1]] + [(op, 'float32', (32, 32), None, True) for op in ['min', 'max', 'sum']] +keep_dims_3d_configs = [(op, 'float32', (32, 2, 16), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1, 2]] + [(op, 'float32', (32, 2, 16), None, True) + for op in ['min', 'max', 'sum']] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + + negative_config + keep_dims_2d_configs + keep_dims_3d_configs) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, + AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + range_k = tl.arange(0, BLOCK_K) + if IS_3D: + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + + range_k[None, None, :]) + else: + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + z = GENERATE_TEST_HERE + + z_ptr = Z + if KEEP_DIMS and AXIS is None: + if IS_3D: + z_ptr = z_ptr[None, None, None, :] + else: + z_ptr = z_ptr[None, None, :] + if IS_3D: + if AXIS == 0: + z_ptr = Z + range_n[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 1 or AXIS == -2: + z_ptr = Z + range_m[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 2 or AXIS == -1: + z_ptr = Z + range_m[:, None] * BLOCK_N + range_n[None, :] + else: + if AXIS == 0: + z_ptr = Z + range_n + elif AXIS == 1 or AXIS == -1: + z_ptr = Z + range_m + if KEEP_DIMS and AXIS is not None: + z_ptr = tl.expand_dims(z_ptr, axis=AXIS) + tl.store(z_ptr, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_tri = to_triton(x, device=device) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax}[op] + z_dtype_str = get_reduced_dtype(dtype_str, op) + z_tri_dtype_str = z_dtype_str + + # numpy result + # Silence numpy error on axis out of bounds, to give triton a chance to fail + np_axis = axis if axis is not None and axis < len(shape) else None + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + + # triton result + z_shape = z_ref.shape + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + BLOCK_K = 1 if len(shape) == 2 else shape[2] + IS_3D = bool(len(shape) == 3) + if axis is not None and axis >= len(shape): + with pytest.raises(triton.TritonError): + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, num_ctas=num_ctas) + return + else: + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, num_ctas=num_ctas) + + z_tri = to_numpy(z_tri) + + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = z_ref + z_tri_index = z_tri + if not keep_dims: + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)] + +scan_configs = [(op, type, shape, axis, reverse, num_warps) + for num_warps in [4, 16] + for type in ['int32', 'float32', 'bfloat16'] + for axis in [1, 0] + for reverse in [True, False] + for shape in scan2d_shapes + for op in ['cumsum', 'cumprod', 'get_first_element', 'linear_recurrence', 'cummax', 'roll']] +negative_config = [('cumsum', 'float32', (32, 32), -1, False, 4)] + +# TODO: beflow 2 configs can not pass somehow. +#failed_config = [('linear_recurrence', 'float32', (4, 32), 0, True, 1), +# ('linear_recurrence', 'int32', (4, 32), 0, True, 1)] + + +@triton.jit +# trivial associative but not commutative function +def get_first_element(a, b): + return a + + +# Compute x_i = a_i * x_{i-1} + b_i +@triton.jit +def linear_recurrence(a1, b1, a2, b2): + return a1 * a2, b1 * a2 + b2 + + +@triton.jit +def cummax(v0, i0, v1, i1): + gt = v0 > v1 + return tl.where(gt, v0, v1), tl.where(gt, i0, i1) + + +@triton.jit +def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): + return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) +def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): + check_type_supported(dtype_str, device) + if dtype_str == 'bfloat16': + if op == 'cummax': + pytest.skip("bfloat16 compare not suppoted before sm90") + if op == 'linear_recurrence': + pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") + numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + + # triton kernel + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + y = tl.load(Y + range_m[:, None] * BLOCK_N + range_n[None, :]) + GENERATE_TEST_HERE + tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) + + if op == 'cumsum' or op == 'cumprod': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'z = tl.{op}(x, axis={axis}, reverse={reverse})'}) + elif op == 'get_first_element': + kernel = patch_kernel( + kernel, + {'GENERATE_TEST_HERE': f'z = tl.associative_scan(x, axis={axis}, combine_fn={op}, reverse={reverse})'}) + elif op == 'cummax': + rg = "range_m[:, None]" if axis == 0 else "range_n[None, :]" + rg = f"tl.broadcast_to({rg}.to(tl.int64), [BLOCK_M, BLOCK_N])" + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, {rg}), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + elif op == 'roll': + assert op == 'roll' + kernel = patch_kernel( + kernel, { + 'GENERATE_TEST_HERE': + f'_, z, _ = tl.associative_scan((1 + 0* x, 0 * x, x), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + else: + assert op == 'linear_recurrence' + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, y), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + # input + rs = RandomState(17) + if op == 'linear_recurrence' and dtype_str in int_dtypes: + # If the numbers are too large the op will overflow + # We sample numbers in -1, 0, 1 + x = rs.randint(-1, 2, shape, dtype=dtype_str) + y = rs.randint(-1, 2, shape, dtype=dtype_str) + else: + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + # y is just used in linear_recurrence + y = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_in = x + if reverse: + x_in = np.flip(x, axis) + z = np.empty_like(x) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + y_tri = to_triton(y, device=device, dst_type=dtype_str) + if op == 'cumsum' or op == 'cumprod': + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] + z_ref = numpy_op(x_in, axis=axis).astype(getattr(np, numpy_dtype_str)) + if reverse: + z_ref = np.flip(z_ref, axis) + + elif op == 'cummax': + # NumPy does not have cummax + z = z.astype(np.int64) + z_ref = torch.cummax(torch.from_numpy(x_in.copy()), axis=axis).indices.numpy() + if reverse: + z_ref = x_in.shape[axis] - np.flip(z_ref, axis) - 1 + elif op == 'roll': + ROLL = 1 + z_ref = np.roll(x_in.copy(), ROLL, axis=axis) + if axis == 0: + z_ref[:ROLL] = 0 + else: + z_ref[:, :ROLL] = 0 + + if reverse: + z_ref = np.flip(z_ref, axis) + elif op == 'linear_recurrence': + # Simplify to the axis=1 case + x_ref = x.T if axis == 0 else x + y_ref = y.T if axis == 0 else y + if reverse: + x_ref = np.flip(x_ref, 1) + y_ref = np.flip(y_ref, 1) + + result = [] + for x_refi, y_refi in zip(x_ref, y_ref): + li = [] + acc = 0 + for xi, yi in zip(x_refi, y_refi): + acc = xi * acc + yi + li.append(acc) + result.append(li) + z_ref = np.array(result) + if reverse: + z_ref = np.flip(z_ref, 1) + + if axis == 0: + z_ref = z_ref.T + else: + assert op == 'get_first_element' + z_ref = x + if axis == 0: + if reverse: + z_ref[:-1] = x[-1] + else: + z_ref[1:] = x[0] + else: + if reverse: + z_ref[:, :-1] = x[:, -1:] + else: + z_ref[:, 1:] = x[:, 0:1] + + # triton result + # we don't cast the `fp32 = bf16 op bf16` result to bfloat16 to alleviate accuracy issues + z_tri = to_triton(z, device=device) + kernel[(1, )](x_tri, y_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) + + z_tri = to_numpy(z_tri) + # compare + if dtype_str not in int_dtypes: + if op == 'cumprod': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3) + else: + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan_layouts = [ + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), +] + +# --------------- +# test histogram +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) +def test_histogram(M, N, device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + offset1) + z = tl.histogram(x, N) + tl.store(z_ptr + offset2, z) + + torch.manual_seed(17) + x = torch.randint(0, N, (M, ), device=device, dtype=torch.int32) + z = torch.empty(N, dtype=torch.int32, device=device) + # torch.histc does not work when the input type is not float and the device is CPU + # https://github.com/pytorch/pytorch/issues/74236 + # This is a workload by converting the input to float + z_torch = torch.histc(x.float(), bins=N, min=0, max=N - 1) + histogram_kernel[(1, )](x, z, M=M, N=N) + assert (z_torch == z).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['sum', 'max', 'min']) +@pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) +@pytest.mark.parametrize("N", [512, 1024, 2048]) +@pytest.mark.parametrize("num_pid_n", [2, 4]) +def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device): + + @triton.jit + def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.constexpr): + start_m = tl.program_id(0) + pid_n = tl.program_id(1) + local = INITIALIZE_PATCH + off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), NUM_PID_N): + off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * N + off_n[None, :] + x = tl.load(Xs) + local = ACCUMULATE_PATCH + tl.store(Y + off_m * NUM_PID_N + pid_n, local) + # the following segfaults AMD backend following #3492 + # really unclear why; the llvm-ir and kernel arguments are + # identical ! + # tl.store(Y + off_m * tl.num_programs(1) + pid_n, local) + + initialize_patch = { + 'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)', + 'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)', + 'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)', + }[op] + reduce_patch = { + 'sum': 'local + tl.sum(x, axis=1)', + 'max': 'tl.maximum(local, tl.max(x, axis=1))', + 'min': 'tl.minimum(local, tl.min(x, axis=1))', + }[op] + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + }[op] + kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch}) + torch.manual_seed(0) + BLOCK_M = 32 + x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device) + y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device) + h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n) + #if not is_interpreter(): + # assert h.asm['ttgir'].count( + # '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" + y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) + y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True) + np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3) + + +@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) +@pytest.mark.parametrize("src_layout", scan_layouts) +@pytest.mark.parametrize("axis", [0, 1]) +def test_scan_layouts(M, N, src_layout, axis, device): + + ir = f""" + #blocked = {src_layout} + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, noWarpReduce=false, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, noWarpReduce=false, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, noWarpReduce=false, parent = #blocked}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, noWarpReduce=false, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %7 = tt.broadcast %4 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %10 = tt.load %9 : tensor<{M}x{N}xi32, #blocked> + %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ + ^bb0(%arg2: i32, %arg3: i32): + %16 = arith.addi %arg2, %arg3 : i32 + tt.scan.return %16 : i32 + }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %14 = tt.broadcast %13 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + tt.store %15, %11 : tensor<{M}x{N}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + rs = RandomState(17) + x = rs.randint(-100, 100, (M, N)).astype('int32') + + z = np.zeros((M, N)).astype('int32') + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + kernel[(1, 1, 1)](x_tri, z_tri) + + z_ref = np.cumsum(x, axis=axis) + + np.testing.assert_equal(z_ref, z_tri.cpu().numpy()) + + +layouts = [ + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [THREADS_PER_WARP // 16, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 16, 16]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 32, 16]), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=True), + WmmaLayout(warps_per_cta=[2, 2]), + WmmaLayout(warps_per_cta=[4, 1]), + WmmaLayout(warps_per_cta=[1, 4]), +] + + +@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [64, 64], [32, 128], [32, 32], [16, 16]]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) +@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): + if isinstance(src_layout, + (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): + pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") + if is_hip() and isinstance(src_layout, MfmaLayout) and ((M, N) == (128, 128)): + pytest.skip("Skipping test because it runs out of shared memory") + if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: + pytest.skip("Skipping sum reduction on float16 due to accuracy issues") + if epilogue_kind == 'expand_reduce2d' and isinstance(src_layout, MmaLayout): + pytest.skip( + "Currently MmaLayout combined with slice encoding and reduce op trigger device illegal memory access") + + if isinstance(src_layout, MmaLayout) and src_layout.version == 3: + src_layout[2] = 16 if dtype_str == "float16" else 8 + + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] + arith_op = { + "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} + }[reduce_op][dtype_str] + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] + rdims_1d = f"{N}" if axis == 0 else f"{M}" + rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" + store_range = "%7" if axis == 0 else "%1" + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) + num_warps = src_layout.warps_per_cta[0] * src_layout.warps_per_cta[1] + if num_warps == 8: + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, 2], [0, 1], [1, 1], [1, 1], [0, 1]) + one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [4], [0], [1], [1], [0]) + + expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1" + other_axis = 1 - axis + epilogue = { + "reduce1d": + f""" + %14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> + %16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, noWarpReduce=false, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, noWarpReduce=false, parent = #blocked}}>> + %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, noWarpReduce=false, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked> + tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + tt.return + }} + }} + """, "reduce2d": + f""" + %14 = "tt.reduce"(%13) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = 0 : i32, noWarpReduce=false}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, noWarpReduce=false, parent = #src}}>>) -> {ty} + tt.store %arg2, %14 : !tt.ptr<{ty}> + tt.return + }} + }} + """, "expand_reduce2d": + f""" + %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, noWarpReduce=false, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> + %15 = "tt.reduce"(%14) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {other_axis} : i32, noWarpReduce=false}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, noWarpReduce=false, parent = #src}}>>) + %16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, noWarpReduce=false, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> + %17 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.store %17, %16 : tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.return + }} + }} + """ + }[epilogue_kind] + + ir = f""" + #blocked = {blocked} + #src = {src_layout} + #one_d_layout = {one_d_layout} + module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, noWarpReduce=false, parent = #blocked}}>> + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, noWarpReduce=false, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 : tensor<{M}x{N}x{ty}, #blocked> + %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> + %13 = "tt.reduce"(%12) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {axis} : i32, noWarpReduce=false}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, noWarpReduce=false, parent = #src}}>> + """ + epilogue + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) + reduce2d = 'reduce2d' in epilogue_kind + z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1) + z = np.zeros(z_shape).astype(dtype_str) + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) + z_ref = numpy_op(x) if reduce2d else numpy_op(x, axis=axis, keepdims=True) + + if dtype_str == 'float16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.parametrize("M", [32, 64, 128, 256]) +@pytest.mark.parametrize("src_layout", layouts) +def test_store_op(M, src_layout, device): + + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 1 : i32}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> -> tensor<{M}x1xf32, #src> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> + %6 = tt.expand_dims %5 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> + tt.store %8, %4 : tensor<{M}x1x!tt.ptr, #src> + tt.return + }} + }} + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + store_kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, 1)).astype('float32') + y = np.zeros((M, 1), dtype='float32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + + pgm = store_kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +layouts = [ + # TODO (lixun): Add MfmaLayout + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.parametrize("M", [64, 128, 256]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("src_dim", [0, 1]) +@pytest.mark.parametrize("dst_dim", [0, 1]) +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): + + ir = f""" + #dst = {dst_layout} + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, noWarpReduce=false, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, noWarpReduce=false, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, noWarpReduce=false, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, noWarpReduce=false, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, noWarpReduce=false, parent = #src}}>> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, noWarpReduce=false, parent = #dst}}>> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, noWarpReduce=false, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, noWarpReduce=false, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, noWarpReduce=false, parent = #dst}}>> + %7 = {GPU_DIALECT}.convert_layout %3 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, noWarpReduce=false, parent = #src}}>> -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, noWarpReduce=false, parent = #dst}}>> + tt.store %6, %7 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, noWarpReduce=false, parent = #dst}}>> + tt.return + }} + }} + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, )).astype('int32') + y = np.zeros((M, ), dtype='int32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + pgm = kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@triton.jit +def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = weight_2 / new_weight + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + # [HIP] TO DO: some tests are flaky with the layout, so turn off them for now. + # BlockedLayout([1, 4], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]) +] + + +@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]]) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("op", ["sum", "max"]) +@pytest.mark.parametrize("first_axis", [0, 1]) +def test_chain_reduce(M, N, src_layout, op, device, first_axis): + + op_str = "" + if op == "sum": + op_str = """ + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32""" + elif op == "max": + op_str = """ + %13 = arith.cmpi "sgt", %arg2, %arg3 : i32 + %14 = arith.select %13, %arg2, %arg3 : i32 + tt.reduce.return %14 : i32""" + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> + %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, noWarpReduce=false, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, noWarpReduce=false, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %5 = tt.broadcast %2 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %6 = tt.broadcast %4 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %10 = tt.load %9 : tensor<{M}x{N}xi32, #src> + %11 = "tt.reduce"(%10) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = {first_axis} : i32, noWarpReduce=false}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, noWarpReduce=false, parent = #src}}>> + %12 = "tt.reduce"(%11) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = 0 : i32, noWarpReduce=false}} : (tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, noWarpReduce=false, parent = #src}}>>) -> i32 + tt.store %arg1, %12 : !tt.ptr + tt.return + }} + }} + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, N)).astype('int32') + + z = np.zeros((1, )).astype('int32') + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, z_tri) + if op == "sum": + z_ref = np.sum(x) + elif op == "max": + z_ref = np.max(x) + + np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@pytest.mark.interpreter +def test_generic_reduction(device): + + @triton.jit + def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): + xindex = tl.arange(0, BLOCK) + x = tl.load(X + xindex) + mean = x + m2 = tl.zeros_like(x) + weight = tl.full(x.shape, 1, x.dtype) + (mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine) + tl.store(out_mean, mean) + tl.store(out_var, m2 / weight) + + SIZE = 512 + x = torch.rand(SIZE, device=device) + out_mean = torch.empty((), device=device) + out_var = torch.empty((), device=device) + + var_mean_kernel[(1, )](x, out_mean, out_var, BLOCK=SIZE) + + expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0) + torch.testing.assert_close(out_mean, expect_mean) + torch.testing.assert_close(out_var, expect_var) + + +# --------------- +# test permute +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float8e4b15', 'float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_permute(dtype_str, shape, perm, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if is_hip() and shape == (128, 128) and dtype_str == 'float32': + pytest.skip("TODO Out of LDS for float32 with shape 128x128") + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + + # input + x = numpy_random(shape, dtype_str=dtype_str) + # triton result + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), + x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), + z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + # numpy result + if dtype_str == 'float8e4b15': + ty = tl.float8e4b15 + z_ref = serialize_fp8(deserialize_fp8(x, ty).T.copy(), ty) + z_tri = z_tri.base + z_tri_contiguous = z_tri_contiguous.base + else: + z_ref = x.transpose(*perm) + # compare + np.testing.assert_allclose(to_numpy(z_tri), z_ref) + np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref) + + if not is_cuda() or is_corex(): + return + + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1]))) +def test_trans_2d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: tl.constexpr, + ou_shape2: tl.constexpr, trans1: tl.constexpr, trans2: tl.constexpr): + in_offs = tl.arange(0, in_shape1)[:, None] * in_shape2 + tl.arange(0, in_shape2)[None, :] + ou_offs = tl.arange(0, ou_shape1)[:, None] * ou_shape2 + tl.arange(0, ou_shape2)[None, :] + tl.store(Out + ou_offs, tl.permute(tl.load(In + in_offs), (trans1, trans2))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1, 2, 3]))) +def test_trans_4d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, # + in_shape1: tl.constexpr, in_shape2: tl.constexpr, in_shape3: tl.constexpr, in_shape4: tl.constexpr, + ou_shape1: tl.constexpr, ou_shape2: tl.constexpr, ou_shape3: tl.constexpr, ou_shape4: tl.constexpr, + trans1: tl.constexpr, trans2: tl.constexpr, trans3: tl.constexpr, trans4: tl.constexpr): + in_ptr = tl.make_block_ptr( + base=In, + shape=(in_shape1, in_shape2, in_shape3, in_shape4), + strides=(in_shape4 * in_shape3 * in_shape2, in_shape4 * in_shape3, in_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(in_shape1, in_shape2, in_shape3, in_shape4), + order=(3, 2, 1, 0), + ) + out_ptr = tl.make_block_ptr( + base=Out, + shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + strides=(ou_shape4 * ou_shape3 * ou_shape2, ou_shape4 * ou_shape3, ou_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + order=(3, 2, 1, 0), + ) + tl.store(out_ptr, tl.load(in_ptr).permute((trans1, trans2, trans3, trans4))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm, num_warps=8) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# --------------- +# test dot +# --------------- + + +def convert_fp8_to_fp32(x, device, dtype_str): + if dtype_str == 'float8e4nv': + return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32) + elif dtype_str == 'float8e5': + return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32) + assert "Unsupported float8 dtype" + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack", + [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1) + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for input_precision in ['tf32', 'tf32x3', 'ieee'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] + if not (input_precision != 'ieee' and (in_dtype in ['float16']))] + + [(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack) + for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], + [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] + for input_precision in ["ieee" if is_hip() else "tf32"] + for col_a in [True, False] + for col_b in [True, False] + for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', + 'float32')] + for kpack in [1, 2 if is_hip() else 1]] + [(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32', 1) + for col_a in [True, False] + for col_b in [True, False]] + + [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1)] + + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1) + for float8_type in ["float8e5", "float8e4nv"]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, num_ctas, device): + if is_interpreter(): + if in_dtype == 'bfloat16': + pytest.skip("bfloat16 is not supported in the interpreter") + else: + if is_cuda(): + capability = torch.cuda.get_device_capability() + if is_corex(): + if in_dtype == 'float8e4nv' or in_dtype == "float8e5": + pytest.skip("float8e5/float8e4nv not supported on iluvatar devices") + if out_dtype != "float32": + pytest.skip("iluvatar devices only support out_dtype==float32") + else: + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if input_precision != "ieee": + pytest.skip("Only test tf32 on devices with sm >= 80") + if capability[0] == 7: + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: + pytest.skip("shared memory out of resource") + if out_dtype == 'float16': + # TODO: support out_dtype=float16 for tl.dot on V100 + pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if capability[0] < 9 and in_dtype == 'float8e4nv': + pytest.skip("float8e4nv not supported on sm <= 80") + if is_hip() and (in_dtype == 'float8e4nv' or in_dtype == 'float8e5'): + pytest.skip("float8e4nv and float8e5 not supported on HIP") + if is_hip() and (input_precision != "ieee"): + pytest.skip(f"{input_precision} not supported on HIP") + if is_hip() and (kpack == 2 and in_dtype == 'int8' and K < 64): + pytest.skip("kpack too large for K") + if not is_hip() and kpack == 2: + pytest.skip("Skip duplicated tests on nv path") + + torch.backends.cuda.matmul.allow_tf32 = input_precision == "tf32" + + if num_ctas > 1 and in_dtype == 'int8': + # FIXME: mma v2 with num_ctas > 1 does not work + pytest.skip() + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, INPUT_PRECISION: tl.constexpr, DO_SOFTMAX: tl.constexpr, + CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + x = tl.load(Xs) + y = tl.load(Ys) + z = tl.dot(x, y, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + if ADD_MATRIX: + z += tl.load(Zs) + if ADD_ROWS: + ZRs = Z + off_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = Z + off_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(Ws) + z = tl.dot(z.to(w.dtype), w, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + tl.store(Zs, z) + + # input + rs = RandomState(17) + if col_a: + x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T + else: + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + if col_b: + y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T + else: + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + w = numpy_random((N, N), dtype_str=in_dtype, rs=rs) + if 'int' not in in_dtype and 'float8' not in in_dtype: + x *= .1 + y *= .1 + if in_dtype == 'float32' and input_precision == "tf32": + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') + x_tri = to_triton(x, device=device, dst_type=in_dtype) + y_tri = to_triton(y, device=device, dst_type=in_dtype) + w_tri = to_triton(w, device=device, dst_type=in_dtype) + # triton result + if out_dtype == 'int8': + z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs) + else: + z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1 + + z_tri = to_triton(z, device=device) + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) + + if out_dtype == 'int8': + out_dtype = tl.int8 + elif out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + kern_kwargs = { + 'COL_A': col_a, 'COL_B': col_b, 'BLOCK_M': M, 'BLOCK_K': K, 'BLOCK_N': N, 'ADD_MATRIX': + epilogue == 'add-matrix', 'ADD_ROWS': epilogue == 'add-rows', 'ADD_COLS': epilogue == 'add-cols', 'DO_SOFTMAX': + epilogue == 'softmax', 'CHAIN_DOT': epilogue == 'chain-dot', 'INPUT_PRECISION': input_precision, 'num_warps': + num_warps, 'num_ctas': num_ctas, 'out_dtype': out_dtype + } + + if is_hip(): + kern_kwargs['kpack'] = kpack + + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, + w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) + + if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"): + if not is_cuda() or is_corex(): + pass + else: + ptx = pgm.asm["ptx"] + start = ptx.find("shfl.sync.bfly") + end = ptx.find("cvt.rn.f16.f32") + red_code = ptx[start:end] + assert len(red_code) > 0 + + # skip this check on hopper because there are some functions whose name contain "shared" in ptx. + # TODO: we should eliminate these unused functions in ptx code. + if not (capability[0] >= 9): + assert "shared" not in red_code + assert "bar.sync" not in red_code + # torch result + if in_dtype == 'int8': + z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) + elif 'float8' in in_dtype: + x = convert_fp8_to_fp32(x, device, in_dtype) + y = convert_fp8_to_fp32(y, device, in_dtype) + z_ref = to_numpy(torch.matmul(x, y)) + else: + z_ref = np.matmul(x, y) + + if epilogue == 'add-matrix': + z_ref += z + if epilogue == 'add-rows': + z_ref += z[:, 0][:, None] + if epilogue == 'add-cols': + z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + if 'float8' in in_dtype: + w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) + z_ref = np.matmul(z_ref, w) + # compare + if in_dtype == 'float32': + # XXX: Somehow there's a larger difference when we use float32 + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + elif out_dtype == tl.float16 or in_dtype == 'bfloat16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + # added atol, to loose precision for float16xfloat16->float32 case + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + if not is_cuda() or is_corex(): + return + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): + # XXX: skip small sizes because they are not vectorized + assert 'ld.global.v4' in ptx + if 'float8' in in_dtype: + assert 'st.global.v2' in ptx + else: + assert 'st.global.v4' in ptx + if in_dtype == 'float32' and input_precision != "ieee": + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float32: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float16: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) + elif in_dtype == 'int8': + if capability[0] == 7 and capability[1] == 5: # Turing + assert 'mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32' in ptx + else: + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + elif in_dtype == "float8e5" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx + elif in_dtype == "float8e4nv" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("B", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("M, N, K", [(64, 64, 64), (32, 32, 32)]) +@pytest.mark.parametrize("in_dtype_str, out_dtype_str", [('int8', 'int8'), ('float16', 'float16'), + ('float16', 'float32'), ('float32', 'float32')]) +def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device): + if is_corex(): + pytest.skip("iluvatar devices does not support dot3d now") + elif is_hip(): + # hip does not support tf32 precision, so use ieee for all tests + input_precision = "ieee" + if "gfx11" in triton.runtime.driver.active.get_current_target().arch: + if in_dtype_str == "float32": + pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") + if out_dtype_str == "float16": + pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + else: + input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" + + @triton.jit + def kernel( + q_ptr, + k_ptr, + o_ptr, + stride_qb, + stride_qm, + stride_qk, + stride_kb, + stride_kk, + stride_kn, + stride_ob, + stride_om, + stride_on, + BLOCK_B: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + INPUT_PRECISION: tl.constexpr, + out_dtype: tl.constexpr = tl.float32, + ): + startm = tl.program_id(0) * BLOCK_M + startn = tl.program_id(1) * BLOCK_N + offs_b = tl.arange(0, BLOCK_B) + offs_m = startm + tl.arange(0, BLOCK_M) + offs_n = startn + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + q_ptrs = q_ptr + offs_b[:, None, None] * stride_qb + offs_m[None, :, None] * stride_qm + offs_k[ + None, None, :] * stride_qk + k_ptrs = k_ptr + offs_b[:, None, None] * stride_kb + offs_k[None, :, None] * stride_kk + offs_n[ + None, None, :] * stride_kn + q = tl.load(q_ptrs) + k = tl.load(k_ptrs) + qk = tl.dot(q, k, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + o_ptrs = o_ptr + offs_b[:, None, None] * stride_ob + offs_m[None, :, None] * stride_om + offs_n[ + None, None, :] * stride_on + tl.store(o_ptrs, qk) + + if out_dtype_str == 'int8': + out_dtype = tl.int8 + elif out_dtype_str == 'float16': + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + rs = RandomState(17) + x = numpy_random((B, M, K), dtype_str=in_dtype_str, rs=rs) + y = numpy_random((B, K, N), dtype_str=in_dtype_str, rs=rs) + if in_dtype_str == 'int8': + out = numpy_random((B, M, N), dtype_str='int32', rs=rs) + else: + out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) + + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + out_tri = to_triton(out, device=device) + + BLOCK_B = B + BLOCK_M, BLOCK_N = 32, 32 + BLOCK_K = K + + grid = ( + triton.cdiv(M, BLOCK_M), + triton.cdiv(N, BLOCK_N), + ) + kernel[grid]( + x_tri, + y_tri, + out_tri, + x_tri.stride(0), + x_tri.stride(1), + x_tri.stride(2), + y_tri.stride(0), + y_tri.stride(1), + y_tri.stride(2), + out_tri.stride(0), + out_tri.stride(1), + out_tri.stride(2), + BLOCK_B=BLOCK_B, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + INPUT_PRECISION=input_precision, + out_dtype=out_dtype, + num_warps=num_warps, + ) + + if in_dtype_str == 'int8': + out_ref = np.matmul(x.astype(np.float32), y.astype(np.float32)).astype(np.int32) + else: + out_ref = np.matmul(x, y) + np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) + + +@pytest.mark.interpreter +def test_max_num_imprecise_acc(device): + + if not hasattr(torch, 'float8_e5m2'): + pytest.skip(f"torch {torch.__version__} does not support float8_e5m2") + + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability != (9, 0): + return + + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + MAX_NUM_IMPRECISE_ACC: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + x = tl.load(X + off_m[:, None] * BLOCK_K + off_k[None, :]) + y = tl.load(Y + off_k[:, None] * BLOCK_N + off_n[None, :]) + z = tl.load(Z + off_m[:, None] * BLOCK_N + off_n[None, :]) + z = tl.dot(x, y, acc=z, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC) + tl.store(Z + off_m[:, None] * BLOCK_N + off_n[None, :], z) + + M, N, K, num_warps, MAX_NUM_IMPRECISE_ACC = 128, 128, 128, 4, 64 + x = torch.zeros((M, K), dtype=torch.float8_e5m2, device=device) + y = torch.zeros((K, N), dtype=torch.float8_e5m2, device=device) + z = torch.zeros((M, N), dtype=torch.float32, device=device) + h = kernel[(1, 1)](x, y, z, M, N, K, MAX_NUM_IMPRECISE_ACC, num_warps=num_warps) + if not is_cuda(): + return + assert h.asm["ptx"].count("add.f32") == (M * N) // (32 * num_warps) * (K / MAX_NUM_IMPRECISE_ACC) + + +@pytest.mark.parametrize('in_dtype', ['float32']) +def test_dot_mulbroadcasted(in_dtype, device): + if is_cuda(): + capability = torch.cuda.get_device_capability() + if not is_corex(): + if capability[0] < 8: + pytest.skip("Requires sm >= 80 to run") + + @triton.jit + def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.constexpr, BN: tl.constexpr, + BK: tl.constexpr): + pidn = tl.program_id(1) + pidm = tl.program_id(0) + offm = tl.arange(0, BM)[:, None] + offn = tl.arange(0, BN)[None, :] + offak = tl.arange(0, BK)[None, :] + offbk = tl.arange(0, BK)[:, None] + acc = tl.full((BM, BN), 0.0, tl.float32) + for ridx5 in range(0, K // BK): + x = tl.load(X + ((pidm * K * BM) + (offm * K) + (ridx5 * BK) + offak)) + y = tl.load(Y + ((pidn * BN) + (offbk * N) + (ridx5 * N * BK) + offn)) + x = tl.expand_dims(x, axis=2) + y = tl.expand_dims(y, axis=0) + t = tl.sum(x * y, axis=1) + acc = t + acc + tl.store(Z + ((pidm * BM * N) + (pidn * BN) + (offm * N) + offn), acc) + + M, N, K = 256, 192, 160 + BM, BN, BK = 128, 32, 32 + rs = RandomState(17) + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + x = x * 0.1 + y = y * 0.1 + z = numpy_random((M, N), dtype_str=in_dtype, rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(z, device=device) + grid = M // BM, N // BN + h = kernel[grid](z_tri, x_tri, y_tri, M, N, K, BM, BN, BK) + z_ref = np.matmul(x, y) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01) + + if not is_cuda() or is_corex(): + return + assert "tt.dot" in h.asm['ttir'] + # When using MMAv3, we will not pipeline the load op for Y, as the loaded + # value is in rowmajor. But MMAv3 requires its second operand is in colmajor + # because transpose is not supported for MMAv3 with float32 input. + if capability[0] >= 9: + assert re.search(r"triton_gpu.async_wait %.* {num = 1 : i32}", h.asm["ttgir"]) is not None + else: + assert re.search(r"triton_gpu.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) +@pytest.mark.parametrize("shape", [(), (1, ), (128, )]) +def test_full(dtype_str, shape, device): + if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): + # PyTorch only has unsigned 8, but not 16, 32, or 64 + dtype = getattr(torch, dtype_str[1:]) # uintx -> intx + else: + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel_static(out): + a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + @triton.jit + def kernel_dynamic(out, val, dtype: tl.constexpr): + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) + out_static = torch.zeros((128), dtype=dtype, device=device) + kernel_static_patched[(1, )](out_static) + assert torch.all(out_static == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device=device) + kernel_dynamic_patched[(1, )](out_dynamic, 2, getattr(triton.language, dtype_str)) + assert torch.all(out_dynamic == 2) + + +@pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), + ('float("-inf")', "f32"), ('float("nan")', "f32"), + ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) +def test_constexpr(literal, dtype_str, device): + + @triton.jit + def kernel(out_ptr): + val = GENERATE_TEST_HERE + tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val) + + kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"}) + out = torch.zeros((1, ), dtype=torch.float32, device=device) + h = kernel_patched[(1, )](out) + assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None + + +@triton.jit +def pass_const(a, b, choose_b): + if choose_b: + return b + else: + return a + + +@pytest.mark.parametrize("choose_const", [True, False]) +@pytest.mark.parametrize("constexpr", [True, False]) +@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) +def test_const(device, choose_const, constexpr, mode): + + @triton.jit(do_not_specialize=["choose_const"]) + def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + @triton.jit + def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32, + BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + if mode == "direct": + if choose_const: + LOSE_TAIL = "final_out = c_out" + else: + LOSE_TAIL = "final_out = out" + elif mode == "call": + LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)" + elif mode == "ternary": + LOSE_TAIL = "final_out = c_out if choose_const else out" + elif mode == "if": + LOSE_TAIL = """ + if choose_const: + final_out = c_out + else: + final_out = out +""" + + SIZE = 128 + input = torch.randn((SIZE, ), dtype=torch.float32, device=device) + output = torch.zeros((SIZE, ), dtype=torch.float32, device=device) + patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''}) + + expect_fail = (not constexpr and mode != "direct") or choose_const + if expect_fail: + with pytest.raises(triton.CompilationError) as exc_info: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + else: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + assert torch.all(input == output) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['float32', 'float16']) +def test_dot_without_load(dtype_str, device): + + @triton.jit + def _kernel(out): + a = GENERATE_TEST_HERE + b = GENERATE_TEST_HERE + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) + a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + out_ref = torch.matmul(a, b) + out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](out) + assert torch.all(out == out_ref) + + +# --------------- +# test arange +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("start", [0, 1, 7, 16]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_arange(start, num_ctas, device): + BLOCK = 128 + z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) + tl.store(z + off, val) + + _kernel[(1, )](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) + np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) + + +# --------------- +# test load +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) + for dtype_str in torch_dtypes + for size in [128, 512] + for size_diff in [0, 1, 2, 3, 4] + for other in [0, 1]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_masked_load(dtype_str, size, size_diff, other, num_ctas, device): + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + input_size = size - size_diff + output_size = size + if dtype_str == 'bool': + input = torch.randint(0, 2, (input_size, ), dtype=dtype, device=device) + elif dtype_str in int_dtypes or dtype_str in uint_dtypes: + input = torch.randint(0, 127, (input_size, ), dtype=dtype, device=device) + else: + input = torch.rand(input_size, dtype=dtype, device=device) + output = torch.zeros((output_size, ), dtype=dtype, device=device) + + @triton.jit + def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): + in_offsets = tl.arange(0, out_size) + # Load inputs. + x = GENERATE_TEST_HERE + # Store output + output_offsets = tl.arange(0, out_size) + tl.store(out_ptr + output_offsets, x) + + mask_str = f"mask=in_offsets < in_size, other={other}" if size_diff > 0 else "None" + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) + kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas) + + reference_out = torch.cat((input, torch.full((size_diff, ), other, dtype=dtype, device=device))) + torch.testing.assert_close(output, reference_out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("mask_val", [True, False]) +@pytest.mark.parametrize("other_val", [0, 1]) +def test_masked_load_scalar(num_ctas, mask_val, other_val, device): + input_val = 4.0 + size = 128 + dtype = torch.float32 + input = torch.full((size, ), input_val, dtype=dtype, device=device) + output = torch.zeros((size, ), dtype=dtype, device=device) + + @triton.jit + def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.constexpr): + offsets = tl.arange(0, size) + x = tl.load(in_ptr + offsets, mask=mask, other=other) + tl.store(out_ptr + offsets, x) + + kernel[(1, )](input, output, size, mask_val, other_val, num_ctas=num_ctas) + + if mask_val: + reference_out = torch.full((size, ), input_val, dtype=dtype, device=device) + else: + reference_out = torch.full((size, ), other_val, dtype=dtype, device=device) + + torch.testing.assert_close(output, reference_out) + + +# Testing masked loads with an intermate copy to shared memory run. +# FIXME: Shape too small for ldmatrix when num_ctas=4 +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_masked_load_shared_memory(dtype, device): + + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + M = 32 + N = 32 + K = 16 + + in1 = torch.rand((M, K), dtype=dtype, device=device) + in2 = torch.rand((K, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=dtype, device=device) + + @triton.jit + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + + M_offsets = tl.arange(0, M) + N_offsets = tl.arange(0, N) + K_offsets = tl.arange(0, K) + + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] + + # Load inputs. + x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K) + w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N) + + # Without a dot product the memory doesn't get promoted to shared. + o = tl.dot(x, w, out_dtype=tl.float32) + + # Store output + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] + tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) + + pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), + out.numel(), M=M, N=N, K=K) + + reference_out = torch.matmul(in1, in2) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) +def test_load_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + if not is_cuda() or is_corex(): + return + + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_vectorization(N, num_ctas, device): + block_size = 1024 * num_ctas + src = torch.empty(block_size, device=device) + dst = torch.empty(block_size, device=device) + + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, N=N, BLOCK_SIZE=block_size) + + if not is_cuda() or is_corex(): + return + + ptx = pgm.asm["ptx"] + if N % 16 == 0: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.b32" in ptx + # np.testing.assert_allclose(dst, src[:N]) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("has_hints", [False, True]) +def test_vectorization_hints(has_hints, device): + src = torch.empty(1024, device=device) + dst = torch.empty(1024, device=device) + off = torch.zeros(1, device=device, dtype=torch.int32) + + @triton.jit + def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offsets = offsets + tl.load(off) + if HINT: + tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + if not is_cuda() or is_corex(): + return + + ptx = pgm.asm["ptx"] + if has_hints: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.v4.b32" not in ptx + + +# --------------- +# test store +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"]) +def test_store_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, cache_modifier=CACHE) + + if not is_cuda() or is_corex(): + return + pgm = _kernel[(1, )](dst, src, CACHE=cache) + ptx = pgm.asm['ptx'] + if cache == '': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.wb': + assert 'st.global.wb' in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cg': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cs': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' in ptx + assert 'st.global.wt' not in ptx + if cache == '.wt': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' in ptx + + +# --------------- +# test default +# --------------- +# TODO: can't be local to test_default + + +@triton.jit +def _impl(value=10): + return value + + +@pytest.mark.interpreter +def test_default(device): + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device=device) + ret1 = torch.zeros(1, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(ret0, ret1, value=3): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) + + _kernel[(1, )](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value + + _kernel[(1, )](ret0, ret1) + assert ret0.item() == 10 + assert ret1.item() == 3 + + +# --------------- +# test noop +# ---------------- + + +@pytest.mark.interpreter +def test_noop(device): + + @triton.jit + def kernel(x): + pass + + x = to_triton(numpy_random((1, ), dtype_str='int32'), device=device) + kernel[(1, )](x) + + +@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) +def test_pointer_arguments(device): + + @triton.jit + def kernel(x): + pass + + pin_memory = 'pinned' in device + x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) + if device == "cpu": + with pytest.raises(ValueError): + kernel[(1, )](x) + else: + kernel[(1, )](x) + + +@pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) +def test_value_specialization(value: int, value_type: str, device) -> None: + + def repr(specialization): + spec_type = specialization.signature["VALUE"] + return f"kernel_{spec_type}" + + @triton.jit(repr=repr) + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + h = kernel[(1, )](value, x) + assert value_type in h.name + + +# -------------------- +# value specialization +# -------------------- + + +@pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) +def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + + if overflow: + with pytest.raises(OverflowError): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) + + +# ---------------- +# test constexpr +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): + + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) + + if op in ['<<', '>>', '&', '^', '|']: # int op + x_str = "3" if is_lhs_constexpr else "x" + y_str = "4" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="int32") + + # NOTE: bitshifting beyond bitwidth can lead to undefined behavior + if op in ['<<', '>>']: + y = numpy_random((1, ), dtype_str="int32", low=0, high=_bitwidth("int32")) + else: + y = numpy_random((1, ), dtype_str="int32") + else: + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="float32") + y = numpy_random((1, ), dtype_str="float32") + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(np.empty((1, ), dtype=z.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) + + +@pytest.mark.interpreter +def test_constexpr_shape(device): + + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + + +@pytest.mark.interpreter +def test_constexpr_scalar_shape(device): + + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + + +reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("formats", reshape_list) +def test_reshape(formats, device): + in_format, out_format = formats + + @triton.jit + def kernel(Z, X, out_tuple: tl.constexpr): + x = tl.load(X_PTR_EXPR) + z = tl.reshape(x, out_tuple) + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + } + return patch_kernel(kernel, to_replace) + + x = numpy_random(in_format, dtype_str="int32") + z = x.reshape(out_format) + x_tri = to_triton(x, device=device) + patched_kernel = generate_kernel(in_format, out_format) + z_tri = to_triton(np.empty(out_format, dtype=np.int32), device=device) + patched_kernel[(1, )](z_tri, x_tri, out_format) + np.testing.assert_equal(z, to_numpy(z_tri)) + + +def test_reshape_err(device): + + @triton.jit + def kernel(): + x = tl.arange(0, 8 * 8) + y = tl.reshape(x, (8 * 4, )) + + with pytest.raises(triton.CompilationError) as exc_info: + kernel[(1, )]() + + assert "reshape" in str(exc_info.value) + + +def test_trans_reshape(device): + + @triton.jit + def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.constexpr): + + in_block_ptr = tl.make_block_ptr( + base=in_base_ptr, + shape=(IN_SHAPE0, IN_SHAPE1), + strides=(IN_SHAPE1, 1), + offsets=(0, 0), + block_shape=(IN_SHAPE0, IN_SHAPE1), + order=(1, 0), + ) + x = tl.load(in_block_ptr) + x = tl.reshape(x, (32, 4, 4, 2)) + x = tl.permute(x, (1, 2, 3, 0)) + x = tl.reshape(x, (IN_SHAPE0 * IN_SHAPE1, )) + tl.store(out_base_ptr + tl.arange(0, IN_SHAPE0 * IN_SHAPE1), x) + + shape = (32, 32) + input = torch.arange(math.prod(shape), dtype=torch.int32, device=device).reshape(shape) + expected = torch.permute(input, (1, 0)) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) + + k = kernel[(1, )](input, actual, shape[0], shape[1]) + assert k.asm['ttgir'].count( + 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# ------------- +# test call +# ------------- + + +@triton.jit +def val_multiplier(val, i): + return val * i + + +@triton.jit(noinline=True) +def val_multiplier_noinline(val, i): + return val * i + + +@triton.jit +def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * 128 + tl.arange(0, 128) + mask = offsets < n_elements + vec = tl.load(ptr + offsets, mask=mask) + for i in range(1, rep): + if type == "inline": + vec = val_multiplier(vec, i) + else: + vec = val_multiplier_noinline(vec, i) + tl.store(ptr + offsets, vec, mask=mask) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("type", ["inline", "noinline"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_call(type, num_ctas, device): + + @triton.jit + def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): + vecmul_kernel(ptr, n_elements, num1, type) + vecmul_kernel(ptr, n_elements, num2, type) + + size = 1024 + rand_val = numpy_random((size, ), dtype_str="float32") + rand_val_tri = to_triton(rand_val, device=device) + err_msg = "" + try: + kernel[(size // 128, )](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) + except Exception as e: + err_msg = str(e) + + if type == "noinline" and not is_interpreter(): + assert err_msg != "" + else: + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) + + +# ------------- +# test if +# ------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("if_type", [ + "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", + "if_and_static" +]) +def test_if(if_type, device): + + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr): + pid = tl.program_id(0) + cond = tl.load(Cond) + if IfType == "if": + if pid % 2 == 0: # eq + tl.store(Ret, tl.load(XTrue)) + elif 1 == pid % 2: # req + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_dynamic": + val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_constexpr": + val = 3.14 if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_void": + tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_static": + tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_dynamic": + if BoolVar and (1 != pid % 2 and pid % 2 != 1): # rne and ne + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_static": + if StaticVaue != 0 and StaticVaue != 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + + cond = torch.ones(1, dtype=torch.int32, device=device) + x_true = torch.tensor([3.14], dtype=torch.float32, device=device) + x_false = torch.tensor([1.51], dtype=torch.float32, device=device) + ret = torch.zeros(1, dtype=torch.float32, device=device) + + kernel[(1, )](cond, x_true, x_false, ret, if_type, True, 1) + assert torch.equal(ret, x_true) + + +def test_num_warps_pow2(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pass + + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel[(1, )](dst=dst, num_warps=3) + _kernel[(1, )](dst=dst, num_warps=1) + _kernel[(1, )](dst=dst, num_warps=2) + _kernel[(1, )](dst=dst, num_warps=4) + + +@pytest.mark.skip +@pytest.mark.interpreter +@pytest.mark.parametrize("func_str", ['sqrt', 'rsqrt', 'exp', 'exp2', 'log', 'log2', 'sin', 'cos']) +def test_unary_math(func_str, device): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.FUNC_STR(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + kernel = patch_kernel(kernel, {'FUNC_STR': func_str}) + + shape = (128, ) + x = torch.randn(shape, dtype=torch.float32, device=device) + if func_str in ['sqrt', 'rsqrt']: + x = torch.abs(x) + if func_str in ['log', 'log2']: + x = torch.max(x, torch.tensor(1e-6, dtype=torch.float32, device=device)) + y = torch.zeros(shape, dtype=torch.float32, device=device) + + kernel[(1, )](x, y, BLOCK=shape[0]) + torch.allclose(getattr(torch, func_str)(x), y, rtol=1e-3) + + +# ----------------------- +# test inline asm +# ----------------------- + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm(num_ctas, device): + if not is_cuda() or is_corex(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, + is_pure=True, pack=1) + tl.store(Z + tl.arange(0, BLOCK), z) + + shape = (128, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint32', rs=rs) + y = numpy_random(shape, dtype_str='uint32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + n = 17 + z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = (y << n) | (x >> (32 - n)) + # compare + np.testing.assert_equal(y_ref, to_numpy(z_tri)) + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm_packed(num_ctas, device): + if not is_cuda() or is_corex(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # shift 4x8bits values together. + y = tl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", "=r,r", [ + x, + ], dtype=tl.int8, is_pure=True, pack=4) + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +@pytest.mark.parametrize('num_ctas', num_ctas_list) +def test_inline_asm_with_pointers(num_ctas, device): + if not is_cuda() or is_corex(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x_ptrs = X + tl.arange(0, BLOCK) + y_ptrs = Y + tl.arange(0, BLOCK) + tl.inline_asm_elementwise( + "ld.global.b8 $0, [$1]; \ + shl.b32 $0, $0, 3; \ + st.global.b8 [$2], $0;", "=r,l,l", [x_ptrs, y_ptrs], dtype=tl.int8, is_pure=False, + pack=1) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +def test_inline_asm_multiple_outputs(device): + if not is_cuda() or is_corex(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # C = A - B + # D = B - A + (c, d) = tl.inline_asm_elementwise( + asm=""" + sub.u32 $0, $2, $3; // C = A - B + sub.u32 $1, $3, $2; // D = B - A + """, + constraints=( + # 2 output registers: $0=C and $1=D. + "=r,=r," + # 2 input registers: $2=A and $3=B. + "r,r"), + args=[a, b], + dtype=(tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint32', rs=rs) + B = numpy_random(shape, dtype_str='uint32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A - B + D_ref = B - A + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +def test_inline_asm_packed_multiple_outputs(device): + if not is_cuda() or is_corex(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint8', rs=rs) + B = numpy_random(shape, dtype_str='float32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='int32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='float32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A.astype(np.int32) + D_ref = np.maximum(A.astype(np.float32), B) + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +# ----------------------- +# test control flow +# ----------------------- + + +@pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), + (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) +def test_for_iv(lo, hi, iv, device): + + @triton.jit + def kernel(Out, lo, hi, iv: tl.constexpr): + acc = 0 + acc = acc.to(tl.int64) + for i in range(lo, hi, iv): + acc += i + tl.store(Out, acc) + + lo = 2**35 + hi = 2**35 + 20 + out = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + kernel[(1, )](out, lo, hi, iv) + assert out[0] == sum(range(lo, hi, iv)) + + +@pytest.mark.interpreter +def test_if_else(device): + + @triton.jit + def kernel(Cond, TrueVal, FalseVal, Out): + if tl.load(Cond): + val = tl.load(TrueVal) + else: + val = tl.load(FalseVal) + tl.store(Out, val) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + true_val = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + false_val = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + cond = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # True + cond[0] = True + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == true_val[0] + # False + cond[0] = False + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == false_val[0] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["dynamic", "static"]) +def test_if_return(mode, device): + + @triton.jit + def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): + if mode == "dynamic": + if tl.load(ExitEarly): + tl.store(Out, 0) + return + else: + if cond: + tl.store(Out, 0) + return + tl.store(Out, 1) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + exit_early = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # exit early path taken + exit_early[0] = 1 + kernel[(1, )](exit_early, out, True, mode) + assert to_numpy(out)[0] == 0 + # exit early path not taken + exit_early[0] = 0 + kernel[(1, )](exit_early, out, False, mode) + assert to_numpy(out)[0] == 1 + + +@triton.jit +def add_fn(x): + return x + 1 + + +@triton.jit(noinline=True) +def add_fn_noinline(x): + return x + 1 + + +@triton.jit +def add_fn_return(x, pid): + if pid == 0: + return x + 1 + else: + return x + 2 + + +@triton.jit +def add_fn_expr(Out, x): + tl.store(Out, x) + + +@triton.jit +def add_fn_static_cond(x, cond: tl.constexpr): + if cond == "": + return x + else: + return x + 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "call_type", + ["attribute", "attribute_jit", "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) +def test_if_call(call_type, device): + + @triton.jit + def kernel(Out, call_type: tl.constexpr): + pid = tl.program_id(0) + o = tl.load(Out) + if call_type == "attribute": + # call attribute + if pid == 0: + a = o + a = a.to(tl.int32).to(tl.int32) + 1 + o = a + elif call_type == "attribute_jit": + # call attribute and jit function + if pid == 0: + a = o + a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1 + o = a + elif call_type == "jit": + if pid == 0: + # regular function call + a = o + a = add_fn(a) + o = a + elif call_type == "jit_if": + # function without end_if block + if pid == 0: + a = o + a = add_fn_return(a, pid) + o = a + elif call_type == "jit_if_exp": + # ifexp expression + if pid == 0: + a = o + a = add_fn(a) if pid == 0 else add_fn_return(a, pid) + o = a + elif call_type == "jit_expr": + # call without return + if pid == 0: + a = o + 1 + add_fn_expr(Out, a) + o = a + elif call_type == "jit_static_cond": + if pid == 0: + a = o + 1 + add_fn_static_cond(o, call_type) + o = a + elif call_type == "jit_noinline": + if pid == 0: + a = o + 1 + add_fn_noinline(a) + o = a + elif call_type == "jit_extern": + if pid == 0: + a = o + 1 + tl.cdiv(a, a) + o = a + + tl.store(Out, o) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + kernel[(1, )](out, call_type) + assert to_numpy(out)[0] == 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("_cond1", [True, False]) +@pytest.mark.parametrize("_cond2", [True, False]) +@pytest.mark.parametrize("_cond3", [True, False]) +def test_nested_if_else_return(_cond1, _cond2, _cond3, device): + + @triton.jit + def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): + val = 0 + if tl.load(Cond1): + if tl.load(Cond2): + val = tl.load(Val1) + else: + return + else: + if tl.load(Cond3): + val = tl.load(Val2) + else: + val = tl.load(Val3) + tl.store(Out, val) + + out = to_triton(np.full((1, ), -1, dtype=np.int32), device=device) + cond1 = to_triton(np.full((1, ), _cond1, dtype=np.int32), device=device) + cond2 = to_triton(np.full((1, ), _cond2, dtype=np.int32), device=device) + cond3 = to_triton(np.full((1, ), _cond3, dtype=np.int32), device=device) + val1 = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + val2 = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + val3 = to_triton(np.full((1, ), 3, dtype=np.int32), device=device) + kernel[(1, )](cond1, cond2, cond3, val1, val2, val3, out) + targets = { + (True, True, True): val1[0], + (True, True, False): val1[0], + (True, False, True): out[0], + (True, False, False): out[0], + (False, True, True): val2[0], + (False, True, False): val3[0], + (False, False, True): val2[0], + (False, False, False): val3[0], + } + assert out[0] == targets[(_cond1, _cond2, _cond3)] + + +@pytest.mark.interpreter +def test_while(device): + + @triton.jit + def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): + init_i = tl.load(InitI) + curr_i = init_i + j = 0 + # Check that init_i is not updated by the loop + while j < tl.load(Bound): + curr_i = curr_i + (j == tl.load(CutOff)) + j += 1 + tl.store(OutInitI, init_i) + tl.store(OutI, curr_i) + tl.store(OutJ, j) + + out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) + cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) + assert out_init_i[0] == init_i[0] + assert out_i[0] == init_i[0] + 1 + assert out_j[0] == bound[0] + + +@pytest.mark.interpreter +def test_nested_while(device): + + @triton.jit + def nested_while(data, countPtr): + for i in range(10): + count = tl.load(countPtr) + while count > 0: + tl.store(data, tl.load(data) + 1.0) + count = count - 2 + + counter = torch.tensor([8], dtype=torch.int32, device=device) + data = torch.zeros((1, ), device=device, dtype=torch.float32) + nested_while[(1, )](data, counter) + assert data[0] == 40 + + +# ----------------------- +# test extra +# ----------------------- + + +def test_num_threads(device): + if is_hip(): + pytest.skip("test_num_threads is not supported in HIP") + + @triton.jit + def kernel(Out): + num_threads: tl.constexpr = tl.extra.cuda.num_threads() + offs = tl.arange(0, num_threads) + tl.store(Out + offs, 1) + + num_threads = 256 + out = to_triton(np.zeros((num_threads, ), dtype=np.int32), device=device) + kernel[(1, )](out, num_warps=num_threads // 32) + assert torch.sum(out) == 256 + + +def test_globaltimer(device): + if is_hip() or is_corex(): + pytest.skip("test_globaltimer is not supported in HIP and COREX") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out1, Out2): + start = tl.extra.cuda.globaltimer() + off = tl.arange(0, 128) + for i in range(10000): + tl.store(Out1 + off, tl.load(Out1 + off) + 1) + end = tl.extra.cuda.globaltimer() + tl.store(Out2, end - start) + + out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device) + out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + h = kernel[(1, )](out1, out2) + assert out2[0] > 0 + assert h.asm["ptx"].count("%globaltimer") == 2 + + +def test_smid(device): + if is_hip() or is_corex(): + pytest.skip("test_smid is not supported in HIP and COREX") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out): + tl.store(Out + tl.program_id(0), tl.extra.cuda.smid()) + + out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) + h = kernel[(out.shape[0], )](out) + assert out.sort()[0].unique().shape[0] > 0 + assert h.asm["ptx"].count("%smid") == 1 + + +# ----------------------- +# test layout conversions +# ----------------------- +# TODO: backend should be tested separately + +layouts = [ + BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([8, 1], [16, THREADS_PER_WARP // 16], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), +] + +intermediate_layouts = [ + None, + SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +def compute_rep_shape(layout): + if type(layout) is BlockedLayout: + warp_shape = np.multiply(layout.sz_per_thread, layout.threads_per_warp) + rep_shape = np.multiply(warp_shape, layout.warps_per_cta) + return rep_shape + else: + assert False, "TODO: support compute_rep_shape for layout " + str(type(layout)) + + +# This function gives a lower bound approximation of scratch buffer shape for convert_layout operation +def compute_scratch_buffer_shape(src_layout, dst_layout, shape): + src_rep_shape = compute_rep_shape(src_layout) + dst_rep_shape = compute_rep_shape(dst_layout) + full_scratch_shape = np.maximum(src_rep_shape, dst_rep_shape) + return np.minimum(full_scratch_shape, shape) + + +@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("interm_layout", intermediate_layouts) +@pytest.mark.parametrize("dst_layout", layouts) +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): + if (M == 1 or N == 1) and interm_layout: + # TODO(jlebar): These OOB accesses don't even hit an assert in the + # compiler, and some of them return the wrong result instead of + # crashing! + pytest.skip("Out of bound access when maxPhase > 1") + if str(src_layout) == str(dst_layout): + pytest.skip() + if is_hip(): + try: + scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) + except AssertionError: + pytest.skip("Can't compute scratch buffer size") + lds_size = 65536 + # consider int32 dtype in scratch buffer size, + # because it is the largest dtype used in convert_layout in this test + int32_size = 4 + # skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding + if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size: + pytest.skip("Scratch buffer is too large") + + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ if interm_layout is None else f""" + #src = {src_layout} + #interm = {interm_layout} + #dst = {dst_layout} + """ + + conversion = f""" + %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ if interm_layout is None else f""" + %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm> + %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm> -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm> + %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm> -> tensor<{M}x{N}xf16, #src> + + %12 = triton_gpu.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + if is_corex(): + cc = torch.cuda.get_device_capability() + CC_INT = cc[0] * 10 + cc[1] + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32, triton_gpu.target = "cuda:{CC_INT}"}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, noWarpReduce=false, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, noWarpReduce=false, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}xf16, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} + }} + """ + else: + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, noWarpReduce=false, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, noWarpReduce=false, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}xf16, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} + }} + """ + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + +mma_pairs = [ + [ + MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + # Mma -> mma support is TODO on Hopper (and Volta) + # [ + # MmaLayout((3, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], +] + + +@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("mma_pair", mma_pairs) +def test_convertmma2mma(M, N, mma_pair, dtype, device): + if is_hip(): + pytest.skip("test_mma2mma is not supported in HIP") + + src_layout, _ = mma_pair + num_warps = np.cumprod(src_layout.warps_per_cta)[-1] + + def do_test(src_layout, dst_layout): + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ + + conversion = f""" + %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, noWarpReduce=false, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, noWarpReduce=false, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, noWarpReduce=false, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}xf16, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} + }} + """ + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + do_test(mma_pair[0], mma_pair[1]) + do_test(mma_pair[1], mma_pair[0]) + + +@pytest.mark.interpreter +def test_load_scalar_with_mask(device): + + @triton.jit + def kernel(Input, Index, Out, N: int): + index = tl.load(Index) + scalar = tl.load(Input + index, mask=index < N, other=0) + tl.store(Out, scalar, mask=index < N) + + Index = torch.tensor([0], dtype=torch.int32, device=device) + Input = torch.tensor([0], dtype=torch.int32, device=device) + Out = torch.empty_like(Index, device=device) + kernel[(1, )](Input, Index, Out, Index.numel()) + assert Out.data[0] == 0 + + +# This test is used to test our own PTX codegen for float16 and int16 conversions +# maybe delete it later after ptxas has been fixed +@pytest.mark.parametrize("dtype_str", ['float16', 'int16']) +def test_ptx_cast(dtype_str, device): + + @triton.jit + def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x0 = xindex + _tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype) + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r1 = rindex + tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype) + tmp1 = 2 + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(dtype) + tmp5 = _tmp4 < tmp3 + _tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4) + tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask) + + torch.manual_seed(123) + if dtype_str == 'int16': + torch_dtype = torch.int16 + triton_dtype = tl.int32 + else: + torch_dtype = torch.float16 + triton_dtype = tl.float32 + + s0 = 4 + buf11 = -torch.ones((6 * s0, 197, 197), device=device, dtype=torch_dtype) + buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) + kernel[(4728, )](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) + assert buf14.to(torch.float32).mean() == -2.0 + + +# ----------------------- +# test fp8 -> fp32 dot +# ----------------------- + + +def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + low_precision_acc: tl.constexpr, # + num_pipeline_stages: tl.constexpr = 3 # +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15']) +@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) +def test_fp8_dot_acc(in_type_str, low_precision_acc, device): + if is_hip(): + pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') + if is_cuda(): + cc = torch.cuda.get_device_capability() + if is_corex(): + pytest.skip('test_fp8_dot_acc for iluvatar does not support.') + else: + if cc[0] >= 9 and in_type_str == "float8e4b15": + pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") + check_type_supported(in_type_str, device) + M, N, K = 128, 256, 256 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128 + A = numpy_random((M, K), dtype_str=in_type_str) + B = numpy_random((K, N), dtype_str=in_type_str) + C = torch.empty((M, N), dtype=torch.float32, device=device) + num_warps = 8 + a = to_triton(A, device=device, dst_type=in_type_str) + b = to_triton(B, device=device, dst_type=in_type_str) + grid = (triton.cdiv(M, BLOCK_M), 1) + matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) + torch_a = torch.from_numpy(A).to(device=device) + th_a = f8_to_f16(torch_a, in_type_str) + torch_b = torch.from_numpy(B).to(device=device) + th_b = f8_to_f16(torch_b, in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': + torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) + elif low_precision_acc > 32: + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + else: + torch.testing.assert_close(ref_out, C) + + +# ----------------------- +# test enable_fp_fusion +# ----------------------- + + +@pytest.mark.parametrize("enable_fp_fusion", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion, device): + if is_hip(): + pytest.skip( + 'test_enable_fp_fusion for HIP currently broken in https://github.com/triton-lang/triton. Use https://github.com/ROCmSoftwarePlatform/triton' + ) + + # Sequential multiply add can be fused by backend + @triton.jit + def mul_add(data): + ptrs = data + tl.arange(0, 128) + tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) + + if not is_cuda() or is_corex(): + return + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion + + +# ----------------------- +# test propagate_nan +# ----------------------- + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) +def test_propagate_nan(dtype, propagate_nan, func, device): + + @triton.jit + def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): + if func == 'clamp': + tl.store( + C, + getattr(tl, func)(tl.load(A), -tl.load(B), tl.load(B), + propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + else: + tl.store(C, + getattr(tl, func)(tl.load(A), tl.load(B), propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + + for mode in ['A', 'B', 'both']: + if func == 'clamp' and mode == 'B': + # clamp does not guarantee propagation from 'min' and 'max' args + continue + A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'A' or mode == 'both': A[0] = torch.nan + B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'B' or mode == 'both': B[0] = torch.nan + C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) + kernel[(1, )](A, B, C, propagate_nan, func) + + if mode == 'both' or propagate_nan == 'ALL': + assert torch.isnan(C[0]) + else: + assert not torch.isnan(C[0]) + + +# ----------------------- +# test clamp +# ----------------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp(dtype, device): + + @triton.jit + def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + min = tl.load(min_ptr + off, mask=mask) + max = tl.load(max_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, min, max), mask=mask) + ref_val = tl.minimum(tl.maximum(x, min), max) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + a = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + b = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + min = torch.min(a, b) + max = torch.max(a, b) + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, min, max, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# Test for symmetric clamp(x, -limit, limit), as it may go through optimized +# codegen in the backends +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp_symmetric(dtype, device): + + @triton.jit + def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + limit = tl.load(limit_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, -limit, limit), mask=mask) + ref_val = tl.minimum(tl.maximum(x, -limit), limit) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + limit = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)).abs() + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, limit, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# ----------------------- +# test iterators +# ----------------------- + + +@pytest.mark.interpreter +def test_static_range(device): + + @triton.jit + def loop_kernel(Z, N: tl.constexpr, step: tl.constexpr): + acc = 0 + for i in tl.static_range(0, N, step=step): + acc += i + tl.store(Z, acc) + + N = 100 + step = 7 + Out = torch.empty(1, dtype=torch.int32, device=device) + loop_kernel[(1, )](Out, N, step) + Acc = torch.tensor([0], dtype=torch.int32, device=device) + for i in range(0, N, step): + Acc += i + assert (Out == Acc).all(), (Out, Acc) + + +@pytest.mark.interpreter +def test_tl_range(device): + if is_hip(): + pytest.skip("test_tl_range is not supported in HIP") + M, N, K = 64, 64, 512 + BLOCK_M, BLOCK_N, BLOCK_K = M, N, 64 + a = torch.randn((M, K), device=device, dtype=torch.float16) + b = torch.randn((K, N), device=device, dtype=torch.float16) + c = torch.empty((M, N), dtype=torch.float32, device=device) + pgm = matmul_kernel[ + 1, + ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, 0, num_pipeline_stages=5) + ref_out = torch.matmul(a, b).to(torch.float32) + if is_interpreter(): + # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. + # Thus we use a higher tolerance + torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) + else: + torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group 0x6' in ptx + + +@triton.jit(noinline=True) +def maxnreg_noinline1(X): + tl.store(X, 0) + + +@triton.jit(noinline=True) +def maxnreg_noinline2(X): + tl.store(X, 0) + + +@pytest.mark.skip(reason="compiler do not support func call in llvmir until 2025 Q2") +def test_maxnreg(device): + assert not is_interpreter(), "this test won't work with the interpreter" + if is_hip(): + pytest.skip('maxnreg only works on CUDA') + + # triton kernel + @triton.jit + def kernel(X): + maxnreg_noinline1(X) + tl.store(X, 0) + maxnreg_noinline2(X) + + X = torch.empty(1, dtype=torch.int32, device=device) + k = kernel[(1, )](X, maxnreg=42) + + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise + + +@pytest.mark.interpreter +def test_temp_var_in_loop(device): + + @triton.jit + def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): + acc = tl.full((BLOCK, ), 0, dtype=tl.int32) + for i in range(N): + if i == 0: + temp = tl.full((BLOCK, ), 2, dtype=tl.int32) + acc = temp + else: + acc += tl.full((BLOCK, ), 1, dtype=tl.int32) + # re-use the temp variable and make sure to check that it isn't creating incorrect IR. + temp = tl.full((BLOCK, ), 1, dtype=tl.int32) + acc += temp + z = Z + tl.arange(0, BLOCK) + tl.store(z, acc) + + N = 10 + BLOCK = 32 + out = torch.empty((BLOCK, ), dtype=torch.int32, device=device) + temp_in_loop[(1, )](out, N, BLOCK) + acc = torch.full((BLOCK, ), 0, dtype=torch.int32, device=device) + for i in range(N): + if i == 0: + temp = torch.full((BLOCK, ), 2, dtype=torch.int32, device=device) + acc = temp + else: + acc += torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + acc += temp + assert (acc == out).all() diff --git a/third_party/iluvatar/python/test/unit/language/test_decorator.py b/third_party/iluvatar/python/test/unit/language/test_decorator.py new file mode 100644 index 000000000..66371ba60 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_decorator.py @@ -0,0 +1,48 @@ +import torch + +import triton +import triton.language as tl +import pytest + + +def test_decorator_with_def(device): + + def triton_heuristics_pointwise(**kwargs): + + def decorator(func): + return func + + return decorator + + # "def" might appear in a decorator call, e.g. a hash string argument. + # This test makes sure the compiler can find the right position of function + # definition. + @triton_heuristics_pointwise(inductor_meta={'backend_hash': 'def0aeffabe53b3f8'}, ) + @triton.jit + def kernel(): + pass + + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + except Exception as e: + pytest.fail(f"triton compile failed with error: {e}") + + +def test_triton_heuristic(device): + N = 1023 + src = torch.empty(N, device=device) + dst = torch.zeros(N, device=device) + + @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1) + @triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs + @triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr, EVEN_N: tl.constexpr, EVEN_src: tl.constexpr): + tl.store(dst, EVEN_N) + tl.store(dst + 1, EVEN_src) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + assert dst[0].item() == 0.0 + assert dst[1].item() == 1.0 + assert _kernel.base_fn.__name__ == "_kernel" diff --git a/third_party/iluvatar/python/test/unit/language/test_iluvatar_bf16.py b/third_party/iluvatar/python/test/unit/language/test_iluvatar_bf16.py new file mode 100644 index 000000000..2bf5ef57d --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_iluvatar_bf16.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Licensed under the MIT License + +import torch + +import triton +import triton.language as tl +from triton.runtime.build import is_corex + +import pytest + +if not is_corex(): + float_dtypes = [torch.float16, torch.float32, torch.float64] +else: + float_dtypes = [torch.float16, torch.float32] +dtypes = float_dtypes +dtypes_with_bfloat16 = dtypes + [torch.bfloat16] + + +def patch_kernel(template, to_replace): + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel.src = kernel.src.replace(key, value) + return kernel + + +# --------------- +# test binary ops +# --------------- + + +@pytest.mark.parametrize("dtype_x, dtype_y, op", + [(torch.bfloat16, torch.bfloat16, op) for op in ['+', '-', '*', '/', '%']]) +def test_bin_op(dtype_x, dtype_y, op, device='cuda'): + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + x = torch.rand(SIZE, dtype=dtype_x, device=device) + 1.0 + y = torch.rand(SIZE, dtype=dtype_y, device=device) + 1.0 + + expr = f' x {op} y' + + # reference result + z_torch = eval(expr) + + # triton result + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + z_tri = torch.empty(SIZE, dtype=z_torch.dtype, device=device) + kernel[(1, )](z_tri, x, y, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_torch, z_tri, rtol=0.0, atol=0.0) + + +# --------------- +# test cast +# --------------- +@pytest.mark.parametrize("dtype_x, dtype_z", [ + (torch.float32, torch.bfloat16), + (torch.bfloat16, torch.float32), + (torch.int32, torch.bfloat16), + (torch.bfloat16, torch.int32), +]) +def test_cast(dtype_x, dtype_z, device='cuda'): + SIZE = 1 + # triton kernel + @triton.jit + def kernel(X, Z, SIZE: tl.constexpr): + x_ptr = X + tl.arange(0, SIZE) + z_ptr = Z + tl.arange(0, SIZE) + x = tl.load(x_ptr) + z = x.to(Z.dtype.element_ty) + tl.store(z_ptr, z) + + if dtype_x in [torch.int32]: + x = torch.randint(low=0, high=10, size=(SIZE, ), dtype=dtype_x, device=device) + elif dtype_x in [torch.bfloat16, torch.float32]: + x = torch.rand(SIZE, dtype=dtype_x, device=device) + # reference result + z_torch = x.to(dtype_z) + # triton result + z_tri = torch.empty(SIZE, dtype=dtype_z, device=device) + kernel[(1, )](x, z_tri, SIZE) + torch.testing.assert_close(z_torch, z_tri, rtol=0.0, atol=0.0) diff --git a/third_party/iluvatar/python/test/unit/language/test_line_info.py b/third_party/iluvatar/python/test/unit/language/test_line_info.py new file mode 100644 index 000000000..6421c7309 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_line_info.py @@ -0,0 +1,171 @@ +import subprocess +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def kernel_single(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def device_inline(x): + return x + x + + +@triton.jit +def kernel_call(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = device_inline(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit(noinline=True) +def device_noinline(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = x + x + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_call_noinline(X, Y, BLOCK: tl.constexpr): + device_noinline(X, Y, BLOCK) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK": 128}, num_warps=4), + ], + key=[], +) +@triton.jit +def kernel_autotune(X, Y, SIZE: tl.constexpr, BLOCK: tl.constexpr): + for i in range(0, SIZE, BLOCK): + x = tl.load(X + i + tl.arange(0, BLOCK)) + tl.store(Y + i + tl.arange(0, BLOCK), x) + + +# AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +# Since the + symbol will take effect in the dot op after combination, +# it seems making sense to annotate with the same line as dot. +@triton.jit +def kernel_dot_combine(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + a = (tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :]).to(tl.int8) + d = tl.dot(a, a) + d = d + c + tl.device_print("", d) + + +def get_disassembler_command_and_debug_line_format(): + """Gets backend specific disassembler information. + + Returns a tuple: (object file kind, disassembler tool command, + debug line anchor, debug line file and line number separator). + """ + backend = triton.runtime.driver.active.get_current_target().backend + + if backend == "cuda": + from triton.backends.nvidia.compiler import _path_to_binary + nvdisasm, _ = _path_to_binary("nvdisasm") + return ("cubin", [nvdisasm, "-g"], "## File", ",") + + if backend == "hip": + import shutil + # Try to find llvm-objdump from the current PATH to disassmble hsaco. + tool = shutil.which("llvm-objdump") + if tool is not None: + return ("hsaco", [tool, "-D", "-l", "--arch=amdgcn"], ";", ":") + raise RuntimeError("llvm-objdump not found in PATH") + + raise RuntimeError(f"unknown backend {backend}") + + +def extract_file_lines(command, anchor, separator, asm): + fd, path = tempfile.mkstemp() + with open(fd, 'wb') as cubin: + cubin.write(asm) + asm = subprocess.check_output(command + [path]).decode("utf-8") + file_lines = [] + lines = asm.splitlines() + for line in lines: + # We are looking for an anchor string and a separator between the file name and line number. + if anchor in line and separator in line: + entries = line[line.index(anchor):].split(separator) + if len(entries) == 2 and all(len(e) != 0 for e in entries): + file_lines.append((entries[0].strip(), entries[1].strip())) + return file_lines + + +def check_file_lines(file_lines, file_name, lineno, should_contain=True): + """ + Check if the file name and line number is in the file_lines + + Args: + file_lines: list of (file_name, line_number) + file_name: file name + lineno: line number, -1 means do not check line number + should_contain: whether the file name and line number should be in the file_lines + """ + for file, line in file_lines: + if lineno == -1: + if file_name in file: + return True + if file_name in file and str(lineno) in line: + return should_contain + return not should_contain + + +func_types = ["single", "call", "call_noinline", "autotune", "dot_combine"] + + +@pytest.mark.parametrize("func", func_types) +def test_line_info(func: str): + try: + obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() + except BaseException: + pytest.skip("disassembler is not available") + + shape = (128, ) + kernel_info = {} + if func == "single": + kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "call": + kernel_info = kernel_call.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "call_noinline": + kernel_info = kernel_call_noinline.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "autotune": + kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1,))[0] + elif func == "dot_combine": + kernel_info = kernel_dot_combine.warmup(20, grid=(1,)) + + file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) + if func == "single": + assert (check_file_lines(file_lines, "test_line_info.py", 15)) + assert (check_file_lines(file_lines, "test_line_info.py", 16)) + elif func == "call": + assert (check_file_lines(file_lines, "test_line_info.py", 28)) + assert (check_file_lines(file_lines, "test_line_info.py", 21)) + assert (check_file_lines(file_lines, "test_line_info.py", 30)) + elif func == "call_noinline": + assert (check_file_lines(file_lines, "test_line_info.py", 42)) + assert (check_file_lines(file_lines, "test_line_info.py", 35)) + assert (check_file_lines(file_lines, "test_line_info.py", 36)) + assert (check_file_lines(file_lines, "test_line_info.py", 37)) + elif func == "autotune": + assert (check_file_lines(file_lines, "test_line_info.py", 53)) + assert (check_file_lines(file_lines, "test_line_info.py", 54)) + assert (check_file_lines(file_lines, "test_line_info.py", 55)) + elif func == "dot_combine": + assert (check_file_lines(file_lines, "test_line_info.py", 65)) + assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False)) diff --git a/third_party/iluvatar/python/test/unit/language/test_random.py b/third_party/iluvatar/python/test/unit/language/test_random.py new file mode 100644 index 000000000..e0e59b069 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_random.py @@ -0,0 +1,255 @@ +import numpy as np +import pytest +import scipy.stats +import torch + +import triton +import triton.language as tl + +##################################### +# Reference Philox Implementation +##################################### + + +class PhiloxConfig: + + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + while len(res) < pad: + res.append(np.array(n, dtype=self._dtype)) + n >>= (np.dtype(self._dtype).itemsize * 8) + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +# Unit Tests +##################################### + +BLOCK: tl.constexpr = 1024 + +# test generation of random uint32 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in ['10', '4,53', '400'] + for seed in [0, 42, 124, 54, 0xffffffff, 0x0000000fcafeb0ba] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randint(size, seed, device, dtype, const_seed): + size = list(map(int, size.split(','))) + torch_dtype = getattr(torch, dtype) + numpy_dtype = getattr(np, f"u{dtype}") + config = {'int32': PHILOX_32, 'int64': PHILOX_64}[dtype] + + @triton.jit + def kernel(X, N, seed): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch_dtype, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed) + else: + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(numpy_dtype).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=config) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + + +# test uniform PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_rand(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert all((x >= 0) & (x <= 1)) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + + +# test normal PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randn(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('dtype', ['int32', 'int64']) +def test_rand_limits(dtype, device): + + @triton.jit + def kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = tl.random.uint_to_uniform_float(x) + tl.store(output + idx, y) + + torch_dtype = getattr(torch, dtype) + min_max_int = torch.tensor([ + torch.iinfo(torch_dtype).min, + torch.iinfo(torch_dtype).max, + ], dtype=torch_dtype, device=device) + output = torch.empty(2, dtype=torch.float32, device=device) + kernel[(1, )](min_max_int, output, 2) + + assert output[0] == output[1] + assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/third_party/iluvatar/python/test/unit/language/test_reproducer.py b/third_party/iluvatar/python/test/unit/language/test_reproducer.py new file mode 100644 index 000000000..a045e8f30 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_reproducer.py @@ -0,0 +1,42 @@ +import os +import shutil + +import pytest + +import torch +import triton +import re + + +@triton.jit +def triton_(): + return + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") +def test_reproducer(): + tmpdir = ".tmp" + reproducer = 'triton-reproducer.mlir' + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) + os.environ["TRITON_CACHE_DIR"] = tmpdir + os.environ["TRITON_REPRODUCER_PATH"] = reproducer + triton_[(1, )]() + foundPipeline = "" + with open(reproducer, 'r') as f: + line = f.read() + if 'pipeline:' in line: + foundPipeline = line + if 0 == len(foundPipeline): + raise Exception("Failed to find pipeline info in reproducer file.") + + ttgir_to_llvm_pass = re.compile("convert-triton-{{.*}}gpu-to-llvm") + if ttgir_to_llvm_pass.search(foundPipeline): + raise Exception("Failed to find triton passes in pipeline") + # cleanup + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) diff --git a/third_party/iluvatar/python/test/unit/language/test_standard.py b/third_party/iluvatar/python/test/unit/language/test_standard.py new file mode 100644 index 000000000..017ff36f8 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_standard.py @@ -0,0 +1,75 @@ +import triton +import pytest +import torch +import triton.language as tl + +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random + +# --------------- +# test maximum/minimum ops +# --------------- + + +# TODO: Tests with unsigned integers failed at compilation stage. +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", int_dtypes + uint_dtypes + float_dtypes + ["bfloat16"]) +@pytest.mark.parametrize("op", ["maximum", "minimum"]) +def test_maximum_minium(dtype, op, device): + expr = f'tl.{op}(x, y)' + numpy_expr = f'np.{op}(x, y)' + _test_binary(dtype, dtype, expr, numpy_expr, device=device) + + +# --------------- +# test sort op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_sort(M, N, descending, dtype_str, device): + + @triton.jit + def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.sort(x, descending=descending) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.sort(x, descending=descending)[0] + z = torch.empty_like(x) + sort_kernel[(1, )](x, z, N, M, descending, num_warps=8) + assert (y == z).all(), (y, z) + + +# --------------- +# test flip op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_flip(M, N, dtype_str, device): + + @triton.jit + def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.flip(x) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.flip(x, (1, )) + z = torch.empty_like(x, device=device) + flip_kernel[(1, )](x, z, N, M, num_warps=8) + assert (y == z).all(), (y, z) diff --git a/third_party/iluvatar/python/test/unit/language/test_subprocess.py b/third_party/iluvatar/python/test/unit/language/test_subprocess.py new file mode 100644 index 000000000..d5c188c5a --- /dev/null +++ b/third_party/iluvatar/python/test/unit/language/test_subprocess.py @@ -0,0 +1,167 @@ +import itertools +import os +import subprocess +import sys +from collections import Counter +from triton.runtime.build import is_corex + +import pytest + +dir_path = os.path.dirname(os.path.realpath(__file__)) +print_path = os.path.join(dir_path, "print_helper.py") +assert_path = os.path.join(dir_path, "assert_helper.py") + +# TODO: bfloat16 after LLVM-15 +assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"] +nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]] +if is_corex(): + torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32"] +else: + torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +# TODO: Print with multiple operands + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + ("no_arg_print", "int32"), + ("print_no_arg", "int32"), + ("device_print_large", "int32"), + ("print_multiple_args", "int32"), + ("device_print_multiple_args", "int32"), + ("device_print_hex", "int16"), + ("device_print_hex", "int32"), + ("device_print_hex", "int64"), + ("device_print_pointer", "int32"), +]) +def test_print(func_type: str, data_type: str): + proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) + outs, err = proc.communicate() + assert proc.returncode == 0 + + if is_interpreter() and func_type != "static_assert": + # Interpreter uses a different format for device_print + # Only check if there's no error + assert err == b'' + return + + outs = [line for line in outs.decode("UTF-8").split("\n") if line] + # The total number of elements in the 1-D tensor to print. + N = 128 + + # Format is + # pid (, , ) idx (, , ...) (operand ) + expected_lines = Counter() + if func_type == "print" or func_type == "device_print": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: {i}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = 1 + elif func_type == "device_print_hex": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: 0x" + if data_type == "int16": + line += f"{i:04x}" + if data_type == "int32": + line += f"{i:08x}" + if data_type == "int64": + line += f"{i:016x}" + expected_lines[line] = 1 + elif func_type == "static_print": + expected_lines[f" int32[constexpr[{N}]]"] = 1 + elif func_type == "no_arg_print": + expected_lines["pid (0, 0, 0) idx (): 0"] = N + elif func_type == "print_no_arg": + expected_lines["pid (0, 0, 0) no arg"] = N + elif func_type == "device_print_large": + for i, j, k in itertools.product(range(2), range(64), range(N)): + expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1 + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1 + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1 + elif func_type == "device_print_pointer": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}) ptr: 0x"] = 1 + + actual_lines = Counter() + for line in outs: + # Trim the exact pointer address in the output--they can change per run. + line = (line.split(':')[0] + ": 0x") if func_type == "device_print_pointer" else line + actual_lines[line] += 1 + + diff = Counter(actual_lines) + diff.subtract(expected_lines) + for line, delta in diff.items(): + if delta == 0: + continue + print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') + assert all(delta == 0 for delta in diff.values()) + + +@pytest.mark.parametrize("func_type", assert_types) +def test_assert(func_type: str): + # The total number of elements in the 1-D tensor to assert on. + N = 128 + + os.environ["TRITON_DEBUG"] = "1" + proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + shell=False) + # _, errs = proc.communicate() + outs, errs = proc.communicate() + errs = outs if is_corex() else errs + errs = errs.splitlines() + num_errs = 0 + for err in errs: + if "x != 0" in err.decode("utf-8", errors="ignore"): + num_errs += 1 + + # Check for segfaults. + assert all("segmentation fault" not in line.decode("utf-8", errors="ignore").lower() for line in errs) + + os.environ["TRITON_DEBUG"] = "0" + if func_type == "static_assert" or func_type == "device_assert_passes": + assert num_errs == 0 + else: + assert num_errs == N - 1 + + +@pytest.mark.parametrize("caller_type, callee_type", nested_types) +def test_assert_nested(caller_type, callee_type): + # The total number of elements in the 1-D tensor to assert on. + N = 128 + + proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) + # _, errs = proc.communicate() + outs, errs = proc.communicate() + errs = outs if is_corex() else errs + errs = errs.splitlines() + num_errs = 0 + for err in errs: + if "x != 0" in err.decode("utf-8", errors="ignore"): + num_errs += 1 + if caller_type == "none": + if callee_type == "true": + assert num_errs == N - 1 + else: + assert num_errs == 0 + elif caller_type == "true": + if callee_type == "false": + assert num_errs == 0 + else: + assert num_errs == N - 1 + elif caller_type == "false": + if callee_type == "true": + assert num_errs == N - 1 + else: + assert num_errs == 0 diff --git a/third_party/iluvatar/python/test/unit/operators/conftest.py b/third_party/iluvatar/python/test/unit/operators/conftest.py new file mode 100644 index 000000000..091f9ea41 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/operators/conftest.py @@ -0,0 +1,5 @@ +# content of conftest.py + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") diff --git a/python/test/unit/operators/test_blocksparse.py b/third_party/iluvatar/python/test/unit/operators/test_blocksparse.py similarity index 100% rename from python/test/unit/operators/test_blocksparse.py rename to third_party/iluvatar/python/test/unit/operators/test_blocksparse.py diff --git a/python/test/unit/operators/test_cross_entropy.py b/third_party/iluvatar/python/test/unit/operators/test_cross_entropy.py similarity index 100% rename from python/test/unit/operators/test_cross_entropy.py rename to third_party/iluvatar/python/test/unit/operators/test_cross_entropy.py diff --git a/third_party/iluvatar/python/test/unit/operators/test_dot_trans.py b/third_party/iluvatar/python/test/unit/operators/test_dot_trans.py new file mode 100644 index 000000000..190d81eda --- /dev/null +++ b/third_party/iluvatar/python/test/unit/operators/test_dot_trans.py @@ -0,0 +1,178 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Licensed under the MIT License + +import pytest +import torch +import os + +import triton +import triton.language as tl +from torch.testing import assert_close + +torch.manual_seed(0) + + +@pytest.mark.parametrize('M, N, K, AT, BT, ACol, BCol, num_warps, disable_sme, dataType', + [(M, N, K, AT, BT, ACol, BCol, num_warps, disable_sme, dataType) + for M in [32, 64, 128] + for N in [32, 64] + for K in [32, 64] + for AT in [False, True] + for BT in [False, True] + for ACol in [False, True] + for BCol in [False, True] + for num_warps in [1, 2, 4] + for disable_sme in ["0", "1"] + for dataType in ["float16", "bfloat16"]]) +def test_sme_and_swizzle_layout_trans(M, N, K, AT, BT, ACol, BCol, num_warps, disable_sme, dataType, device='cuda'): + + @triton.jit + def kernel( + A, + B, + C, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + A_T: tl.constexpr, + B_T: tl.constexpr, + ): + off_m = tl.arange(0, BLOCK_M) + off_mk = tl.arange(0, BLOCK_K) + if A_T: + off_m = tl.arange(0, BLOCK_K) + off_mk = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_nk = tl.arange(0, BLOCK_K) + if B_T: + off_n = tl.arange(0, BLOCK_K) + off_nk = tl.arange(0, BLOCK_N) + off_cm = tl.arange(0, BLOCK_M) + off_cn = tl.arange(0, BLOCK_N) + a = A + off_m[:, None] * stride_am + off_mk[None, :] * stride_ak + b = B + off_nk[:, None] * stride_bk + off_n[None, :] * stride_bn + C = C + off_cm[:, None] * stride_cm + off_cn[None, :] * stride_cn + x = tl.load(a) + y = tl.load(b) + if A_T: + x = tl.trans(x) + if B_T: + y = tl.trans(y) + z = tl.dot(x, y) + tl.store(C, z) + + os.environ['TRITON_DISABLE_SME'] = disable_sme #when disable_sme=1, this test swizzle trans + #run test + dataType = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dataType] + a = .1 * torch.randn((K, M) if (AT ^ ACol) else (M, K), device='cuda', dtype=dataType) + b = .1 * torch.randn((N, K) if (BT ^ BCol) else (K, N), device='cuda', dtype=dataType) + + tt_c = .1 * torch.randn((M, N), device='cuda', dtype=dataType) + tt_a = a + tt_b = b + + if ACol: + tt_a = a.t() + if BCol: + tt_b = b.t() + + # triton result + kernel[(1, 1)](tt_a, tt_b, tt_c, tt_a.stride(0), tt_a.stride(1), tt_b.stride(0), tt_b.stride(1), tt_c.stride(0), + tt_c.stride(1), BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, A_T=AT, B_T=BT, num_warps=num_warps) + + th_a = a.t() if (AT ^ ACol) else a + th_b = b.t() if (BT ^ BCol) else b + #torch result + th_c = torch.matmul(th_a, th_b) + assert_close(tt_c, th_c, atol=1e-2, rtol=0) + + +@pytest.mark.parametrize('M, N, K, AT, BT, CT, num_warps, dataType', [(M, N, K, AT, BT, CT, num_warps, dataType) + for M in [32, 64, 128] + for N in [32, 64] + for K in [32, 64] + for AT in [False, True] + for BT in [False, True] + for CT in [False, True] + for num_warps in [1, 2, 4] + for dataType in ["float16", "bfloat16"]]) +def test_multi_dot_trans(M, N, K, AT, BT, CT, num_warps, dataType, device='cuda'): + + @triton.jit + def kernel( + A, + B, + C, + D, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_dm, + stride_dn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + A_T: tl.constexpr, + B_T: tl.constexpr, + C_T: tl.constexpr, + ): + off_m = tl.arange(0, BLOCK_M) + off_mk = tl.arange(0, BLOCK_K) + if A_T: + off_m = tl.arange(0, BLOCK_K) + off_mk = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_nk = tl.arange(0, BLOCK_K) + if B_T: + off_n = tl.arange(0, BLOCK_K) + off_nk = tl.arange(0, BLOCK_N) + off_cm = tl.arange(0, BLOCK_M) + off_cn = tl.arange(0, BLOCK_N) + if C_T: + off_cm = tl.arange(0, BLOCK_N) + off_cn = tl.arange(0, BLOCK_M) + off_dn = tl.arange(0, BLOCK_N) + a = A + off_m[:, None] * stride_am + off_mk[None, :] * stride_ak + b = B + off_nk[:, None] * stride_bk + off_n[None, :] * stride_bn + c = C + off_cm[:, None] * stride_cm + off_cn[None, :] * stride_cn + x = tl.load(a) + y = tl.load(b) + w = tl.load(c) + if A_T: + x = tl.trans(x) + if B_T: + y = tl.trans(y) + if C_T: + w = tl.trans(w) + z = tl.dot(x, y) + z = z.to(C.dtype.element_ty) + p = tl.dot(tl.trans(z), w) + D = D + off_dn[:, None] * stride_dm + off_dn[None, :] * stride_dn + tl.store(D, p) + + #run test + dataType = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dataType] + a = .1 * torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dataType) + b = .1 * torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dataType) + c = .1 * torch.randn((N, M) if CT else (M, N), device='cuda', dtype=dataType) + d = .1 * torch.randn((N, N), device='cuda', dtype=dataType) + # triton result + kernel[(1, 1)](a, b, c, + d, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), d.stride(0), + d.stride(1), BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, A_T=AT, B_T=BT, C_T=CT, num_warps=num_warps) + ta = a.t() if AT else a + tb = b.t() if BT else b + tc = c.t() if CT else c + #torch result + th_c = torch.matmul(torch.matmul(ta, tb).t(), tc) + assert_close(d, th_c, atol=1e-2, rtol=0) diff --git a/third_party/iluvatar/python/test/unit/operators/test_flash_attention.py b/third_party/iluvatar/python/test/unit/operators/test_flash_attention.py new file mode 100644 index 000000000..7c251868d --- /dev/null +++ b/third_party/iluvatar/python/test/unit/operators/test_flash_attention.py @@ -0,0 +1,119 @@ +import pytest +import torch +import os + +import triton +import triton.ops +from triton.runtime.build import is_corex + + +@pytest.mark.interpreter +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ # + (2, 4, 512, 16), + (2, 4, 512, 32), + (2, 4, 512, 64), + (2, 4, 512, 128), +]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('seq_par', [True, False]) +def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device): + capability = torch.cuda.get_device_capability() + if not is_corex(): + if capability[0] < 8: + pytest.skip("Flash attention only supported for compute capability >= 80") + if dtype == torch.bfloat16 and os.environ.get("TRITON_INTERPRET", "0") == "1": + pytest.skip("Flash attention bfloat16 not supported in interpreter mode") + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_() + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device=device)) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).to(dtype) + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # # triton implementation + tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_out), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_out), dim=0), atol=atol, rtol=0) + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dv), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_dv), dim=0), atol=atol, rtol=0) + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dk), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_dk), dim=0), atol=atol, rtol=0) + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dq), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_dq), dim=0), atol=atol, rtol=0) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], x_vals=[2**i for i in range(10, 14)], line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}', args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'casual': casual, + 'seq_par': seq_par, + }) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False] +] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + sm_scale = 1.3 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + if provider == "triton": + fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# only works on post-Ampere GPUs right now +# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/python/test/unit/operators/test_inductor.py b/third_party/iluvatar/python/test/unit/operators/test_inductor.py similarity index 100% rename from python/test/unit/operators/test_inductor.py rename to third_party/iluvatar/python/test/unit/operators/test_inductor.py diff --git a/third_party/iluvatar/python/test/unit/operators/test_matmul.py b/third_party/iluvatar/python/test/unit/operators/test_matmul.py new file mode 100644 index 000000000..fcb6b838f --- /dev/null +++ b/third_party/iluvatar/python/test/unit/operators/test_matmul.py @@ -0,0 +1,209 @@ +import itertools + +import pytest +import torch + +import triton +import triton.language as tl +import triton.ops +from triton.runtime.build import is_corex + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@pytest.mark.parametrize( + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, INPUT_PRECISION, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE", + itertools.chain( + *[[ + # 1 warp + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # 2 warp + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # 4 warp + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # 8 warp + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # variable input + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]], + # n-stage + *[[ + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, None, True, None, None), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + ] + for DTYPE in ["float16", "bfloat16", "float32"] + for AT in [False, True] + for BT in [False, True] + for STAGES in [4]], + # tf32x3 + *[[ + (16, 16, 16, 1, 1, 2, 32, 32, 80, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (64, 32, 64, 1, 2, 2, 128, 64, 128, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (128, 64, 16, 1, 4, 2, 256, 128, 80, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (256, 128, 32, 1, 8, 2, 512, 256, 160, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (128, 128, 32, 1, 4, 2, 256, 256, 160, AT, BT, "float32", "float32", "tf32x3", True, None, None), + ] for AT in [False, True] for BT in [False, True]], + # mixed-precision + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float8e5"), + ("float8e4nv", "float8e4nv"), + ("float8e5", "float8e4nv"), + ("float8e5", "float8e5"), + ("float8e4b15", "float8e4b15"), + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("int8", "bfloat16"), + ("float16", "int8"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]], + # mixed-precision block layout + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, True, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, True, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, None, True, None, None), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True]], + # acc-out-dtype and output_dtype + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, False, False, "float16", "float16", None, True, ACC_DTYPE, + OUTPUT_DTYPE), + (128, 256, 32, 1, 8, 2, None, None, None, False, False, "float16", "float16", None, True, ACC_DTYPE, + OUTPUT_DTYPE), + # ] for ACC_DTYPE in [None, "float16", "float32"] for OUTPUT_DTYPE in [None, "float16", "float32"]], + ] for ACC_DTYPE in [None, "float32"] for OUTPUT_DTYPE in [None, "float16", "float32"]], + ), +) +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, INPUT_PRECISION, + F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE): + capability = torch.cuda.get_device_capability() + if is_corex(): + if "float8" in ADTYPE or "float8" in BDTYPE: + pytest.skip("Iluvatar devices do not support float8 for now") + if (ADTYPE == "int8" and BDTYPE == "bfloat16") or (ADTYPE == "float16" and BDTYPE == "int8"): + pytest.skip("Iluvatar devices do not support this for now") + else: + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8 and (ADTYPE == "bfloat16" or BDTYPE == "bfloat16"): + pytest.skip("Only test bfloat16 on devices with sm >= 80") + if capability[0] < 9 and (ADTYPE == "float8e4nv" or BDTYPE == "float8e4nv"): + pytest.skip("Only test float8e4nv on devices with sm >= 90") + if (ADTYPE == "bfloat16" or BDTYPE == "bfloat16") and SPLIT_K != 1: + pytest.skip("bfloat16 matmuls don't allow split_k for now") + torch.manual_seed(0) + # nuke kernel decorators -- will set meta-parameters manually + kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} + pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() + configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] + kernel = triton.ops._matmul.kernel + kernel.configs = configs + # kernel.run = kernel.run.run.run + + # get matrix shape + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K * SPLIT_K if K is None else K + + def is_fp8(dtype): + return "float8" in dtype + + def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty_strided(x.shape, x.stride(), dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + def upcast_if_fp8(x, dtype): + if is_fp8(dtype): + return f8_to_f16(x, dtype) + return x + + def init_input(m, n, dtype, acc_dtype): + if 'float8' in dtype: + ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype] + sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128 + val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth + return sign | val + if dtype == "int8": + return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) + # Use small range of values to prevent numerical issues. + min_exp = -4 if acc_dtype == "float16" else -10 + exponents = torch.randint(min_exp, 0, size=(m, n)) + ret = (2.**exponents).to(getattr(torch, dtype)).to("cuda") + return ret + + if is_hip(): + if INPUT_PRECISION == 'tf32x3' or is_fp8(ADTYPE) or is_fp8(BDTYPE): + pytest.skip("fp8 inputs or tf32x3 precison does not have native support on hip") + # allocate/transpose inputs + a = init_input(M, K, ADTYPE, ACC_DTYPE) + b = init_input(K, N, BDTYPE, ACC_DTYPE) + a = a if not AT else a.T.contiguous().T + b = b if not BT else b.T.contiguous().T + # run test + th_a = upcast_if_fp8(a, ADTYPE) + th_b = upcast_if_fp8(b, BDTYPE) + ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) + acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype + output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype + th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype)) + try: + if is_fp8(ADTYPE): + a = triton.reinterpret(a, getattr(tl, ADTYPE)) + if is_fp8(BDTYPE): + b = triton.reinterpret(b, getattr(tl, BDTYPE)) + tt_c = triton.ops.matmul(a, b, acc_dtype if ACC_DTYPE else None, INPUT_PRECISION, F8_FASTACCUM, output_dtype) + torch.testing.assert_close(th_c, tt_c) + except triton.OutOfResources as e: + pytest.skip(str(e)) diff --git a/third_party/iluvatar/python/test/unit/runtime/test_autotuner.py b/third_party/iluvatar/python/test/unit/runtime/test_autotuner.py new file mode 100644 index 000000000..a0e41a46d --- /dev/null +++ b/third_party/iluvatar/python/test/unit/runtime/test_autotuner.py @@ -0,0 +1,140 @@ +import torch + +import triton +import triton.language as tl +import pytest +from triton.runtime.build import is_corex + + +@pytest.mark.parametrize('use_cuda_graph', [False, True]) +def test_kwargs(use_cuda_graph: bool): + N = 1024 + src = torch.empty(N, device='cuda') + dst = torch.empty(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N) + _kernel[grid](dst=dst, src=src, N=N) + + +def test_restore(): + N = 1024 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1) + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](src, N) + triton.testing.assert_close(src, torch.ones_like(src)) + + +@pytest.mark.skip(reason="iluvatar") +def test_hooks(): + # Autotuner's pre- and post- hooks should be called the same number of times + N = 4096 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 4096}), triton.Config(kwargs={'BLOCK_SIZE': 32})] + + values = {"counter": 0, "has_exception": False} + + def _pre_hook(*args, **kwargs): + values["counter"] += 1 + + def _post_hook(*args, exception): + values["counter"] -= 1 + if exception is not None: + values["has_exception"] = True + assert values["counter"] == 0 + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook) + @triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4}) + @triton.jit + def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + max_iters = tl.cdiv(N, BLOCK_SIZE) + for _ in tl.range(max_iters, num_stages=N_STAGES): + x = tl.load(src + offsets, mask=offsets < N) + tl.store(src + offsets, x, mask=offsets < N) + offsets += BLOCK_SIZE + + _kernel[(1, )](src, N) + + # On NVIDIA GPUs: + # The tunning knob `num_stages` can be set by users. + # This will cause out of resources when N_STAGES = 100 + # shared memory bytes = N_STAGES * BLOCK_SIZE * sizeof(float) + # On AMD GPUs: + # `num_stages` is a fixed value of 2, so it won't cause out of resources + if triton.runtime.driver.active.get_current_target().backend == "cuda": + assert values["has_exception"] is True + else: + assert values["has_exception"] is False + + +@pytest.mark.skip(reason="iluvatar") +@pytest.mark.parametrize('with_perf_model', [False, True]) +def test_prune_configs(with_perf_model: bool): + # N = 1024 + if is_corex(): + if with_perf_model: + N = 512 + else: + N = 1024 + src = torch.empty(N, device='cuda') + dst = torch.empty(N, device='cuda') + records = {} + + def early_config_prune(configs, named_args, **kwargs): + records['run_early_config_prune'] = True + if "N" in kwargs and kwargs["N"] == 1024: + records['capture_kwargs'] = True + if "dst" in named_args and "src" in named_args and len(named_args) == 2: + records['capture_named_args'] = True + return [configs[0]] + + def perf_model(*args, **kwargs): + records['run_perf_model'] = True + return kwargs['BLOCK_SIZE'] + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + if with_perf_model: + prune_configs_by = {'perf_model': perf_model, 'top_k': 1} + else: + prune_configs_by = {'early_config_prune': early_config_prune} + + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + torch.testing.assert_close(src, dst) + if with_perf_model: + assert len(records) == 1 + assert records['run_perf_model'] + else: + assert len(records) == 3 + assert records['run_early_config_prune'] + assert records['capture_kwargs'] + assert records['capture_named_args'] diff --git a/third_party/iluvatar/python/test/unit/runtime/test_bindings.py b/third_party/iluvatar/python/test/unit/runtime/test_bindings.py new file mode 100644 index 000000000..c48ba9b4a --- /dev/null +++ b/third_party/iluvatar/python/test/unit/runtime/test_bindings.py @@ -0,0 +1,81 @@ +import triton +import triton.language as tl + +import torch + + +@triton.jit +def add_helper(x, y): + return x + y + + +@triton.jit +def add_kernel( + in_ptr0, + in_ptr1, + n_elements, + out_ptr, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = add_helper(x, y) + tl.store(out_ptr + offsets, output, mask=mask) + + +def test_module_walk(): + """ + Test the MLIR bindings exposed for the out-ot-tree walk. + """ + + def walk_fn(op): + name = op.get_name() + for i in range(op.get_num_results()): + op.get_result(i).id() + for i in range(op.get_num_operands()): + op.get_operand(i).id() + for i in range(op.get_num_regions()): + op.get_region(i).id() + block = op.get_block() + if block is not None: + block.id() + for i in range(block.get_num_arguments()): + block.get_argument(i) + if name == "tt.func": + op.get_str_attr("sym_name") + if name == "tt.call": + op.get_flat_symbol_ref_attr("callee") + + kernel = add_kernel + args = [ + torch.empty((32, 32), device="cuda"), # in_ptr0 + torch.empty((32, 32), device="cuda"), # in_ptr1 + 1024, # n_elements + torch.empty((32, 32), device="cuda"), # out_ptr + 16, # BLOCK_SIZE + ] + src = triton.compiler.compiler.ASTSource( + fn=kernel, + signature={i: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args) + if i not in kernel.constexprs}, + constants={i: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + attrs=kernel._get_config(*args, ), + ) + + context = triton._C.libtriton.ir.context() + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options(dict()) + codegen_fns = dict() + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + ttir_module = src.make_ir(options, codegen_fns, context) + ttir_module.walk(walk_fn) diff --git a/third_party/iluvatar/python/test/unit/runtime/test_cache.py b/third_party/iluvatar/python/test/unit/runtime/test_cache.py new file mode 100644 index 000000000..4387c47b0 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/runtime/test_cache.py @@ -0,0 +1,537 @@ +import importlib.util +import itertools +import os +import shutil +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl +from triton.runtime.jit import JITFunction +from triton.runtime.build import is_corex + +tmpdir = ".tmp" + + +@triton.jit +def function_1(i): + i = i + 1 + i = function_2(i) + return i + + +@triton.jit +def function_2(i): + i = i + 1 + return i + + +@triton.jit +def combine_fn(a, b): + return COMBINE_OP # noqa: F821 + + +@triton.jit +def kernel(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize=["i"]) +def kernel_nospec(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit +def kernel_with_combine_fn(X, BLOCK: tl.constexpr): + i = tl.arange(0, BLOCK) + i = REDUCE_OR_SCAN(i, 0, combine_fn) # noqa: F821 + tl.store(X, i) + + +def apply_src_change(target, old, new): + kernel.hash = None + function_1.hash = None + function_2.hash = None + function_1.src = function_1.src.replace(old, new) + target.src = target.src.replace(old, new) + ret = target.cache_key + target.src = target.src.replace(new, old) + return ret + + +def test_nochange(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 1') + assert baseline == updated + + +def test_toplevel_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2') + assert baseline != updated + + +def test_nested1_change(): + baseline = kernel.cache_key + updated = apply_src_change(function_1, 'i + 1', 'i + 2') + assert baseline != updated + + +def test_combine_fn_change(): + # Test that tl.reduce and associative_scan calls include + # the combine_fn in the hash + + orig_combine_fn_src = combine_fn.src + orig_kernel_src = kernel_with_combine_fn.src + seen_keys = set() + + for reduce_or_scan, combine_op in itertools.product( + ["tl.reduce", "tl.associative_scan"], + ["a + b", "a * b"], + ): + combine_fn.src = orig_combine_fn_src.replace("COMBINE_OP", combine_op) + kernel_with_combine_fn.src = orig_kernel_src.replace("REDUCE_OR_SCAN", reduce_or_scan) + try: + key = kernel_with_combine_fn.cache_key + finally: + combine_fn.src = orig_combine_fn_src + kernel_with_combine_fn.src = orig_kernel_src + + kernel_with_combine_fn.hash = None + combine_fn.hash = None + + assert key not in seen_keys + seen_keys.add(key) + + +def write_and_load_module(code, num_extra_lines): + with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: + f.write(('# extra line\n' * num_extra_lines) + code) + f.flush() + spec = importlib.util.spec_from_file_location("module.name", f.name) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_changed_line_numbers_invalidate_cache(): + from textwrap import dedent + code = dedent(""" + import triton + @triton.jit + def test_kernel(i): + i = i + 1 + """) + orig_mod = write_and_load_module(code, 0) + orig_cache_key = orig_mod.test_kernel.cache_key + + updated_mod = write_and_load_module(code, 1) + updated_cache_key = updated_mod.test_kernel.cache_key + assert orig_cache_key != updated_cache_key + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + # https://stackoverflow.com/questions/303200/how-do-i-remove-delete-a-folder-that-is-not-empty + shutil.rmtree(tmpdir, ignore_errors=True) + + +def test_reuse(): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device='cuda') + for i in range(10): + kernel[(1, )](x, 1, BLOCK=1024) + assert counter == 1 + + +@pytest.mark.parametrize('mode', ['enable', 'disable']) +def test_specialize(mode): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device='cuda') + function = {'enable': kernel, 'disable': kernel_nospec}[mode] + target = {'enable': 3, 'disable': 1}[mode] + for i in [1, 2, 4, 8, 16, 32]: + function[(1, )](x, i, BLOCK=512) + assert counter == target + + +def test_annotation(): + + @triton.jit + def kernel(X, i: tl.int32): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + + device = torch.cuda.current_device() + kernel[(1, )](x, 1) + kernel[(1, )](x, 8) + kernel[(1, )](x, 16) + kernel[(1, )](x, 17) + assert len(kernel.cache[device]) == 3 + + +GLOBAL_DEFAULT_ARG = 1 + + +def test_kernel_default_arg(): + global GLOBAL_DEFAULT_ARG + + @triton.jit + def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + kernel[(1, )](x) + assert x == torch.ones_like(x) + + # Changing the global variable should not change the default argument in + # `kernel`. That value gets set at the time the function is declared. + GLOBAL_DEFAULT_ARG = 2 + kernel[(1, )](x) + assert x == torch.ones_like(x) + + device = torch.cuda.current_device() + assert len(kernel.cache[device]) == 1 + + +GLOBAL_VAR: tl.constexpr = 1 + + +def test_kernel_global_var_change(): + global GLOBAL_VAR + + @triton.jit + def kernel(X): + tl.store(X, GLOBAL_VAR) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + kernel[(1, )](x) + assert x == torch.ones_like(x) + + GLOBAL_VAR = 2 + with pytest.raises(RuntimeError) as e: + kernel[(1, )](x) + + assert "global variable" in str(e.value).lower() + + +GLOBAL = 42 # noqa + + +def test_local_shadows_global(): + global GLOBAL + + @triton.jit + def kernel(): + _, GLOBAL = 0, 0 # noqa + a = GLOBAL # noqa + + # No error because the `GLOBAL` we're modifying is not the same `GLOBAL` as + # inside the kernel. + GLOBAL = 42 + kernel[(1, )]() + GLOBAL = 43 + kernel[(1, )]() + + +CONSTEXPR_GLOBAL: tl.constexpr = 42 + + +def test_local_does_not_shadow_global(): + global CONSTEXPR_GLOBAL + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + _, CONSTEXPR_GLOBAL = 0, 0 # noqa + + CONSTEXPR_GLOBAL = 42 + kernel[(1, )]() + CONSTEXPR_GLOBAL = 43 + + # Error because the `CONSTEXPR_GLOBAL` we're modifying is the same + # `CONSTEXPR_GLOBAL` that's read inside `kernel`. (Alternatively, we could + # make this kernel an error altogether, as it is if it's a pure Python + # function -- the fact that we store to `CONSTEXPR_GLOBAL` inside the kernel + # makes the first read a read of the local variable, which doesn't exist + # yet.) + with pytest.raises(RuntimeError): + kernel[(1, )]() + + +CONFLICTING_GLOBAL: tl.constexpr = 0 + + +@triton.jit +def conflicting_global_inner(): + a = CONFLICTING_GLOBAL # noqa + + +def test_conflicting_global_in_inner_function(): + global CONFLICTING_GLOBAL + + @triton.jit + def kernel1(): + a = CONFLICTING_GLOBAL # noqa + conflicting_global_inner() + + @triton.jit + def kernel2(): + a = CONFLICTING_GLOBAL #noqa + conflicting_global_inner() + + kernel1[(1, )]() + + # This should be an error because kernel2 calls conflicting_global_inner, + # which saw a value for 42 for the global when it was first compiled. + CONFLICTING_GLOBAL = 1 + + with pytest.raises(RuntimeError) as e: + kernel2[(1, )]() + + assert "Global variable CONFLICTING_GLOBAL has value" in str(e.value) + + +def test_use_builtin(): + + @triton.jit + def kernel(): + a = float(0) # noqa + + # No error about the value of `float` changing. + kernel[(1, )]() + kernel[(1, )]() + + +def test_no_cache_module_as_global(): + + @triton.jit + def kernel(): + tl.arange(0, 16) + + kernel[(1, )]() + # `tl` should not be entered into used_global_vals + assert not kernel.used_global_vals + + +BUILTIN_AS_GLOBAL = tl.int32 + + +def test_cache_builtin_as_global(): + global BUILTIN_AS_GLOBAL + + @triton.jit + def kernel(): + x = BUILTIN_AS_GLOBAL # noqa + + kernel[(1, )]() + + BUILTIN_AS_GLOBAL = tl.int64 + with pytest.raises(RuntimeError) as e: + kernel[(1, )]() + + assert "global variable" in str(e.value).lower() + + +@triton.jit +def no_cache_callable_inner(): + pass + + +def test_no_cache_callable(): + + @triton.jit + def kernel(): + no_cache_callable_inner() + + kernel[(1, )]() + # `no_cache_callable_inner` should not be entered into used_global_vals. + assert not kernel.used_global_vals + + +def test_constexpr_not_callable() -> None: + + @triton.jit + def kernel(X, c: tl.constexpr): + tl.store(X, 2) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + error = False + try: + kernel[(1, )](x, c="str") + except BaseException: + error = True + assert error is False + # try and catch + try: + kernel[(1, )](x, c=tl.abs) + except BaseException: + error = True + assert error is True + + +def test_jit_warmup_cache() -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + 32, + ] + device = torch.cuda.current_device() + assert len(kernel_add.cache[device]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.cache[device]) == 2 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.cache[device]) == 2 + + +def test_jit_debug() -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + device = torch.cuda.current_device() + assert len(kernel_add.cache[device]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.debug = False + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 2 + kernel_add.debug = True + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 3 + bins = list(kernel_add.cache[device].values()) + assert bins[2].asm['ttir'] != bins[1].asm['ttir'] + + +@triton.jit +def add_fn(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + +def test_jit_noinline() -> None: + if is_corex(): + pytest.skip("not supported for iluvatar devices") + + @triton.jit + def kernel_add_device(a, b, o, N: tl.constexpr): + add_fn(a, b, o, N) + + device = torch.cuda.current_device() + assert len(kernel_add_device.cache[device]) == 0 + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + inline_ttir = bins[0].asm['ttir'] + add_fn.noinline = True + add_fn.hash = None + kernel_add_device.hash = None + kernel_add_device.cache[device].clear() + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + noinline_ttir = bins[0].asm['ttir'] + assert inline_ttir != noinline_ttir + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + +def test_preload() -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) + + device = torch.cuda.current_device() + + # get the serialized specialization data + specialization_data = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + pre_compile = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + hash = pre_compile.hash + assert specialization_data is not None + + # clear the cache + reset_tmp_dir() + kernel_add.cache[device].clear() + + # preload the kernel + kernel_preload = kernel_add.preload(specialization_data) + assert kernel_preload.hash == hash + assert len(kernel_add.cache[device]) == 1 + + # we should hit the cache and not compile anything + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + JITFunction.cache_hook = None + assert counter == 0 + assert len(kernel_add.cache[device]) == 1 + assert final_kernel.hash == hash + + # test that we can't preload a mismatched kernel + with pytest.raises(RuntimeError, match="Specialization data is for"): + kernel_sub.preload(specialization_data) diff --git a/third_party/iluvatar/python/test/unit/runtime/test_driver.py b/third_party/iluvatar/python/test/unit/runtime/test_driver.py new file mode 100644 index 000000000..de00082f5 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/runtime/test_driver.py @@ -0,0 +1,14 @@ +import sys + +import triton + + +def test_is_lazy(): + from importlib import reload + reload(sys.modules["triton.runtime.driver"]) + reload(sys.modules["triton.runtime"]) + mod = sys.modules[triton.runtime.driver.__module__] + assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy")) + assert triton.runtime.driver.active._obj is None + utils = triton.runtime.driver.active.utils # noqa: F841 + assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase")) diff --git a/third_party/iluvatar/python/test/unit/runtime/test_jit.py b/third_party/iluvatar/python/test/unit/runtime/test_jit.py new file mode 100644 index 000000000..5892494c4 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/runtime/test_jit.py @@ -0,0 +1,42 @@ +import itertools +import pytest +import torch + +import triton +import triton.language as tl + + +def test_pre_call_hooks(device): + + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + class MyTensor(torch.Tensor): + pass + + def my_hook(*args, **kwargs): + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, MyTensor): + raise Exception("MyTensor is not allowed") + + add_kernel.add_pre_run_hook(my_hook) + + x = torch.randn(4, device=device) + y = MyTensor(x) + out = torch.zeros_like(x) + with pytest.raises(Exception): + add_kernel[(4, )](x, y, out, 4, 4) diff --git a/third_party/iluvatar/python/test/unit/runtime/test_launch.py b/third_party/iluvatar/python/test/unit/runtime/test_launch.py new file mode 100644 index 000000000..f17c05674 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/runtime/test_launch.py @@ -0,0 +1,134 @@ +import gc +# import importlib +# import os +# import sys +# import tempfile +# import textwrap +# import time +import tracemalloc + +import torch + +import triton +import triton.language as tl + +# from typing import Tuple + + +def test_metadata() -> None: + + used_hook = False + + def _launch_metadata(grid, kernel, args): + ret = dict() + ret["grid"] = grid + ret["value"] = args["x"] + return ret + + def hook(launch_metadata): + nonlocal used_hook + metadata = launch_metadata.get() + assert metadata["grid"] == (1, 3, 2) + assert metadata["value"] == 6 + used_hook = True + + @triton.jit(launch_metadata=_launch_metadata) + def kernel(x): + pass + + # launch kernel + triton.compiler.CompiledKernel.launch_enter_hook = hook + kernel[(1, 3, 2)](6) + triton.compiler.CompiledKernel.launch_enter_hook = None + assert used_hook + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + tracemalloc.start() + try: + inp = torch.randn(10, device='cuda') + out = torch.randn(10, device='cuda') + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + begin, _ = tracemalloc.get_traced_memory() + for _ in range(100): + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + end, _ = tracemalloc.get_traced_memory() + assert end - begin < 30000 + finally: + tracemalloc.stop() + + +# LATENCY_THRESHOLD_US = 46 + +# def test_kernel_launch_latency() -> None: +# def define_kernel(kernel_name: str, num_tensor_args: int) -> str: +# arg_str = ",".join([f"arg{i}: torch.Tensor" for i in range(num_tensor_args)]) +# arg_str += ", n_elements: int, BLOCK_SIZE: tl.constexpr" +# func_str = f""" +# import torch + +# import triton +# import triton.language as tl + +# @triton.jit +# def {kernel_name}({arg_str}): +# pass +# """ +# with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py", delete=False) as temp_file: +# temp_file.write(textwrap.dedent(func_str)) +# temp_file_path = temp_file.name + +# return temp_file_path + +# def import_kernel(file_path, kernel_name): +# directory, filename = os.path.split(file_path) +# module_name, _ = os.path.splitext(filename) +# sys.path.insert(0, directory) + +# module = importlib.import_module(module_name) +# kernel = getattr(module, kernel_name) +# return kernel + +# def empty(*kernel_args: Tuple[torch.Tensor]): +# first_arg = kernel_args[0] +# n_elements = first_arg.numel() +# grid = (triton.cdiv(n_elements, 1024),) +# device = torch.cuda.current_device() +# # Warmup +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# torch.cuda.synchronize() +# # Measure launch overhead at steady state +# num_runs = 1000 +# start_time = time.time() +# for i in range(num_runs): +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# end_time = time.time() +# latency_us = (end_time - start_time) / num_runs * 1e6 + +# assert latency_us < LATENCY_THRESHOLD_US, "Kernel launch time has increased!" + +# num_tensor_args = 40 +# kernel_name = 'empty_kernel' +# file_path = define_kernel(kernel_name, num_tensor_args) +# empty_kernel = import_kernel(file_path, kernel_name) + +# # Initialize random tensors for the empty_kernel +# torch.manual_seed(0) +# size = 1024 +# kernel_args = (torch.rand(size, device='cuda') for i in range(num_tensor_args)) + +# # Run empty, which would run empty_kernel internally +# empty(*kernel_args) diff --git a/third_party/iluvatar/python/test/unit/runtime/test_subproc.py b/third_party/iluvatar/python/test/unit/runtime/test_subproc.py new file mode 100644 index 000000000..333d1f929 --- /dev/null +++ b/third_party/iluvatar/python/test/unit/runtime/test_subproc.py @@ -0,0 +1,73 @@ +import multiprocessing +import os +import shutil + +import torch + +import triton +import triton.language as tl +from triton.compiler import ASTSource + +tmpdir = ".tmp" + +target = triton.runtime.driver.active.get_current_target() + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + + +def compile_fn(attrs, capability): + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + + src = ASTSource( + fn=kernel_sub, + constants={3: 32}, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + attrs=attrs, + ) + triton.compile(src=src, target=target) + + +def test_compile_in_subproc() -> None: + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(4)), ()) + + multiprocessing.set_start_method('fork') + proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_fn_dot(attrs, capability): + + @triton.jit + def kernel_dot(Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + tl.store(Z + offs, z) + + src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) + triton.compile(src=src, target=target) + + +def test_compile_in_forked_subproc() -> None: + reset_tmp_dir() + major, minor = torch.cuda.get_device_capability(0) + capability = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) + + assert multiprocessing.get_start_method() == 'fork' + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) + proc.start() + proc.join() + assert proc.exitcode == 0 diff --git a/third_party/iluvatar/python/triton/_C/include b/third_party/iluvatar/python/triton/_C/include new file mode 120000 index 000000000..b85a40983 --- /dev/null +++ b/third_party/iluvatar/python/triton/_C/include @@ -0,0 +1 @@ +../../../include/ \ No newline at end of file diff --git a/third_party/iluvatar/python/triton/__init__.py b/third_party/iluvatar/python/triton/__init__.py new file mode 100644 index 000000000..a5f77f91e --- /dev/null +++ b/third_party/iluvatar/python/triton/__init__.py @@ -0,0 +1,73 @@ +"""isort:skip_file""" +__version__ = '3.1.0' + +# --------------------------------------- +# Note: import order is significant here. + +# submodules +from .runtime import ( + autotune, + Config, + heuristics, + JITFunction, + KernelInterface, + reinterpret, + TensorWrapper, + OutOfResources, + InterpreterError, + MockTensor, +) +from .runtime.jit import jit +from .compiler import compile, CompilationError +from .errors import TritonError + +from . import language +from . import testing +from . import tools + +__all__ = [ + "autotune", + "cdiv", + "CompilationError", + "compile", + "Config", + "heuristics", + "impl", + "InterpreterError", + "jit", + "JITFunction", + "KernelInterface", + "language", + "MockTensor", + "next_power_of_2", + "ops", + "OutOfResources", + "reinterpret", + "runtime", + "TensorWrapper", + "TritonError", + "testing", + "tools", +] + +# ------------------------------------- +# misc. utilities that don't fit well +# into any specific module +# ------------------------------------- + + +def cdiv(x: int, y: int): + return (x + y - 1) // y + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/third_party/iluvatar/python/triton/backends b/third_party/iluvatar/python/triton/backends new file mode 120000 index 000000000..13a83a85c --- /dev/null +++ b/third_party/iluvatar/python/triton/backends @@ -0,0 +1 @@ +../../../../python/triton/backends \ No newline at end of file diff --git a/third_party/iluvatar/python/triton/compiler/__init__.py b/third_party/iluvatar/python/triton/compiler/__init__.py new file mode 100644 index 000000000..ce0cfedfc --- /dev/null +++ b/third_party/iluvatar/python/triton/compiler/__init__.py @@ -0,0 +1,4 @@ +from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict +from .errors import CompilationError + +__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/third_party/iluvatar/python/triton/compiler/code_generator.py b/third_party/iluvatar/python/triton/compiler/code_generator.py new file mode 100644 index 000000000..de7355e18 --- /dev/null +++ b/third_party/iluvatar/python/triton/compiler/code_generator.py @@ -0,0 +1,1308 @@ +import ast +import inspect +import re +import sys +import warnings +import os +import textwrap +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from .. import language +from .._C.libtriton import ir +from ..language import constexpr, tensor, str_to_ty +from ..runtime.jit import _normalize_ty +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) +from types import ModuleType + + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + assert False, "Unsupported type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return isinstance(o, constexpr) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _unwrap_if_constexpr(o: Any): + return o.value if isinstance(o, constexpr) else o + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +def _get_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + for s in body: + if self.visit(s): + return True + return False + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) == ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + self.context = context + self.builder = ir.builder(context) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + self.gscope = gscope + self.lscope = dict() + self.attributes = attributes + self.constants = constants + self.jit_fn = jit_fn + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.debug = options.debug if debug is None else debug + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + if a := self.gscope.get("__annotations__", {}).get(name): + return _normalize_ty(a) == "constexpr" + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if (val is absent # + or name in self.builtin_namespace # + or type(val) == ModuleType # + or isinstance(val, JITFunction) # + or getattr(val, "__triton_builtin__", False) # + or getattr(val, "__module__", "").startswith("triton.language") # + or isinstance(val, language.dtype) # + or self._is_constexpr_global(name) # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + or self.visiting_arg_default_value # + or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are annotated as constexpr (`x: triton.language.constexpr = 42` + or `x = triton.language.constexpr(42)`). Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + # ret_block = self.builder.create_block() + # post_ret_block = self.builder.create_block() + # self.builder.create_branch(ret_block) + # self.builder.set_insertion_point_to_end(ret_block) + if ret_value is None: + self.builder.ret([]) + ret_ty = language.void + elif isinstance(ret_value, tuple): + ret_values = [language.core._to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + ret_ty = tuple(ret_types) + else: + ret = language.core._to_tensor(ret_value, self.builder) + self.builder.ret([ret.handle]) + ret_ty = ret.type + # self.builder.create_branch(post_ret_block) + # self.builder.set_insertion_point_to_end(post_ret_block) + + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + for name, value in self.attributes[i]: + self.fn.set_arg_attr(idx, name, value) + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + # finalize function + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + # update return type + if isinstance(self.ret_type, tuple): + self.prototype.ret_types = list(self.ret_type) + self.fn.reset_type(self.prototype.to_ir(self.builder)) + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + # Remove dead code + self.fn.finalize() + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not _is_constexpr(value): + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise self._unsupported(node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not _is_list_like(names): + names = [names] + if not _is_list_like(values): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_tensor(value) and \ + not isinstance(value, native_nontensor_types): + value = language.core._to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + ret_types = [] + ir_ret_types = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + assert defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name].type}, '\ + f'but the {block_name} block redefines it as {defs[name].type}' + if name in then_defs or name in else_defs: + names.append(name) + ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in then_defs.keys() & else_defs.keys(): + if name in names: + continue + then_ty = then_defs[name].type + else_ty = else_defs[name].type + assert then_ty == else_ty, \ + f'mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + ret_types.append(then_ty) + ir_ret_types.append(then_defs[name].handle.get_type()) + + return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + + def visit_if_top_level(self, cond, node): + has_endif_block = True + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create basic-block after conditional + endif_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # then terminator + self.builder.set_insertion_point_to_end(then_block) + if then_block.has_return() and else_block.has_return(): + has_endif_block = False + endif_block.erase() + if not then_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + if not else_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + if has_endif_block: + for ty in ir_ret_types: + endif_block.add_argument(ty) + if has_endif_block: + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + # update values + for i, name in enumerate(names): + new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) + self.set_value(name, new_tensor) + + def visit_If(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if self.scf_stack and contains_return: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + elif self.scf_stack or not contains_return: + self.visit_if_scf(cond, node) + else: + self.visit_if_top_level(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.orelse) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = language.core._to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = language.core._to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) == ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) == ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_builder=self.builder) + try: + return getattr(operand, fn)() + except AttributeError: + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + assert _is_triton_tensor(loop_defs[name]), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop' + assert loop_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {loop_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + # these are loop-carried values + names.append(name) + ret_types.append(loop_defs[name].type) + init_args.append(liveins[name]) + + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + self.builder.set_insertion_point_to_start(before_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + yields.append(loop_defs[name]) + self.builder.create_yield_op([y.handle for y in yields]) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = language.core._to_tensor(lb, self.builder) + ub = language.core._to_tensor(ub, self.builder) + step = language.core._to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_undef(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' + assert _is_triton_tensor(liveins[name]) + assert self.local_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {self.local_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + names.append(name) + init_args.append(language.core._to_tensor(liveins[name], self.builder)) + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + if num_stages is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + + self.scf_stack.append(node) + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create YieldOp + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + if not self.debug: + return + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + # Convert assert to triton's device_assert which happens on the device + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = fn.__globals__ + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = _get_fn_file_line(fn) + debug = self.debug if fn.debug is None else fn.debug + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + if fn is language.core.device_assert: # TODO: this should not be so hardcoded + if not self.debug: + return + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = dict(_builder=self.builder) + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + return fn(*args, **extra_kwargs, **kws) + except Exception as e: + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, None) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + return fn(*args, **kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise self._unsupported( + node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + if sys.version_info < (3, 8): + + def visit_NameConstant(self, node): + return constexpr(node.value) + + def visit_Num(self, node): + return constexpr(node.n) + + def visit_Str(self, node): + return constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs): + if node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisible_by_16: + suffix += 'd' + if i in specialization.divisible_by_8: + suffix += 'e' + return suffix + + +def ast_to_ttir(fn, specialization, context, options, codegen_fns): + attrs = specialization.attrs + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in specialization.constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = fn.repr(specialization) + tys = list(specialization.signature.values()) + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1} + new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16} + for k in attrs.divisible_by_8: + attr = new_attrs[k] if k in new_attrs else [] + attr.append(("tt.max_divisibility", 8)) + new_attrs[k] = attr + + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + file_name, begin_line = _get_fn_file_line(fn) + + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns) + generator.visit(fn.parse()) + + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/third_party/iluvatar/python/triton/compiler/compiler.py b/third_party/iluvatar/python/triton/compiler/compiler.py new file mode 100644 index 000000000..ed7586280 --- /dev/null +++ b/third_party/iluvatar/python/triton/compiler/compiler.py @@ -0,0 +1,424 @@ +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir +from ..backends import backends +from ..backends.compiler import GPUTarget +from .. import __version__ +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.driver import driver +from ..runtime.jit import JITFunction +# TODO: this shouldn't be here +from dataclasses import dataclass +from .code_generator import ast_to_ttir +from pathlib import Path +import re +import functools +import os + + +@dataclass +class AttrsDescriptor: + divisible_by_16: set = None + equal_to_1: set = None + divisible_by_8: set = None + + def __post_init__(self): + if self.divisible_by_16 is None: + self.divisible_by_16 = set() + if self.equal_to_1 is None: + self.equal_to_1 = set() + if self.divisible_by_8 is None: + self.divisible_by_8 = set() + + def to_dict(self): + return { + 'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1), 'divisible_by_8': + list(self.divisible_by_8) + } + + @staticmethod + def from_dict(data): + return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', + [])), equal_to_1=set(data.get('equal_to_1', [])), + divisible_by_8=set(data.get('divisible_by_8', []))) + + def hash(self): + key = str([sorted(x) for x in self.__dict__.values()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+),?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if + # e.g. someone has an instruction (not module) attribute named "num-warps". + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + num_warps = int(num_warps_matches[0]) + return num_warps + + +class ASTSource: + + def __init__(self, fn, signature, constants=None, attrs=None) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.attrs = attrs + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + if self.constants is None: + self.constants = dict() + if self.attrs is None: + self.attrs = AttrsDescriptor() + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + # Note - we stringify the keys here to allow sorting to work for cases + # where constants have mixed int/str keys. + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def parse_options(self): + if self.ext == "ttgir": + return {'num_warps': _get_num_warps_from_ir_str(self.src)} + return dict() + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx": + return Path(full_name).read_text() + if ext == "cubin": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +def compile(src, target=None, options=None): + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + if not ir_source and isinstance(src.fn, JITFunction): + src.fn.hash_cache_file = hash + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + metadata_filename = f"{src.name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" + if not always_compile and metadata_path is not None: + # cache hit! + metadata = json.loads(Path(metadata_path).read_text()) + return CompiledKernel(src, metadata_group, hash) + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() + try: + module = src.make_ir(options, codegen_fns, context) + except Exception as e: + filter_traceback(e) + raise + use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1" + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + ir_filename = f"{src.name}.{ext}" + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)): + print(f"\nOverriding kernel with file {ir_filename}") + full_name = fn_override_manager.get_file(ir_filename) + next_module = parse(full_name, ext, context) + # use an env variable to parse ttgir from file + if use_ttgir_loc and ext == "ttgir": + ttgir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ttgir_full_name) + print(f"Create new locations for {ttgir_full_name}") + module = next_module + if not ir_source: + src.fn.so_path = driver.active.get_cache_path() + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target): + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self) -> None: + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = { + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + } + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + def _init_handles(self): + if self.module is not None: + return + device = driver.active.get_current_device() + # create launcher + self.run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + + def __getattribute__(self, name): + if name == 'run': + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if CompiledKernel.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + if i in self.src.fn.constexprs: + arg_dict[arg_name] = self.src.constants[arg_name] + else: + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) + + return runner diff --git a/third_party/iluvatar/python/triton/compiler/errors.py b/third_party/iluvatar/python/triton/compiler/errors.py new file mode 100644 index 000000000..39e6c4dfb --- /dev/null +++ b/third_party/iluvatar/python/triton/compiler/errors.py @@ -0,0 +1,51 @@ +import ast +from typing import Optional +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/third_party/iluvatar/python/triton/compiler/make_launcher.py b/third_party/iluvatar/python/triton/compiler/make_launcher.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/iluvatar/python/triton/errors.py b/third_party/iluvatar/python/triton/errors.py new file mode 100644 index 000000000..3a0a86355 --- /dev/null +++ b/third_party/iluvatar/python/triton/errors.py @@ -0,0 +1,5 @@ +"""Base class for all errors raised by Triton""" + + +class TritonError(Exception): + ... diff --git a/third_party/iluvatar/python/triton/language/__init__.py b/third_party/iluvatar/python/triton/language/__init__.py new file mode 100644 index 000000000..168dccfea --- /dev/null +++ b/third_party/iluvatar/python/triton/language/__init__.py @@ -0,0 +1,284 @@ +"""isort:skip_file""" +# Import order is significant here. + +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + sigmoid, + softmax, + sort, + sum, + swizzle2d, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + _experimental_descriptor_load, + _experimental_descriptor_store, + advance, + arange, + associative_scan, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + const, + const_pointer_type, + constexpr, + debug_barrier, + device_assert, + device_print, + dot, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + function_type, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + max_constancy, + max_contiguous, + maximum, + minimum, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + program_id, + range, + reduce, + reshape, + split, + static_assert, + static_print, + static_range, + store, + tensor, + trans, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) + +__all__ = [ + "PropagateNan", + "TRITON_MAX_TENSOR_NUMEL", + "_experimental_descriptor_load", + "_experimental_descriptor_store", + "abs", + "advance", + "arange", + "argmax", + "argmin", + "associative_scan", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "block_type", + "broadcast", + "broadcast_to", + "builtin", + "cat", + "cast", + "cdiv", + "ceil", + "clamp", + "const", + "const_pointer_type", + "constexpr", + "cos", + "cumprod", + "cumsum", + "debug_barrier", + "device_assert", + "device_print", + "div_rn", + "dot", + "dtype", + "erf", + "exp", + "exp2", + "expand_dims", + "extra", + "fdiv", + "flip", + "float16", + "float32", + "float64", + "float8e4b15", + "float8e4nv", + "float8e4b8", + "float8e5", + "float8e5b16", + "floor", + "fma", + "full", + "function_type", + "histogram", + "inline_asm_elementwise", + "interleave", + "int1", + "int16", + "int32", + "int64", + "int8", + "ir", + "join", + "load", + "log", + "log2", + "make_block_ptr", + "math", + "max", + "max_constancy", + "max_contiguous", + "maximum", + "min", + "minimum", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "permute", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "range", + "ravel", + "reduce", + "reshape", + "rsqrt", + "sigmoid", + "sin", + "softmax", + "sort", + "split", + "sqrt", + "sqrt_rn", + "static_assert", + "static_print", + "static_range", + "store", + "sum", + "swizzle2d", + "tensor", + "trans", + "triton", + "uint16", + "uint32", + "uint64", + "uint8", + "uint_to_uniform_float", + "umulhi", + "view", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] + + +def str_to_ty(name): + if name[0] == "*": + name = name[1:] + if name[0] == "k": + name = name[1:] + ty = str_to_ty(name) + return const_pointer_type(ty) + ty = str_to_ty(name) + return pointer_type(ty) + tys = { + "fp8e4nv": float8e4nv, + "fp8e4b8": float8e4b8, + "fp8e5": float8e5, + "fp8e5b16": float8e5b16, + "fp8e4b15": float8e4b15, + "fp16": float16, + "bf16": bfloat16, + "fp32": float32, + "fp64": float64, + "i1": int1, + "i8": int8, + "i16": int16, + "i32": int32, + "i64": int64, + "u1": int1, + "u8": uint8, + "u16": uint16, + "u32": uint32, + "u64": uint64, + "B": int1, + } + return tys[name] diff --git a/third_party/iluvatar/python/triton/language/core.py b/third_party/iluvatar/python/triton/language/core.py new file mode 100644 index 000000000..f2d3266e9 --- /dev/null +++ b/third_party/iluvatar/python/triton/language/core.py @@ -0,0 +1,2621 @@ +from __future__ import annotations + +from warnings import warn +from contextlib import contextmanager +from enum import Enum +from functools import partial, wraps +import typing +from typing import Union, Callable, List, Sequence, TypeVar, Optional +import builtins +from ..runtime.jit import jit +import inspect +import os + +from .._C.libtriton import ir +from . import semantic + +T = TypeVar('T') + +TRITON_MAX_TENSOR_NUMEL = 1048576 + +TRITON_BUILTIN = "__triton_builtin__" + +PropagateNan = ir.PROPAGATE_NAN + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def _tensor_member_fn(fn: T) -> T: + """Decorator that adds this free function as a member fn on class tensor. + + When called as a member function on class tensor, the first argument to `fn` + is `self`, i.e. the tensor object. + + If there are multiple decorators on a function, you probably want this one + to be the highest one (i.e. furthest from the function's `def`), so it's + applied last. + + Unfortunately you still need to add a type stub to the body of class tensor + in order for pytype to know about it. + """ + assert callable(fn) + orig_sig = inspect.signature(fn) + # Does fn take args other than _builder, _generator, and the tensor itself? + has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1 + + if not fn.__doc__: + fn.__doc__ = "" + fn.__doc__ += f""" + This function can also be called as a member function on :py:class:`tensor`, + as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of + :code:`{fn.__name__}(x{", ..." if has_args else ""})`. + """ + + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Match the signature of `fn`, but change the first arg to `self` so the + # docs are a little less weird. + new_params = list(orig_sig.parameters.values()) + new_params[0] = new_params[0].replace(name='self') + new_sig = orig_sig.replace(parameters=new_params) + wrapper.__signature__ = new_sig + wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function" + # If fn is a builtin, mark the wrapper as a builtin too. + if is_builtin(fn): + setattr(wrapper, TRITON_BUILTIN, True) + + setattr(tensor, fn.__name__, wrapper) + return fn + + +def _unwrap_iterable(x): + """Returns x[0] if x has one element and x[0] is iterable.""" + if len(x) == 1: + # Determine whether x[0] is iterable. + # + # You might want to use collections.abc.Iterable instead of this + # try/except block. Unfortunately, this doesn't work with constexpr. + # + # The problem is that abc.Iterable checks for __iter__ on the *class*. + # But we want constexpr to expose an __iter__ method if and only if the + # wrapped *object* (i.e. self.value) is iterable. Therefore there's no + # right answer for whether the class constexpr defines __iter__, and + # abc.Iterable doesn't work (at least not without some metaclass magic). + try: + iter(x[0]) + return x[0] + except TypeError: + pass + + return x + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +@builtin +def to_tensor(x, _builder=None): + return _to_tensor(x, _builder) + + +def _to_tensor(x, builder): + if isinstance(x, bool): + return tensor(builder.get_int1(x), int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + return tensor(builder.get_int32(x), int32) + elif 2**31 <= x < 2**32: + return tensor(builder.get_uint32(x), uint32) + elif -2**63 <= x < 2**63: + return tensor(builder.get_int64(x), int64) + elif 2**63 <= x < 2**64: + return tensor(builder.get_uint64(x), uint64) + else: + raise RuntimeError(f'Nonrepresentable integer {x}.') + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + return tensor(builder.get_fp32(x), float32) + else: + return tensor(builder.get_fp64(x), float64) + + elif isinstance(x, constexpr): + return _to_tensor(x.value, builder) + elif isinstance(x, tensor): + return x + assert False, f"cannot convert {x} of type {type(x)} to tensor" + + +class dtype: + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + def __init__(self, name): + if hasattr(name, 'value'): + name = name.value + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 53 + self.primitive_bitwidth = 64 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + +# Some functions have a param named `dtype`, which shadows the `dtype` class. +# We can't change the param name because it is part of function's public API. +# Declare an alias so those functions can still reference the dtype class. +_DtypeClass = dtype + + +class pointer_type(dtype): + + def __init__(self, element_ty: dtype, address_space: int = 1): + if not isinstance(element_ty, dtype): + raise TypeError(f'element_ty is a {type(element_ty).__name__}.') + self.element_ty = element_ty + self.address_space = address_space + + self.name = f'pointer<{element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class const_pointer_type(pointer_type): + + def __init__(self, element_ty: dtype, address_space: int = 1): + super().__init__(element_ty, address_space) + + def __str__(self): + return f'const_pointer<{self.element_ty}>' + + def is_const(self): + return True + + def __eq__(self, other) -> bool: + if not isinstance(other, const_pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + + # shape can be empty ([]) when an input is a 0D tensor. + if not shape: + raise TypeError('0d block_type is forbidden') + if isinstance(shape[0], constexpr): + shape = [s.value for s in shape] + + self.shape = shape + self.numel = 1 + for s in self.shape: + self.numel *= s + if self.numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: + self.ret_types = ret_types + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_types}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(ir_param_types, ret_types) + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') +float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') +float8e4b15 = dtype('fp8e4b15') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _constexpr_to_value + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _constexpr_to_value(other)) + + def __radd__(self, other): + return constexpr(_constexpr_to_value(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _constexpr_to_value(other)) + + def __rsub__(self, other): + return constexpr(_constexpr_to_value(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _constexpr_to_value(other)) + + def __mod__(self, other): + return constexpr(self.value % _constexpr_to_value(other)) + + def __rmul__(self, other): + return constexpr(_constexpr_to_value(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _constexpr_to_value(other)) + + def __rtruediv__(self, other): + return constexpr(_constexpr_to_value(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _constexpr_to_value(other)) + + def __rfloordiv__(self, other): + return constexpr(_constexpr_to_value(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _constexpr_to_value(other)) + + def __rgt__(self, other): + return constexpr(_constexpr_to_value(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _constexpr_to_value(other)) + + def __rge__(self, other): + return constexpr(_constexpr_to_value(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _constexpr_to_value(other)) + + def __rlt__(self, other): + return constexpr(_constexpr_to_value(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _constexpr_to_value(other)) + + def __rle__(self, other): + return constexpr(_constexpr_to_value(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _constexpr_to_value(other)) + + def __ne__(self, other): + return constexpr(self.value != _constexpr_to_value(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _constexpr_to_value(other)) + + def logical_and(self, other): + return constexpr(self.value and _constexpr_to_value(other)) + + def __or__(self, other): + return constexpr(self.value | _constexpr_to_value(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _constexpr_to_value(other)) + + def logical_or(self, other): + return constexpr(self.value or _constexpr_to_value(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_constexpr_to_value(other)) + + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _constexpr_to_value(other)) + + def __lshift__(self, other): + return constexpr(self.value << _constexpr_to_value(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +CONSTEXPR_0 = constexpr(0) + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +class tensor: + """Represents an N-dimensional array of values or pointers. + + :code:`tensor` is the fundamental data structure in Triton programs. Most + functions in :py:mod:`triton.language` operate on and return tensors. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + :code:`tensor` also defines most of the magic/dunder methods, so you can + write :code:`x+y`, :code:`x << 2`, etc. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, type: dtype): + """Not called by user code.""" + # IR handle + self.handle = handle + # Block shape + self.shape = type.shape if type.is_block() else () + self.numel = 1 + for s in self.shape: + self.numel *= s + self.numel = constexpr(self.numel) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = [constexpr(s) for s in self.shape] + + def __str__(self) -> str: + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.add(self, other, _builder) + + @builtin + def __radd__(self, other, _builder=None): + return self.__add__(other, _builder=_builder) + + @builtin + def __sub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(self, other, _builder) + + @builtin + def __rsub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(other, self, _builder) + + @builtin + def __mul__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mul(self, other, _builder) + + @builtin + def __rmul__(self, other, _builder=None): + return self.__mul__(other, _builder=_builder) + + @builtin + def __truediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(self, other, _builder) + + @builtin + def __rtruediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(other, self, _builder) + + @builtin + def __floordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(self, other, _builder) + + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(other, self, _builder) + + @builtin + def __mod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(self, other, _builder) + + @builtin + def __rmod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(other, self, _builder) + + # unary operators + @builtin + def __neg__(self, _builder=None): + return semantic.minus(self, _builder) + + @builtin + def __invert__(self, _builder=None): + return semantic.invert(self, _builder) + + # bitwise operators + + @builtin + def __and__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(self, other, _builder) + + @builtin + def __rand__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(other, self, _builder) + + @builtin + def __or__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(self, other, _builder) + + @builtin + def __ror__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(other, self, _builder) + + @builtin + def __xor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(self, other, _builder) + + @builtin + def __rxor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(other, self, _builder) + + @builtin + def __lshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + return semantic.shl(self, other, _builder) + + @builtin + def __rlshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + return semantic.shl(other, self, _builder) + + @builtin + def __rshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + + @builtin + def __rrshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(other, self, _builder) + else: + return semantic.lshr(other, self, _builder) + + # > + @builtin + def __gt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) + + @builtin + def __rgt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) + + # >= + @builtin + def __ge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + + @builtin + def __rge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) + + # < + @builtin + def __lt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) + + @builtin + def __rlt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(other, self, _builder) + + # <= + @builtin + def __le__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) + + @builtin + def __rle__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) + + # == + @builtin + def __eq__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(self, other, _builder) + + @builtin + def __req__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(other, self, _builder) + + @builtin + def __ne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) + + @builtin + def __rne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(other, self, _builder) + + @builtin + def logical_and(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_and(self, other, _builder) + + @builtin + def logical_or(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_or(self, other, _builder) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _builder=None): + return semantic.not_(self, _builder) + + @builtin + def __getitem__(self, slices, _builder=None): + if isinstance(slices, (slice, constexpr)) or slices is None: + slices = [slices] + ret = self + for dim, sl in enumerate(slices): + if sl is None or isinstance(sl, constexpr) and sl.value is None: + ret = semantic.expand_dims(ret, dim, _builder) + elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + pass + else: + raise ValueError(f"unsupported tensor index: {sl}") + return ret + + @property + def T(self): + """Transposes a 2D tensor.""" + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Alias for :py:func:`tensor.cast`. + """ + # Triton doesn't like core functions calling other core functions, so we + # just copy-paste the implementation of cast here. It's not too bad. + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder, fp_downcast_rounding) + + # Type stubs for functions added by the _tensor_member_fn decorator. + # (Unfortunately these can't be created automatically.) + # + # We couldn't write these definitions out even if we wanted to, because some + # of these functions are defined in standard.py. + def broadcast_to(self, *shape) -> tensor: + ... + + def trans(self, *dims) -> tensor: + ... + + def permute(self, *dims) -> tensor: + ... + + def split(self) -> tuple[tensor, tensor]: + ... + + def view(self, *shape) -> tensor: + ... + + def reshape(self, *shape) -> tensor: + ... + + def expand_dims(self, axis) -> tensor: + ... + + def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor: + ... + + def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor: + ... + + def advance(self, offsets) -> tensor: + ... + + def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor: + ... + + def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def exp(self) -> tensor: + ... + + def log(self) -> tensor: + ... + + def cos(self) -> tensor: + ... + + def sin(self) -> tensor: + ... + + def sqrt(self) -> tensor: + ... + + def rsqrt(self) -> tensor: + ... + + def abs(self) -> tensor: + ... + + def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: + ... + + def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: + ... + + def histogram(self, num_bins) -> tensor: + ... + + def cdiv(self, div) -> tensor: + ... + + def sigmoid(self) -> tensor: + ... + + def softmax(self, ieee_rounding=False) -> tensor: + ... + + def ravel(self) -> tensor: + ... + + def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def xor_sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def cumsum(self, axis=0, reverse=False) -> tensor: + ... + + def cumprod(self, axis=0, reverse=False) -> tensor: + ... + + def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor: + ... + + def flip(self, dim=None) -> tensor: + ... + + +def get_bool_env_var(var_name): + v = os.getenv(var_name, "0") + return v == "1" or v == "true" or v == "on" + + +# ----------------------- +# SPMD Programming Model +# ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v + + +@builtin +def program_id(axis, _builder=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(0, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) + + +@builtin +def num_programs(axis, _builder=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _builder=None): + """ + Returns contiguous values within the half-open interval :code:`[start, + end)`. :code:`end - start` must be less than or equal to + :code:`TRITON_MAX_TENSOR_NUMEL = 131072` + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two greater than + :code:`start`. + :type end: int32 + """ + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) + + +def _shape_check_impl(shape): + shape = _constexpr_to_value(shape) + for i, d in enumerate(shape): + if isinstance(d, int): + d = constexpr(d) + if not isinstance(d, constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + if d.value & (d.value - 1) != 0: + raise ValueError(f"Shape element {i} must be a power of 2") + return [_constexpr_to_value(x) for x in shape] + + +@builtin +def full(shape, value, dtype, _builder=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :value value: A scalar value to fill the array with + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + shape = _shape_check_impl(shape) + value = _constexpr_to_value(value) + dtype = _constexpr_to_value(dtype) + return semantic.full(shape, value, dtype, _builder) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _builder=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return semantic.broadcast_impl_value(input, other, _builder) + + +@_tensor_member_fn +@builtin +def broadcast_to(input, *shape, _builder=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + broadcast_to(x, (32, 32)) + broadcast_to(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.broadcast_impl_shape(input, shape, _builder) + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + If no permutation is specified, tries to do a (1,0) permutation, i.e. tries + to transpose a 2D tensor. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + if not dims: + dims = (1, 0) + return semantic.permute(input, dims, _builder) + + +@_tensor_member_fn +@builtin +def permute(input, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + :param input: The input tensor. + :type input: Block + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + permute(x, (2, 1, 0)) + permute(x, 2, 1, 0) + + :py:func:`trans` is equivalent to this function, except when + :code:`dims` is empty, it tries to do a (1,0) permutation. + """ + dims = _unwrap_iterable(dims) + return semantic.permute(input, dims, _builder) + + +@builtin +def cat(input, other, can_reorder=False, _builder=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: + :param other: The second input tensor. + :type other: + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops) + """ + return semantic.cat(input, other, can_reorder, _builder) + + +@builtin +def join(a, b, _builder=None): + """ + Join the given tensors in a new, minor dimension. + + For example, given two tensors of shape (4,8), produces a new tensor of + shape (4,8,2). Given two scalars, returns a tensor of shape (2). + + The two inputs are broadcasted to be the same shape. + + If you want to join more than two elements, you can use multiple calls to + this function. This reflects the constraint in Triton that tensors must + have power-of-two sizes. + + join is the inverse of split. + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + return semantic.join(a, b, _builder) + + +@jit +def _take_first(a, b): + return a + + +@_tensor_member_fn +@builtin +def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: + """ + Split a tensor in two along its last dim, which must have size 2. + + For example, given a tensor of shape (4,8,2), produces two tensors of shape + (4,8). Given a tensor of shape (2), returns two scalars. + + If you want to split into more than two pieces, you can use multiple calls + to this function (probably plus calling reshape). This reflects the + constraint in Triton that tensors must have power-of-two sizes. + + split is the inverse of join. + + :param a: The tensor to split. + :type a: Tensor + """ + # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars. + # But semantic.split can only handle returning tensors. Work around this by + # expanding the input to shape [1,2] and then reducing the result. + was_rank_1 = len(a.shape) == 1 + if was_rank_1: + a = semantic.expand_dims(a, 0, _builder) + + out_lhs, out_rhs = semantic.split(a, _builder) + + if was_rank_1: + # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar. + out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator)) + out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator)) + + return out_lhs, out_rhs + + +@_tensor_member_fn +@builtin +def view(input, *shape, _builder=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + view(x, (32, 32)) + view(x, 32, 32) + """ + warn("view is deprecated, please use reshape with can_reorder being true.") + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder=True, builder=_builder) + + +@_tensor_member_fn +@builtin +def reshape(input, *shape, can_reorder=False, _builder=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: Block + :param shape: The new shape. + + :code:`shape ` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + reshape(x, (32, 32)) + reshape(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder, _builder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@_tensor_member_fn +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + input = _to_tensor(input, _builder) + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, Sequence) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + """ + input = _to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + return semantic.cast(input, dtype, _builder, fp_downcast_rounding) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _builder=None): + """ + Returns the matrix product of two blocks. + + The two blocks must be two-dimensional and have compatible inner dimensions. + + :param input: The first tensor to be multiplied. + :type input: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input_precision: How to exercise the Tensor Cores for f32 x f32. If + the device does not have Tensor Cores or the inputs are not of dtype f32, + this option is ignored. For devices that do have tensor cores, the + default precision is tf32. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + Only one of :code:`input_precision` and :code:`allow_tf32` can be + specified (i.e. at least one must be :code:`None`). + """ + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + if input_precision is None: + supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions + default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) + + input_precision = _constexpr_to_value(input_precision) + out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, _builder=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be None, and + - `boundary_check` and `padding_option` can be specified to control + the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + # `mask` and `other` can be constexpr + mask = _constexpr_to_value(mask) + other = _constexpr_to_value(other) + if mask is not None: + mask = _to_tensor(mask, _builder) + if other is not None: + other = _to_tensor(other, _builder) + padding_option = _constexpr_to_value(padding_option) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile, _builder) + + +@builtin +def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None): + """ + Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This loads a tensor of data based on the descriptor and offsets. + """ + type = block_type(dtype, shape) + return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) + + +@builtin +def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): + """ + Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This stores a tensor of data based on the descriptor and offsets. + """ + return semantic.descriptor_store(desc_pointer, value, offsets, _builder) + + +@_tensor_member_fn +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): + """ + Store a tensor of data into memory locations defined by `pointer`. + + (1) If `pointer` is a single element pointer, a scalar is stored. In + this case: + + - `mask` must also be scalar, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional block is stored. In this case: + + - `mask` is implicitly broadcast to `pointer.shape`, and + - `boundary_check` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block + of data is stored. In this case: + + - `mask` must be None, and + - `boundary_check` can be specified to control the behavior of out-of-bound access. + + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + """ + # `value` can be constexpr + value = _to_tensor(value, _builder) + mask = _constexpr_to_value(mask) + if mask is not None: + mask = _to_tensor(mask, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder) + + +@_tensor_member_fn +@builtin +def advance(base, offsets, _builder=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return semantic.advance(base, offsets, _builder) + + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = f""" + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default), + "ACQUIRE", "RELEASE", or "RELAXED") + :type sem: str + :param scope: Scope of threads that observe synchronizing effect of the + atomic operation ("GPU" (default), "CTA", or "SYSTEM") + :type scope: str + """ + func.__doc__ = docstr + return func + + return _decorator + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("compare-and-swap", has_cmp=True) +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_add(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_max(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_min(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_and(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_or(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, _builder=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = _to_tensor(condition, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.where(condition, x, y, _builder) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.minimum(x, y, propagate_nan, _builder) + + +@builtin +def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.maximum(x, y, propagate_nan, _builder) + + +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + min = _to_tensor(min, _builder) + max = _to_tensor(max, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + min = _promote_bfloat16_to_float32(min, _builder=_builder) + max = _promote_bfloat16_to_float32(max, _builder=_builder) + + propagate_nan = _constexpr_to_value(propagate_nan) + + return semantic.clamp(x, min, max, propagate_nan, _builder) + + +# ----------------------- +# Reductions +# ----------------------- + + +def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the reduction should be done + :param keep_dims: if true, keep the reduced dimensions with length 1""" + if return_indices_arg is not None: + docstr += f""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value""" + if tie_break_arg is not None: + docstr += f""" + :param {tie_break_arg}: if true, return the left-most indices in case of ties for values that aren't NaN""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@_tensor_member_fn +@builtin +def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param keep_dims: if true, keep the reduced dimensions with length 1 + + """ + if isinstance(input, tensor): + return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) + + def expand_ndims(t, ndims): + for _ in builtins.range(ndims): + t = expand_dims(t, 0, _builder=_builder) + return t + + axis = _constexpr_to_value(axis) + keep_dims = _constexpr_to_value(keep_dims) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + ret = semantic.reduction(input, axis, make_combine_region, _builder) + if keep_dims: + if axis is not None: + ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret) + else: + ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + return ret + + +@builtin +def _promote_bfloat16_to_float32(t, _builder=None): + scalar_ty = t.type.scalar + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) + return t + + +@builtin +def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, + _generator=_generator) + return rvalue, rindices + + +# ----------------------- +# Scans +# ----------------------- + + +def _add_scan_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the scan should be done""" + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@_tensor_member_fn +@builtin +def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None): + """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param reverse: apply the associative scan in the reverse direction along axis. + + """ + if isinstance(input, tensor): + return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(scan_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = scan_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_scan_ret(*handles) + + axis = _constexpr_to_value(axis) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder) + + +@_tensor_member_fn +@builtin +def histogram(input, num_bins, _builder=None, _generator=None): + """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. + + :param input: the input tensor + :param num_bins: number of histogram bins + + """ + num_bins = _constexpr_to_value(num_bins) + return semantic.histogram(input, num_bins, _builder) + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_builder=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return semantic.debug_barrier(_builder) + + +@builtin +def multiple_of(input, values, _builder=None): + """ + Let the compiler know that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) + + +@builtin +def max_constancy(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are constant. + + e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal, + for example [0, 0, 0, 0, 1, 1, 1, 1]. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_constancy(input, values) + + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + + NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`, + which has special requirements for the arguments. + + .. highlight:: python + .. code-block:: python + + tl.static_print(f"{BLOCK_SIZE=}") + ''' + pass + + +@builtin +def static_assert(cond, msg="", _builder=None): + ''' + Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable + is set. + + .. highlight:: python + .. code-block:: python + + tl.static_assert(BLOCK_SIZE == 1024) + ''' + pass + + +@builtin +def device_print(prefix, *args, hex=False, _builder=None): + ''' + Print the values at runtime from the device. String formatting does not work for runtime values, so you should + provide the values you want to print as arguments. The first value must be a string, all following values must + be scalars or tensors. + + Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match + this function (not the normal requirements for :code:`print`). + + .. highlight:: python + .. code-block:: python + + tl.device_print("pid", pid) + print("pid", pid) + + On CUDA, printfs are streamed through a buffer of limited size (on one host, + we measured the default as 6912 KiB, but this may not be consistent across + GPUs and CUDA versions). If you notice some printfs are being dropped, you + can increase the buffer size by calling + + .. highlight:: python + .. code-block:: python + + triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes) + + CUDA may raise an error if you try to change this value after running a + kernel that uses printfs. The value set here may only affect the current + device (so if you have multiple GPUs, you'd need to call it multiple times). + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + :param hex: print all values as hex instead of decimal + ''' + import string + prefix = _constexpr_to_value(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_to_tensor(arg, _builder)) + return semantic.device_print(prefix, new_args, hex, _builder) + + +@builtin +def device_assert(cond, msg="", _builder=None): + ''' + Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG` + is set to a value besides :code:`0` in order for this to have any effect. + + Using the Python :code:`assert` statement is the same as calling this function, except that the second argument + must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must + be set for this :code:`assert` statement to have any effect. + + .. highlight:: python + .. code-block:: python + + tl.device_assert(pid == 0) + assert pid == 0, f"pid != 0" + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _constexpr_to_value(msg) + import inspect + frame = inspect.currentframe() + module = inspect.getmodule(frame) + # The triton function module doesn't have the name attribute. + # We use this trick to find the caller. + while hasattr(module, "__name__"): + frame = frame.f_back + module = inspect.getmodule(frame) + lineno = 0 + func_name = 'unknown' + file_name = 'unknown' + if frame is not None and frame.f_back is not None: + func_name = frame.f_code.co_name + file_name = frame.f_back.f_code.co_filename + # TODO: The line number currently indicates the line + # where the triton function is called but not where the + # device_assert is called. Need to enhance this. + lineno = frame.f_back.f_lineno + return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder) + + +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]], + is_pure: bool, pack: int, _builder=None): + ''' + Execute inline assembly over a tensor. Essentially, this is :code:`map` + where the function is inline assembly. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + :code:`dtype` can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes :code:`pack` elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty :code:`dtype` -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Example using + [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + assembly: + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + :param asm: assembly to run. Must match target's assembly format. + :param constraints: asm constraints in + [LLVM format](https://llvm.org/docs/LangRef.html#inline-asm-constraint-string) + :param args: the input tensors, whose values are passed to the asm block + :param dtype: the element type(s) of the returned tensor(s) + :param is_pure: if true, the compiler assumes the asm block has no side-effects + :param pack: the number of elements to be processed by one instance of inline assembly + :param _builder: the builder + :return: one tensor or a tuple of tensors of the given dtypes + ''' + asm = _constexpr_to_value(asm) + constraints = _constexpr_to_value(constraints) + pack = _constexpr_to_value(pack) + is_pure = _constexpr_to_value(is_pure) + + # Wrap `dtype` in a tuple if it's not already. + try: + iter(dtype) # type: ignore + has_multiple_outputs = True + except TypeError: + has_multiple_outputs = False + dtype = (dtype, ) # type: ignore + + dtype = typing.cast(Sequence[_DtypeClass], dtype) + + res_tys = dtype + if dispatch_args := [_to_tensor(arg, _builder) for arg in args]: + bin_op_type_checking = partial( + semantic.binary_op_type_checking_impl, + builder=_builder, + arithmetic_check=False, + allow_lhs_ptr=True, + allow_rhs_ptr=True, + ) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = bin_op_type_checking(item, broadcast_arg) + if broadcast_arg.shape: + # Change the shape of each argument based on the broadcast shape + for i, item in enumerate(dispatch_args): + dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg) + res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype] + handles = [t.handle for t in dispatch_args] + call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack) + + if not has_multiple_outputs: + return tensor(call.get_result(0), res_tys[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr) + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr) + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + assert isinstance(arg2, constexpr) + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +class range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + +# ----------------------- +# Extern functions +# ----------------------- + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _builder=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = block_type(ret_type, ret_shape) + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) + + +@builtin +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _builder=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + ret_shape = None + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _to_tensor(dispatch_args[i], _builder) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if len(arg_types) > 0: + arg_types = tuple(arg_types) + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in builtins.range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + if not all_scalar: + ret_shape = broadcast_arg.shape + func = _builder.create_extern_elementwise + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder) + + +def binary_op_type_legalization(lhs, rhs, builder): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs, builder) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) diff --git a/third_party/iluvatar/python/triton/language/extra/__init__.py b/third_party/iluvatar/python/triton/language/extra/__init__.py new file mode 100644 index 000000000..14e1778d2 --- /dev/null +++ b/third_party/iluvatar/python/triton/language/extra/__init__.py @@ -0,0 +1,4 @@ +from . import cuda +from . import hip + +__all__ = ['cuda', 'hip'] diff --git a/third_party/iluvatar/python/triton/language/extra/cuda/__init__.py b/third_party/iluvatar/python/triton/language/extra/cuda/__init__.py new file mode 100644 index 000000000..45fd25b65 --- /dev/null +++ b/third_party/iluvatar/python/triton/language/extra/cuda/__init__.py @@ -0,0 +1,11 @@ +from . import libdevice + +from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80) + +__all__ = [ + "libdevice", + #"globaltimer", + "num_threads", "num_warps", + #"smid", + "convert_custom_float8_sm70", "convert_custom_float8_sm80" +] diff --git a/python/triton/language/extra/cuda/libdevice.py b/third_party/iluvatar/python/triton/language/extra/cuda/libdevice.py similarity index 100% rename from python/triton/language/extra/cuda/libdevice.py rename to third_party/iluvatar/python/triton/language/extra/cuda/libdevice.py diff --git a/python/triton/language/extra/cuda/utils.py b/third_party/iluvatar/python/triton/language/extra/cuda/utils.py similarity index 100% rename from python/triton/language/extra/cuda/utils.py rename to third_party/iluvatar/python/triton/language/extra/cuda/utils.py diff --git a/third_party/iluvatar/python/triton/language/extra/hip/__init__.py b/third_party/iluvatar/python/triton/language/extra/hip/__init__.py new file mode 100644 index 000000000..229b57d87 --- /dev/null +++ b/third_party/iluvatar/python/triton/language/extra/hip/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/python/triton/language/extra/hip/libdevice.py b/third_party/iluvatar/python/triton/language/extra/hip/libdevice.py similarity index 100% rename from python/triton/language/extra/hip/libdevice.py rename to third_party/iluvatar/python/triton/language/extra/hip/libdevice.py diff --git a/third_party/iluvatar/python/triton/language/extra/libdevice.py b/third_party/iluvatar/python/triton/language/extra/libdevice.py new file mode 100644 index 000000000..625cf3957 --- /dev/null +++ b/third_party/iluvatar/python/triton/language/extra/libdevice.py @@ -0,0 +1,1213 @@ +from .cuda import libdevice as cuda_libdevice +from .hip import libdevice as hip_libdevice +from triton.language import core +from functools import wraps +from typing import TypeVar + +T = TypeVar('T') + + +def dispatch(fn: T) -> T: + """Dispatch a function to a correct implementation.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + _backend = kwargs["_builder"].options.backend_name + if _backend == 'cuda': + _curr_libdevice_module = cuda_libdevice + elif _backend == 'hip': + _curr_libdevice_module = hip_libdevice + else: + raise RuntimeError('unknown backend') + + try: + _impl = getattr(_curr_libdevice_module, fn.__name__) + except AttributeError: + raise RuntimeError(f'`{_backend}` does not provide support for `{fn.__name__}` extra function') + + return _impl(*args, **kwargs) + + return wrapper + + +@core.extern +@dispatch +def clz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def popc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def byte_perm(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def mulhi(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul24(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def brev(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sad(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def abs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def floor(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp64h(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ceil(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def trunc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def saturatef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rn(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rz(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rd(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_ru(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fast_dividef(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def add_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hiloint2double(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2loint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2hiint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_int(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_uint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def longlong_as_double(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double_as_longlong(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_sinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_cosf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log2f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_logf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_expf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_tanf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_exp10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_powf(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def hadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ffs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llrint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nearbyint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isnan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def signbit(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def copysign(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def finitef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nextafter(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinpi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cospi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atan2(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def atan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log1p(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def expm1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def norm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def cbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def yn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def jn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcx(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def lgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ldexp(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def scalbn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fmod(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def remainder(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fma(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def pow(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def tgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def round(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llround(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fdim(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def ilogb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def logb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isfinited(arg0, _builder=None): + ... diff --git a/third_party/iluvatar/python/triton/language/math.py b/third_party/iluvatar/python/triton/language/math.py new file mode 100644 index 000000000..de5b5be6b --- /dev/null +++ b/third_party/iluvatar/python/triton/language/math.py @@ -0,0 +1,250 @@ +from . import core +from . import semantic +from functools import wraps +from typing import List + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x` and :code:`y`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + :param z: the input values + :type z: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + + +@core.builtin +@_add_math_1arg_docstr("absolute value") +@core._tensor_member_fn +def abs(x, _builder=None): + x = core._to_tensor(x, _builder) + dtype = x.dtype + if dtype.is_fp8e4b15(): + mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder) + return core.tensor(_builder.create_and(x.handle, mask.handle), x.type) + elif dtype.is_floating(): + return core.tensor(_builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return core.tensor(_builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +@core.builtin +@_add_math_2arg_docstr("fast division") +def fdiv(x, y, ieee_rounding=False, _builder=None): + ieee_rounding = core._constexpr_to_value(ieee_rounding) + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + return semantic.fdiv(x, y, ieee_rounding, _builder) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest)") +def div_rn(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) + + +@core.builtin +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + z = core._to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) diff --git a/third_party/iluvatar/python/triton/language/random.py b/third_party/iluvatar/python/triton/language/random.py new file mode 100644 index 000000000..430aeb09e --- /dev/null +++ b/third_party/iluvatar/python/triton/language/random.py @@ -0,0 +1,207 @@ +from ..runtime.jit import jit +from . import core as tl +from . import math + +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +# ------------------- +# randint +# ------------------- + + +@jit +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + if c0.dtype == tl.uint32: + PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 + PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 + PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 + else: + tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl") + PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B + PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93 + PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157 + + for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = math.umulhi(B, _c2) ^ c1 ^ k0 + c2 = math.umulhi(A, _c0) ^ c3 ^ k1 + c1 = B * _c2 + c3 = A * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A + k1 = k1 + PHILOX_KEY_B + return c0, c1, c2, c3 + + +@jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = tl.to_tensor(seed) + c0 = tl.to_tensor(c0) + c1 = tl.to_tensor(c1) + c2 = tl.to_tensor(c2) + c3 = tl.to_tensor(c3) + seed = seed.to(tl.uint64) + if tl.constexpr(c0.dtype.primitive_bitwidth) == 32: + int_dtype = tl.uint32 + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + else: + tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox") + int_dtype = tl.uint64 + seed_hi = tl.full((1, ), 0, dtype=int_dtype) + seed_lo = seed + c0 = c0.to(int_dtype, bitcast=True) + c1 = c1.to(int_dtype, bitcast=True) + c2 = c2.to(int_dtype, bitcast=True) + c3 = c3.to(int_dtype, bitcast=True) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +@jit +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offset: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +@jit +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + # _0 = tl.zeros(offset.shape, offset.dtype) + _0 = offset * 0 + return philox(seed, offset, _0, _0, _0, n_rounds) + + +# ------------------- +# rand +# ------------------- + +# @jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + + +@jit +def uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + # TODO: fix frontend issues and cleanup + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32): + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + x = x.to(tl.int32, bitcast=True) + scale = 4.6566127342e-10 + else: + tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)) + x = x.to(tl.int64, bitcast=True) + scale = 1.0842020432385337e-19 + x = tl.where(x < 0, -x - 1, x) + return x * scale + + +@jit +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset, n_rounds) + return uint_to_uniform_float(source) + + +@jit +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + u3 = uint_to_uniform_float(i3) + u4 = uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +# ------------------- +# randn +# ------------------- + + +@jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = math.sqrt(-2.0 * math.log(u1)) + return r * math.cos(th), r * math.sin(th) + + +@jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/third_party/iluvatar/python/triton/language/semantic.py b/third_party/iluvatar/python/triton/language/semantic.py new file mode 100644 index 000000000..851fda6be --- /dev/null +++ b/third_party/iluvatar/python/triton/language/semantic.py @@ -0,0 +1,1666 @@ +from __future__ import annotations # remove after python 3.11 + +from typing import List, Optional, Sequence, Tuple, TypeVar + +from .._C.libtriton import ir +from . import core as tl +from . import math + +T = TypeVar('T') + + +class IncompatibleTypeErrorImpl(Exception): + + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") + + +def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 + if not a_ty.is_int() or not b_ty.is_int(): + raise TypeError(f"unexpected type {a_ty} and {b_ty}") + # 5 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False, + allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) + lhs = cast(lhs, ret_sca_ty, builder) + rhs = cast(rhs, ret_sca_ty, builder) + return lhs, rhs + + +def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise TypeError("cannot add pointers together") + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # * int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to the highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + raise TypeError(f"unexpected type {input_scalar_ty}") + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise TypeError("both operands of fdiv must have floating scalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + # input - input.div(other, rounding_mode="floor") * other + ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder) + return ret + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +############## +# other arithmetic ops +############## + + +def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + min, max = binary_op_type_checking_impl(min, max, builder) + x, min = binary_op_type_checking_impl(x, min, builder) + x, max = binary_op_type_checking_impl(x, max, builder) + + dtype = x.dtype + if dtype.is_floating(): + return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") + + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return and_(input, other, builder) + + +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return or_(input, other, builder) + + +def not_(input: tl.tensor, builder: ir.builder): + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + return invert(input, builder) + + +def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, builder) + + +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + if (range & (range - 1)) != 0: + raise ValueError("arange's range must be a power of 2") + shape = [range] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) + + +def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + value = cast(value, dtype, builder) + else: + # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + value = tl.tensor(value, dtype) + + return splat(value, shape, builder) + + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + + +def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("reshape() cannot change total number of elements in tensor") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) + + +def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type) + + +def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor: + a, b = broadcast_impl_value(a, b, builder) + + # The IR can't handle joining two scalars, so upcast them to 1D tensors, + # then downcast the result. + was_rank_1 = a.shape == [] + if was_rank_1: + a = expand_dims(a, 0, builder) + b = expand_dims(b, 0, builder) + + if isinstance(a.shape[-1], tl.constexpr): + two = tl.constexpr(2) + else: + two = 2 + new_shape = a.shape + [two] + + ret_type = tl.block_type(a.type.scalar, new_shape) + ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type) + + if was_rank_1: + ret = reshape(ret, [2], can_reorder=False, builder=builder) + + return ret + + +def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + assert (len(a.shape) > 0) + assert (tl._constexpr_to_value(a.shape[-1]) == 2) + + new_shape = a.shape[:-1] + ret_type = tl.block_type(a.type.scalar, new_shape) + outLHS, outRHS = builder.create_split(a.handle) + return ( + tl.tensor(outLHS, ret_type), + tl.tensor(outRHS, ret_type), + ) + + +def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor: + if len(input.shape) != len(dims): + raise ValueError("permute dims must have the same length as input shape") + if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))): + raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") + + ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) + return tl.tensor(builder.create_trans(input.handle, dims), ret_type) + + +def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for _ in range(len(lhs_shape), len(rhs_shape)): + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for _ in range(len(rhs_shape), len(lhs_shape)): + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + + +####### +# cast +####### + + +def _str_to_rounding_mode(rounding_mode: Optional[str]): + if rounding_mode is None: + return None + if rounding_mode == 'rtne': + return ir.ROUNDING_MODE.RTNE + if rounding_mode == 'rtz': + return ir.ROUNDING_MODE.RTZ + raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.") + + +def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, + fp_downcast_rounding: Optional[str] = None) -> tl.tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + + if (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): + assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" + + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def _str_to_load_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_store_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + elif cache_modifier == ".wt": + cache = ir.CACHE_MODIFIER.WT + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + +def _str_to_padding_option(padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + +def _str_to_sem(sem_option): + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + if sem_option: + if sem_option == "acquire": + sem = ir.MEM_SEMANTIC.ACQUIRE + elif sem_option == "release": + sem = ir.MEM_SEMANTIC.RELEASE + elif sem_option == "acq_rel": + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + elif sem_option == "relaxed": + sem = ir.MEM_SEMANTIC.RELAXED + else: + raise ValueError(f"Memory semantic {sem_option} not supported") + return sem + + +def _str_to_scope(scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + +def _canonicalize_boundary_check(boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return () + + +def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask is not None or other is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + + # Build IR + return tl.tensor( + builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) + + +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other is not None: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast `other` into `ele_ty` type + if other is not None: + other = cast(other, elt_ty, builder) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + return tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) + + +def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, + padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, + builder: ir.builder) -> tl.tensor: + # Cache, eviction and padding options + cache = _str_to_load_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + padding = _str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + + +def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type, + builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder), + _str_to_load_cache_modifier(cache_modifier), + _str_to_eviction_policy(eviction_policy)) + return tl.tensor(x, type) + + +def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void) + + +def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = broadcast_impl_shape(val, block_shape, builder) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, block_shape) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), + tl.void) + + +def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) + + +def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: + # Cache and eviction options + cache = _str_to_store_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder) + + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty in [tl.int1, tl.int8, tl.int16]: + raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val is not None: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_val.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_ptr.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + if sca_ty.is_int64(): + # Split it into low and high 32 bits and cast them to int32 + low_mask = full([], 0xFFFFFFFF, tl.int32, builder) + val_low = and_(val, low_mask, builder) + val_low_int32 = cast(val_low, tl.int32, builder) + + _32 = full([], 32, sca_ty, builder) + val_shr = lshr(val, _32, builder) + val_high = and_(val_shr, low_mask, builder) + val_high_int32 = cast(val_high, tl.int32, builder) + + # Split the pointer into two addresses for low and high parts + addr_low = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) + one_int32 = full(addr_low.shape, 1, tl.int32, builder) + addr_high = builder.create_addptr(addr_low.handle, one_int32.handle) + + # Perform atomic addition for the low 32 bits + if ptr.type.is_block(): + sum_ty = tl.block_type(tl.int32, ptr.type.get_block_shapes()) + else: + sum_ty = tl.int32 + old_value_low = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.ADD, addr_low.handle, val_low_int32.handle, mask.handle, sem, scope), + sum_ty) + + # Check for unsigned overflow in the low part and perform atomic addition for the high 32 bits + sum_low = add(old_value_low, val_low_int32, builder) + overflow = tl.tensor(builder.create_icmpULT(sum_low.handle, val_low_int32.handle), + _bool_like(sum_low)) # treat as unsigned + _1 = full([], 1, tl.int32, builder) + _0 = full([], 0, tl.int32, builder) + val_high_adjusted = add(val_high_int32, where(overflow, _1, _0, builder), builder) + old_value_high = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.ADD, addr_high, val_high_adjusted.handle, mask.handle, sem, scope), + sum_ty) + + # Combine the high and low results back into a 64-bit integer, treat low value as unisigned. + old_value_low_int64 = tl.tensor(builder.create_int_cast(old_value_low.handle, tl.int64.to_ir(builder), False), + tl.int64) + old_value_high_int64 = cast(old_value_high, tl.int64, builder) + old_value_high_shifted = shl(old_value_high_int64, _32, builder) + old_value = or_(old_value_high_shifted, old_value_low_int64, builder) + return old_value + else: + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + +def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + +def _str_to_dot_input_precision(input_precision, builder): + assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + return getattr(ir.INPUT_PRECISION, input_precision) + + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + + def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): + if not options.allow_fp8e4nv: + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( + ), "Dot op does not support fp8e4nv on CUDA arch < 90" + if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): + return + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + else: + if lhs_dtype.is_int() or rhs_dtype.is_int(): + assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8( + ), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): + if options.allow_fp8e4b15: + allowed_types = ['fp8e4nv', 'fp8e5', 'fp8e4b15'] + else: + allowed_types = ['fp8e4nv', 'fp8e5'] + + def _validate_dtype(dtype, allowed_types, operand_name): + if not any(getattr(dtype, f'is_{dtype_name}')() for dtype_name in allowed_types): + supported_types = ', '.join(allowed_types) + raise AssertionError(f"Only supports {supported_types}. {operand_name} ({dtype})") + + _validate_dtype(lhs_dtype, allowed_types, "First operand") + _validate_dtype(rhs_dtype, allowed_types, "Second operand") + else: + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1( + ), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1( + ), f"Unsupported dtype {rhs_dtype}" + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + + assert lhs.type.is_block() and rhs.type.is_block() + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + lhs = cast(lhs, tl.float16, builder) + rhs = cast(rhs, tl.float16, builder) + + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + input_precision = _str_to_dot_input_precision(input_precision, builder) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert lhs.shape[-2].value >= 16 and lhs.shape[-1].value >= 16 \ + and rhs.shape[-1].value >= 16, \ + f"All non-batch values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + # TODO: This is CUDA specific, check if ROCm has the same limitation + assert lhs.shape[1].value >= 32, "small blocks not supported!" + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if max_num_imprecise_acc is None: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), + ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + + +def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + condition = cast(condition, tl.int1, builder) + if condition.type.is_block(): + condition, x = broadcast_impl_value(condition, x, builder) + x, y = broadcast_impl_value(x, y, builder) + condition, x = broadcast_impl_value(condition, x, builder) + + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + if not condition.type.is_block(): + condition, _ = broadcast_impl_value(condition, x, builder) + ret_ty = x.type + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + +def wrap_tensor(x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + +def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: + if axis is None: + inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + assert axis < rank, f"reduction axis must be < inputs rank ({rank})" + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Associative Scan +# ===----------------------------------------------------------------------=== + + +def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool, + builder: ir.builder) -> Tuple[tl.tensor, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + scan_op.verify() + + return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Histogram +# ===----------------------------------------------------------------------=== + + +def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor: + assert len(input.shape) == 1, "histogram only supports 1D input" + assert input.dtype.is_int(), "histogram only supports integer input" + return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, ))) + + +## + + +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if max(1, len(x.shape)) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_constancy does not match the length of values") + x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(), tl.void) + + +def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + + new_args = [arg.handle for arg in args] + return tl.tensor(builder.create_print(prefix, hex, new_args), tl.void) + + +def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: + cond_ty = cond.type + if not cond_ty.is_block(): + cond_ty = tl.block_type(cond_ty.scalar, (1, )) + cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) + return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) + + +def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) + elif elem.dtype != tl.int32 and not require_i64: + assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ + "add a `.to(tl.int32)` or use regular indexing for 64 bit support" + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + +def _convert_to_ir_values(builder, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like] + return [_convert_elem_to_ir_value(builder, list_like, require_i64)] + + +def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = _convert_to_ir_values(builder, shape) + strides = _convert_to_ir_values(builder, strides) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) + + +def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + # Convert dynamic offsets to IR values + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return tl.tensor(builder.create_advance(base.handle, offsets), base.type) diff --git a/third_party/iluvatar/python/triton/language/standard.py b/third_party/iluvatar/python/triton/language/standard.py new file mode 100644 index 000000000..de30cf260 --- /dev/null +++ b/third_party/iluvatar/python/triton/language/standard.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +from ..runtime.jit import jit +from . import core +from . import math + +# constexpr utilities (triton metaprogramming sucks) + + +def _unwrap_if_constexpr(o): + return o.value if isinstance(o, core.constexpr) else o + + +def _log2(i: core.constexpr): + log2 = 0 + n = i.value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + +def _is_power_of_two(i: core.constexpr): + n = i.value + return core.constexpr((n & (n - 1)) == 0 and n != 0) + + +# ----------------------- +# Standard library +# ----------------------- + + +@core._tensor_member_fn +@jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type x: Block + :param div: the divisor + :param div: Block + """ + return (x + div - 1) // div + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("softmax") +def softmax(x, ieee_rounding=False): + z = x - max(x, 0) + num = math.exp(z) + den = sum(num, 0) + return math.fdiv(num, den, ieee_rounding) + + +@core._tensor_member_fn +@jit +def ravel(x): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.reshape(x, [x.numel], can_reorder=True) + + +@jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms indices of a row-major :code:`size_i * size_j` matrix into those + of one where the indices are col-major for each group of :code:`size_g` + rows. + + For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will + transform :: + + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + + into :: + + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # new row and column indices + new_i = off_i + (ij % size_g) + new_j = (ij % size_gj) // size_g + return new_i, new_j + + +@jit +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) + + +@jit +def zeros_like(input): + """ + Creates a tensor of zeros with the same shape and type as a given tensor. + """ + return zeros(input.shape, input.dtype) + + +# max and argmax + + +@jit +def _argmax_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + gt = value1 > value2 or tie + v_ret = core.where(gt, value1, value2) + i_ret = core.where(gt, index1, index2) + return v_ret, i_ret + + +@jit +def _argmax_combine_tie_break_left(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, True) + + +@jit +def _argmax_combine_tie_break_fast(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_max(a, b): + return core.maximum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") +def argmax(input, axis, tie_break_left=True, keep_dims=False): + (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +# min and argmin + + +@jit +def _argmin_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + lt = value1 < value2 or tie + value_ret = core.where(lt, value1, value2) + index_ret = core.where(lt, index1, index2) + return value_ret, index_ret + + +@jit +def _argmin_combine_tie_break_left(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, True) + + +@jit +def _argmin_combine_tie_break_fast(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_min(a, b): + return core.minimum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") +def argmin(input, axis, tie_break_left=True, keep_dims=False): + _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +@jit +def _sum_combine(a, b): + return a + b + + +# sum + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("sum") +def sum(input, axis=None, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims) + + +@jit +def _xor_combine(a, b): + return a ^ b + + +# xor sum + + +@core._tensor_member_fn +@core.builtin +@core._add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = core._promote_bfloat16_to_float32(input, _builder=_builder) + return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator) + + +# cumsum + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumsum") +def cumsum(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _sum_combine, reverse) + + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _prod_combine, reverse) + + +# sort + + +@jit +def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): + n_outer: core.constexpr = x.numel >> n_dims + shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = core.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = core.arange(0, 2)[None, :, None] + left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape) + right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) + left = core.reshape(left, x.shape) + right = core.reshape(right, x.shape) + # actual compare-and-swap + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + ret = ix ^ core.where((left > right) ^ flip, ileft ^ iright, zeros_like(ix)) + return ret.to(x.dtype, bitcast=True) + + +@jit +def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + n_outer: core.constexpr = x.numel >> n_dims + core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims) + return x + + +@core._tensor_member_fn +@jit +def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: core.constexpr = _log2(x.shape[_dim]) + for i in core.static_range(1, n_dims + 1): + x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims) + return x + + +# flip + + +def _get_flip_dim(dim, shape): + dim = _unwrap_if_constexpr(dim) + shape = _unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + assert dim == len(shape) - 1, "Currently only support flipping the last dimension" + return core.constexpr(dim) + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along (currently only final dimension supported) + :type dim: int + """ + core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)])) + core.static_assert(_is_power_of_two(x.numel)) + # # reshape the tensor to have all dimensions be 2. + # # TODO: We shouldn't have to change the dimensions not sorted. + steps: core.constexpr = _log2(x.numel) + start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)]) + y = core.reshape(x, [2] * steps) + y = core.expand_dims(y, start) + flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2)) + for i in core.static_range(start, steps): + flip2 = flip + for j in core.static_range(0, steps + 1): + if j != i and j != i + 1: + flip2 = core.expand_dims(flip2, j) + y = sum(y * flip2, i + 1, keep_dims=True) + x = core.reshape(y, x.shape) + return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. + + The two tensors must have the same shape. + + Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])` + """ + c = core.join(a, b) + + assert isinstance(c.shape, list) + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) diff --git a/third_party/iluvatar/python/triton/ops/__init__.py b/third_party/iluvatar/python/triton/ops/__init__.py new file mode 100644 index 000000000..4ab86ab2a --- /dev/null +++ b/third_party/iluvatar/python/triton/ops/__init__.py @@ -0,0 +1,11 @@ +# from .conv import _conv, conv +from . import blocksparse +from .cross_entropy import _cross_entropy, cross_entropy +from .flash_attention import attention +from .matmul import _matmul, get_higher_dtype, matmul +from .bmm_matmul import _bmm, bmm + +__all__ = [ + "blocksparse", "_cross_entropy", "cross_entropy", "_matmul", "matmul", "_bmm", "bmm", "attention", + "get_higher_dtype" +] diff --git a/python/triton/ops/blocksparse/__init__.py b/third_party/iluvatar/python/triton/ops/blocksparse/__init__.py similarity index 100% rename from python/triton/ops/blocksparse/__init__.py rename to third_party/iluvatar/python/triton/ops/blocksparse/__init__.py diff --git a/python/triton/ops/blocksparse/matmul.py b/third_party/iluvatar/python/triton/ops/blocksparse/matmul.py similarity index 100% rename from python/triton/ops/blocksparse/matmul.py rename to third_party/iluvatar/python/triton/ops/blocksparse/matmul.py diff --git a/python/triton/ops/blocksparse/softmax.py b/third_party/iluvatar/python/triton/ops/blocksparse/softmax.py similarity index 100% rename from python/triton/ops/blocksparse/softmax.py rename to third_party/iluvatar/python/triton/ops/blocksparse/softmax.py diff --git a/third_party/iluvatar/python/triton/ops/bmm_matmul.py b/third_party/iluvatar/python/triton/ops/bmm_matmul.py new file mode 100644 index 000000000..c47b6893d --- /dev/null +++ b/third_party/iluvatar/python/triton/ops/bmm_matmul.py @@ -0,0 +1,178 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Licensed under the MIT License + +import torch + +import triton +import triton.language as tl +from .matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [1]: + # TODO support block size 16 for MFMA dot op + for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 4 if block_n <= 64 else 8 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + #for split_k in [2, 4, 8, 16]: + # configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +def get_configs_compute_bound(): + configs = [] + for block_m in [64, 128, 256]: + for block_n in [64, 128, 256]: + for block_k in [32, 64, 128]: + num_warps = 8 if block_n <= 64 else 16 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=1, num_warps=num_warps)) + return configs + + +@triton.autotune( + configs=[] + get_configs_compute_bound() + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={'early_config_prune': early_config_prune, 'perf_model': estimate_matmul_time, 'top_k': 10}, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_K'] == 0, +}) +@triton.jit +def _bmm_kernel( + A, + B, + C, + M, + N, + K, + stride_aq, + stride_am, + stride_ak, + stride_bq, + stride_bk, + stride_bn, + stride_cq, + stride_cm, + stride_cn, + dot_out_dtype: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + idx_q = tl.program_id(1) # batch dimension for BMM + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q * stride_aq) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q * stride_bq) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_q = tl.program_id(1) # batch dimension for BMM + idx_m = rm[:, None] + idx_n = rn[None, :] + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn + idx_q * stride_cq) + mask = (idx_m < M) & (idx_n < N) + # handles write-back with reduction-splitting + tl.store(C, acc, mask=mask) + + +class _bmm(torch.autograd.Function): + kernel = _bmm_kernel + + _locks = {} + + @staticmethod + def _call(a, b, dot_out_dtype): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + + #only MR support Trans layout + if hasattr(torch, "corex"): + capability = torch.cuda.get_device_capability(device) + capability = capability[0] * 10 + capability[1] + if (capability < 71): + if a.stride(0) >= 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) >= 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[0] == b.shape[0], "incompatible dimensions" + assert a.shape[2] == b.shape[1], "incompatible dimensions" + B, M, K = a.shape + _, _, N = b.shape + # allocates output + c = torch.empty((B, M, N), device=device, dtype=a.dtype) + if dot_out_dtype is None: + if a.dtype in [torch.float16, torch.float32, torch.bfloat16]: + dot_out_dtype = tl.float32 + else: + dot_out_dtype = tl.int32 + else: + assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype" + if dot_out_dtype == torch.float16: + dot_out_dtype = tl.float16 + elif dot_out_dtype in [torch.float32, torch.bfloat16]: + dot_out_dtype = tl.float32 + else: + dot_out_dtype = tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), B, 1) + _bmm_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), a.stride(2), b.stride(0), b.stride(1), + b.stride(2), c.stride(0), c.stride(1), c.stride(2), dot_out_dtype=dot_out_dtype, GROUP_M=8) + return c + + @staticmethod + def forward(ctx, a, b, dot_out_dtype=None): + return _bmm._call(a, b, dot_out_dtype=dot_out_dtype) + + +bmm = _bmm.apply diff --git a/python/triton/ops/cross_entropy.py b/third_party/iluvatar/python/triton/ops/cross_entropy.py similarity index 100% rename from python/triton/ops/cross_entropy.py rename to third_party/iluvatar/python/triton/ops/cross_entropy.py diff --git a/third_party/iluvatar/python/triton/ops/flash_attention.py b/third_party/iluvatar/python/triton/ops/flash_attention.py new file mode 100644 index 000000000..b693e17ac --- /dev/null +++ b/third_party/iluvatar/python/triton/ops/flash_attention.py @@ -0,0 +1,472 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) + +Sequence Parallel implementation inspired by HazyResearch +(see https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py) +""" + +import torch +import triton + +from .. import cdiv, jit +from .. import language as tl +from triton.runtime.build import is_corex + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@jit +def _fwd_kernel(Q, K, V, sm_scale, # + L, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + IS_CAUSAL: tl.constexpr # + ): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + vk_offset = qvk_offset // stride_qm + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(BLOCK_DMODEL, Z_H_N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, vk_offset), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(vk_offset, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # credits to: Adam P. Goucher (https://github.com/apgoucher): + # scale sm_scale by 1/log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + + offs_k = tl.arange(0, BLOCK_DMODEL) + Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + q = tl.load(Q_ptrs) + + q = (q * qk_scale).to(K.dtype.element_ty) + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back l and m + acc = acc / l_i[:, None] + l_ptrs = L + off_hz * N_CTX + offs_m + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(vk_offset + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) + + +@jit +def _bwd_preprocess( + Out, + DO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_m, delta) + + +@jit +def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): + if CAUSAL: + lo = start_n * BLOCK_M + else: + lo = 0 + + Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm + DQ_offset = off_z * stride_qz + off_h * stride_qh + K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn + V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn + if SEQUENCE_PARALLEL: + DQ_offset += stride_dqa * start_n + DQ_offset = DQ_offset // stride_qm + + Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0)) + K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0)) + DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0)) + DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + + # initialize row/col offsets + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(Q_block_ptr) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + if CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf")) + else: + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= qk_scale + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i[:, None]) + # compute dv + do = tl.load(DO_block_ptr) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp = tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty) + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds), q) + # compute dq + if not SEQUENCE_PARALLEL: + dq = tl.load(DQ_block_ptr) + dq += tl.dot(ds, k) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) + elif SEQUENCE_PARALLEL: + dq = tl.dot(ds, k) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) + + # increment pointers + DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0)) + Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) + # write-back + tl.store(DV_block_ptr, dv.to(V.dtype.element_ty)) + tl.store(DK_block_ptr, dk.to(K.dtype.element_ty)) + + +@jit +def _bwd_kernel(Q, K, V, sm_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + Z_H_N_CTX, # + SQ_Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): + qk_scale = sm_scale * 1.44269504 + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + if SEQUENCE_PARALLEL: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + else: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + DK_block_ptr = tl.make_block_ptr( + base=DK, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DV_block_ptr = tl.make_block_ptr( + base=DV, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + num_block_n = tl.cdiv(N_CTX, BLOCK_N) + if not SEQUENCE_PARALLEL: + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) + else: + start_n = tl.program_id(1) + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): + # only support for Ampere now + capability = torch.cuda.get_device_capability() + if is_corex(): + BLOCK_M = 64 + BLOCK_N = 64 + num_stages = 1 + else: + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported for compute capability >= 80") + BLOCK_M = 128 + BLOCK_N = 64 + num_stages = 4 + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( + q, k, v, sm_scale, # + L, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, # + IS_CAUSAL=causal, # + num_warps=num_warps, # + num_stages=num_stages # + ) + + ctx.save_for_backward(q, k, v, o, L) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + ctx.causal = causal + ctx.sequence_parallel = sequence_parallel + return o + + @staticmethod + def backward(ctx, do): + capability = torch.cuda.get_device_capability() + MMA_V3 = capability[0] >= 9 + # otherwise shared memory out of resource + BLOCK = 128 if not is_corex() else 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases. + num_warps = 16 if is_corex() and ctx.BLOCK_DMODEL > 64 else 8 + + if is_hip(): + # Bwd pass runs out of shared memory on HIP with larger block size. + BLOCK = 64 + + q, k, v, o, L = ctx.saved_tensors + sequence_parallel = ctx.sequence_parallel + seq_len_kv = k.shape[2] + do = do.contiguous() + if sequence_parallel: + replicas = cdiv(seq_len_kv, BLOCK) + new_dq_shape = (replicas, ) + q.shape + dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype) + else: + dq = torch.zeros_like(q, dtype=q.dtype) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + delta = torch.empty_like(L) + _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )]( + o, + do, + delta, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)]( + q, k, v, ctx.sm_scale, # + o, do, # + dq, dk, dv, # + L, # + delta, # + o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + SEQUENCE_PARALLEL=sequence_parallel, # + CAUSAL=ctx.causal, # + MMA_V3=MMA_V3, # + num_warps=num_warps, # + num_stages=1 # + ) + + if len(dq.shape) == 5: + dq = dq.sum(dim=0) + return dq, dk, dv, None, None, None + + +attention = _attention.apply diff --git a/third_party/iluvatar/python/triton/ops/matmul.py b/third_party/iluvatar/python/triton/ops/matmul.py new file mode 100644 index 000000000..d26fbeaaf --- /dev/null +++ b/third_party/iluvatar/python/triton/ops/matmul.py @@ -0,0 +1,244 @@ +import torch + +from .. import Config, autotune, cdiv, heuristics, jit +from .. import language as tl +from .matmul_perf_model import early_config_prune, estimate_matmul_time + +_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] + + +def upcast_if_fp8(a): + if "fp8" in str(a): + return torch.float16 + return a + + +def get_higher_dtype(a, b): + a = upcast_if_fp8(a) + b = upcast_if_fp8(b) + if a is b: + return a + + assert a in _ordered_datatypes + assert b in _ordered_datatypes + + for d in _ordered_datatypes: + if a is d: + return b + if b is d: + return a + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + if hasattr(torch, "corex"): + return configs + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +def get_configs_compute_bound(): + configs = [] + if hasattr(torch, "corex"): + for block_m in [32, 64, 128, 256]: + for block_n in [32, 64, 128, 256]: + for block_k in [32, 64, 128, 256]: + # for num_stages in [1, 2]: + for num_stages in [1, 2]: + num_warps = 16 if block_m >= 128 or block_n >= 128 or block_k >= 128 else 8 + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + return configs + + +def get_nv_configs(): + configs = [] + if hasattr(torch, "corex"): + return configs + configs = [ + # basic configs for compute-bound matmuls + Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + return configs + + +@autotune( + configs=get_nv_configs() + get_configs_io_bound() + get_configs_compute_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10, + }, +) +@heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@jit +def _kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + acc_dtype: tl.constexpr, # + input_precision: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr # + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, input_precision=input_precision) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +class _matmul(torch.autograd.Function): + kernel = _kernel + + _locks = {} + + @staticmethod + def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if (output_dtype is None): + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + # Allowed types for acc_type given the types of a and b. + supported_acc_dtypes = { + torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32, ), torch.int8: (torch.int32, ) + } + + if acc_dtype is None: + acc_dtype = supported_acc_dtypes[ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert acc_dtype in supported_acc_dtypes[a.dtype], "acc_dtype not compatible with the type of a" + assert acc_dtype in supported_acc_dtypes[b.dtype], "acc_dtype not compatible with the type of b" + + def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: + ab_dtype = None + # launch kernel + grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid]( + a, b, c, M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + acc_dtype=acc_dtype, # + input_precision=input_precision, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, AB_DTYPE=ab_dtype) + return c + + @staticmethod + def forward(ctx, a, b, acc_dtype=None, input_precision=None, fp8_fast_accum=True, output_dtype=None): + return _matmul._call(a, b, acc_dtype=acc_dtype, input_precision=input_precision, fp8_fast_accum=fp8_fast_accum, + output_dtype=output_dtype) + + +matmul = _matmul.apply diff --git a/third_party/iluvatar/python/triton/ops/matmul_perf_model.py b/third_party/iluvatar/python/triton/ops/matmul_perf_model.py new file mode 100644 index 000000000..56c055a05 --- /dev/null +++ b/third_party/iluvatar/python/triton/ops/matmul_perf_model.py @@ -0,0 +1,180 @@ +import functools +import heapq + +import torch + +from .. import cdiv +from ..runtime import driver +from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi) + + +@functools.lru_cache() +def get_clock_rate_in_khz(): + try: + return nvsmi(['clocks.max.sm'])[0] * 1e3 + except FileNotFoundError: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 + + +def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops( + dtype, get_clock_rate_in_khz(), device) + return tflops + + +def get_simd_tflops(device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) + return tflops + + +def get_tflops(device, num_ctas, num_warps, dtype): + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: + return get_simd_tflops(device, num_ctas, num_warps, dtype) + return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) + + +def estimate_matmul_time( + # backend, device, + num_warps, num_stages, # + A, B, C, # + M, N, K, # + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, # + debug=False, **kwargs # +): + ''' return estimated running time in ms + = max(compute, loading) + store ''' + device = torch.cuda.current_device() + dtype = A.dtype + dtsize = A.element_size() + + if (K % (BLOCK_K * SPLIT_K) != 0): + return float('inf') + num_cta_m = cdiv(M, BLOCK_M) + num_cta_n = cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k + + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) + + # time to compute + total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS + tput = get_tflops(device, num_ctas, num_warps, dtype) + compute_ms = total_ops / tput + + # time to load data + num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) + load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw + + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram / store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram / reduce_bw + # c.zero_() + zero_ms = M * N * 2 / (1024 * 1024) / store_bw + store_ms += zero_ms + + # total_time_ms = max(compute_ms, load_ms) + store_ms + total_time_ms = compute_ms + load_ms + store_ms + if debug: + print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' + f'loading time: {load_ms}ms, store time: {store_ms}ms, ' + f'Activate CTAs: {active_cta_ratio*100}%') + return total_time_ms + + +def early_config_prune(configs, named_args, **kwargs): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['A'].element_size() + dtype = named_args['A'].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages + + max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + + # Some dtypes do not allow atomic_add + if dtype not in [torch.float16, torch.float32]: + configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] + + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages + + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if capability[0] >= 8 and not hasattr(torch, "corex"): + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + if hasattr(torch, "corex"): + for stage in range(len(v)): + random_config = v[stage][0] + random_config.num_stages = v[stage][1] + pruned_configs.append(random_config) + else: + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs diff --git a/third_party/iluvatar/python/triton/runtime/__init__.py b/third_party/iluvatar/python/triton/runtime/__init__.py new file mode 100644 index 000000000..0b3979d28 --- /dev/null +++ b/third_party/iluvatar/python/triton/runtime/__init__.py @@ -0,0 +1,23 @@ +from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics) +from .cache import RedisRemoteCacheBackend, RemoteCacheBackend +from .driver import driver +from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret +from .errors import OutOfResources, InterpreterError + +__all__ = [ + "autotune", + "Autotuner", + "Config", + "driver", + "Heuristics", + "heuristics", + "InterpreterError", + "JITFunction", + "KernelInterface", + "MockTensor", + "OutOfResources", + "RedisRemoteCacheBackend", + "reinterpret", + "RemoteCacheBackend", + "TensorWrapper", +] diff --git a/third_party/iluvatar/python/triton/runtime/autotuner.py b/third_party/iluvatar/python/triton/runtime/autotuner.py new file mode 100644 index 000000000..f6ce2bdc9 --- /dev/null +++ b/third_party/iluvatar/python/triton/runtime/autotuner.py @@ -0,0 +1,455 @@ +from __future__ import annotations + +import builtins +import os +import time +import inspect +from typing import Dict + +import json +import os +import hashlib + +from ..testing import do_bench, do_bench_cudagraph +from .jit import KernelInterface +from .errors import OutOfResources +from .cache import default_cache_dir + + +def build_best_config_hash(args_names, key): + cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) + hasher = hashlib.sha256() + hasher.update(f"{'_'.join(args_names) + str(key)}\n".encode()) + cfg_hash = hasher.hexdigest() + cfg_hash_dir = os.path.join(cache_dir, cfg_hash) + cfg_hash_file = os.path.splitext(cfg_hash)[0] + ".best_config" + cfg_hash_file = os.path.join(cfg_hash_dir, cfg_hash_file) + return cfg_hash_dir, cfg_hash_file + + +def load_best_config(args_names, key): + _, cfg_hash_file = build_best_config_hash(args_names, key) + if os.path.exists(cfg_hash_file): + with open(cfg_hash_file) as fd: + best_config = json.loads(fd.read()) + num_warps = best_config.pop('num_warps') if 'num_warps' in best_config else 4 + num_stages = best_config.pop('num_stages') if 'num_stages' in best_config else 1 + return best_config, num_warps, num_stages + return None + + +def save_best_config(cfg, args_names, key): + cfg_hash_dir, cfg_hash_file = build_best_config_hash(args_names, key) + if os.path.exists(cfg_hash_dir): + return + os.makedirs(cfg_hash_dir, exist_ok=True) + with open(cfg_hash_file, "w") as fd: + fd.write(json.dumps({ + **cfg.kwargs, + "num_warps": cfg.num_warps, + "num_stages": cfg.num_stages, + })) + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + use_cuda_graph=False, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] + + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args, exception: 0 + if pre_hook: + self.pre_hook = pre_hook + elif (len(self.reset_idx) > 0 or len(self.restore_idx) > 0): + + def _pre_hook(args, reset_only=False): + for i in self.reset_idx: + args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + elif len(self.restore_idx) > 0: + + def _post_hook(args, exception): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + self.num_warmups = warmup + self.num_reps = rep + import torch + self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available() + # cache_fn_map fmt: {"fn_cache_key: [hash_cache_file_0, hash_cache_file_1, ...], [so_path_0, so_path_1, ...]]"} + self.cache_fn_map = dict() + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(args, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(args, exception=None) + + try: + if self.use_cuda_graph: + import torch + with torch.cuda.stream(torch.cuda.Stream()): + bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median") + return bench_res + bench_results = do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure): + bench_results = float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")] + + cache_key = str(self.get_jit_func().cache_key) + check_key = self.cache_fn_map.get(str(cache_key), None) + if not check_key: + self.cache_fn_map.setdefault(cache_key, [[], []]) + hash_cache_file = str(self.get_jit_func().hash_cache_file) + so_path = '' + if self.get_jit_func().so_path: + so_path = self.get_jit_func().so_path.split('/')[-2] + self.cache_fn_map[cache_key][0].append(hash_cache_file) + self.cache_fn_map[cache_key][1].append(so_path) + return bench_results + + def get_jit_func(self): + if hasattr(self.fn, "cache_key"): + # for autotune + jit + return self.fn + elif hasattr(self.fn.fn, "cache_key"): + # for autotune + heuristics + jit + return self.fn.fn + else: + msg = f'Current {self.fn} or {self.fn.fn} has no attribute cache_key.' + raise RuntimeError(msg) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + divisibility = 16 + for arg in args: + if hasattr(arg, "data_ptr"): + key.append(arg.dtype) + key.append(arg.data_ptr() % divisibility == 0) + elif isinstance(arg, int): + key.append(arg) + key = tuple(key) + if key not in self.cache: + load_config = load_best_config(self.arg_names, key) + if load_config: + best_config, num_warps, num_stages = load_config + config = Config(best_config, num_warps, num_stages) + self.cache[key] = config + self.pre_hook(args, reset_only=True) + else: + # prune configs + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + list_keys = list(timings.keys()) + best_key_index = list_keys.index(builtins.min(timings, key=timings.get)) + save_best_config(self.cache[key], self.arg_names, key) + self.pre_hook(args, reset_only=True) + self.configs_timings = timings + cache_key = str(self.get_jit_func().cache_key) + check_key = self.cache_fn_map.get(cache_key, None) + if check_key: + best_cache_file = self.cache_fn_map[cache_key][0][best_key_index] + best_so_path = self.cache_fn_map[cache_key][1][best_key_index] + ck_list = [best_cache_file, best_so_path] + for i in range(len(ck_list)): + for tmp_key in check_key[i]: + if ck_list[i] != tmp_key: + del_cache_file = os.path.join( + os.environ.get('TRITON_CACHE_DIR', default_cache_dir()), tmp_key) + import shutil + shutil.rmtree(del_cache_file, ignore_errors=True) + self.cache_fn_map.clear() + + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + if config.pre_hook is not None: + config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.maxnreg = maxnreg + self.pre_hook = pre_hook + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("maxnreg", self.maxnreg), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=25, rep=100, use_cuda_graph=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :type warmup: int + :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :type rep: int + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/third_party/iluvatar/python/triton/runtime/build.py b/third_party/iluvatar/python/triton/runtime/build.py new file mode 100644 index 000000000..66cb2539a --- /dev/null +++ b/third_party/iluvatar/python/triton/runtime/build.py @@ -0,0 +1,86 @@ +import contextlib +import sys +import io +import sysconfig +import os +import shutil +import subprocess +import setuptools + + +def is_corex(): + import torch + return hasattr(torch, "corex") and torch.corex == True + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +def _build(name, src, srcdir, library_dirs, include_dirs, libraries): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + if is_corex(): + cc = clang if clang is not None else gcc + else: + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + include_dirs = include_dirs + [srcdir, py_include_dir] + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + ret = subprocess.check_call(cc_cmd) + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) + return so diff --git a/third_party/iluvatar/python/triton/runtime/cache.py b/third_party/iluvatar/python/triton/runtime/cache.py new file mode 100644 index 000000000..bd3c29b99 --- /dev/null +++ b/third_party/iluvatar/python/triton/runtime/cache.py @@ -0,0 +1,281 @@ +import importlib +import json +import os +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Optional +import hashlib + + +def default_cache_dir(): + return os.path.join(Path.home(), ".triton", "cache") + + +def default_override_dir(): + return os.path.join(Path.home(), ".triton", "override") + + +def default_dump_dir(): + return os.path.join(Path.home(), ".triton", "dump") + + +class CacheManager(ABC): + + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use tempfile to be robust against program interruptions + temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + return filepath + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, key: str): + pass + + @abstractmethod + def get(self, filenames: List[str]) -> Dict[str, bytes]: + pass + + @abstractmethod + def put(self, filename: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + + def __init__(self, key): + import redis + self._key = key + self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}") + self._redis = redis.Redis( + host=os.environ.get("TRITON_REDIS_HOST", "localhost"), + port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + ) + + def _get_key(self, filename: str) -> str: + return self._key_fmt.format(key=self._key, filename=filename) + + def get(self, filenames: List[str]) -> Dict[str, str]: + results = self._redis.mget([self._get_key(f) for f in filenames]) + return {filename: result for filename, result in zip(filenames, results) if result is not None} + + def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + self._redis.set(self._get_key(filename), data) + + +class RemoteCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. + remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"] + module_path, clz_nme = remote_cache_manager.split(":") + module = importlib.import_module(module_path) + remote_cache_cls = getattr(module, clz_nme) + self._backend = remote_cache_cls(key) + + self._override = override + self._dump = dump + + # Use a `FileCacheManager` to materialize remote cache paths locally. + self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) + + def _materialize(self, filename: str, data: bytes): + # We use a backing `FileCacheManager` to provide the materialized data. + return self._file_cache_manager.put(data, filename, binary=True) + + def get_file(self, filename: str) -> Optional[str]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_file(filename) + + # We always check the remote cache backend -- even if our internal file- + # based cache has the item -- to make sure LRU accounting works as + # expected. + results = self._backend.get([filename]) + if len(results) == 0: + return None + (_, data), = results.items() + return self._materialize(filename, data) + + def put(self, data, filename: str, binary=True) -> str: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put(data, filename, binary=binary) + + if not isinstance(data, bytes): + data = str(data).encode("utf-8") + self._backend.put(filename, data) + return self._materialize(filename, data) + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_group(filename) + + grp_filename = f"__grp__{filename}" + grp_filepath = self.get_file(grp_filename) + if grp_filepath is None: + return None + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + + result = None + + # Found group data. + if child_paths is not None: + result = {} + for child_path, data in self._backend.get(child_paths).items(): + result[child_path] = self._materialize(child_path, data) + + return result + + def put_group(self, filename: str, group: Dict[str, str]): + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put_group(filename, group) + + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename) + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(key) + + +def get_override_manager(key) -> CacheManager: + return __cache_cls(key, override=True) + + +def get_dump_manager(key) -> CacheManager: + return __cache_cls(key, dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + return key diff --git a/third_party/iluvatar/python/triton/runtime/driver.py b/third_party/iluvatar/python/triton/runtime/driver.py new file mode 100644 index 000000000..c3b97a764 --- /dev/null +++ b/third_party/iluvatar/python/triton/runtime/driver.py @@ -0,0 +1,60 @@ +from ..backends import backends +from ..backends import DriverBase + + +def _create_driver(): + actives = [x.driver for x in backends.values() if x.driver.is_active()] + if len(actives) != 1: + raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.") + return actives[0]() + + +class LazyProxy: + + def __init__(self, init_fn): + self._init_fn = init_fn + self._obj = None + + def _initialize_obj(self): + if self._obj is None: + self._obj = self._init_fn() + + def __getattr__(self, name): + self._initialize_obj() + return getattr(self._obj, name) + + def __setattr__(self, name, value): + if name in ["_init_fn", "_obj"]: + super().__setattr__(name, value) + else: + self._initialize_obj() + setattr(self._obj, name, value) + + def __delattr__(self, name): + self._initialize_obj() + delattr(self._obj, name) + + def __repr__(self): + if self._obj is None: + return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>" + return repr(self._obj) + + def __str__(self): + self._initialize_obj() + return str(self._obj) + + +class DriverConfig: + + def __init__(self): + self.default = LazyProxy(_create_driver) + self.active = self.default + + def set_active(self, driver: DriverBase): + self.active = driver + + def reset_active(self): + self.active = self.default + + +driver = DriverConfig() diff --git a/third_party/iluvatar/python/triton/runtime/errors.py b/third_party/iluvatar/python/triton/runtime/errors.py new file mode 100644 index 000000000..4dce91767 --- /dev/null +++ b/third_party/iluvatar/python/triton/runtime/errors.py @@ -0,0 +1,26 @@ +from ..errors import TritonError +from typing import Optional + + +class InterpreterError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + return self.error_message or "" + + +class OutOfResources(TritonError): + + def __init__(self, required, limit, name): + self.required = required + self.limit = limit + self.name = name + + def __str__(self) -> str: + return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help." + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) diff --git a/third_party/iluvatar/python/triton/runtime/interpreter.py b/third_party/iluvatar/python/triton/runtime/interpreter.py new file mode 100644 index 000000000..a82832ecf --- /dev/null +++ b/third_party/iluvatar/python/triton/runtime/interpreter.py @@ -0,0 +1,1127 @@ +import inspect +from typing import Tuple + +import math +import numpy as np + +import triton +import triton.language as tl +from dataclasses import dataclass +from .errors import InterpreterError +from functools import partial +from .._C.libtriton import interpreter as _interpreter +from .._C.libtriton import ir as _ir + + +class TensorHandle: + + def __init__(self, data, dtype): + ''' + data: numpy array + dtype: triton type, either pointer_type or scalar_type. + we don't store block_type here because the shape information is already availale in the data field + attr: a dictionary of attributes + ''' + self.data = data + self.dtype = dtype + self.attr = {} + + def __bool__(self): + return bool(self.data.all()) + + def get_element_ty(self): + dtype = self.dtype + while hasattr(dtype, "element_ty"): + dtype = dtype.element_ty + return dtype + + def clone(self): + return TensorHandle(self.data.copy(), self.dtype) + + def set_attr(self, key, value): + self.attr[key] = value + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, tensor_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.tensor_shape = tensor_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.get_element_ty() + n_bytes = dtype_tt.primitive_bitwidth // 8 + tensor_shape = self.tensor_shape + ptrs = np.broadcast_to(self.base.data, self.tensor_shape) + masks = np.ones(self.tensor_shape, dtype=bool) + for dim in range(len(tensor_shape)): + bcast_dims = [1] * len(tensor_shape) + bcast_dims[dim] = tensor_shape[dim] + off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = np.logical_and(masks, off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +@dataclass(frozen=True) +class InterpreterOptions: + extern_libs: dict = None + debug: bool = False + arch: str = None + allow_fp8e4nv: bool = True + allow_fp8e4b15: bool = True + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: int = 0 + + +def _get_signed_np_dtype(dtype): + if dtype == np.uint8: + return np.int8 + if dtype == np.uint16: + return np.int16 + if dtype == np.uint32: + return np.int32 + if dtype == np.uint64: + return np.int64 + return dtype + + +def _get_np_dtype(tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.int1: np.dtype(bool), + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + # bfloat16 types are stored as uint16 + tl.bfloat16: np.dtype(np.uint16), + # float8 types are stored as uint8 + tl.float8e5: np.dtype(np.uint8), + tl.float8e5b16: np.dtype(np.uint8), + tl.float8e4nv: np.dtype(np.uint8), + tl.float8e4b8: np.dtype(np.uint8), + tl.float8e4b15: np.dtype(np.uint8), + } + if isinstance(tt_dtype, tl.block_type): + if isinstance(tt_dtype.element_ty, tl.pointer_type): + return np.dtype(np.uint64) + return np_types[tt_dtype.element_ty] + return np_types[tt_dtype] + + +def _convert_float(input, input_dtype, output_dtype, rounding_mode): + input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}") + output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}") + input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype) + sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01 + input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1 + output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1 + significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1) + bias_input = input_dtype.exponent_bias + bias_output = output_dtype.exponent_bias + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + subnormal_index = exponent == 0 + if np.any(subnormal_index): + # Credit to Phil: phil@openai.com + # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0)) + bit_pos = np.zeros_like(input_bin, dtype=np.int32) + # Find the most significant bit of the mantissa in the significand + for i in range(input_dtype.fp_mantissa_width): + bit_index = ((significand >> i) & 0x01) + # pos should be >= 1 + bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i + zero_significand_index = significand == 0 + exponent[subnormal_index] = 1 - bit_pos[subnormal_index] + # 0 significand and subnormal should be treated as 0 + exponent[zero_significand_index & subnormal_index] = bias_input - bias_output + significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( + (1 << input_dtype.fp_mantissa_width) - 1) + # Prevent overflow and underflow + exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_output = exponent_output.astype(output_unint_dtype) + sign_output = sign.astype(output_unint_dtype) + if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast + significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even + # find the cut-off bit + cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1)) + significand_output = significand_output + (cut_off > 0) + significand_output = significand_output.astype(output_unint_dtype) + else: # Upcast + significand_output = (significand.astype(output_unint_dtype) << + (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + subnormal_index = exponent_output == 0 + if np.any(subnormal_index): # underflow + # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # shift = (1 - exp_bias_output) - (exp - exp_bias_input) + # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift)) + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + non_zero_exponent_index = exponent != 0 + # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa + subnormal_index = subnormal_index & non_zero_exponent_index + shift = np.zeros_like(input_bin, dtype=np.int32) + shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) + significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( + 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( + exponent_output << output_dtype.fp_mantissa_width) | significand_output + return output.reshape(input.shape) + + +def _erf(x): + # Numpy does not support erf + return math.erf(x) + + +def _umulhi_64(a, b): + # Numpy does not support 128-bit multiplication + # So we have to implement it manually + return (int(a) * int(b)) >> 64 + + +np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32]) +np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64]) +np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64]) + + +class ExtraFunctions: + + @staticmethod + def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder): + return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty) + + +class InterpreterBuilder: + ir_sem_to_interpreter_sem = { + _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE, + _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE, + _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED, + _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE, + } + + ir_rmw_op_to_interpreter_rmw_op = { + _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD, + _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD, + _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN, + _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN, + _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX, + _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX, + _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND, + _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR, + _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR, + _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG, + } + + def __init__(self) -> None: + self.arch = None + self.options = InterpreterOptions() + self.codegen_fns = {} + self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + + def set_grid_idx(self, x, y, z): + if not x < self.grid_dim[0]: + raise ValueError("x >= grid_dim[0]") + if not y < self.grid_dim[1]: + raise ValueError("y >= grid_dim[1]") + if not z < self.grid_dim[2]: + raise ValueError("z >= grid_dim[2]") + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + # constants + + def get_half_ty(self): + return tl.float16 + + def get_bf16_ty(self): + return tl.bfloat16 + + def get_float_ty(self): + return tl.float32 + + def get_double_ty(self): + return tl.float64 + + def get_int8_ty(self): + return tl.int8 + + def get_uint8_ty(self): + return tl.uint8 + + def get_int16_ty(self): + return tl.int16 + + def get_uint16_ty(self): + return tl.uint16 + + def get_int32_ty(self): + return tl.int32 + + def get_uint32_ty(self): + return tl.uint32 + + def get_int64_ty(self): + return tl.int64 + + def get_uint64_ty(self): + return tl.uint64 + + def get_fp8e4nv_ty(self): + return tl.float8e4nv + + def get_fp8e4b15_ty(self): + return tl.float8e4b15 + + def get_fp8e4b8_ty(self): + return tl.float8e4b8 + + def get_fp8e5_ty(self): + return tl.float8e5 + + def get_fp8e5b16_ty(self): + return tl.float8e5b16 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.block_type(dtype, shape) + + def get_int1(self, value): + return TensorHandle(np.array([value], dtype=np.bool_), tl.int1) + + def get_uint8(self, value): + return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8) + + def get_int8(self, value): + return TensorHandle(np.array([value], dtype=np.int8), tl.int8) + + def get_uint16(self, value): + return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16) + + def get_int16(self, value): + return TensorHandle(np.array([value], dtype=np.int16), tl.int16) + + def get_uint32(self, value): + return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_uint64(self, value): + return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_fp64(self, value): + return TensorHandle(np.array([value], dtype=np.float64), tl.float64) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + if self.grid_idx is None: + raise ValueError("grid_idx is None") + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \ + (src_element_type == tl.float32 and dst_element_type == tl.bfloat16): + data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + else: + return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type, rounding_mode): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + + def create_idiv(self, lhs, rhs): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar) + + def create_ashr(self, lhs, rhs): + # Triton's rshift operator depends on the signedness of the left operand + lhs_dtype = _get_signed_np_dtype(lhs.data.dtype) + rhs_dtype = _get_signed_np_dtype(rhs.data.dtype) + lhs.data = lhs.data.astype(lhs_dtype) + rhs.data = rhs.data.astype(rhs_dtype) + return self.binary_op(lhs, rhs, np.right_shift) + + def create_umulhi(self, lhs, rhs): + dtype = lhs.data.dtype + if dtype == np.int64 or dtype == np.uint64: + return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar) + else: + compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}") + lhs_data = lhs.data.astype(compute_dtype) + rhs_data = rhs.data.astype(compute_dtype) + ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8) + return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar) + + create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + def create_fma(self, x, y, z): + return TensorHandle(x.data * y.data + z.data, z.dtype.scalar) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype.scalar) + + def create_fabs(self, arg): + # Mask out the sign bit based on the primitive length + dtype_tt = arg.dtype + mask_bitwidth = dtype_tt.primitive_bitwidth - 1 + np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}") + data = arg.data.view(np_uint_dtype) + mask = (1 << mask_bitwidth) - 1 + ret = (data & mask).view(_get_np_dtype(dtype_tt)) + return TensorHandle(ret, arg.dtype.scalar) + + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + create_floor = lambda self, arg: self.unary_op(arg, np.floor) + create_ceil = lambda self, arg: self.unary_op(arg, np.ceil) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_log2 = lambda self, arg: self.unary_op(arg, np.log2) + create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + + def create_erf(self, arg): + ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data) + return TensorHandle(ret, arg.dtype.scalar) + + def create_rsqrt(self, arg): + return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar) + + # tensor operators + create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar) + + def create_trans(self, arg, perm): + return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar) + + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): + a_data = a.data + b_data = b.data + if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \ + (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()): + a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16) + b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16) + return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + def create_histogram(self, data, bins): + return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.get_element_ty() + element_bitwidth = dtype_tt.primitive_bitwidth + # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic + element_bytewidth = max(1, element_bitwidth // 8) + return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if padding_option is None: + other = None + elif padding_option == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding_option == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding option {padding_option}") + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) + + def create_int_to_ptr(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_ptr_to_int(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_cat(self, lhs, rhs): + return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) + + def create_join(self, lhs, rhs): + # Triton only supports joining two original tensors into a new one along the last axis + return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar) + + def create_split(self, val): + # Triton only supports splitting the original tensor into two along the last axis + return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar)) + + def create_splat(self, arg, shape): + if isinstance(arg.dtype, tl.block_type): + return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + else: # scalar + return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + def create_atomic_cas(self, ptr, cmp, val, sem, scope): + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar) + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope): + if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op: + raise ValueError(f"unsupported rmwOp {rmwOp}") + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp] + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar) + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + raise NotImplementedError("extern_elementwise not supported in interpreter mode") + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + raise NotImplementedError("inline_asm not supported in interpreter mode") + + def create_print(self, prefix, hex, values): + # Interpreter's device_print function has a different format than Triton's device_print + msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" + if prefix: + msg += f" {prefix}" + if hex: + np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"}) + for value in values: + print(msg + f" {value.data}") + if hex: + np.set_printoptions(formatter=None) + + def create_assert(self, condition, message, fileName, funcName, lineNo): + # Interpreter's device_assert function has a different format than Triton's device_assert + assert condition, f"{message} in {fileName}:{funcName}:{lineNo}" + + def create_barrier(self): + # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in offsets] + return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order) + + def create_advance(self, ptr, offsets): + if len(ptr.offsets) != len(offsets): + raise ValueError("len(ptr.offsets) != len(offsets)") + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in ptr.offsets] + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + def get_all_ones_value(self, type): + np_type = _get_np_dtype(type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar) + else: + raise TypeError(f"unsupported type {type}") + + +def _patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_builder"}, _builder=builder)) + setattr(obj, name, new_member) + + +def _patch_builtin(pkg, builder): + for name, member in inspect.getmembers(pkg): + if tl.core.is_builtin(member): + _patch_attr(pkg, name, member, builder) + + +def _patch_lang_tensor(tensor): + + def _get_bool(self): + data = self.handle.data + # in triton, only scalars can be converted to booleans + # here we need this hack because all scalars are tensors + return bool(data) if data.size == 1 else True + + def _get_transpose(self): + return tl.core.tensor(TensorHandle(np.transpose(self.handle.data), self.handle.dtype), self.dtype.scalar) + + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: _get_bool(self) + tensor.__repr__ = lambda self: repr(self.handle.data) + tensor.__str__ = lambda self: str(self.handle.data) + tensor.T = property(_get_transpose) + + +class ReduceScanOpIneterface: + + def __init__(self, axis, combine_fn): + self.axis = axis + self.combine_fn = combine_fn + + def check_axis(self, shape, axis): + if axis is not None and axis >= len(shape): + raise ValueError(f"axis {axis} out of bounds for shape {shape}") + + def check_tensor(self, input): + for arg in input: + if not isinstance(arg, tl.core.tensor): + raise ValueError(f"input must be a tensor, got {type(arg)}") + self.check_axis(arg.shape, self.axis) + + def to_tensor(self, ret, dtype): + if hasattr(ret, "shape") and ret.shape: + ret_type = tl.block_type(dtype, ret.shape) + else: + ret = np.array([ret], dtype=_get_np_dtype(dtype)) + ret_type = dtype + return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) + + def apply(self, input): + if not isinstance(input, tuple): + input = (input, ) + self.check_tensor(input) + return self.apply_impl(input) + + def apply_impl(self, input): + raise NotImplementedError("apply_impl not implemented") + + +class ReduceOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, keep_dims): + super().__init__(axis, combine_fn) + self.keep_dims = keep_dims + + def unravel(self, input, axis): + ret = [] + for data in input: + if axis is not None: + ret.append(data) + else: + axis = 0 + ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype)) + return tuple(ret), axis + + def generic_reduce(self, input): + original_axis = self.axis + input, axis = self.unravel(input, self.axis) + input_data = [] + output_data = [] + input_shape = input[0].handle.data.shape + output_shape = input_shape[0:axis] + input_shape[axis + 1:] + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype)) + # Reduce on axis + for i in range(input_data[0].size): + # Recover input_index from i using input_shape + input_index = np.unravel_index(i, input_shape) + output_index = input_index[0:axis] + input_index[axis + 1:] + input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data)) + if input_index[axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][output_index] = input_tuple[j].handle.data.item() + else: + acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + if self.keep_dims: + if original_axis is not None: + data = np.expand_dims(data, axis) + else: + for _ in range(len(input_shape)): + data = np.expand_dims(data, 0) + + elif original_axis is None: + # Take a scalar + data = data.item() + ret.append(self.to_tensor(data, input[i].dtype)) + return ret[0] if len(ret) == 1 else tuple(ret) + + def min_max(self, input, val_reduce_op, idx_reduce_op=None): + # If input is a tuple, it must be (val, index), and we only take val + input = input[0] if isinstance(input, tuple) else input + val = None + idx = None + if val_reduce_op: + val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + if idx_reduce_op: + idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32) + if val is not None and idx is not None: + return val, idx + elif val is not None: + return val + elif idx is not None: + return idx + else: + raise ValueError("val_reduce_op and idx_reduce_op are both None") + + def sum(self, input): + return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + + def apply_impl(self, input): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + elif self.combine_fn == tl.standard._elementwise_max: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input) + + +class ScanOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, reverse): + super().__init__(axis, combine_fn) + self.reverse = reverse + + def cumsum(self, input): + return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def cumprod(self, input): + return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def generic_scan(self, input): + input_data = [] + output_data = [] + shape = input[0].handle.data.shape + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype)) + # Scan on axis + for i in range(input_data[0].size): + # Recover index from i using shape + index = np.unravel_index(i, shape) + data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data)) + if index[self.axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][index] = data[j].handle.data.item() + else: + prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index))) + acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def apply_impl(self, input): + new_input = [] + if self.reverse: + for arg in input: + new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype)) + else: + new_input = input + if self.combine_fn == tl.standard._sum_combine: + ret = self.cumsum(new_input[0]) + elif self.combine_fn == tl.standard._prod_combine: + ret = self.cumprod(new_input[0]) + else: + # Fall back to the slow mode + ret = self.generic_scan(new_input) + if self.reverse: + for arg in ret: + arg.handle.data = np.flip(arg.handle.data, axis=self.axis) + return len(ret) == 1 and ret[0] or tuple(ret) + + +def _patch_reduce_scan(): + # Because interpreter doesn't support region_builder_fn, we cannot patch the builder + # to use the new reduce and scan functions. + # Instead, we need to patch reduce and reduce functions in tl and tl.core + def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs): + return ReduceOps(axis, combine_fn, keep_dims).apply(input) + + def _new_scan(input, axis, combine_fn, reverse=False, **kwargs): + return ScanOps(axis, combine_fn, reverse).apply(input) + + tl.reduce = _new_reduce + tl.associative_scan = _new_scan + tl.core.reduce = _new_reduce + tl.core.associative_scan = _new_scan + + +def _patch_lang_core(lang): + + def _new_to_ir(self, builder): + # We need to specify signedness for integer types in the numpy mode + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8': + return builder.get_int8_ty() + elif self.name == 'uint8': + return builder.get_uint8_ty() + elif self.name == 'int16': + return builder.get_int16_ty() + elif self.name == 'uint16': + return builder.get_uint16_ty() + elif self.name == 'int32': + return builder.get_int32_ty() + elif self.name == 'uint32': + return builder.get_uint32_ty() + elif self.name == 'int64': + return builder.get_int64_ty() + elif self.name == 'uint64': + return builder.get_uint64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + # can't just map lang.static_range to `range`, because `tl.static_range` + # can get `step` passed by keyword + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_static_assert(cond, msg=""): + assert cond, msg + + def _set_attr(input, values, name): + # skip non tensor types. This may happen for induction variables. + if not isinstance(input, tl.tensor): + return input + # Unwrap constexpr + values = [values] if not isinstance(values, (list, tuple)) else values + values = [v.value if isinstance(v, tl.constexpr) else v for v in values] + if len(values) != max(1, len(input.shape)): + raise ValueError(f"len(values) != len(input.shape) for {name}") + input.handle.set_attr(name, values) + return input + + lang.range = _new_range + lang.static_range = _new_range + lang.static_assert = _new_static_assert + lang.static_print = print + lang.dtype.to_ir = _new_to_ir + lang.multiple_of = partial(_set_attr, name="tt.divisiblity") + lang.max_contiguous = partial(_set_attr, name="tt.contiguity") + lang.max_constancy = partial(_set_attr, name="tt.constancy") + + _patch_reduce_scan() + + +def _patch_lang(fn): + lang = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_builtin(lang[0], interpreter_builder) + _patch_builtin(lang[0].tensor, interpreter_builder) + if lang[0] == tl: + _patch_builtin(lang[0].math, interpreter_builder) + _patch_lang_tensor(lang[0].tensor) + _patch_lang_core(lang[0]) + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + dtype = np.int32 + if -2**31 <= arg < 2**31: + dtype = np.int32 + elif 2**31 <= arg < 2**32: + dtype = np.uint32 + elif -2**63 <= arg < 2**63: + dtype = np.int64 + elif 2**63 <= arg < 2**64: + dtype = np.uint64 + else: + raise ValueError(f"Unsupported integer value {arg}") + handle = TensorHandle(np.array([arg], dtype=dtype), ty) + return tl.tensor(handle, ty) + if hasattr(arg, "data_ptr"): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + return arg + + +interpreter_builder = InterpreterBuilder() + +# These keywords are not supported by the interpreter +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] + + def _init_args_hst(self, args_dev, kwargs): + args_hst = [] + for arg in args_dev: + if hasattr(arg, "data_ptr"): + args_hst.append(arg.cpu()) + else: + args_hst.append(arg) + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + if hasattr(value, "data_ptr"): + kwargs_hst[key] = value.cpu() + else: + kwargs_hst[key] = value + return args_hst, kwargs_hst + + def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): + for arg_dev, arg_hst in zip(args_dev, args_hst): + if hasattr(arg_dev, "data_ptr"): + arg_dev.data.copy_(arg_hst.to(arg_dev.device).data) + + # Restore keyword arguments + for key, kwarg_dev in kwargs.items(): + kwarg_hst = kwargs_hst[key] + if hasattr(kwarg_dev, "data_ptr"): + kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data) + + def __call__(self, *args_dev, **kwargs): + # removes reserved keywords from kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + if kwargs.pop("warmup", False): + return + # copy arguments to the host + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + # remaps core language functions to interpreted ones + _patch_lang(self.fn) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3, "grid must have at most 3 dimensions" + grid = grid + (1, ) * (3 - len(grid)) + interpreter_builder.set_grid_dim(*grid) + try: + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + interpreter_builder.set_grid_idx(x, y, z) + self.fn(**args) + except Exception as e: + raise InterpreterError(repr(e)) from e + # copy arguments back to propagate side-effects + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + + +class InterpretedFunction: + + def __init__(self, fn) -> None: + self.fn = fn + + def run(*args, **kwargs): + grid = kwargs["grid"] + return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + @property + def __name__(self): + return self.fn.__name__ + + def __getitem__(self, grid): + return GridExecutor(self.fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + # This is a device function call + _patch_lang(self.fn) + try: + return self.fn(*args, **kwargs) + except Exception as e: + raise InterpreterError(repr(e)) from e diff --git a/third_party/iluvatar/python/triton/runtime/jit.py b/third_party/iluvatar/python/triton/runtime/jit.py new file mode 100644 index 000000000..296f92cdf --- /dev/null +++ b/third_party/iluvatar/python/triton/runtime/jit.py @@ -0,0 +1,1049 @@ +from __future__ import annotations, division +import ast +import hashlib +import inspect +import itertools +import os +import re +import textwrap +from collections import defaultdict +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple +from ..runtime.driver import driver +from types import ModuleType +from triton.runtime.build import is_corex + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + +import torch + + +def get_disable_sme(): + disable_sme = os.getenv("TRITON_DISABLE_SME", default="0") + cc = torch.cuda.get_device_capability() + cc = cc[0] * 10 + cc[1] + if cc == 70: # for ivcore10 + disable_sme = "1" + + return disable_sme + + +def get_corex_sme(args, constexpr_indices, enable_sme=True): + can_use_sme = 0 + shape_info = '' + if not enable_sme: + return can_use_sme, shape_info + import torch + if not (hasattr(torch, "corex") and torch.corex == True): + return can_use_sme, shape_info + close_sme = get_disable_sme() + if close_sme == "1": + return can_use_sme, shape_info + index = 0 + shape_info = '' + for i, arg in enumerate(args): + if i in constexpr_indices: + continue + if (isinstance(arg, int) and arg == 1): + continue + if torch.is_tensor(arg) and arg.dtype in [torch.float16, torch.float32, torch.bfloat16, torch.int8 + ] and arg.dim() >= 2: + dim_M = arg.shape[-2] + dim_K = arg.shape[-1] + shape_info += '_' + str(dim_M) + '_' + str(dim_K) + if dim_M == 1 or dim_K == 1: + index += 1 + continue + sme_dim = 64 / arg.element_size() + if (arg.is_contiguous() and dim_K % sme_dim == 0) or \ + (not arg.is_contiguous() and dim_M % sme_dim == 0): + can_use_sme = (1 << index) | can_use_sme + index += 1 + return can_use_sme, shape_info + + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + val = self.globals.get(node.id, None) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) != ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins # + ): + self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + + def is_triton_builtin(func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + func = self.visit(node.func) + assert func is None or is_triton_builtin(func) or isinstance( + func, JITFunction + ), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this' + + # Traverse arguments as well as node.func so we can find JITFunctions + # passed to tl.reduce or tl.associative_scan as the combine_fn + for obj in itertools.chain( + (func, ), + map(self.visit, node.args), + (self.visit(kw.value) for kw in node.keywords), + ): + if not isinstance(obj, JITFunction): + continue + if is_triton_builtin(obj): + continue + + func_cache_key = obj.cache_key + + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & obj.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = obj.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + + self.used_global_vals.update(obj.used_global_vals) + + noinline = str(getattr(obj, "noinline", False)) + + key = func_cache_key + noinline + self.hasher.update(key.encode("utf-8")) + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + if isinstance(ty, type): + return ty.__name__ + elif isinstance(ty, str): + return ty + return repr(ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self): + annotation = self.annotation + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +def compute_spec_key(v): + + if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): + return "D" + elif isinstance(v, int): + # bool is a subclass of int, so we don't check explicitly above. + if (v % 16 == 0): + return "D" + elif v == 1: + return "1" + return "N" + + +dtype2str = {} + + +def mangle_type(arg, is_const=False): + + if arg is None: + return "none" + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + else: + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] + dtype2str[dsk] = res + return res + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': + options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + + assert len(sig.parameters) == len(kparams) + + # Create the function argument list and the dict entries for the return statement + func_args = [] + dict_entries = [] + constexpr_vals = [] + non_constexpr_vals = [] + signature_types = [] + specialisations = [] + + for ((name, sp), kp) in zip(sig.parameters.items(), kparams): + if sp.default is inspect.Parameter.empty: + func_args.append(name) + dict_entries.append(f"'{name}': {name}") + else: + func_args.append(f"{name}=default_{name}") + dict_entries.append(f"'{name}': {name}") + if kp.is_constexpr: + constexpr_vals.append(name) + else: + non_constexpr_vals.append(name) + if not kp.do_not_specialize: + specialisations.append('compute_spec_key(%s)' % name) + if kp.annotation_type: + signature_types.append('"%s"' % kp.annotation_type) + else: + signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False')) + + cache_key = ''.join([x + ', ' for x in signature_types + specialisations]) + constexpr_vals = ''.join([x + ', ' for x in constexpr_vals]) + non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals]) + + func_args.append('**excess_kwargs') + + # Join all arguments into a function definition string + args_str = ', '.join(func_args) + dict_str = ', '.join(dict_entries) + func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % ( + args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals) + + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace['mangle_type'] = mangle_type + func_namespace['compute_spec_key'] = compute_spec_key + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +type_canonicalisation_dict = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +class JITFunction(KernelInterface[T]): + # Hook for inspecting compiled functions and modules + cache_hook = None + divisibility = 16 + divisibility_8 = 8 + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif arg is None: + return None + else: + raise TypeError(f"Unsupported type {type(arg)} for {arg}") + + @staticmethod + def _device_of(arg): + if hasattr(arg, "device") and hasattr(arg.device, "type"): + return arg.device.type + else: + return "" + + @staticmethod + def _pinned_memory_of(arg): + if hasattr(arg, "is_pinned") and callable(arg.is_pinned): + return arg.is_pinned() + else: + return False + + @staticmethod + def _spec_of(arg): + if hasattr(arg, "data_ptr"): + return arg.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(arg, int): + return (arg % 16 == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) + return (arg is None, ) + + def _get_config(self, *args): + from ..compiler import AttrsDescriptor + + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(x, int): + return x % JITFunction.divisibility == 0 + if x is None: + return True + return False + + def is_divisible_by_8(x): + if isinstance(x, int): + if is_corex(): + return x % 64 == 0 + return x % JITFunction.divisibility_8 == 0 + if x is None: + return True + return False + + divisible_by_16 = { + param.num + for param, arg in zip(self.params, args) + if is_divisible_by_16(arg) and not param.do_not_specialize + } + divisible_by_8 = { + param.num + for param, arg in zip(self.params, args) + if is_divisible_by_8(arg) and not param.do_not_specialize + } + equal_to_1 = { + param.num + for param, arg in zip(self.params, args) + if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize + } + # folded equal_to_1 and None + # TODO: method to collect all folded args + return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1), tuple(divisible_by_8)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, + # equal_to_1) + + @staticmethod + def _type_of(key, is_const=False): + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + elif isinstance(key, str): + return key + + dtype_str = str(key).split(".")[-1] + dtype_str = type_canonicalisation_dict[dtype_str] + const_str = "*k" if is_const else "*" + return const_str + dtype_str + + def _make_constants(self, constexpr_key): + constants = dict(zip(self.constexprs, constexpr_key)) + return constants + + def _call_hook( + self, + key, + signature, + device, + constants, + options, + configs, + ): + if JITFunction.cache_hook is None: + return False + + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" + + class JitFunctionInfo: + + def __init__(self, module, name, jit_function): + self.module = module + self.name = name + self.jit_function = jit_function + pass + + specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + } + + return JITFunction.cache_hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=False, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + self.make_backend = make_backend + self.binder = create_function_from_signature(self.signature, self.params) + self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] + self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] + self.specialised_indices = [ + i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr) + ] + + def run(self, *args, grid, warmup, **kwargs): + # parse options + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + kwargs["debug"] = self.debug + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + if self.binder is None: + self.create_binder() + + bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) + + # compute cache key + target = driver.active.get_current_target() + backend = self.make_backend(target) + options = backend.parse_options(kwargs) + options.use_sme, shape_info = get_corex_sme(args, self.constexpr_indices, options.enable_sme) + if not shape_info: + for arg in args: + if torch.is_tensor(arg): + shape_info += '_' + '_'.join(str(_) for _ in list(arg.shape)) + key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) + str((options.use_sme, shape_info)) + kernel = self.cache[device].get(key, None) + + pinned_memory_flags = [self._pinned_memory_of(arg) for arg in args] + device_types = [self._device_of(arg) for arg in args] + device_types = [_device_type for _device_type in device_types if _device_type != ""] + is_cpu = device_types and all(device_type == "cpu" for device_type in device_types) + is_pinned_memory = any(pinned_memory_flag for pinned_memory_flag in pinned_memory_flags) + if is_cpu and not is_pinned_memory: + raise ValueError("Cannot find backend for cpu") + + if kernel is None: + # Kernel is not cached; we have to compile. + + # deprecated arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: + if k not in options.__dict__: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + + bound_vals = tuple(bound_args.values()) + + # `None` is nullptr. Implicitly convert to *i8. This needs to be + # done here rather than when we build the signature as otherwise + # the kernel cache key could not distinguish between byte pointers + # and None arguments, resulting in a downstream mismatch: + sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigvals = sig_and_spec[:len(sigkeys)] + signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + + configs = (self._get_config(*bound_vals), ) + constants = { + p.name: v + for (v, p) in zip(bound_vals, self.params) + if p.is_constexpr or p.num in configs[0].equal_to_1 or v is None + } + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + + if self._call_hook(key, signature, device, constants, options, configs): + return None + # compile the kernel + src = self.ASTSource(self, signature, constants, configs[0]) + kernel = self.compile( + src, + target=target, + options=options.__dict__, + ) + self.cache[device][key] = kernel + + # Check that used global values have not changed. + not_present = object() + for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) + return kernel + + def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None, repr=None, + launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.starting_line_number = inspect.getsourcelines(fn)[1] + self.repr = lambda _: fn.__name__ if repr is None else repr(_) + self.launch_metadata = launch_metadata + + self.binder = None + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize) + self.params.append(KernelParam(i, param, dns)) + + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():] + # cache of just-in-time compiled kernels + self.cache = defaultdict(dict) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + # use to record fn cache files + self.hash_cache_file = None + self.so_path = None + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler import AttrsDescriptor, compile, ASTSource + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self.fn.__name__: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in deserialized_obj['constants'].items() + } + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.cache[device][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == "src": + self.hash = None + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from .interpreter import InterpretedFunction + return InterpretedFunction(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, i): + return self.base.stride(i) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") diff --git a/third_party/iluvatar/python/triton/testing.py b/third_party/iluvatar/python/triton/testing.py new file mode 100644 index 000000000..d1806e570 --- /dev/null +++ b/third_party/iluvatar/python/triton/testing.py @@ -0,0 +1,500 @@ +import functools +import os +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl +from triton.runtime.build import is_corex + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + """ + import torch + assert return_mode in ["min", "max", "mean", "median"] + + if torch.cuda.current_stream() == torch.cuda.default_stream(): + raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.") + # warmup + fn() + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for i in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for i in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + times = torch.tensor(ret) + return getattr(torch, return_mode)(times).item() + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", + device_type="cuda"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + """ + assert return_mode in ["min", "max", "mean", "median"] + import torch + + di = torch._dynamo.device_interface.get_interface_for_device(device_type) + + fn() + di.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device_type) + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device=device_type) + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + color=None, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + html.close() + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + capability = torch.cuda.get_device_capability() + if capability[0] == 8: + mem_clock_khz = 1800000 + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 or is_corex() + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/iluvatar/python/triton/tools/__init__.py b/third_party/iluvatar/python/triton/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/iluvatar/python/triton/tools/build_extern.py b/third_party/iluvatar/python/triton/tools/build_extern.py new file mode 100644 index 000000000..8f0168d59 --- /dev/null +++ b/third_party/iluvatar/python/triton/tools/build_extern.py @@ -0,0 +1,365 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + + +class Symbol: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) + + @property + def name(self) -> str: + return self._name + + @property + def op_name(self) -> str: + return self._op_name + + @property + def ret_type(self) -> str: + return self._ret_type + + @property + def arg_names(self) -> List[str]: + return self._arg_names + + @property + def arg_types(self) -> List[str]: + return self._arg_types + + +def convert_type(type_str) -> Optional[str]: + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str) -> str: + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = format + self._grouping = grouping + + @property + def name(self) -> str: + return self._name + + @property + def path(self) -> str: + return self._path + + @property + def symbols(self) -> Dict[str, Symbol]: + return self._symbols + + @property + def grouping(self) -> bool: + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file) -> None: + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir) -> None: + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + self.is_pure = True + + @staticmethod + def _extract_symbol(line) -> Optional[Symbol]: + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self) -> None: + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file) -> None: + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self) -> str: + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return core.extern_elementwise("libdevice", , , , _builder) + import_str = "from . import core\n" + + header_str = "" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@core.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += f", is_pure={self.is_pure}" + return_str += ", _builder=_builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + _path: str + _ll_file: str + + def __init__(self, path) -> None: + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path: str) -> None: + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self) -> str: + return self._ll_file + + @property + def path(self) -> str: + return self._path + + +extern_libs = ["libdevice"] + + +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/third_party/iluvatar/python/triton/tools/compile.c b/third_party/iluvatar/python/triton/tools/compile.c new file mode 100644 index 000000000..971bf6191 --- /dev/null +++ b/third_party/iluvatar/python/triton/tools/compile.c @@ -0,0 +1,67 @@ +/* clang-format off */ +#include +#include +#include +#include +#include + + +// helpers to check for cuda errors +#define CUDA_CHECK(ans) {{\ + gpuAssert((ans), __FILE__, __LINE__);\ + }}\ + +static inline void gpuAssert(CUresult code, const char *file, int line) {{ + if (code != CUDA_SUCCESS) {{ + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + printf("%s\\n", err); + exit(code); + }} +}} + +// globals +#define CUBIN_NAME {kernel_name}_cubin +CUmodule {kernel_name}_mod = NULL; +CUfunction {kernel_name}_func = NULL; +unsigned char CUBIN_NAME[{bin_size}] = {{ {bin_data} }}; + + +void unload_{kernel_name}(void) {{ + CUDA_CHECK(cuModuleUnload({kernel_name}_mod)); +}} + +// TODO: some code duplication with `runtime/backend/cuda.c` +void load_{kernel_name}() {{ + int dev = 0; + void *bin = (void *)&CUBIN_NAME; + int shared = {shared}; + CUDA_CHECK(cuModuleLoadData(&{kernel_name}_mod, bin)); + CUDA_CHECK(cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}")); + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev)); + if (shared > 49152 && shared_optin > 49152) {{ + CUDA_CHECK(cuFuncSetCacheConfig({kernel_name}_func, CU_FUNC_CACHE_PREFER_SHARED)); + CUDA_CHECK(cuFuncSetAttribute({kernel_name}_func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)) + }} +}} + +/* +{kernel_docstring} +*/ +CUresult {kernel_name}(CUstream stream, {signature}) {{ + if ({kernel_name}_func == NULL) + load_{kernel_name}(); + unsigned int gX = {gridX}; + unsigned int gY = {gridY}; + unsigned int gZ = {gridZ}; + void *args[{num_args}] = {{ {arg_pointers} }}; + // TODO: shared memory + if(gX * gY * gZ > 0) + return cuLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * 32, 1, 1, {shared}, stream, args, NULL); +}} diff --git a/third_party/iluvatar/python/triton/tools/compile.h b/third_party/iluvatar/python/triton/tools/compile.h new file mode 100644 index 000000000..d98b7063b --- /dev/null +++ b/third_party/iluvatar/python/triton/tools/compile.h @@ -0,0 +1,14 @@ +#ifndef TT_KERNEL_INCLUDES +#define TT_KERNEL_INCLUDES + +#include +#include +#include +#include + +#endif + +void unload_{kernel_name}(void); +void load_{kernel_name}(void); +// tt-linker: {kernel_name}:{full_signature}:{algo_info} +CUresult{_placeholder} {kernel_name}(CUstream stream, {signature}); diff --git a/third_party/iluvatar/python/triton/tools/compile.py b/third_party/iluvatar/python/triton/tools/compile.py new file mode 100644 index 000000000..872332b03 --- /dev/null +++ b/third_party/iluvatar/python/triton/tools/compile.py @@ -0,0 +1,145 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from pathlib import Path +from typing import List + +import triton +from triton.compiler.code_generator import kernel_suffix +from triton.backends.nvidia.driver import ty_to_cpp + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + +if __name__ == "__main__": + + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) + parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") + parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") + parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) + args = parser.parse_args() + + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {k: v for k, v in hints.items() if v is not None} + constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + divisible_by_16 = [i for i, h in hints.items() if h == 16] + equal_to_1 = [i for i, h in hints.items() if h == 1] + attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + for i in equal_to_1: + constants.update({i: 1}) + src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, options=opts) + arg_names = [] + arg_types = [] + for i in signature.keys(): + if i not in equal_to_1: + arg_names += [kernel.arg_names[i]] + arg_types += [signature[i]] + + # dump C stub code + suffix = kernel_suffix(signature.values(), attrs) + func_name = '_'.join([out_name, sig_hash, suffix]) + hex_ = str(binascii.hexlify(ccinfo.asm["cubin"]))[2:-1] + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(hex_), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), + "num_args": len(arg_names), + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": '_'.join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + } + for ext in ['h', 'c']: + template_path = Path(__file__).parent / f"compile.{ext}" + with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp: + fp.write(Path(template_path).read_text().format(**params)) diff --git a/third_party/iluvatar/python/triton/tools/disasm.py b/third_party/iluvatar/python/triton/tools/disasm.py new file mode 100644 index 000000000..1e309a2e4 --- /dev/null +++ b/third_party/iluvatar/python/triton/tools/disasm.py @@ -0,0 +1,142 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import functools +import os +import re +import subprocess +import tempfile + +from ..common.backend import path_to_cuobjdump, path_to_nvdisasm + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) is not None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +@functools.lru_cache() +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + +def extract(file_path, fun): + cuobjdump, _ = path_to_cuobjdump() + nvdisasm, _ = path_to_nvdisasm() + os.environ["NVDISASM_PATH"] = nvdisasm + if fun is None: + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) + else: + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + + # Looking for new function header (function: ) + while FNAME_RE.match(line) is None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) is not None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convention: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret diff --git a/third_party/iluvatar/python/triton/tools/link.py b/third_party/iluvatar/python/triton/tools/link.py new file mode 100644 index 000000000..75a1157a5 --- /dev/null +++ b/third_party/iluvatar/python/triton/tools/link.py @@ -0,0 +1,322 @@ +from collections import defaultdict +from pathlib import Path +from typing import Sequence, Union + +from dataclasses import dataclass + + +def _exists(x): + return x is not None + + +class LinkerError(Exception): + pass + + +@dataclass +class KernelLinkerMeta: + orig_kernel_name: str + arg_names: Sequence[str] + arg_ctypes: Sequence[str] + sizes: Sequence[Union[int, None]] + sig_hash: str + triton_suffix: str + suffix: str + num_specs: int + """ number of specialized arguments """ + + +class HeaderParser: + + def __init__(self) -> None: + import re + + # [kernel_name, c signature] + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + # [name, hash, suffix] + self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") + # [(type, name)] + self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") + # [d|c] + self.arg_suffix = re.compile("[c,d]") + + self.kernels = defaultdict(list) + + def extract_linker_meta(self, header: str): + for ln in header.splitlines(): + if ln.startswith("//"): + m = self.linker_directives.match(ln) + if _exists(m): + ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3) + name, sig_hash, suffix = self._match_name(ker_name) + c_types, arg_names = self._match_c_sig(c_sig) + num_specs, sizes = self._match_suffix(suffix, c_sig) + self._add_kernel( + "_".join([name, algo_info]), + KernelLinkerMeta( + orig_kernel_name=name, + arg_names=arg_names, + arg_ctypes=c_types, + sizes=sizes, + sig_hash=sig_hash, + triton_suffix=suffix, + suffix=suffix, + num_specs=num_specs, + ), + ) + + def _match_name(self, ker_name: str): + m = self.kernel_name.match(ker_name) + if _exists(m): + name, sig_hash, suffix = m.group(1), m.group(2), m.group(3) + return name, sig_hash, suffix + raise LinkerError(f"{ker_name} is not a valid kernel name") + + def _match_c_sig(self, c_sig: str): + m = self.c_sig.findall(c_sig) + if len(m): + tys, args = [], [] + for ty, arg_name in m: + tys.append(ty) + args.append(arg_name) + return tys, args + + raise LinkerError(f"{c_sig} is not a valid argument signature") + + def _match_suffix(self, suffix: str, c_sig: str): + args = c_sig.split(",") + s2i = {"c": 1, "d": 16} + num_specs = 0 + sizes = [] + # scan through suffix, first find the index, + # then see if it is followed by d or c + for i in range(len(args)): + pos = suffix.find(str(i)) + if pos == -1: + raise LinkerError(f"{suffix} is not a valid kernel suffix") + pos += len(str(i)) + if self.arg_suffix.match(suffix, pos): + num_specs += 1 + sizes.extend([None] * (i - len(sizes))) + sizes.append(s2i[suffix[pos]]) + pos += 1 + if i < len(args) - 1: + suffix = suffix[pos:] + else: + sizes.extend([None] * (len(args) - len(sizes))) + return num_specs, sizes + + def _add_kernel(self, name: str, ker: KernelLinkerMeta): + if name in self.kernels: + last: KernelLinkerMeta = self.kernels[name][-1] + + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): + if cur != new_: + raise LinkerError( + f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" + ) + + self.kernels[name].append(ker) + + +def gen_signature_with_full_args(m): + return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)]) + + +def gen_signature(m): + arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1] + arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1] + sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)]) + return sig + + +# generate declarations of kernels with meta-parameter and constant values +def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + return f""" +CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}); +void load_{name}(); +void unload_{name}(); + """ + + +# generate declarations of kernels with meta-parameter and constant values +def make_global_decl(meta: KernelLinkerMeta) -> str: + return f""" +CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}); +CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id); +void load_{meta.orig_kernel_name}(); +void unload_{meta.orig_kernel_name}(); + """ + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different integer value hints +def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + src = f"// launcher for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" + src += "\n" + + src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{") + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + cond_fn = ( # + lambda val, hint: f"({val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required + arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] + src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" + src += "\n" + src += " return CUDA_ERROR_INVALID_VALUE;\n" + src += "}\n" + + for mode in ["load", "unload"]: + src += f"\n// {mode} for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f"void {mode}_{name}() {{" + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" + src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" + src += "}\n" + return src + + +# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values +def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: + # the table of hint dispatchers + src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n" + src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" + for name in names: + src += f" {name},\n" + src += "};\n" + return src + + +# generate definition for load/unload functions for kernels with different meta-parameter and constant values +def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str: + src = "" + for mode in ["load", "unload"]: + src += f"void {mode}_{meta.orig_kernel_name}(void){{\n" + for name in names: + src += f" {mode}_{name}();\n" + src += "}\n\n" + return src + + +def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void);" + return src + + +def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n" + src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n" + src += "}\n" + return src + + +desc = """ +Triton ahead-of-time linker: + +This program takes in header files generated by compile.py, and generates a +single entry-point responsible for dispatching the user's input to the right +kernel given the specializations that were compiled. + +Example usage: +python link.py /path/to/headers/*.h -o kernel_name +""" + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser(description=desc) + parser.add_argument( + "headers", + nargs="+", + help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", + ) + parser.add_argument("--out", "-o", type=Path, help="Out filename") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) + args = parser.parse_args() + + # metadata + parser = HeaderParser() + includes = [] + for header in args.headers: + h_path = Path(header) + h_str = h_path.read_text() + includes.append(h_path.name) + parser.extract_linker_meta(h_str) + + # generate headers + algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()] + meta_lists = [meta for name, meta in parser.kernels.items()] + meta = meta_lists[0][0] + get_num_algos_decl = make_get_num_algos_decl(meta) + global_decl = make_global_decl(meta) + with args.out.with_suffix(".h").open("w") as fp: + out = "#include \n" + out += "\n".join(algo_decls) + out += "\n" + out += get_num_algos_decl + out += "\n" + out += global_decl + fp.write(out) + + # generate source + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + names = [name for name in parser.kernels.keys()] + func_pointers_def = make_func_pointers(names, meta) + meta_const_def = make_kernel_meta_const_dispatcher(meta) + load_unload_def = make_kernel_load_def(names, meta) + get_num_algos_def = make_get_num_algos_def(meta) + default_algo_kernel = make_default_algo_kernel(meta) + with args.out.with_suffix(".c").open("w") as fp: + out = "" + out += "#include \n" + out += "#include \n" + out += "#include \n" + out += "\n" + out += "\n".join(defs) + out += "\n" + out += func_pointers_def + out += "\n" + out += get_num_algos_def + out += "\n" + out += meta_const_def + out += "\n" + out += load_unload_def + out += "\n" + out += default_algo_kernel + fp.write(out) diff --git a/third_party/mthreads/CMakeLists.txt b/third_party/mthreads/CMakeLists.txt new file mode 100644 index 000000000..6cbb306f5 --- /dev/null +++ b/third_party/mthreads/CMakeLists.txt @@ -0,0 +1,24 @@ +add_subdirectory(include) +add_subdirectory(lib) + +if(TRITON_BUILD_PYTHON_MODULE) + if(FLAGTREE_PLUGIN) + add_subdirectory(plugin) + add_triton_plugin(TritonMTHREADS + SHARED_LIB mthreadsTritonPlugin + ) + else() + find_library(mthreadsTritonPluginLib + NAMES + mthreadsTritonPlugin.so + PATHS + ${CMAKE_CURRENT_SOURCE_DIR} + REQUIRED + ) + add_triton_plugin(TritonMTHREADS + SHARED_LIB ${mthreadsTritonPluginLib} + ) + endif() +endif() + +add_subdirectory(bin) diff --git a/third_party/mthreads/backend/__init__.py b/third_party/mthreads/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/mthreads/backend/compiler.py b/third_party/mthreads/backend/compiler.py new file mode 100644 index 000000000..1f80fc56e --- /dev/null +++ b/third_party/mthreads/backend/compiler.py @@ -0,0 +1,234 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes, llvm, mthreads + +from dataclasses import dataclass +import functools +from typing import Any, Tuple, Optional +import hashlib +import re +import tempfile +import signal +import os +import subprocess +from pathlib import Path +import shutil + + +def get_kernel_name(src: str, pattern: str) -> str: + assert src + for line in src.split('\n'): + line = line.strip() + if line.startswith(pattern): + return line.split()[-1] + + +@functools.lru_cache() +def get_musa_version(): + version = subprocess.check_output(["/usr/local/musa/bin/musa_toolkits_version"]).decode("utf-8") + return version + + +@functools.lru_cache(None) +def file_hash(path): + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +@dataclass(frozen=True) +class MUSAOptions: + num_warps: int = 4 + num_ctas: int = 1 + num_stages: int = 3 + # maxnreg corresponds to the ptx parameter .maxnreg, which controls the + # maximum number of 32-bit registers used by one thread. + maxnreg: Optional[int] = None + cluster_dims: tuple = (1, 1, 1) + capability: int = None + enable_fp_fusion: bool = True + allow_fp8e4nv: bool = False + allow_fp8e4b15: bool = False + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: bool = None + extern_libs: dict = None + debug: bool = False + backend_name: str = 'musa' + + def __post_init__(self): + extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get('libdevice', None): + if self.capability >= 31: + default_libdir = "/usr/local/musa/mtgpu/bitcode/libdevice.31.bc" + else: + default_libdir = "/usr/local/musa/mtgpu/bitcode/libdevice.bc" + # here we add an new ENV: MUSA_LIBDEVICE_PATH for MUSA, + # which represents the path of libdevice.bc + musa_env_path = os.environ.get("MUSA_LIBDEVICE_PATH", default_libdir) + extern_libs['libdevice'] = musa_env_path + object.__setattr__(self, 'extern_libs', tuple(extern_libs.items())) + assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \ + "num_warps must be a power of 2" + + def hash(self): + hash_dict = dict(self.__dict__) + hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class MUSABackend(BaseBackend): + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == 'musa' + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + self.capability = target.arch + self.warp_size = target.warp_size + assert isinstance(self.capability, int) + self.binary_ext = "mubin" + + def parse_options(self, opts) -> Any: + opts["capability"] = self.capability + opts["allow_fp8e4nv"] = self.capability >= 31 + args = {k: opts[k] for k in MUSAOptions.__dataclass_fields__.keys() if k in opts} + return MUSAOptions(**args) + + def pack_metadata(self, metadata): + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + def get_codegen_implementation(self): + import triton.language.extra.musa as musa + codegen_fns = { + "convert_custom_types": None, + } + return codegen_fns + + def load_dialects(self, ctx): + mthreads.load_dialects(ctx) + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttgir(mod, metadata, opt, capability, warp_size): + # TTIR -> TTGIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttir.add_convert_to_ttgpuir(pm, f"musa:{capability}", opt.num_warps, warp_size, opt.num_ctas) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.common.add_cse(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + return mod + + @staticmethod + def make_llir(src, metadata, options, capability): + # warp-specialization mutates num_warps + num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + if num_warp_groups is not None: + metadata["num_warps"] *= num_warp_groups + mod = src + # TritonGPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + passes.ttgpuir.add_allocate_shared_memory(pm) + mthreads.passes.ttgpuir.add_to_llvmir(pm, capability) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + + # FIXME: shall we consider to use load/store with robust instruction to support ld/st with predicate + mthreads.passes.ttgpuir.add_mtgpu_builtin_func_to_llvmir(pm) + pm.run(mod) + + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + + llvm_mod = llvm.to_module(mod, context) + mthreads.attach_datalayout(llvm_mod) + + if options.extern_libs: + paths = [path for (name, path) in options.extern_libs] + llvm.link_extern_libs(llvm_mod, paths) + + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) + + # Get some metadata + metadata["shared"] = src.get_int_attr("triton_gpu.shared") + ret = str(llvm_mod) + del llvm_mod + del context + return ret + + @staticmethod + def make_mubin(src, metadata, opt, capability): + ''' + Translate TritonGPU module to MUSA binary code. + ''' + if (os.environ.get("LLVM_IR_ENABLE_DUMP", "0") == "1"): + print("// -----// LLVM IR") + print(src) + + opt_option = "-mtgpu-enable-const-calc=1" + if (os.environ.get("MUSA_ENABLE_LLC_OPT", "0") == "1"): + opt_option = "-mtgpu-opt-level=1" + + ret = mthreads.translate_llvmir_to_mubin(src, opt_option, capability, 0) + if (os.environ.get("MUSA_ASM_ENABLE_DUMP", "0") == "1"): + print("// -----// MTGPU ASM") + print(ret[0]) + + mubin_save_path = os.environ.get("MUBIN_SAVE_PATH", "") + if mubin_save_path != "": + mubin_file_name = os.path.join(mubin_save_path, "test.out") + shutil.copy2(ret[1], mubin_file_name) + + metadata["name"] = get_kernel_name(ret[0], pattern='.globl') + return ret + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability, self.warp_size) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability) + stages["mubin"] = lambda src, metadata: self.make_mubin(src, metadata, options, self.capability) + + @functools.lru_cache() + def hash(self): + version = get_musa_version() + return f'{version}-{self.capability}' diff --git a/third_party/mthreads/backend/driver.c b/third_party/mthreads/backend/driver.c new file mode 100644 index 000000000..09d4a9b70 --- /dev/null +++ b/third_party/mthreads/backend/driver.c @@ -0,0 +1,177 @@ +#include "musa.h" +#include +#include +#define PY_SSIZE_T_CLEAN +#include + +// Raises a Python exception and returns false if code is not MUSA_SUCCESS. +static bool gpuAssert(MUresult code, const char *file, int line) { + if (code == MUSA_SUCCESS) + return true; + + const char *prefix = "Triton Error [MUSA]: "; + const char *str; + muGetErrorString(code, &str); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + return false; +} + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define MUSA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +// Used to check if functions exist in old MUSA driver versions. +#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ + do { \ + if ((funcPointer) == NULL) { \ + (funcPointer) = (initializerFunction)(); \ + if ((funcPointer) == NULL) { \ + return NULL; \ + } \ + } \ + } while (0) + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + MUdevice device; + muDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem; + int max_num_regs; + int multiprocessor_count; + int warp_size; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &max_shared_mem, MU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &max_num_regs, MU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &multiprocessor_count, MU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + MUSA_CHECK_AND_RETURN_NULL( + muDeviceGetAttribute(&warp_size, MU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &sm_clock_rate, MU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &mem_clock_rate, MU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + MUSA_CHECK_AND_RETURN_NULL(muDeviceGetAttribute( + &mem_bus_width, MU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + MUfunction fun; + MUmodule mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + // create driver handles + MUcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muCtxGetCurrent(&pctx)); + if (!pctx) { + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muDevicePrimaryCtxRetain(&pctx, device)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muCtxSetCurrent(pctx)); + } + + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muModuleLoadData(&mod, data)); + // MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muModuleLoad(&mod, data)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muFuncGetAttribute(&n_regs, MU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muFuncGetAttribute(&n_spills, MU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + // set dynamic shared memory if necessary + int shared_optin; + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muDeviceGetAttribute( + &shared_optin, MU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + // supported based on QY2, PH1 is ok here + if (shared > 73728 && shared_optin > 73728) { + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muFuncSetCacheConfig(fun, MU_FUNC_CACHE_PREFER_SHARED)); + int shared_total, shared_static; + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muDeviceGetAttribute( + &shared_total, MU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(muFuncGetAttribute( + &shared_static, MU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + MUSA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + muFuncSetAttribute(fun, MU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static)); + } + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided mubin into MUSA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "musa_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_musa_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/mthreads/backend/driver.py b/third_party/mthreads/backend/driver.py new file mode 100644 index 000000000..efb718e9f --- /dev/null +++ b/third_party/mthreads/backend/driver.py @@ -0,0 +1,472 @@ +import functools +import os +import hashlib +import subprocess +import tempfile +import shutil +import sysconfig +from pathlib import Path +from triton.runtime.cache import get_cache_manager +from triton.backends.compiler import GPUTarget +from triton.backends.driver import GPUDriver + +dirname = os.path.dirname(os.path.realpath(__file__)) + + +@functools.lru_cache() +def musa_home_dir(): + return os.getenv("MUSA_HOME", default="/usr/local/musa") + + +@functools.lru_cache() +def musa_include_dir(): + musa_home = musa_home_dir() + return os.path.join(musa_home, "include") + + +@functools.lru_cache() +def libmusa_dirs(): + musa_home = musa_home_dir() + return os.path.join(musa_home, "lib") + + +def _build(name, src, srcdir): + musa_lib_dir = libmusa_dirs() + mu_include_dir = musa_include_dir() + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + # `musa.h` has some fucking issues recently, which introduce c++ style code into `musa.h`, we have to use `g++` but not `gcc` until these issues fixed by musa team. + gcc = shutil.which("g++") + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + + cc_cmd = [ + cc, src, "-O3", f"-I{mu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", + f"-L{musa_lib_dir}", "-lmusa", "-o", so + ] + # cc_cmd += [f"-L{dir}" for dir in musa_lib_dir] + ret = subprocess.check_call(cc_cmd) + + if ret == 0: + return so + # Backup source file and cmd. + dst = os.path.join("/tmp", os.path.basename(src)) + with open(dst, 'w') as f: + f.write("// " + ' '.join(cc_cmd) + "\n" + open(src).read()) + raise RuntimeError(f"Failed to compile stub for {name}. Source file and compile cmd backup to {dst}") + ''' + # fallback on setuptools + extra_compile_args = [] + library_dirs = musa_lib_dir + include_dirs = [srcdir, mu_include_dir] + libraries = ['cuda'] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) + return so + ''' + + +def compile_module_from_src(src, name): + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ------------------------ +# Utils +# ------------------------ + + +class MusaUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(MusaUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "musa_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "MUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def make_launcher(constants, signature, ids, warp_size): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiKKOOOO" + args_format + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + # generate glue code + params = [i for i in signature.keys() if i not in constants] + src = f""" +#include \"musa.h\" +#include +#include +#include + +static inline void gpuAssert(MUresult code, const char *file, int line) +{{ + if (code != MUSA_SUCCESS) + {{ + const char* prefix = "Triton Error [CUDA]: "; + const char* str; + muGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); + }} +}} + +#define MUSA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +typedef MUresult (*muLaunchKernelEx_t)(const MUlaunchConfig* config, MUfunction f, void** kernelParams, void** extra); + +static muLaunchKernelEx_t getLaunchKernelExHandle() {{ + // Open the shared library + void* handle = dlopen("libmusa.so", RTLD_LAZY); + if (!handle) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libmusa.so"); + return NULL; + }} + // Clear any existing error + dlerror(); + muLaunchKernelEx_t muLaunchKernelExHandle = (muLaunchKernelEx_t)dlsym(handle, "muLaunchKernelEx"); + // Check for errors + const char *dlsym_error = dlerror(); + if (dlsym_error) {{ + PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve muLaunchKernelEx from libmusa.so"); + return NULL; + }} + return muLaunchKernelExHandle; +}} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, MUstream stream, MUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in params)} }}; + if (gridX*gridY*gridZ > 0) {{ + if (num_ctas == 1) {{ + MUSA_CHECK(muLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0)); + }} else {{ + MUlaunchAttribute launchAttr[2]; + launchAttr[0].id = MU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + launchAttr[1].id = MU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = MU_CLUSTER_SCHEDULING_POLICY_SPREAD; + MUlaunchConfig config; + config.gridDimX = gridX * clusterDimX; + config.gridDimY = gridY * clusterDimY; + config.gridDimZ = gridZ * clusterDimZ; + config.blockDimX = {warp_size} * num_warps; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared_memory; + config.hStream = stream; + config.attrs = launchAttr; + config.numAttrs = 2; + static muLaunchKernelEx_t muLaunchKernelExHandle = NULL; + if (muLaunchKernelExHandle == NULL) {{ + muLaunchKernelExHandle = getLaunchKernelExHandle(); + }} + MUSA_CHECK(muLaunchKernelExHandle(&config, function, params, 0)); + }} + }} +}} + +typedef struct _DevicePtrInfo {{ + MUdeviceptr dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + uint64_t dev_ptr; + int status = muPointerGetAttribute(&dev_ptr, MU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == MUSA_ERROR_INVALID_VALUE) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} + ptr_info.dev_ptr = dev_ptr; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; + }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (MUstream)_stream, (MUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; + }} + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +class MusaLauncher(object): + + def __init__(self, src, metadata): + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + src = make_launcher(constants, signature, ids, metadata.target.warp_size) + mod = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class MusaDriver(GPUDriver): + + def __init__(self): + super().__init__() + self.utils = MusaUtils() # TODO: make static + self.launcher_cls = MusaLauncher + + self.get_device_capability = self._get_device_capability + self.get_current_stream = self._get_current_stream + self.get_current_device = self._get_current_device + self.set_current_device = self._set_current_device + + def _get_device_capability(self, device): + return torch_musa.get_device_capability(device) + + def _get_current_stream(self, idx): + try: + # return torch_musa._MUSAC._musa_getCurrentStream(idx) + return torch_musa._MUSAC._musa_getCurrentRawStream(idx) + except ImportError: + return torch_musa.current_stream(idx).musa_stream + + def _get_current_device(self): + """ + Get current device + """ + return torch_musa.current_device() + + def _set_current_device(self, device): + """ + Set current device as the given device + """ + torch_musa.set_device(device) + + def get_current_target(self): + device = self.get_current_device() + warp_size = 128 + capability = self.get_device_capability(device) + if capability[0] > 2: + warp_size = 32 + capability = capability[0] * 10 + capability[1] + return GPUTarget("musa", capability, warp_size) + + @staticmethod + def is_active(): + try: + import torch + import torch_musa + return torch.musa.is_available() + except: + return False + + +if MusaDriver.is_active(): + import torch + import torch_musa diff --git a/third_party/mthreads/backend/musa_testing.py b/third_party/mthreads/backend/musa_testing.py new file mode 100644 index 000000000..b8cd8f42c --- /dev/null +++ b/third_party/mthreads/backend/musa_testing.py @@ -0,0 +1,74 @@ +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + """ + assert return_mode in ["min", "max", "mean", "median"] + import torch + + fn() + torch.musa.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='musa') + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device='musa') + + # Estimate the runtime of the function + start_event = torch.musa.Event(enable_timing=True) + end_event = torch.musa.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.musa.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [torch.musa.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.musa.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.musa.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() diff --git a/third_party/mthreads/bin/CMakeLists.txt b/third_party/mthreads/bin/CMakeLists.txt new file mode 100644 index 000000000..0b68f75fa --- /dev/null +++ b/third_party/mthreads/bin/CMakeLists.txt @@ -0,0 +1,95 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) + +add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED) + +# TODO: what's this? +llvm_update_compile_flags(triton-opt) +target_link_libraries(triton-opt PRIVATE + TritonLLVMIR + TritonAnalysis + TritonTransforms + TritonGPUTransforms + MLIRGPUToROCDLTransforms + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # MLIR core + MLIROptLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-opt) +set_target_properties(triton-opt PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) + +add_llvm_executable(triton-reduce triton-reduce.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(triton-reduce) + +llvm_update_compile_flags(triton-reduce) +target_link_libraries(triton-reduce PRIVATE + TritonLLVMIR + TritonAnalysis + TritonTransforms + TritonGPUTransforms + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # MLIR core + MLIRReduceLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-reduce) +set_target_properties(triton-reduce PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) + +add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED) +mlir_check_all_link_libraries(triton-lsp) + +llvm_update_compile_flags(triton-lsp) +target_link_libraries(triton-lsp PRIVATE + TritonAnalysis + TritonTransforms + TritonGPUTransforms + ${dialect_libs} + ${conversion_libs} + ${triton_libs} + # MLIR core + MLIRLspServerLib + MLIRPass + MLIRTransforms +) + +mlir_check_all_link_libraries(triton-lsp) +set_target_properties(triton-lsp PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) + + +add_llvm_executable(triton-llvm-opt + triton-llvm-opt.cpp + + PARTIAL_SOURCES_INTENDED + DEPENDS + intrinsics_gen + SUPPORT_PLUGINS + ) +target_link_libraries(triton-llvm-opt PRIVATE + TritonLLVMIR + + LLVMAnalysis + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(triton-llvm-opt) +set_target_properties(triton-llvm-opt PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin/ +) diff --git a/third_party/mthreads/bin/RegisterTritonDialects.h b/third_party/mthreads/bin/RegisterTritonDialects.h new file mode 100644 index 000000000..26e274048 --- /dev/null +++ b/third_party/mthreads/bin/RegisterTritonDialects.h @@ -0,0 +1,53 @@ +#pragma once + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" + +#include "mlir/Dialect/LLVMIR/MTGPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/InitAllPasses.h" + +#include +#include +#include + +#include "python/src/plugin.h" + +using BackendRegisterFunc = void (*)(); + +BackendRegisterFunc load_backend_register_func(const char *backend_name, + const char *func_name) { + void *symbol = load_backend_symbol(backend_name, func_name); + return reinterpret_cast(symbol); +} + +inline void registerTritonDialects(mlir::DialectRegistry ®istry) { + mlir::registerAllPasses(); + mlir::registerTritonPasses(); + + mlir::triton::gpu::registerTritonGPUPasses(); + mlir::triton::registerConvertTritonToTritonGPUPass(); + mlir::triton::registerAllocateSharedMemoryPass(); + mlir::registerLLVMDIScope(); + + // TODO(mthreads): registerMthreadsPasses is not working currently, + // since both libtriton.so and mthreadsTritonPlugin.so are linked the + // MLIRPass.a + auto backend_register_func = + load_backend_register_func("mthreads", "registerMthreadsPasses"); + backend_register_func(); + + // TODO: register Triton & TritonGPU passes + registry.insert(); +} diff --git a/third_party/mthreads/bin/triton-llvm-opt.cpp b/third_party/mthreads/bin/triton-llvm-opt.cpp new file mode 100644 index 000000000..1ec804cb5 --- /dev/null +++ b/third_party/mthreads/bin/triton-llvm-opt.cpp @@ -0,0 +1,121 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt OutputFilename("o", + cl::desc("Override output filename"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple::normalize(TargetTriple)); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + // Default to standard output. + if (OutputFilename.empty()) + OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + Out->keep(); + return 0; +} diff --git a/third_party/mthreads/bin/triton-lsp.cpp b/third_party/mthreads/bin/triton-lsp.cpp new file mode 100644 index 000000000..b185b0374 --- /dev/null +++ b/third_party/mthreads/bin/triton-lsp.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry)); +} diff --git a/third_party/mthreads/bin/triton-opt.cpp b/third_party/mthreads/bin/triton-opt.cpp new file mode 100644 index 000000000..2d2570771 --- /dev/null +++ b/third_party/mthreads/bin/triton-opt.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "Triton (GPU) optimizer driver\n", registry)); +} diff --git a/third_party/mthreads/bin/triton-reduce.cpp b/third_party/mthreads/bin/triton-reduce.cpp new file mode 100644 index 000000000..8235f8fc8 --- /dev/null +++ b/third_party/mthreads/bin/triton-reduce.cpp @@ -0,0 +1,11 @@ +#include "./RegisterTritonDialects.h" + +#include "mlir/Tools/mlir-reduce/MlirReduceMain.h" + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + registerTritonDialects(registry); + + mlir::MLIRContext context(registry); + return mlir::failed(mlir::mlirReduceMain(argc, argv, context)); +} diff --git a/third_party/mthreads/include/CMakeLists.txt b/third_party/mthreads/include/CMakeLists.txt new file mode 100644 index 000000000..109c292fe --- /dev/null +++ b/third_party/mthreads/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton) diff --git a/third_party/mthreads/include/triton/Analysis/Alias.h b/third_party/mthreads/include/triton/Analysis/Alias.h new file mode 100644 index 000000000..a06df5ae2 --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/Alias.h @@ -0,0 +1,96 @@ +#ifndef TRITON_ANALYSIS_ALIAS_H +#define TRITON_ANALYSIS_ALIAS_H + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { + +class AliasInfo { +public: + AliasInfo() = default; + AliasInfo(Value value) { insert(value); } + + void insert(Value value) { allocs.insert(value); } + + const DenseSet &getAllocs() const { return allocs; } + + bool operator==(const AliasInfo &other) const { + return allocs == other.allocs; + } + + /// The pessimistic value state of a value without alias + static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) { + return AliasInfo(); + } + static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); } + + /// The union of both arguments + static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs); + + void print(raw_ostream &os) const { + llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); }); + } + +private: + /// The set of allocated values that are aliased by this lattice. + /// For now, we only consider aliased value produced by the following + /// situations: + /// 1. values returned by scf.yield + /// 2. block arguments in scf.for + /// Example: + /// alloc v1 alloc v2 + /// | | + /// |--------------| |------------| + /// scf.for v3 scf.for v4 scf.for v5 + /// | + /// scf.yield v6 + /// + /// v1's alloc [v1] + /// v2's alloc [v2] + /// v3's alloc [v1] + /// v4's alloc [v1, v2] + /// v5's alloc [v2] + /// v6's alloc [v1] + /// + /// Therefore, v1's liveness range is the union of v3, v4, and v6 + /// v2's liveness range is the union of v4 and v5. + DenseSet allocs; +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Alias Analysis +//===----------------------------------------------------------------------===// +class SharedMemoryAliasAnalysis + : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +public: + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::SparseForwardDataFlowAnalysis; + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + + /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. + /// Given two values, returns their aliasing behavior. + AliasResult alias(Value lhs, Value rhs); + + /// Returns the modify-reference behavior of `op` on `location`. + ModRefResult getModRef(Operation *op, Value location); + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, lattice->join( + AliasInfo::getPessimisticValueState(lattice->getPoint()))); + } + + /// Computes if the alloc set of the results are changed. + void + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALIAS_H diff --git a/third_party/mthreads/include/triton/Analysis/Allocation.h b/third_party/mthreads/include/triton/Analysis/Allocation.h new file mode 100644 index 000000000..92f63eb48 --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/Allocation.h @@ -0,0 +1,257 @@ +#ifndef TRITON_ANALYSIS_ALLOCATION_H +#define TRITON_ANALYSIS_ALLOCATION_H + +#include "triton/Analysis/Utility.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include + +namespace mlir { + +namespace triton { +class AllocationAnalysis; + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec); +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op); + +} // namespace triton + +/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h +/// A class that represents an interval, specified using a start and an end +/// values: [Start, End). +template class Interval { +public: + Interval() {} + Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); } + T start() const { return Start; } + T end() const { return End; } + T size() const { return End - Start; } + bool contains(T Addr) const { return Start <= Addr && Addr < End; } + bool intersects(const Interval &R) const { + return Start < R.End && R.Start < End; + } + bool operator==(const Interval &R) const { + return Start == R.Start && End == R.End; + } + bool operator!=(const Interval &R) const { return !(*this == R); } + bool operator<(const Interval &R) const { + return std::make_pair(Start, End) < std::make_pair(R.Start, R.End); + } + +private: + T Start = std::numeric_limits::min(); + T End = std::numeric_limits::max(); +}; + +template Interval(T, T) -> Interval; + +class Allocation { +public: + /// A unique identifier for shared memory buffers + using BufferId = size_t; + using BufferIdSetT = DenseSet; + using FuncAllocMapT = CallGraph::FuncDataMapT; + + static constexpr BufferId InvalidBufferId = + std::numeric_limits::max(); + + Allocation() = default; + /// Creates a new Allocation analysis that computes the shared memory + /// information for all associated shared memory values. + explicit Allocation(Operation *operation) : operation(operation) {} + + /// Runs allocation analysis on the given top-level operation. + void run(FuncAllocMapT &funcAllocMap); + + /// Returns the operation this analysis was constructed from. + Operation *getOperation() const { return operation; } + + /// Returns the offset of the given buffer in the shared memory. + size_t getOffset(BufferId bufferId) const { + return bufferSet.at(bufferId).offset; + } + + /// Returns the size of the given buffer in the shared memory. + size_t getAllocatedSize(BufferId bufferId) const { + return bufferSet.at(bufferId).size; + } + + /// Returns the allocated interval of the given buffer. + Interval getAllocatedInterval(BufferId bufferId) const { + auto &buffer = bufferSet.at(bufferId); + return Interval(buffer.offset, buffer.offset + buffer.size); + } + + /// Returns the buffer id of the given value. + /// This interface only returns the allocated buffer id. + /// If you want to get all the buffer ids that are associated with the given + /// value, including alias buffers, use getBufferIds. + BufferId getBufferId(Value value) const { + if (valueBuffer.count(value)) { + return valueBuffer.lookup(value)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns all the buffer ids of the given value, including alias buffers. + BufferIdSetT getBufferIds(Value value) const { + BufferIdSetT bufferIds; + auto allocBufferId = getBufferId(value); + if (allocBufferId != InvalidBufferId) + bufferIds.insert(allocBufferId); + for (auto *buffer : aliasBuffer.lookup(value)) { + if (buffer->id != InvalidBufferId) + bufferIds.insert(buffer->id); + } + return bufferIds; + } + + /// Returns the scratch buffer id of the given value. + BufferId getBufferId(Operation *operation) const { + if (opScratch.count(operation)) { + return opScratch.lookup(operation)->id; + } else if (opVirtual.count(operation)) { + return opVirtual.lookup(operation)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns if the given buffer is a virtual buffer. + bool isVirtualBuffer(BufferId bufferId) const { + return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual; + } + + /// Returns the size of total shared memory allocated + size_t getSharedMemorySize() const { return sharedMemorySize; } + +private: + /// A class that represents a shared memory buffer + struct BufferT { + /// Explicit: triton_gpu.local_alloc + /// Scratch: triton_gpu.convert_layout + /// Virtual: triton.call + enum class BufferKind { Explicit, Scratch, Virtual }; + + /// MT: thread-safe + inline static std::atomic nextId = 0; + + BufferKind kind; + BufferId id; + size_t size; + size_t alignment; + size_t offset; + + bool operator==(const BufferT &other) const { return id == other.id; } + bool operator<(const BufferT &other) const { return id < other.id; } + + BufferT() : BufferT(BufferKind::Explicit, 0) {} + BufferT(BufferKind kind, size_t size, size_t alignment = 4, + size_t offset = 0) + : kind(kind), id(nextId++), size(size), alignment(alignment), + offset(offset) {} + + size_t setOffsetAligned(size_t newOffset) { + return offset = llvm::alignTo(newOffset, alignment); + } + }; + + /// Op -> Scratch Buffer + using OpScratchMapT = DenseMap; + /// Value -> Explicit Buffer + using ValueBufferMapT = llvm::MapVector; + /// Value -> Alias Buffer + using AliasBufferMapT = llvm::MapVector>; + /// BufferId -> Buffer + using BufferSetT = std::map; + +private: + template + void addBuffer(KeyType &key, Args &&...args) { + auto buffer = BufferT(Kind, std::forward(args)...); + bufferSet[buffer.id] = std::move(buffer); + if constexpr (Kind == BufferT::BufferKind::Explicit) { + valueBuffer[key] = &bufferSet[buffer.id]; + } else if constexpr (Kind == BufferT::BufferKind::Virtual) { + opVirtual[key] = &bufferSet[buffer.id]; + } else { + opScratch[key] = &bufferSet[buffer.id]; + } + } + + void addAlias(Value value, Value alloc) { + aliasBuffer[value].insert(valueBuffer[alloc]); + } + +private: + Operation *operation = nullptr; + OpScratchMapT opScratch; + OpScratchMapT opVirtual; + ValueBufferMapT valueBuffer; + AliasBufferMapT aliasBuffer; + BufferSetT bufferSet; + size_t sharedMemorySize = 0; + + friend class triton::AllocationAnalysis; +}; + +/// Static analysis that computes the allocation of shared memory buffers +/// of the entire call graph. +/// The allocation is performed in a post-order walk of the call graph. +/// Each call op is treated like convert_layout that allocates a scratch buffer. +/// At each call, we compute the start offset of the scratch buffer and pass it +/// as an argument to the callee. +class ModuleAllocation : public CallGraph { +public: + using FuncOffsetMapT = DenseMap; + + explicit ModuleAllocation(ModuleOp moduleOp) + : CallGraph(moduleOp) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); + if (inserted) + iter->second.run(funcMap); + }); + } + + size_t getSharedMemorySize() { + size_t size = 0; + for (auto funcOp : getRoots()) { + auto *alloc = getFuncData(funcOp); + size = std::max(size, alloc->getSharedMemorySize()); + } + return size; + } + + size_t getSharedMemorySize(FunctionOpInterface funcOp) { + return getFuncData(funcOp)->getSharedMemorySize(); + } + + void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) { + sharedMemoryValue[funcOp] = value; + } + + Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) { + return sharedMemoryValue[funcOp]; + } + +private: + FuncOffsetMapT sharedMemoryValue; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALLOCATION_H diff --git a/third_party/mthreads/include/triton/Analysis/AxisInfo.h b/third_party/mthreads/include/triton/Analysis/AxisInfo.h new file mode 100644 index 000000000..22a7ed554 --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/AxisInfo.h @@ -0,0 +1,215 @@ +#ifndef TRITON_ANALYSIS_AXISINFO_H +#define TRITON_ANALYSIS_AXISINFO_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include +#include + +namespace mlir::triton { + +//===----------------------------------------------------------------------===// +// AxisInfo +//===----------------------------------------------------------------------===// + +/// This lattice value represents known information on the axes of a lattice. +class AxisInfo { +public: + typedef SmallVector DimVectorT; + +public: + AxisInfo() : AxisInfo({}, {}, {}) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy) + : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy, + std::optional constantValue) + : contiguity(contiguity), divisibility(divisibility), + constancy(constancy), constantValue(constantValue) { + assert(divisibility.size() == contiguity.size()); + assert(constancy.size() == contiguity.size()); + } + + // contiguity[d] is the length of the shortest sequence of contiguous integers + // along dimension d. + // + // If we have an array of N elements with a contiguity value C, then the array + // can be divided into a list of N/C sequences of C contiguous elements. + // Since we have N = 2^k, C must be a power of two. + // + // For example, the 2D array + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has contiguity [1, 4], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27], + // [18, 22, 26, 30], + // [19, 23, 27, 31]] + // + // has contiguity [2, 1]. + int64_t getContiguity(size_t dim) const { return contiguity[dim]; } + const DimVectorT &getContiguity() const { return contiguity; } + + // divisibility[d] is the largest power of two that divides the first element + // of all groups of length contiguity[d] along dimension d. + // + // For example, + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has divisibility [1, 2], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27]] + // + // has divisibility [4, 1]. + // + // On the other hand, + // + // [0, 1, 2, 0, 4, 5, 6, 7] + // + // has divisibility 1 because its contiguity is 1. + int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } + const DimVectorT &getDivisibility() const { return divisibility; } + + // constancy[d] is the length of the shortest sequence of repeating integers + // along dimension d. + // + // This is particularly useful to infer the contiguity of operations (e.g. + // add) involving a constant. + // + // If we have an array of N elements, with a constancy value C, then the array + // can be divided into a list of N/C sequences of C elements with the same + // value. Since we have N = 2^k, C must be a power of two. + // + // For example + // + // [[8, 8, 8, 8, 12, 12, 12, 12], + // [16, 16, 16, 16, 20, 20, 20, 20]] + // + // has constancy [1, 4]. + int64_t getConstancy(size_t dim) const { return constancy[dim]; } + const DimVectorT &getConstancy() const { return constancy; } + + int getRank() const { return contiguity.size(); } + + std::optional getConstantValue() const { return constantValue; } + + template + static void + initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, + DimVectorT *divisibility, DimVectorT *constancy); + + bool operator==(const AxisInfo &other) const { + return contiguity == other.contiguity && + divisibility == other.divisibility && constancy == other.constancy && + constantValue == other.constantValue; + } + + static AxisInfo getPessimisticValueState(Value value); + + // The gcd of both arguments for each dimension + static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); + + void print(raw_ostream &os) const { + auto print = [&](StringRef name, DimVectorT vec) { + os << name << " = ["; + llvm::interleaveComma(vec, os); + os << "]"; + }; + print("contiguity", contiguity); + print(", divisibility", divisibility); + print(", constancy", constancy); + os << ", constant_value = "; + if (constantValue) + os << *constantValue; + else + os << ""; + } + +private: + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + + // The constant value of the lattice if we can infer it. + std::optional constantValue; +}; + +// Module level axis info analysis based on the call graph, assuming that we do +// not have recursive functions. +// +// Since each function will be called multiple times, we need to calculate the +// axis info based on the axis info of all the callers. In the future, we can +// perform optimization using function cloning so that each call site will have +// unique axis info. +using AxisInfoMapT = DenseMap; +class ModuleAxisInfoAnalysis : public CallGraph { +public: + explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, AxisInfoMapT{}); + }); + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + SymbolTableCollection symbolTable; + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = + dyn_cast(callOp.resolveCallable(&symbolTable)); + update(callOp, callee); + }); + } + } + + AxisInfo *getAxisInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } + + unsigned getPtrContiguity(Value ptr); + unsigned getPtrAlignment(Value ptr); + unsigned getMaskAlignment(Value mask); + +private: + void initialize(FunctionOpInterface funcOp); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; + +} // namespace mlir::triton + +#endif diff --git a/third_party/mthreads/include/triton/Analysis/Membar.h b/third_party/mthreads/include/triton/Analysis/Membar.h new file mode 100644 index 000000000..43bd5d15b --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/Membar.h @@ -0,0 +1,154 @@ +#ifndef TRITON_ANALYSIS_MEMBAR_H +#define TRITON_ANALYSIS_MEMBAR_H + +#include "Allocation.h" +#include "llvm/ADT/SmallPtrSet.h" + +#include + +namespace mlir { + +class OpBuilder; + +struct BlockInfo { + using BufferIdSetT = Allocation::BufferIdSetT; + using IntervalSetT = std::set>; + + IntervalSetT syncReadIntervals; + IntervalSetT syncWriteIntervals; + + BlockInfo() = default; + + /// Unions two BlockInfo objects. + BlockInfo &join(const BlockInfo &other) { + syncReadIntervals.insert(other.syncReadIntervals.begin(), + other.syncReadIntervals.end()); + syncWriteIntervals.insert(other.syncWriteIntervals.begin(), + other.syncWriteIntervals.end()); + return *this; + } + + /// Returns true if intervals in two BlockInfo objects are intersected. + bool isIntersected(const BlockInfo &other) const { + return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) || + /*WAR*/ + isIntersected(syncReadIntervals, other.syncWriteIntervals) || + /*WAW*/ + isIntersected(syncWriteIntervals, other.syncWriteIntervals); + } + + /// Clears the intervals because a barrier is inserted. + void sync() { + syncReadIntervals.clear(); + syncWriteIntervals.clear(); + } + + /// Compares two BlockInfo objects. + bool operator==(const BlockInfo &other) const { + return syncReadIntervals == other.syncReadIntervals && + syncWriteIntervals == other.syncWriteIntervals; + } + + bool operator!=(const BlockInfo &other) const { return !(*this == other); } + +private: + bool isIntersected(const IntervalSetT &lhsIntervalSet, + const IntervalSetT &rhsIntervalSet) const { + for (auto &lhs : lhsIntervalSet) + for (auto &rhs : rhsIntervalSet) + if (lhs.intersects(rhs)) + return true; + return false; + } +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Barrier Analysis +//===----------------------------------------------------------------------===// +class MembarAnalysis { +public: + using FuncBlockInfoMapT = CallGraph::FuncDataMapT; + /// Creates a new Membar analysis that generates the shared memory barrier + /// in the following circumstances: + /// - RAW: If a shared memory write is followed by a shared memory read, and + /// their addresses are intersected, a barrier is inserted. + /// - WAR: If a shared memory read is followed by a shared memory write, and + /// their addresses are intersected, a barrier is inserted. + /// The following circumstances do not require a barrier: + /// - WAW: not possible because overlapped memory allocation is not allowed. + /// - RAR: no write is performed. + /// Temporary storage of operations such as Reduce are considered as both + /// a shared memory read. If the temporary storage is written but not read, + /// it is considered as the problem of the operation itself but not the membar + /// analysis. + MembarAnalysis() = default; + explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {} + + /// Runs the membar analysis to the given operation, inserts a barrier if + /// necessary. + void run(FuncBlockInfoMapT &funcBlockInfoMap); + +private: + /// Applies the barrier analysis based on the SCF dialect, in which each + /// region has a single basic block only. + /// Example: + /// region1 + /// op1 + /// op2 (scf.if) + /// region2 + /// op3 + /// op4 + /// region3 + /// op5 + /// op6 + /// op7 + /// TODO: Explain why we don't use ForwardAnalysis: + void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder); + + /// Updates the BlockInfo operation based on the operation. + void update(Operation *operation, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder); + + /// Collects the successors of the terminator + void visitTerminator(Operation *operation, SmallVector &successors); + + void insertBarrier(Operation *operation, OpBuilder *builder); + +private: + Allocation *allocation = nullptr; +}; + +/// Postorder traversal on the callgraph to insert membar instructions +/// of each function. +/// Each function maintains a BlockInfo map that includes all potential buffers +/// after returning. This way users do not have to explicitly insert membars +/// before and after function calls, but might be a bit conservative. +class ModuleMembarAnalysis : public CallGraph { +public: + ModuleMembarAnalysis(ModuleAllocation *moduleAllocation) + : CallGraph(moduleAllocation->getModuleOp()), + moduleAllocation(moduleAllocation) {} + + void run() { + walk( + // Pre-order walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order walk callback + [&](FunctionOpInterface funcOp) { + auto *allocation = moduleAllocation->getFuncData(funcOp); + auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo()); + if (inserted) { + MembarAnalysis analysis(allocation); + analysis.run(funcMap); + } + }); + } + +private: + ModuleAllocation *moduleAllocation; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_MEMBAR_H diff --git a/third_party/mthreads/include/triton/Analysis/Utility.h b/third_party/mthreads/include/triton/Analysis/Utility.h new file mode 100644 index 000000000..7b215f267 --- /dev/null +++ b/third_party/mthreads/include/triton/Analysis/Utility.h @@ -0,0 +1,366 @@ +#ifndef TRITON_ANALYSIS_UTILITY_H +#define TRITON_ANALYSIS_UTILITY_H + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + +inline bool isZeroConst(Value v) { + auto constantOp = v.getDefiningOp(); + if (!constantOp) + return false; + if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + return false; +} + +class ReduceOpHelper { +public: + explicit ReduceOpHelper(triton::ReduceOp op) + : op(op.getOperation()), axis(op.getAxis()) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = op.getElementTypes(); + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + + ArrayRef getSrcShape() { return srcShape; } + + Attribute getSrcLayout() { return srcEncoding; } + + triton::ReduceOp getOperation() { return op; } + + bool isReductionOnLayoutFastAxis(); + + unsigned getThreadOffsetOnReductionAxis(); + + bool isWarpSynchronous(); + + unsigned getInterWarpSize(); + + unsigned getIntraWarpSize(); + + unsigned getInterWarpSizeWithUniqueData(); + + unsigned getIntraWarpSizeWithUniqueData(); + + unsigned getThreadsReductionAxis(); + + SmallVector getScratchConfig(); + + SmallVector getOrderWithAxisAtBeginning(); + + unsigned getScratchSizeInBytes(); + + bool isSupportedLayout(); + + bool isReduceWithinCTA(); + + unsigned getAxis() { return axis; } + +private: + triton::ReduceOp op; + ArrayRef srcShape; + Attribute srcEncoding; + SmallVector srcElementTypes; + int axis; +}; + +class ScanLoweringHelper { +public: + explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = op.getElementTypes(); + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + // Return true if the lowering of the scan op is supported. + bool isSupported(); + // Return the number of elements per thread along axis dim. + unsigned getAxisNumElementsPerThread(); + // Return the number of elements per thread along non-axis dims. + unsigned getNonAxisNumElementsPerThread(); + // Return the number of threads per warp along non-axis dims. + unsigned getNonAxisNumThreadsPerWarp(); + // Return the flat numbers of threads computing independent scan results. + unsigned getNonAxisNumThreadsPerCTA(); + // Return the number of warps per CTA along axis dim. + unsigned getAxisNumWarps(); + // Return the number of warps per CTA along axis dim with unique data. + unsigned getAxisNumWarpsWithUniqueData(); + // Return the number of threads per warp along axis dim. + unsigned getAxisNumThreadsPerWarp(); + // Return the number of threads per warp along axis dim with unique data. + unsigned getAxisNumThreadsPerWarpWithUniqueData(); + // Return the number of blocks along axis dim. + unsigned getAxisNumBlocks(); + // Return the number of blocks along non axis dim. + unsigned getNonAxisNumBlocks(); + // Return the size of the scratch space needed for scan lowering. + unsigned getScratchSizeInBytes(); + // Return the number of elements of the scratch space needed for scan + // lowering. + unsigned getScratchSizeInElems(); + + // Stride between contiguous element along axis dim. + unsigned getAxisElementStride(); + // Stride between contiguous threads along axis dim. + unsigned getAxisThreadStride(); + // Stride between contiguous blocks along axis dim. + unsigned getAxisBlockStride(); + + Location getLoc() { return scanOp.getLoc(); } + unsigned getAxis() { return scanOp.getAxis(); } + bool getReverse() { return scanOp.getReverse(); } + triton::gpu::BlockedEncodingAttr getEncoding(); + llvm::ArrayRef getShape() { return srcShape; } + unsigned getNumOperands() { return scanOp.getNumOperands(); } + SmallVector getElementTypes() { return srcElementTypes; } + Attribute getSrcLayout() { return srcEncoding; } + Region &getCombineOp(); + +private: + triton::ScanOp scanOp; + Attribute srcEncoding; + llvm::ArrayRef srcShape; + SmallVector srcElementTypes; +}; + +// Decomposes a reshape into simpler pieces. +// +// As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2]. +// You might explain what this does as follows. +// +// - Split the first input dimension into [2,2]. +// - Take the remaining two input dimensions, merge them into a single [16] +// dim, and then split that into [8,2]. +// +// In general, a reshape can be described a sequence of smushing one or more +// input dimensions together and then breaking them apart into one or more +// output dimensions. So we could represent the example above as follows. +// +// [ +// ([0], [0, 1]), # input dim [0] -> output dims [0, 1] +// ([1, 2], [2, 3]), # input dims [1, 2] -> output dims [2, 3] +// ] +// +// Notice that the input dims (first tuple elems) appear in sequential order if +// you read left-to-right-top-to-bottom, and so do the output dims. +// +// This function returns the above decomposition. +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, ArrayRef dstShape); + +bool maybeSharedAllocationOp(Operation *op); + +bool supportMFMA(triton::DotOp op); + +bool supportWMMA(triton::DotOp op); + +bool supportMMA(triton::DotOp op, int version); + +bool supportMMA(Value value, int version); + +bool isSingleValue(Value value); + +bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); + +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); + +bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy); + +// Return true if the src and dst layout match. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy); + +// TODO: Move utility functions that belong to ConvertLayoutOp to class +// ConvertLayoutOpHelper in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); + +/// Multi-root DAG topological sort. +/// Performs a topological sort of the Operation in the `toSort` SetVector. +/// Returns a topologically sorted SetVector. +/// It is faster than mlir::topologicalSort because it prunes nodes that have +/// been visited before. +SetVector +multiRootTopologicalSort(const SetVector &toSort); + +/// This uses the toplogicalSort above +SetVector +multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, + TransitiveFilter forwardFilter = nullptr); + +/// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +/// This class represents a call graph for a given ModuleOp and holds +/// data of type T associated with each FunctionOpInterface. +template class CallGraph { +public: + using FuncDataMapT = DenseMap; + + /// Constructor that builds the call graph for the given moduleOp. + explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); } + + /// Walks the call graph and applies the provided update functions + /// to the edges and nodes. + template + void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) { + DenseSet visited; + for (auto root : roots) { + doWalk(root, visited, updateEdgeFn, + updateNodeFn); + } + } + + /// Retrieves the data associated with a function + T *getFuncData(FunctionOpInterface funcOp) { + if (funcMap.count(funcOp)) { + return &funcMap[funcOp]; + } + return nullptr; + } + + /// Getters + ModuleOp getModuleOp() const { return moduleOp; } + SmallVector getRoots() const { return roots; } + size_t getNumFunctions() const { return funcMap.size(); } + + /// Returns true if the given function is a root. + bool isRoot(FunctionOpInterface funcOp) const { + return llvm::is_contained(roots, funcOp); + } + + /// Maps the data and the graph nodes associated with a funcOp to a + /// targetFuncOp. + template + void mapFuncOp(FROM funcOp, TO targetFuncOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.second == funcOp) { + edge.second = targetFuncOp; + } + } + } + graph[targetFuncOp] = graph[funcOp]; + // Replace in roots + for (auto it = roots.begin(); it != roots.end(); ++it) { + if (*it == funcOp) { + *it = targetFuncOp; + break; + } + } + // Replace in funcMap + funcMap[targetFuncOp] = funcMap[funcOp]; + } + + /// Maps the graph edges associated with a callOp to a targetCallOp. + template + void mapCallOp(FROM callOp, TO targetCallOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.first == callOp) { + edge.first = targetCallOp; + } + } + } + } + +private: + void build() { + SymbolTableCollection symbolTable; + DenseSet visited; + // Build graph + moduleOp.walk([&](Operation *op) { + auto caller = op->getParentOfType(); + if (auto callOp = dyn_cast(op)) { + auto *callee = callOp.resolveCallable(&symbolTable); + auto funcOp = dyn_cast_or_null(callee); + if (funcOp) { + graph[caller].emplace_back( + std::pair(callOp, funcOp)); + visited.insert(funcOp); + } + } + }); + // Find roots + moduleOp.walk([&](FunctionOpInterface funcOp) { + if (!visited.count(funcOp)) { + roots.push_back(funcOp); + } + }); + } + + template + void doWalk(FunctionOpInterface funcOp, + DenseSet &visited, UpdateEdgeFn updateEdgeFn, + UpdateNodeFn updateNodeFn) { + if (visited.count(funcOp)) { + llvm::report_fatal_error("Cycle detected in call graph"); + } + if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) { + updateNodeFn(funcOp); + } + for (auto [callOp, callee] : graph[funcOp]) { + if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) { + updateEdgeFn(callOp, callee); + } + doWalk(callee, visited, updateEdgeFn, + updateNodeFn); + if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) { + updateEdgeFn(callOp, callee); + } + } + if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) { + updateNodeFn(funcOp); + } + visited.erase(funcOp); + } + +protected: + ModuleOp moduleOp; + DenseMap>> + graph; + FuncDataMapT funcMap; + SmallVector roots; +}; +// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v); + +} // namespace mlir + +#endif // TRITON_ANALYSIS_UTILITY_H diff --git a/third_party/mthreads/include/triton/CMakeLists.txt b/third_party/mthreads/include/triton/CMakeLists.txt new file mode 100644 index 000000000..27c703b3c --- /dev/null +++ b/third_party/mthreads/include/triton/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) diff --git a/third_party/mthreads/include/triton/Conversion/CMakeLists.txt b/third_party/mthreads/include/triton/Conversion/CMakeLists.txt new file mode 100644 index 000000000..730f5cadd --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonGPUToLLVM) +add_subdirectory(TritonToTritonGPU) diff --git a/third_party/mthreads/include/triton/Conversion/MLIRTypes.h b/third_party/mthreads/include/triton/Conversion/MLIRTypes.h new file mode 100644 index 000000000..fadba413f --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/MLIRTypes.h @@ -0,0 +1,42 @@ +#ifndef TRITON_CONVERSION_MLIR_TYPES_H +#define TRITON_CONVERSION_MLIR_TYPES_H + +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// This file redefines some common MLIR types for easy usage. +namespace mlir { +namespace triton { +namespace type { + +// Integer types +inline Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); } +inline Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); } +inline Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); } +inline Type u32Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 32, IntegerType::Unsigned); +} +inline Type u1Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 1, IntegerType::Unsigned); +} + +// Float types +inline Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); } +inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); } +inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); } +inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); } + +inline bool isFloat(Type type) { + return type.isF32() || type.isF64() || type.isF16() || type.isF128() || + type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || + type.isFloat8E5M2FNUZ(); +} + +inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } + +} // namespace type +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_MLIR_TYPES_H diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h new file mode 100644 index 000000000..00ec88089 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h @@ -0,0 +1,27 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +class ConversionPatternRewriter; +class Location; + +namespace triton { +using llvm::StringRef; + +inline std::string strJoin(llvm::ArrayRef strs, + llvm::StringRef delimiter) { + return llvm::join(strs.begin(), strs.end(), delimiter); +} + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..93f8374e5 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonGPUToLLVM) +add_public_tablegen_target(TritonGPUConversionPassIncGen) diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h new file mode 100644 index 000000000..5203ffff9 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -0,0 +1,232 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H +#define TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton { + +namespace gpu { + +SmallVector reorderValues(const SmallVector &values, Type inType, + Type ouType); + +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + +Type getElementType(Value value); + +class MultipleOperandsRange + : public iterator_range>::iterator> { + using ContainerT = SmallVector>; + +public: + using iterator_range::iterator_range; + ContainerT::reference operator[](ContainerT::size_type idx) { + return begin()[idx]; + } + ContainerT::const_reference operator[](ContainerT::size_type idx) const { + return begin()[idx]; + } + ContainerT::size_type size() const { return end() - begin(); } +}; + +// Base pattern for elementwise conversion using ConcreteT. Unpacks individual +// elements from a `!llvm.struct` via `llvm.extactvalue`, calls +// ConcreteT::createDestOps on each element, and packs them back into an +// `!llvm.struct` using `llvm.insertvalue`. +// +// Also supports processing the inputs in a vectorized form by consuming and +// producing multiple operand sets in ConcreteT::createDestOps. +template +class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit ElementwiseOpConversionBase( + LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit), + axisAnalysisPass(axisAnalysisPass) {} + + // Try to deduplicate the resultVals based on the + // constancy properties of the result discovered by + // the axis analysis pass. If possible, redundant + // computation is eliminated. + SmallVector maybeDeduplicate(SourceOp op, + SmallVector resultVals) const { + if (!isMemoryEffectFree(op)) + // the op has side effects: can't dedup + return resultVals; + SmallVector results = op->getResults(); + if (results.size() == 0 || results.size() > 1) + // there must be exactly 1 result + return resultVals; + Value result = results[0]; + Type type = result.getType(); + if (!type) + return resultVals; + RankedTensorType rtType = dyn_cast(type); + if (!rtType) + // the result must be a tensor + return resultVals; + Attribute encoding = rtType.getEncoding(); + if (!encoding) + // encoding not available + return resultVals; + Attribute baseEncoding = encoding; + if (isa(baseEncoding)) + // TODO: this logic seems incorrect for mfma layout. Skip for now. + // We saw mismatches for some flash-attention tests on AMD backend. + // Note that this logic works for sliced layout whose parent is + // mfma layout. Therefore, this is not combined with the following check. + return resultVals; + while (auto sliced = dyn_cast(baseEncoding)) + baseEncoding = sliced.getParent(); + if (isa(baseEncoding)) { + // TODO: this logic seems incorrect for mma layout. Skip for now. + // The following test crashes and some other miscompile: + // test_core::test_fp8_dot_acc + return resultVals; + } + + SmallVector elemsPerThread = getElemsPerThread(rtType); + int rank = elemsPerThread.size(); + if (product(elemsPerThread) != resultVals.size()) + return resultVals; + AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result); + if (!axisInfo) + // axis info (e.g., constancy) not available + return resultVals; + SmallVector contigPerThread = getContigPerThread(encoding); + if (rank != contigPerThread.size()) + return resultVals; + + SmallVector constancy = axisInfo->getConstancy(); + if (rank != constancy.size()) + return resultVals; + bool hasConstancy = false; + for (int i = 0; i < rank; ++i) { + if (constancy[i] > contigPerThread[i]) { + if (constancy[i] % contigPerThread[i] != 0) + // constancy is not evenly covered by contigPerThread + return resultVals; + // can't move the values across different + // "contigPerThread"-sized blocks + constancy[i] = contigPerThread[i]; + } + if (elemsPerThread[i] < 1 || constancy[i] < 1) + return resultVals; + if (!(elemsPerThread[i] % constancy[i] == 0 || + constancy[i] % elemsPerThread[i] == 0)) + // either the constancy along each dimension must fit + // into the elemsPerThread or the other way around + return resultVals; + if (constancy[i] > 1) + hasConstancy = true; + } + if (!hasConstancy) + // nothing to deduplicate + return resultVals; + + if (rank > 1) { + // reorder the shape and constancy vectors by the axis order: + // from the fastest-changing to the smallest-changing axis + SmallVector order = getOrder(encoding); + if (rank != order.size()) + return resultVals; + elemsPerThread = applyPermutation(elemsPerThread, order); + constancy = applyPermutation(constancy, order); + } + + SmallVector strides(rank, 1); + for (int i = 1; i < rank; ++i) { + strides[i] = strides[i - 1] * elemsPerThread[i - 1]; + } + SmallVector dedupResultVals; + dedupResultVals.reserve(resultVals.size()); + for (int i = 0; i < resultVals.size(); ++i) { + // each coordinate of the orig_idx is "coarsened" using the + // constancy along this dimension: the resulting dedup_idx + // points to the reused value in the original resultsVal + int orig_idx = i; + int dedup_idx = 0; + for (int j = 0; j < rank; ++j) { + int coord_j = orig_idx % elemsPerThread[j]; + dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j]; + orig_idx /= elemsPerThread[j]; + } + dedupResultVals.push_back(resultVals[dedup_idx]); + } + + return dedupResultVals; + } + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTy = op.getType(); + Location loc = op->getLoc(); + // element type + auto resultElementTy = getElementTypeOrSelf(resultTy); + Type elemTy = this->getTypeConverter()->convertType(resultElementTy); + SmallVector> allOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + subOperands = unpackI32(subOperands, argTy, rewriter, loc, + this->getTypeConverter()); + allOperands.resize(subOperands.size()); + for (auto v : llvm::enumerate(subOperands)) + allOperands[v.index()].push_back(v.value()); + } + if (allOperands.size() == 0) + allOperands.push_back({}); + + SmallVector resultVals; + for (auto it = allOperands.begin(), end = allOperands.end(); it != end;) { + auto curr = static_cast(this)->createDestOps( + op, adaptor, rewriter, elemTy, MultipleOperandsRange(it, end), loc); + if (curr.size() == 0) + return failure(); + for (auto v : curr) { + if (!static_cast(v)) + return failure(); + resultVals.push_back(v); + } + it += curr.size(); + } + if (op->getNumOperands() > 0) { + auto argTy = op->getOperand(0).getType(); + resultVals = reorderValues(resultVals, argTy, resultTy); + } + resultVals = maybeDeduplicate(op, resultVals); + resultVals = + packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); + Value view = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + rewriter.replaceOp(op, view); + + return success(); + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + +} // namespace gpu + +} // namespace mlir::triton +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.h new file mode 100644 index 000000000..b013f2628 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.h @@ -0,0 +1,32 @@ +#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H +#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +#define GEN_PASS_DECL +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +namespace gpu { +std::unique_ptr> createAllocateSharedMemoryPass(); + +} // namespace gpu + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.td new file mode 100644 index 000000000..700dcd6b4 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -0,0 +1,11 @@ +#ifndef TRITONCOMMONGPU_CONVERSION_PASSES +#define TRITONCOMMONGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> { + let summary = "Add metadata for shared memory allocation"; + let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()"; +} + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h new file mode 100644 index 000000000..d1494fd7e --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -0,0 +1,104 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H + +#include "TargetInfoBase.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::BlockedEncodingAttr; + +namespace SharedToDotOperandFMA { +Value convertLayout(int opIdx, Value val, Value llVal, + BlockedEncodingAttr dLayout, Value thread, Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); +} +LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); +namespace mlir { +namespace triton { + +constexpr int patternBenefitDefault = 1; +constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; +constexpr int patternBenefitClampOptimizedPattern = 20; +constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; + +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateViewOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateMinMaxFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool hwNanPropagationSupported, + PatternBenefit benefit); +void populateClampFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit); + +void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Patterns.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Patterns.h new file mode 100644 index 000000000..934501ad3 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Patterns.h @@ -0,0 +1,32 @@ +#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PATTERNS_H +#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PATTERNS_H + +#include + +namespace mlir { +class ModuleOp; +class RankedTensorType; + +namespace triton::gpu { + +/// Replaces `blocked -> dot_op` with `blocked -> shared -> dot_op` in the given +/// |module| op because the codegen doesn't handle `blocked -> dot_op` directly. +void decomposeBlockedToDotLayoutConversion(ModuleOp module); + +/// Replaces `splat -> shared` with `splat -> blocked -> shared` in the given +/// |module| op. +void decomposeSplatOpToSharedLayoutConversion(ModuleOp module); + +/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the +/// given |module| op, but bypass the decomposition if |shortcutFn| returns +/// true. +using ShortcutFn = std::function; +template +void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, + ShortcutFn shortcutFn); + +} // namespace triton::gpu + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h new file mode 100644 index 000000000..d03f6b862 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -0,0 +1,66 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H + +#include "triton/Conversion/MLIRTypes.h" + +namespace mlir::triton { +class TargetInfoBase { +public: + virtual bool supportMaximumMinimum() const = 0; + + virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0; + + virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc, + Type type, Value cmp) const = 0; + + virtual void storeShared(ConversionPatternRewriter &rewriter, Location loc, + Value ptr, Value val, Value pred) const = 0; + virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *converter, Value ptr, + Type elemTy, Value pred) const = 0; + + virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const = 0; + virtual Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const = 0; + virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const = 0; + virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, Value i) const = 0; + + virtual Value programId(ConversionPatternRewriter &rewriter, Location loc, + ModuleOp moduleOp, int axis) const = 0; + + virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce) const = 0; + + virtual bool processReplicaUsingStMatrix( + ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + SmallVector &vals, RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, ArrayRef origRepShape, + ArrayRef outOrd, unsigned accumNumReplicates, + int swizzleByteWidth = 0) const = 0; + + virtual std::string getMulhiFuncName(Type resultElementTy) const = 0; + // Emits LLVM code with |rewriter| to print a message following the given + // format from the device. |formatStrStart| is the pointer to the start of + // the format string global variable; |args| are the arguments to fill + // placeholders in the format string. + virtual void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const = 0; + // Emits LLVM code with |rewriter| to perform assertion failure with the given + // |message| from the given |func| in |file|. + virtual void assertFail(ConversionPatternRewriter &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const = 0; + + // Whether to enable linear layout. This is a per-backend temporary escape + // hatch to disable linear layout while figuring out issues. Eventually we + // want to enable linear layout everywhere and delete this control. + virtual bool enableLinearLayout() const { return true; } + + virtual ~TargetInfoBase() {} +}; +} // namespace mlir::triton +#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h new file mode 100644 index 000000000..ab9d0ebf8 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h @@ -0,0 +1,26 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); + + Type getElementTypeForStruct(TensorOrMemDesc type); + Type convertTritonPointerType(triton::PointerType type); + Type convertTritonTensorType(RankedTensorType type); + Type convertMemDescType(MemDescType type); + Type convertAsyncToken(triton::gpu::AsyncTokenType type); +}; + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Utility.h new file mode 100644 index 000000000..4de0e4c8c --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -0,0 +1,1598 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H + +#include + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "ttgpu_to_llvm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::triton; + +// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive +// Operators +#define inttofloat(...) rewriter.create(loc, __VA_ARGS__) +#define inttoptr(...) rewriter.create(loc, __VA_ARGS__) +#define ptrtoint(...) rewriter.create(loc, __VA_ARGS__) +#define zext(...) rewriter.create(loc, __VA_ARGS__) +#define sext(...) rewriter.create(loc, __VA_ARGS__) +#define fpext(...) rewriter.create(loc, __VA_ARGS__) +#define trunc(...) rewriter.create(loc, __VA_ARGS__) +#define udiv(...) rewriter.create(loc, __VA_ARGS__) +#define urem(...) rewriter.create(loc, __VA_ARGS__) +#define add(...) rewriter.create(loc, __VA_ARGS__) +#define sub(...) rewriter.create(loc, __VA_ARGS__) +#define fadd(...) rewriter.create(loc, __VA_ARGS__) +#define mul(...) rewriter.create(loc, __VA_ARGS__) +#define fmul(...) rewriter.create(loc, __VA_ARGS__) +#define smax(...) rewriter.create(loc, __VA_ARGS__) +#define umax(...) rewriter.create(loc, __VA_ARGS__) +#define fmax(...) rewriter.create(loc, __VA_ARGS__) +#define smin(...) rewriter.create(loc, __VA_ARGS__) +#define umin(...) rewriter.create(loc, __VA_ARGS__) +#define fmin(...) rewriter.create(loc, __VA_ARGS__) +#define shl(...) rewriter.create(loc, __VA_ARGS__) +#define lshr(...) rewriter.create(loc, __VA_ARGS__) +#define and_(...) rewriter.create(loc, __VA_ARGS__) +#define xor_(...) rewriter.create(loc, __VA_ARGS__) +#define or_(...) rewriter.create(loc, __VA_ARGS__) +#define bitcast(val__, type__) \ + rewriter.create(loc, type__, val__) +#define addrspacecast(...) \ + rewriter.create(loc, __VA_ARGS__) +#define gep(...) rewriter.create(loc, __VA_ARGS__) +#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) +#define insert_val(...) rewriter.create(loc, __VA_ARGS__) +#define extract_val(...) rewriter.create(loc, __VA_ARGS__) +#define insert_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define extract_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define load(...) rewriter.create(loc, __VA_ARGS__) +#define store(...) rewriter.create(loc, __VA_ARGS__) +#define fcmp_ogt(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::ogt, lhs, rhs) +#define fcmp_olt(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::olt, lhs, rhs) +#define fcmp_eq(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::oeq, lhs, rhs) +#define icmp_eq(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) +#define icmp_ne(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__) +#define icmp_slt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) +#define icmp_sle(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__) +#define icmp_sgt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__) +#define icmp_sge(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__) +#define icmp_ult(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__) +#define icmp_ule(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__) +#define icmp_ugt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__) +#define icmp_uge(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__) +#define select(...) rewriter.create(loc, __VA_ARGS__) +#define address_of(...) rewriter.create(loc, __VA_ARGS__) +#define barrier() rewriter.create(loc) +#define undef(...) rewriter.create(loc, __VA_ARGS__) +#define null(...) rewriter.create(loc, __VA_ARGS__) +#define call(...) rewriter.create(loc, __VA_ARGS__) + +// Types +#define int_ty(width) rewriter.getIntegerType(width) +#define i64_ty rewriter.getIntegerType(64) +#define i32_ty rewriter.getIntegerType(32) +#define i16_ty rewriter.getIntegerType(16) +#define i32_ty rewriter.getIntegerType(32) +#define i64_ty rewriter.getIntegerType(64) +#define ui32_ty rewriter.getIntegerType(32, false) +#define ui64_ty rewriter.getIntegerType(64, false) +#define f16_ty rewriter.getF16Type() +#define bf16_ty rewriter.getBF16Type() +#define i8_ty rewriter.getIntegerType(8) +#define i1_ty rewriter.getI1Type() +#define f32_ty rewriter.getF32Type() +#define f64_ty rewriter.getF64Type() +#define vec_ty(type, num) VectorType::get(num, type) +#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) +#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) +#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) + +// Constants +#define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__) +#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) +#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) +#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) +#define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__) +#define int_val(width, val) \ + LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) +#define tid_val() getThreadId(rewriter, loc) + +// Attributes +#define i32_arr_attr(...) rewriter.getI32ArrayAttr({__VA_ARGS__}) +#define i64_arr_attr(...) rewriter.getI64ArrayAttr({__VA_ARGS__}) +#define str_attr(str) ::mlir::StringAttr::get(ctx, (str)) + +namespace mlir { +namespace triton { + +// Delinearize supposing order is [0, 1, .. , n] +template +llvm::SmallVector getMultiDimIndexImpl(T linearIndex, + llvm::ArrayRef shape) { + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearRemain = linearIndex; + llvm::SmallVector multiDimIndex(rank); + for (int i = rank - 1; i >= 0; --i) { + multiDimIndex[i] = linearRemain / accMul; + linearRemain = linearRemain % accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return multiDimIndex; +} + +template +llvm::SmallVector getMultiDimIndex(T linearIndex, llvm::ArrayRef shape, + llvm::ArrayRef order) { + size_t rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + auto reorderedMultiDim = getMultiDimIndexImpl(linearIndex, reordered); + llvm::SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +// Linearize supposing order is [0, 1, .. , n] +template +T getLinearIndexImpl(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape) { + assert(multiDimIndex.size() == shape.size()); + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearIndex = 0; + for (int i = rank - 1; i >= 0; --i) { + linearIndex += multiDimIndex[i] * accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return linearIndex; +} + +template +T getLinearIndex(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape, + llvm::ArrayRef order) { + assert(shape.size() == order.size()); + return getLinearIndexImpl(applyPermutation(multiDimIndex, order), + applyPermutation(shape, order)); +} + +namespace gpu { +Type getFunctionType(Type resultType, ValueRange operands); + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, + Operation *op, StringRef funcName, + Type funcType, StringRef libname = "", + StringRef libpath = ""); +} // namespace gpu + +} // namespace triton + +namespace LLVM { +using namespace mlir::triton; + +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v); + +/// Create a 64-bit integer constant. +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v); + +/// Create a 16-bit float constant. +Value createConstantF16(Location loc, OpBuilder &rewriter, float v); + +/// Create a 32-bit float constant. +Value createConstantF32(Location loc, OpBuilder &rewriter, float v); + +/// Create a 64-bit float constant. +Value createConstantF64(Location loc, OpBuilder &rewriter, double v); + +/// Create NaN constant of specified type. +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type); + +/// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + const TypeConverter *converter, int64_t value); + +/// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value); + +/// Helper function to get strides from a given shape and its order +SmallVector getStridesFromShapeAndOrder(ArrayRef shape, + ArrayRef order, + Location loc, + RewriterBase &rewriter); +struct SharedMemoryObject { + Value base; // i32 ptr. The start address of the shared memory object after + // the initial allocation or the last slicing operation. + Type baseElemType; + // We need to store strides as Values, not integers, because the + // extract_slice instruction can take a slice at arbitrary offsets. + // Take $a[16:32, 16:32] as an example; though we know the stride of $a[0] is + // 32, we need to let the instruction that uses $a be aware of that. + // Otherwise, when we use $a, we only know that the shape of $a is 16x16. If + // we store strides into an attribute array of integers, the information + // cannot pass through block argument assignment because attributes are + // associated with operations, not Values. + // TODO(Keren): We may need to figure out a way to store strides as integers + // if we want to support more optimizations. + SmallVector + strides; // i32 int. The strides of the shared memory object. + SmallVector offsets; // i32 int. + // Offsets are applied at the last slicing operation. + // We can use offsets to recover the previous base. + // The offsets are zero at the initial allocation. + + SharedMemoryObject(Value base, Type baseElemType, ArrayRef strides, + ArrayRef offsets) + : base(base), baseElemType(baseElemType), + strides(strides.begin(), strides.end()), + offsets(offsets.begin(), offsets.end()) {} + + SharedMemoryObject(Value base, Type baseElemType, ArrayRef shape, + ArrayRef order, Location loc, + RewriterBase &rewriter) + : base(base), baseElemType(baseElemType) { + strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); + offsets.append(order.size(), i32_val(0)); + } + + SmallVector getStrides() const { return strides; } + SmallVector getOffsets() const { return offsets; } + Value getBase() const { return base; } + Type getBaseElemType() const { return baseElemType; } + + SmallVector getElems() const { + SmallVector elems; + elems.push_back(base); + elems.append(strides.begin(), strides.end()); + elems.append(offsets.begin(), offsets.end()); + return elems; + } + + SmallVector getTypes() const { + SmallVector types; + types.push_back(base.getType()); + types.append(strides.size(), IntegerType::get(base.getContext(), 32)); + types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); + return types; + } + + Value getCSwizzleOffset(int order) const { + assert(order >= 0 && order < strides.size()); + return offsets[order]; + } + + Value getBaseBeforeSlice(int order, Location loc, + ConversionPatternRewriter &rewriter) const { + Value cSwizzleOffset = getCSwizzleOffset(order); + Value offset = sub(i32_val(0), cSwizzleOffset); + Type type = base.getType(); + return gep(type, baseElemType, base, offset); + } +}; + +SharedMemoryObject +getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, + ConversionPatternRewriter &rewriter); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape); + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape, + ArrayRef order); + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape); + +Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, + StringRef key, StringRef content); + +// Given an elemId which represents the index of an element from the list of +// elements that are in the thread's registers (i.e. total of +// numel(sizePerThread)), it calculates the multi dim offset of the element in +// the smem buffer. Recall that the smem buffer will only store a single replica +// when converting distributed to distributed layout. Also, a replica is the +// smallest CTA tile that is common between input and output layouts. +SmallVector getMultiDimOffset(Attribute layout, Location loc, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + unsigned elemId, RankedTensorType type, + ArrayRef multiDimCTAInRepId, + ArrayRef shapePerCTATile); + +// Given a multiDimOffset, this function wraps around each dimension to be +// within shape. +SmallVector getWrappedMultiDimOffset( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDimOffset, ArrayRef shape, + SmallVector shapePerCTATile, SmallVector shapePerCTA); + +inline bool isKernel(FunctionOpInterface funcOp) { + return funcOp.getVisibility() == SymbolTable::Visibility::Public; +} + +inline Value getStackPointer(PatternRewriter &rewriter, + FunctionOpInterface funcOp) { + auto mod = funcOp->getParentOfType(); + LLVM::GlobalOp globalBase = nullptr; + mod.walk([&](LLVM::GlobalOp op) { + if (op.getSymName() == "global_smem") + globalBase = op; + }); + assert(globalBase); + if (isKernel(funcOp)) + return rewriter.create(funcOp.getLoc(), globalBase); + else + return funcOp.getArgument(funcOp.getNumArguments() - 1); +} + +inline Value getSharedMemoryBase(Location loc, + ConversionPatternRewriter &rewriter, + Operation *op) { + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + FunctionOpInterface func = + op->template getParentOfType(); + assert(op->hasAttr("allocation.offset")); + size_t offset = cast(op->getAttr("allocation.offset")) + .getValue() + .getZExtValue(); + Value offVal = i32_val(offset); + Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); + return base; +} +} // namespace LLVM + +/* ------------------------------------ */ +// Returns CTA level thread idx +inline Value getThreadIdInCTA(RewriterBase &rewriter, Location loc) { + Value tid = + rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); + return rewriter.create(loc, i32_ty, tid); +} + +// Returns CTA level thread idx. +inline Value getThreadId(RewriterBase &rewriter, Location loc) { + Value tid = getThreadIdInCTA(rewriter, loc); + auto mod = rewriter.getBlock()->getParent()->getParentOfType(); + return tid; +} + +// ----------------------------------------------------------------------- +// Shared memory utilities +// ----------------------------------------------------------------------- +using LLVM::getMultiDimIndex; +using LLVM::SharedMemoryObject; +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::SharedMemoryObject; +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::CTALayoutAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, + ArrayRef strides) { + assert(offsets.size() == strides.size()); + Value ret = i32_val(0); + for (auto [offset, stride] : llvm::zip(offsets, strides)) { + ret = add(ret, mul(offset, stride)); + } + return ret; +} + +// ----------------------------------------------------------------------- +// Blocked layout indices +// ----------------------------------------------------------------------- + +// "Applies" the given layout by computing layout(indices) and returning the +// resulting Values. +// +// In other words, this generates LLVM-dialect MLIR code to "run" the layout +// function. +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices); + +inline SmallVector +emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, + const BlockedEncodingAttr &blockedLayout, + RankedTensorType type) { + MLIRContext *ctx = rewriter.getContext(); + auto shape = type.getShape(); + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout)); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + auto sizePerThread = blockedLayout.getSizePerThread(); + auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + auto order = blockedLayout.getOrder(); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); + unsigned rank = shape.size(); + + // delinearize threadId to get the base index + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, order); + SmallVector multiDimThreadId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + + SmallVector multiDimBase(rank); + for (unsigned k = 0; k < rank; ++k) { + // Wrap around multiDimWarpId/multiDimThreadId in case + // shapePerCTATile[k] > shapePerCTA[k] + auto maxWarps = + ceil(shapePerCTA[k], sizePerThread[k] * threadsPerWarp[k]); + auto maxThreads = ceil(shapePerCTA[k], sizePerThread[k]); + multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps)); + multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads)); + // multiDimBase[k] = (multiDimThreadId[k] + + // multiDimWarpId[k] * threadsPerWarp[k]) * + // sizePerThread[k]; + Value threadsPerWarpK = i32_val(threadsPerWarp[k]); + Value sizePerThreadK = i32_val(sizePerThread[k]); + multiDimBase[k] = + mul(sizePerThreadK, + add(multiDimThreadId[k], mul(multiDimWarpId[k], threadsPerWarpK))); + } + + return multiDimBase; +} + +inline SmallVector> +emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout, + RankedTensorType type) { + auto ctx = type.getContext(); + auto shape = type.getShape(); + auto sizePerThread = blockedLayout.getSizePerThread(); + auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + auto order = blockedLayout.getOrder(); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); + + unsigned rank = shape.size(); + SmallVector tilesPerDim(rank); + for (unsigned k = 0; k < rank; ++k) + tilesPerDim[k] = ceil(shapePerCTA[k], shapePerCTATile[k]); + + unsigned elemsPerThread = triton::gpu::getTotalElemsPerThread(type); + unsigned totalSizePerThread = product(sizePerThread); + SmallVector> reorderedOffset(elemsPerThread); + for (unsigned n = 0; n < elemsPerThread; ++n) { + unsigned linearNanoTileId = n / totalSizePerThread; + unsigned linearNanoTileElemId = n % totalSizePerThread; + SmallVector multiDimNanoTileId = + getMultiDimIndex(linearNanoTileId, tilesPerDim, order); + SmallVector multiDimNanoTileElemId = + getMultiDimIndex(linearNanoTileElemId, sizePerThread, order); + for (unsigned k = 0; k < rank; ++k) { + unsigned reorderedMultiDimId = + (multiDimNanoTileId[k] * + (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + + multiDimNanoTileElemId[k]) % + shapePerCTA[k]; + + reorderedOffset[n].push_back(reorderedMultiDimId); + } + } + + return reorderedOffset; +} + +// ----------------------------------------------------------------------- +// Mma layout indices +// ----------------------------------------------------------------------- + +inline SmallVector +emitBaseIndexWithinCTAForMmaLayoutV1(Location loc, RewriterBase &rewriter, + const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto wpt = mmaLayout.getWarpsPerCTA(); + static constexpr std::array fpw{{2, 2, 1}}; + auto [isARow, isBRow, isAVec4, isBVec4, _] = + mmaLayout.decodeVoltaLayoutStates(); + + Value thread = getThreadId(rewriter, loc); + auto *ctx = thread.getContext(); + Value _1 = i32_val(1); + Value _2 = i32_val(2); + Value _4 = i32_val(4); + Value _16 = i32_val(16); + Value _32 = i32_val(32); + Value _fpw0 = i32_val(fpw[0]); + Value _fpw1 = i32_val(fpw[1]); + + // A info + auto aRep = mmaLayout.getMMAv1Rep(0); + auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); + // B info + auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); + auto bRep = mmaLayout.getMMAv1Rep(1); + + SmallVector rep({aRep[0], bRep[1]}); + SmallVector spw({aSpw[0], bSpw[1]}); + SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); + + Value lane = urem(thread, _32); + Value warp = udiv(thread, _32); + + Value warp0 = urem(warp, i32_val(wpt[0])); + Value warp12 = udiv(warp, i32_val(wpt[0])); + Value warp1 = urem(warp12, i32_val(wpt[1])); + + // warp offset + Value offWarpM = mul(warp0, i32_val(spw[0])); + Value offWarpN = mul(warp1, i32_val(spw[1])); + // quad offset + Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0); + Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1); + // pair offset + Value offPairM = udiv(urem(lane, _16), _4); + offPairM = urem(offPairM, _fpw0); + offPairM = mul(offPairM, _4); + Value offPairN = udiv(urem(lane, _16), _4); + offPairN = udiv(offPairN, _fpw0); + offPairN = urem(offPairN, _fpw1); + offPairN = mul(offPairN, _4); + offPairM = mul(offPairM, i32_val(rep[0] / 2)); + offQuadM = mul(offQuadM, i32_val(rep[0] / 2)); + offPairN = mul(offPairN, i32_val(rep[1] / 2)); + offQuadN = mul(offQuadN, i32_val(rep[1] / 2)); + // quad pair offset + Value offLaneM = add(offPairM, offQuadM); + Value offLaneN = add(offPairN, offQuadN); + // a, b offset + Value offsetAM = add(offWarpM, offLaneM); + Value offsetBN = add(offWarpN, offLaneN); + // m indices + Value offsetCM = add(and_(lane, _1), offsetAM); + // n indices + Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN))); + return {offsetCM, offsetCN}; +} + +inline SmallVector> +emitOffsetForMmaLayoutV1(const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + + auto [isARow, isBRow, isAVec4, isBVec4, _] = + mmaLayout.decodeVoltaLayoutStates(); + + // TODO: seems like the pattern below to get `rep`/`spw` appears quite often + // A info + auto aRep = mmaLayout.getMMAv1Rep(0); + auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); + // B info + auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); + auto bRep = mmaLayout.getMMAv1Rep(1); + + auto wpt = mmaLayout.getWarpsPerCTA(); + static constexpr std::array fpw{{2, 2, 1}}; + SmallVector rep({aRep[0], bRep[1]}); + SmallVector spw({aSpw[0], bSpw[1]}); + SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); + + SmallVector idxM; + for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0]) + for (unsigned mm = 0; mm < rep[0]; ++mm) + idxM.push_back(m + mm * 2); + + SmallVector idxN; + for (int n = 0; n < shape[1]; n += shapePerCTA[1]) { + for (int nn = 0; nn < rep[1]; ++nn) { + idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]); + idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1); + } + } + + SmallVector> ret; + for (unsigned x1 : idxN) { // N + for (unsigned x0 : idxM) { // M + SmallVector idx(2); + idx[0] = x0; // M + idx[1] = x1; // N + ret.push_back(std::move(idx)); + } + } + return ret; +} + +inline SmallVector> +emitOffsetForMmaLayoutV2(const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + SmallVector> ret; + + auto rank = shape.size(); + for (unsigned i = 0; i < shapePerCTA[rank - 2]; + i += getShapePerCTATile(mmaLayout)[rank - 2]) { + for (unsigned j = 0; j < shapePerCTA[rank - 1]; + j += getShapePerCTATile(mmaLayout)[rank - 1]) { + if (rank == 3) { + ret.push_back({0, i, j}); + ret.push_back({0, i, j + 1}); + ret.push_back({0, i + 8, j}); + ret.push_back({0, i + 8, j + 1}); + } else { + ret.push_back({i, j}); + ret.push_back({i, j + 1}); + ret.push_back({i + 8, j}); + ret.push_back({i + 8, j + 1}); + } + } + } + return ret; +} + +// Note that this may return a null Value for one or more dimensions. This is +// valid only if you're going to slice off the relevant dimension. +inline SmallVector +emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter, + const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto _warpsPerCTA = mmaLayout.getWarpsPerCTA(); + auto rank = shape.size(); + assert(rank == 2 || rank == 3); + auto warpOrder = triton::gpu::getWarpOrder(mmaLayout); + ArrayRef instrShape = mmaLayout.getInstrShape(); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + + uint32_t repM = + (_warpsPerCTA[rank - 2] * instrShape[rank - 2]) / shapePerCTA[rank - 2]; + uint32_t repN = + (_warpsPerCTA[rank - 1] * instrShape[rank - 1]) / shapePerCTA[rank - 1]; + + uint32_t warpsM; + if (repM > 1) + warpsM = _warpsPerCTA[rank - 2] / repM; + else + warpsM = shape[rank - 2] / instrShape[rank - 2]; + + uint32_t warpsN; + if (repN > 1) + warpsN = _warpsPerCTA[rank - 1] / repN; + else + warpsN = shape[rank - 1] / instrShape[rank - 1]; + + SmallVector multiDimWarpId(rank); + multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, warpOrder); + Value warpIdM = urem(multiDimWarpId[rank - 2], i32_val(warpsM)); + Value warpIdN = urem(multiDimWarpId[rank - 1], i32_val(warpsN)); + + Value offWarpM = mul(warpIdM, i32_val(instrShape[rank - 2])); + Value offWarpN = mul(warpIdN, i32_val(instrShape[rank - 1])); + + SmallVector multiDimBase(rank); + if (rank == 3) + multiDimBase[0] = multiDimWarpId[0]; + + // warpsM/N may be 0, in which case warpIDM/N is poison (division by 0), which + // will cause LLVM to eliminate all ops that depend on the poison value. This + // *can* be okay, if the bad dimension is filtered out by a slice layout. So + // we rely on the caller to check. Worst case we crash, which is better than + // silently producing bad code. + if (warpsM != 0) + multiDimBase[rank - 2] = add(udiv(laneId, i32_val(4)), offWarpM); + if (warpsN != 0) + multiDimBase[rank - 1] = + add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarpN); + + return multiDimBase; +} + +inline SmallVector> +emitOffsetForMmaLayoutV3(const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + SmallVector> ret; + ArrayRef instrShape = mmaLayout.getInstrShape(); + + for (unsigned i = 0; i < shapePerCTA[0]; + i += getShapePerCTATile(mmaLayout)[0]) { + for (unsigned j = 0; j < shapePerCTA[1]; + j += getShapePerCTATile(mmaLayout)[1]) { + for (unsigned k = 0; k < instrShape[1]; k += 8) { + ret.push_back({i, j + k}); + ret.push_back({i, j + k + 1}); + ret.push_back({i + 8, j + k}); + ret.push_back({i + 8, j + k + 1}); + } + } + } + return ret; +} + +inline SmallVector +emitBaseIndexForMfmaLayout(Location loc, RewriterBase &rewriter, + const AMDMfmaEncodingAttr &mfmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto rank = shape.size(); + assert(rank == 2 || rank == 3); + auto _warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + unsigned mDim = mfmaLayout.getMDim(); + unsigned nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout)); + Value effectiveWarpSize = warpSize; + if (mDim == 4 && nDim == 4) { + const int uniqueValuesPerWarp = 4; + effectiveWarpSize = i32_val(uniqueValuesPerWarp); + } + Value laneId = urem(threadId, effectiveWarpSize); + Value warpId = udiv(threadId, warpSize); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, _warpsPerCTA, + triton::gpu::getWarpOrder(mfmaLayout)); + if (shape[rank - 2] >= mDim) { + assert(shape[rank - 2] % mDim == 0); + multiDimWarpId[rank - 2] = + urem(multiDimWarpId[rank - 2], + i32_val(ceil(shape[rank - 2], mDim))); + } + if (shape[rank - 1] >= nDim) { + assert(shape[rank - 1] % nDim == 0); + multiDimWarpId[rank - 1] = + urem(multiDimWarpId[rank - 1], + i32_val(ceil(shape[rank - 1], nDim))); + } + Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mDim)); + Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(nDim)); + + SmallVector multiDimBase(rank); + if (mfmaLayout.getIsTransposed()) { + multiDimBase[rank - 1] = + add(mul(i32_val(4), udiv(laneId, i32_val(mDim))), offWarp1); + multiDimBase[rank - 2] = add(urem(laneId, i32_val(mDim)), offWarp0); + } else { + multiDimBase[rank - 2] = + add(mul(i32_val(4), udiv(laneId, i32_val(nDim))), offWarp0); + multiDimBase[rank - 1] = add(urem(laneId, i32_val(nDim)), offWarp1); + } + // TODO(Lixun): It is assumed when rank = 3, warpsPerCTA is set to + // {numWarps, 1, 1}. We need to generalize the offset computation. + if (rank == 3) { + assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); + multiDimBase[0] = urem(warpId, i32_val(shape[0])); + } + return multiDimBase; +} + +inline void emitMfmaOffsetForCTA(const AMDMfmaEncodingAttr &mfmaLayout, + SmallVector> &offsets, + unsigned bOff, unsigned ctaOffsetX, + unsigned ctaOffsetY) { + auto mDim = mfmaLayout.getMDim(); + auto nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + // MFMA output tile consists of repeated "dot operand B" layout groups along + // row axis. This variable defines number of these groups. + DenseMap groups{{4, 1}, {16, 1}, {32, 4}}; + unsigned numGroups = groups.at(std::min(mDim, nDim)); + const unsigned elemsPerThreadPerGroup = 4; + auto warpSize = getWarpSize(mfmaLayout); + assert(warpSize == 64); + auto shapePerCta = getShapePerCTATile(mfmaLayout); + auto rank = shapePerCta.size(); + SmallVector elemOff(rank, 0); + for (unsigned block = 0; block < numGroups; block++) { + unsigned rowOrColOffset = + block * elemsPerThreadPerGroup * warpSize / std::min(mDim, nDim); + for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { + if (mfmaLayout.getIsTransposed()) { + elemOff[rank - 2] = ctaOffsetX * shapePerCta[rank - 2]; + elemOff[rank - 1] = + ctaOffsetY * shapePerCta[rank - 1] + elem + rowOrColOffset; + } else { + elemOff[rank - 2] = + ctaOffsetX * shapePerCta[rank - 2] + elem + rowOrColOffset; + elemOff[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; + } + if (rank == 3) + elemOff[0] = bOff; + offsets.push_back(elemOff); + } + } +} + +inline SmallVector> +emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout, + RankedTensorType type) { + auto tensorShape = type.getShape(); + SmallVector> offsets; + auto shapePerCTA = getShapePerCTA(mfmaLayout, tensorShape); + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + auto rank = type.getRank(); + SmallVector numReps(rank); + unsigned mDim = mfmaLayout.getMDim(); + unsigned nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 2] = mDim; + shapePerWarp[rank - 1] = nDim; + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); + unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); + numReps[d] = ceil(inPerWarp, shapePerWarp[d]); + } + + unsigned repBatch = rank == 3 ? numReps[0] : 1; + auto warpsPerBatch = + rank == 3 ? std::min(tensorShape[0], warpsPerCTA[0]) : 1; + + for (unsigned b = 0; b < repBatch; ++b) { + for (unsigned i = 0; i < numReps[rank - 2]; ++i) { + for (unsigned j = 0; j < numReps[rank - 1]; ++j) { + emitMfmaOffsetForCTA(mfmaLayout, offsets, b * warpsPerBatch, i, j); + } + } + } + return offsets; +} + +inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout, + SmallVector> &offsets, + unsigned ctaBatchOffset, unsigned ctaOffsetX, + unsigned ctaOffsetY) { + const unsigned elemsPerThreadPerGroup = 8; + auto warpSize = getWarpSize(wmmaLayout); + assert(warpSize == 32); + auto shapePerCta = getShapePerCTATile(wmmaLayout); + auto rank = shapePerCta.size(); + assert(rank == 2 || rank == 3); + SmallVector elemOffset(rank, 0); + if (rank == 3) + elemOffset[0] = ctaBatchOffset; + for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { + elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; + elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; + offsets.push_back(elemOffset); + } +} + +inline SmallVector +emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, + const AMDWmmaEncodingAttr &wmmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto _warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + auto rank = _warpsPerCTA.size(); + assert(rank == 2 || rank == 3); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(wmmaLayout)); + Value laneId = + urem(threadId, i32_val(triton::gpu::getWarpSize(wmmaLayout) / 2)); + Value threadIdPerWarp = urem(threadId, warpSize); + + Value warpId = udiv(threadId, warpSize); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, _warpsPerCTA, + triton::gpu::getWarpOrder(wmmaLayout)); + if (shape[rank - 2] >= mnkDim[0]) { + assert(shape[rank - 2] % mnkDim[0] == 0); + multiDimWarpId[rank - 2] = + urem(multiDimWarpId[rank - 2], + i32_val(ceil(shape[rank - 2], mnkDim[0]))); + } + if (shape[rank - 1] >= mnkDim[1]) { + assert(shape[rank - 1] % mnkDim[1] == 0); + multiDimWarpId[rank - 1] = + urem(multiDimWarpId[rank - 1], + i32_val(ceil(shape[rank - 1], mnkDim[1]))); + } + Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mnkDim[0])); + Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(mnkDim[1])); + + SmallVector multiDimBase(rank); + + multiDimBase[rank - 2] = + add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + multiDimBase[rank - 1] = add(laneId, offWarp1); + + // TODO: It is assumed when rank = 3, warpsPerCTA is set to + // {numWarps, 1, 1}. We need to generalize the offset computation. + if (rank == 3) { + assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); + multiDimBase[0] = urem(warpId, i32_val(shape[0])); + } + return multiDimBase; +} + +inline SmallVector> +emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout, + RankedTensorType type) { + auto tensorShape = type.getShape(); + SmallVector> offsets; + auto shapePerCTA = getShapePerCTA(wmmaLayout, tensorShape); + auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + + auto rank = tensorShape.size(); + assert(rank == 2 || rank == 3); + + SmallVector numWarpsPerDim(rank, 1); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 2] = mnkDim[0]; + shapePerWarp[rank - 1] = mnkDim[1]; + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); + unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); + numWarpsPerDim[d] = ceil(inPerWarp, shapePerWarp[d]); + } + + unsigned repBatch = rank == 3 ? numWarpsPerDim[0] : 1; + unsigned repM = numWarpsPerDim[rank - 2]; + unsigned repN = numWarpsPerDim[rank - 1]; + auto warpsPerBatch = + rank == 3 ? std::min(tensorShape[0], warpsPerCTA[0]) : 1; + + for (unsigned b = 0; b < repBatch; ++b) { + for (unsigned i = 0; i < repM; ++i) { + for (unsigned j = 0; j < repN; ++j) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, b * warpsPerBatch, i, j); + } + } + } + return offsets; +} + +inline SmallVector> +emitOffsetForLayout(Attribute layout, RankedTensorType type); + +inline SmallVector> +emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout, + RankedTensorType type) { + auto parentEncoding = sliceLayout.getParent(); + unsigned dim = sliceLayout.getDim(); + auto parentShape = sliceLayout.paddedShape(type.getShape()); + RankedTensorType parentTy = + RankedTensorType::get(parentShape, type.getElementType(), parentEncoding); + auto parentOffsets = emitOffsetForLayout(parentEncoding, parentTy); + if (parentOffsets.empty()) + return {}; + + SmallVector> resultOffsets; + std::set> uniqueOffsets; + + for (unsigned i = 0; i < parentOffsets.size(); ++i) { + SmallVector offsets(parentOffsets[i].begin(), + parentOffsets[i].end()); + offsets.erase(offsets.begin() + dim); + if (auto [it, inserted] = uniqueOffsets.insert(offsets); inserted) { + resultOffsets.push_back(offsets); + } + } + + // It can happen that after deduplicating elements above, resultOffsets has + // fewer than getTotalElementsPerThread() elements. In that case repeat the + // sequence. + int elemsPerThread = triton::gpu::getTotalElemsPerThread(type); + assert(resultOffsets.size() > 0); + assert(elemsPerThread % resultOffsets.size() == 0); + int numRepeats = elemsPerThread / resultOffsets.size(); + SmallVector> ret; + for (int i = 0; i < numRepeats; ++i) { + for (unsigned j = 0; j < resultOffsets.size(); ++j) { + ret.push_back(SmallVector(resultOffsets[j])); + } + } + return ret; +} + +// ----------------------------------------------------------------------- +// Get offsets / indices for any layout +// ----------------------------------------------------------------------- + +inline SmallVector emitCTAOffsetForLayout(Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + Attribute layout, + ArrayRef shape) { + unsigned rank = shape.size(); + SmallVector CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); + SmallVector CTASplitNum = triton::gpu::getCTASplitNum(layout); + SmallVector CTAOrder = triton::gpu::getCTAOrder(layout); + SmallVector shapePerCTA = + triton::gpu::getShapePerCTA(CTASplitNum, shape); + + // Delinearize clusterCTAId + Value clusterCTAId = target.getClusterCTAId(rewriter, loc); + SmallVector multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + // CTA Wrapping + for (unsigned i = 0; i < rank; ++i) { + // This wrapping rule must be consistent with getShapePerCTA + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + multiDimClusterCTAId[i] = urem(multiDimClusterCTAId[i], i32_val(splitNum)); + } + + SmallVector CTAOffset(rank); + for (unsigned i = 0; i < rank; ++i) + CTAOffset[i] = mul(multiDimClusterCTAId[i], i32_val(shapePerCTA[i])); + + return CTAOffset; +} + +inline SmallVector +emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { + auto shape = type.getShape(); + + SmallVector baseIndex; + RewriterBase::InsertionGuard guard(rewriter); + SmallVector result; + if (auto blockedLayout = mlir::dyn_cast(layout)) { + result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter, + blockedLayout, type); + } else if (auto mmaLayout = mlir::dyn_cast(layout)) { + if (mmaLayout.isVolta()) + result = + emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter, mmaLayout, type); + if (mmaLayout.isAmpere() || mmaLayout.isHopper()) + result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, mmaLayout, + type); + } else if (auto mfmaLayout = mlir::dyn_cast(layout)) { + result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type); + } else if (auto wmmaLayout = mlir::dyn_cast(layout)) { + result = emitBaseIndexForWmmaLayout(loc, rewriter, wmmaLayout, type); + } else if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(type.getShape()); + RankedTensorType parentTy = + RankedTensorType::get(parentShape, type.getElementType(), parentLayout); + result = emitBaseIndexForLayoutImpl(loc, rewriter, target, parentLayout, + parentTy, withCTAOffset); + result.erase(result.begin() + sliceLayout.getDim()); + // CTAOffset has been added in emitBaseIndexForLayout of parentLayout + return result; + } else { + llvm_unreachable("unsupported emitBaseIndexForLayout"); + } + if (withCTAOffset) { + auto CTAOffset = + emitCTAOffsetForLayout(loc, rewriter, target, layout, shape); + assert(CTAOffset.size() == result.size() && "Rank mismatch"); + for (unsigned k = 0; k < result.size(); ++k) { + // Individual elements of `result` may be null. In the caller + // (emitBaseIndexForLayout), we assert that all such dimensions are sliced + // off. + if (!result[k]) + continue; + result[k] = add(result[k], CTAOffset[k]); + } + } + return result; +} + +inline SmallVector +emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { + SmallVector idx = emitBaseIndexForLayoutImpl( + loc, rewriter, target, layout, type, withCTAOffset); + + // Check that any null values were sliced out. + for (Value v : idx) { + if (!v) { + llvm::errs() << "Failed to generate indexing code, possibly due to bad " + "#mma layout. Please rerun your program with " + "MLIR_ENABLE_DUMP=1 and file a bug." + << "\nloc: " << loc << "\nlayout: " << layout + << "\ntype: " << type << "\nwithCTAOffset: " << withCTAOffset + << "\n"; + llvm::report_fatal_error("Failed to generate indexing code"); + } + } + + return idx; +} + +inline SmallVector> +emitOffsetForLayout(Attribute layout, RankedTensorType type) { + if (auto blockedLayout = dyn_cast(layout)) + return emitOffsetForBlockedLayout(blockedLayout, type); + if (auto mmaLayout = dyn_cast(layout)) { + if (mmaLayout.isVolta()) + return emitOffsetForMmaLayoutV1(mmaLayout, type); + if (mmaLayout.isAmpere()) + return emitOffsetForMmaLayoutV2(mmaLayout, type); + if (mmaLayout.isHopper()) + return emitOffsetForMmaLayoutV3(mmaLayout, type); + } + if (auto mfmaLayout = mlir::dyn_cast(layout)) { + return emitOffsetForMfmaLayout(mfmaLayout, type); + } + if (auto wmmaLayout = mlir::dyn_cast(layout)) { + return emitOffsetForWmmaLayout(wmmaLayout, type); + } + if (auto sliceLayout = mlir::dyn_cast(layout)) + return emitOffsetForSliceLayout(sliceLayout, type); + llvm_unreachable("unsupported emitOffsetForLayout"); +} + +// Eventually this will become the only emitIndices function. +std::optional>> +emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset); + +// Emit indices calculation within each ConversionPattern, and returns a +// [elemsPerThread X rank] index matrix. +inline SmallVector> +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset, + bool allowLL = true) { + // Eventually the LinearLayout path will be the only one. For now we allow + // both paths so we can test that they produce the same results. + if (allowLL && target.enableLinearLayout()) { + std::optional>> llOffsets = + emitIndicesUsingLinearLayouts(loc, rewriter, target, layout, type, + withCTAOffset); + if (llOffsets.has_value()) + return *llOffsets; + } + + // step 1, delinearize threadId to get the base index + auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, target, layout, + type, withCTAOffset); + // step 2, get offset of each element + auto offset = emitOffsetForLayout(layout, type); + // step 3, add offset to base, and reorder the sequence + // of indices to guarantee that elems in the same + // sizePerThread are adjacent in order + auto shape = type.getShape(); + unsigned rank = shape.size(); + unsigned elemsPerThread = offset.size(); + SmallVector> multiDimIdx(elemsPerThread, + SmallVector(rank)); + for (unsigned n = 0; n < elemsPerThread; ++n) + for (unsigned k = 0; k < rank; ++k) + multiDimIdx[n][k] = add(multiDimBase[k], i32_val(offset[n][k])); + + return multiDimIdx; +} + +/* ---------------- */ +/* ---------------- */ +inline DenseMap getSwizzledSharedPtrs( + Location loc, const TargetInfoBase &target, unsigned inVec, + RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout, + Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter, + SmallVectorImpl &offsetVals, SmallVectorImpl &srcStrides) { + // This utility computes the pointers for accessing the provided swizzled + // shared memory layout `resSharedLayout`. More specifically, it computes, + // for all indices (row, col) of `srcEncoding` such that idx % inVec = 0, + // the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] + + // colOff) where : + // phase = (row // perPhase) % maxPhase + // rowOff = row + // colOff = colOffSwizzled + colOffOrdered + // colOffSwizzled = ((col // outVec) ^ phase) * outVec + // colOffOrdered = (col % outVec) // minVec * minVec + // + // Note 1: + // ------- + // Because swizzling happens at a granularity of outVec, we need to + // decompose the offset into a swizzled factor and a non-swizzled + // (ordered) factor + // + // Note 2: + // ------- + // If we have x, y, z of the form: + // x = 0b00000xxxx + // y = 0byyyyy0000 + // z = 0b00000zzzz + // then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y + // This means that we can use some immediate offsets for shared memory + // operations. + auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); + auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); + Value dstPtrBase = gep(dstPtrTy, resElemTy, smemObj.base, dstOffset); + + auto srcEncoding = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto srcShapePerCTA = triton::gpu::getShapePerCTA(srcTy); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + // swizzling params as described in TritonGPUAttrDefs.td + unsigned outVec = resSharedLayout.getVec(); + unsigned perPhase = resSharedLayout.getPerPhase(); + unsigned maxPhase = resSharedLayout.getMaxPhase(); + // Order + auto inOrder = triton::gpu::getOrder(srcEncoding); + auto outOrder = triton::gpu::getOrder(resSharedLayout); + assert(maxPhase == 1 || + outVec * maxPhase <= srcShape[outOrder[0]] && + "Swizzling would generate out of bounds memory accesses"); + // Tensor indices held by the current thread, as LLVM values + auto srcIndices = emitIndices(loc, rewriter, target, srcEncoding, srcTy, + /*withCTAOffset=*/false); + // Swizzling with leading offsets (e.g. Hopper GMMA) + unsigned swizzlingByteWidth = 0; + if (resSharedLayout.getHasLeadingOffset()) { + if (perPhase == 4 && maxPhase == 2) + swizzlingByteWidth = 32; + else if (perPhase == 2 && maxPhase == 4) + swizzlingByteWidth = 64; + else if (perPhase == 1 && maxPhase == 8) + swizzlingByteWidth = 128; + else + llvm::report_fatal_error("Unsupported shared layout."); + } + unsigned numElemsPerSwizzlingRow = + swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth(); + Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow); + unsigned leadingDimOffset; + if (outOrder.size() >= 2) { + leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]]; + } else { + leadingDimOffset = numElemsPerSwizzlingRow; + } + + Value leadingDimOffsetVal = i32_val(leadingDimOffset); + // Return values + DenseMap ret; + // cache for non-immediate offsets + DenseMap cacheCol, cacheRow; + unsigned minVec = std::min(outVec, inVec); + Value strideRow = outOrder.size() >= 2 ? srcStrides[outOrder[1]] : i32_val(0); + Value strideCol = srcStrides[outOrder[0]]; + LDBG("getSwizzledSharedPtrs: perPhase = " + << perPhase << " maxPhase = " << maxPhase << " minVec = " << minVec + << " inVec = " << inVec << " outVec = " << outVec << " strideRow " + << strideRow << " strideCol " << strideCol); + for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { + Value offset = i32_val(0); + // Extract multi dimensional index for current element + auto idx = srcIndices[elemIdx]; + Value idxCol = idx[outOrder[0]]; // contiguous dimension + Value idxRow; + if (outOrder.size() >= 2) { + idxRow = idx[outOrder[1]]; // discontiguous dimension + } else { + idxRow = i32_val(0); + } + // compute phase = (row // perPhase) % maxPhase + Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase)); + // extract dynamic/static offset for immediate offsetting + unsigned immedateOffCol = 0; + unsigned immedateOffRow = 0; + if (leadingDimOffset) { + // hopper + offset = + mul(udiv(idxCol, numElemsPerSwizzlingRowVal), leadingDimOffsetVal); + // Shrink by swizzling blocks + idxCol = urem(idxCol, numElemsPerSwizzlingRowVal); + strideRow = numElemsPerSwizzlingRowVal; + } + if (auto add = dyn_cast_or_null(idxCol.getDefiningOp())) { + if (auto _cst = dyn_cast_or_null( + add.getRhs().getDefiningOp())) { + unsigned cst = + cast(_cst.getValue()).getValue().getSExtValue(); + unsigned key = cst % (outVec * maxPhase); + cacheCol.insert({key, idxCol}); + idxCol = cacheCol[key]; + immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase); + } + } + if (auto add = dyn_cast_or_null(idxRow.getDefiningOp())) { + if (auto _cst = dyn_cast_or_null( + add.getRhs().getDefiningOp())) { + unsigned cst = + mlir::cast(_cst.getValue()).getValue().getSExtValue(); + unsigned key = cst % (perPhase * maxPhase); + cacheRow.insert({key, idxRow}); + idxRow = cacheRow[key]; + immedateOffRow = cst / (perPhase * maxPhase) * (perPhase * maxPhase); + } + } + // row offset is simply row index + Value rowOff = mul(idxRow, strideRow); + // because swizzling happens at a granularity of outVec, we need to + // decompose the offset into a swizzled factor and a non-swizzled + // (ordered) factor: colOffSwizzled = ((col // outVec) ^ phase) * outVec + // colOffOrdered = (col % outVec) // minVec * minVec + Value colOffSwizzled = xor_(udiv(idxCol, i32_val(outVec)), phase); + colOffSwizzled = mul(colOffSwizzled, i32_val(outVec)); + Value colOffOrdered = urem(idxCol, i32_val(outVec)); + colOffOrdered = udiv(colOffOrdered, i32_val(minVec)); + colOffOrdered = mul(colOffOrdered, i32_val(minVec)); + Value colOff = add(colOffSwizzled, colOffOrdered); + // compute non-immediate offset + if (outOrder.size() == 3) + offset = add(offset, mul(idx[outOrder[2]], srcStrides[outOrder[2]])); + offset = add(offset, add(rowOff, mul(colOff, strideCol))); + Value currPtr = gep(dstPtrTy, resElemTy, dstPtrBase, offset); + // compute immediate offset + Value immediateOff; + if (outOrder.size() >= 2) { + immediateOff = + add(mul(i32_val(immedateOffRow), strideRow), i32_val(immedateOffCol)); + } else { + immediateOff = i32_val(immedateOffCol); + } + + ret[elemIdx] = gep(dstPtrTy, resElemTy, currPtr, immediateOff); + } + return ret; +} + +inline SmallVector loadSharedToDistributed( + Value dst, Value src, SharedMemoryObject smemObj, Type elemTy, Location loc, + ConversionPatternRewriter &rewriter, const TargetInfoBase &target) { + auto dstTy = cast(dst.getType()); + auto dstShape = dstTy.getShape(); + assert(dstShape.size() <= 2 && "Unexpected rank of loadSharedToDistributed"); + auto srcTy = cast(src.getType()); + auto dstDistributedLayout = dstTy.getEncoding(); + if (auto mmaLayout = dyn_cast(dstDistributedLayout)) { + assert((!mmaLayout.isVolta()) && + "ConvertLayout Shared->MMAv1 is not supported yet"); + } + auto srcSharedLayout = + cast(srcTy.getEncoding()); + auto srcElemTy = srcTy.getElementType(); + auto dstElemTy = dstTy.getElementType(); + LDBG("loadSharedToDistributed elemTy " << elemTy << " srcElemTy " << srcElemTy + << " dstElemTy " << dstElemTy); + auto inOrd = triton::gpu::getOrder(srcSharedLayout); + auto outOrd = triton::gpu::getOrder(dstDistributedLayout); + unsigned outVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + dstDistributedLayout, dstShape)[outOrd[0]] + : 1; + + // If the shmem layout is not swizzled, we can trivially vectorize loads + // across the whole width of the most-minor dimension of the shape, because + // Triton requires all the dims are powers of 2. + unsigned inVec = srcSharedLayout.getMaxPhase() == 1 + ? srcTy.getShape()[inOrd[0]] + : srcSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy); + SmallVector offsetVals = {smemObj.strides.size(), i32_val(0)}; + + DenseMap sharedPtrs = + getSwizzledSharedPtrs(loc, target, outVec, dstTy, srcSharedLayout, elemTy, + smemObj, rewriter, offsetVals, smemObj.strides); + assert(outElems % minVec == 0 && "Unexpected number of elements"); + unsigned numVecs = outElems / minVec; + auto wordTy = vec_ty(elemTy, minVec); + SmallVector outVals(outElems); + for (unsigned i = 0; i < numVecs; ++i) { + Value smemAddr = sharedPtrs[i * minVec]; + smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); + auto valVec = load(wordTy, smemAddr); + valVec.setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8); + for (unsigned v = 0; v < minVec; ++v) { + Value currVal = extract_element(elemTy, valVec, i32_val(v)); + outVals[i * minVec + v] = currVal; + } + } + return outVals; +} + +inline void storeDistributedToShared(Value src, ArrayRef inVals, + ArrayRef dstStrides, Value dst, + Value smemBase, Type elemTy, Location loc, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &target) { + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + auto rank = srcShape.size(); + assert(rank <= 3 && "Unexpected rank of storeDistributedToShared"); + auto dstTy = cast(dst.getType()); + auto srcDistributedLayout = srcTy.getEncoding(); + if (auto mmaLayout = dyn_cast(srcDistributedLayout)) { + assert((!mmaLayout.isVolta()) && + "ConvertLayout MMAv1->Shared is not supported yet"); + } + auto dstSharedLayout = + cast(dstTy.getEncoding()); + auto dstElemTy = dstTy.getElementType(); + auto inOrd = triton::gpu::getOrder(srcDistributedLayout); + auto outOrd = dstSharedLayout.getOrder(); + unsigned inVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + srcDistributedLayout, srcShape)[inOrd[0]] + : 1; + // If the shmem layout is not swizzled, we can trivially vectorize stores + // across the whole width of the most-minor dimension of the shape, because + // Triton requires all the dims are powers of 2. + unsigned outVec = dstSharedLayout.getMaxPhase() == 1 + ? dstTy.getShape()[inOrd[0]] + : dstSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + auto wordTy = vec_ty(elemTy, minVec); + Value word; + + SmallVector srcStrides(dstStrides); + SmallVector offsetVals(rank, i32_val(0)); + SharedMemoryObject smemObj(smemBase, elemTy, srcStrides, offsetVals); + + DenseMap sharedPtrs = + getSwizzledSharedPtrs(loc, target, inVec, srcTy, dstSharedLayout, elemTy, + smemObj, rewriter, offsetVals, srcStrides); + LDBG("storeDistributedToShared: numElems = " << numElems << " minVec = " + << minVec << " " << wordTy); + for (unsigned i = 0; i < numElems; ++i) { + if (i % minVec == 0) + word = undef(wordTy); + word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec)); + if (i % minVec == minVec - 1) { + Value smemAddr = sharedPtrs[i / minVec * minVec]; + smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); + store(word, smemAddr) + .setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8); + } + } +} + +inline Value +getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, + ConversionPatternRewriter &rewriter) { + auto elems = smemObj.getElems(); + auto types = smemObj.getTypes(); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + // pack into struct + Value llvmStruct = rewriter.create(loc, structTy); + for (const auto &v : llvm::enumerate(elems)) { + assert(v.value() && "can not insert null values"); + llvmStruct = insert_val(structTy, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +inline SmallVector +unpackLLElements(Location loc, Value llvmStruct, + ConversionPatternRewriter &rewriter) { + assert(bool(llvmStruct) && "can not unpack null values"); + if (llvmStruct.getType().isIntOrIndexOrFloat() || + isa(llvmStruct.getType()) || + isa(llvmStruct.getType())) + return {llvmStruct}; + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector results(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + results[i] = extract_val(type, llvmStruct, i); + } + return results; +} + +inline Value packLLElements(Location loc, + const LLVMTypeConverter *typeConverter, + ValueRange resultVals, + ConversionPatternRewriter &rewriter, Type type) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (!structType) { + assert(resultVals.size() == 1); + return *resultVals.begin(); + } + + auto elementTypes = structType.getBody(); + if (elementTypes.size() != resultVals.size()) { + emitError(loc) << " size mismatch when packing elements for LLVM struct" + << " expected " << elementTypes.size() << " but got " + << resultVals.size(); + } + Value llvmStruct = rewriter.create(loc, structType); + for (const auto &v : llvm::enumerate(resultVals)) { + if (!v.value()) { + emitError(loc) + << "cannot insert null values into struct, but tried to insert" + << v.value(); + } + if (v.value().getType() != elementTypes[v.index()]) { + LDBG("type " << type << " structType " << structType); + LDBG("value " << v.value()); + emitError(loc) << "invalid element type in packLLEElements. Expected " + << elementTypes[v.index()] << " but got " + << v.value().getType(); + } + llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +inline bool isLayoutMmaV1(Attribute layout) { + bool isMmaV1 = false; + if (auto mmaLayout = dyn_cast(layout)) { + isMmaV1 = mmaLayout.isVolta(); + } + if (auto sliceLayout = dyn_cast(layout)) { + isMmaV1 = isa(sliceLayout.getParent()) && + cast(sliceLayout.getParent()).isVolta(); + } + return isMmaV1; +} + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..99d90c4d7 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPU) +add_public_tablegen_target(TritonConversionPassIncGen) diff --git a/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.h b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.h new file mode 100644 index 000000000..e159406b3 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_CONVERSION_PASSES_H +#define TRITON_CONVERSION_PASSES_H + +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.td b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.td new file mode 100644 index 000000000..84150fe67 --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/Passes.td @@ -0,0 +1,37 @@ +#ifndef TRITON_CONVERSION_PASSES +#define TRITON_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> { + let summary = "Convert Triton to TritonGPU"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + // TODO: Does this pass depend on SCF? + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; + + let options = [ + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps">, + + Option<"threadsPerWarp", "threads-per-warp", + "int32_t", /*default*/"32", + "number of threads per warp">, + Option<"numCTAs", "num-ctas", + "int32_t", /*default*/"1", + "number of ctas in a cga">, + Option<"target", "target", + "std::string", /*default*/"\"\"", + "the GPU target, e.g., cuda:80, hip:gfx942"> + ]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h new file mode 100644 index 000000000..d3da1394e --- /dev/null +++ b/third_party/mthreads/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -0,0 +1,31 @@ +#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H +#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H + +#include +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps"; +constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas"; +constexpr static char AttrTargetName[] = "triton_gpu.target"; + +constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp"; + +// Create the pass with numWarps passed from cl::opt. +std::unique_ptr> createConvertTritonToTritonGPUPass(); + +// Create the pass with numWarps set explicitly. +std::unique_ptr> +createConvertTritonToTritonGPUPass(const std::string &target, int numWarps, + int threadsPerWarp = 32, int numCTAs = 1); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/CMakeLists.txt new file mode 100644 index 000000000..27cb65ce5 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) diff --git a/third_party/mthreads/include/triton/Dialect/Triton/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..f682f54a1 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,27 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS TritonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + +set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td) +mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) + +add_public_tablegen_target(TritonTableGen) diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h new file mode 100644 index 000000000..b1f1597c5 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h @@ -0,0 +1,83 @@ +#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITON_IR_DIALECT_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h.inc" +#include "triton/Dialect/Triton/IR/OpsEnums.h.inc" +#include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { + +struct GlobalMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +class DialectInferLayoutInterface + : public DialectInterface::Base { +public: + DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + inferTransOpEncoding(Attribute operandEncoding, ArrayRef order, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const = 0; + + // Note: This function only verifies the operand encoding. It doesn't infer + // the result encoding. + virtual LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const = 0; + + // Tries to compute the encoding for the result of a reshape operation that + // makes the reshape a "nop", i.e. the same GPU threads contain the same + // elements as before the reshape. Note that this is not always possible (in + // which case you'd need to choose a different layout for the input to the + // reshape). + virtual LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + // Verify that the encoding are compatible to be used together in a dot + // operation + virtual LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const = 0; +}; + +} // namespace triton +} // namespace mlir + +#endif // TRITON_IR_DIALECT_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Interfaces.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Interfaces.h new file mode 100644 index 000000000..f8f3a6f74 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Interfaces.h @@ -0,0 +1,9 @@ +#ifndef TRITON_IR_INTERFACES_H_ +#define TRITON_IR_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Traits.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Traits.h new file mode 100644 index 000000000..f34a0fd59 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Traits.h @@ -0,0 +1,120 @@ +#ifndef TRITON_IR_TRAITS_H_ +#define TRITON_IR_TRAITS_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LogicalResult.h" + +#include + +namespace mlir { +namespace OpTrait { + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +// The rationale for this trait is to prevent users from creating programs +// that would have catastrophic register pressure and cause the compiler to +// hang. +// Since H100 has 256KB registers, we should allow users to create tensors +// of size up to 256K elements. It will spill for datatypes wider than 1B, +// but we probably should limit number of elements (rather than bytes) to +// keep specs simple +int constexpr maxTensorNumElements = 1048576; + +LogicalResult verifyTensorSize(Operation *op); +LogicalResult verifyTensorLayouts(Operation *op); + +LogicalResult verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType = false); + +LogicalResult +verifySameOperandsAndResultEncoding(Operation *op, + bool allowTensorPointerType = false); + +LogicalResult verifySameLoadStoreOperandsShape(Operation *op); + +LogicalResult verifySameLoadStoreOperandsAndResultShape(Operation *op); + +} // namespace impl + +template +class TensorSizeTrait : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorSize(op); + } +}; + +// Trait applied to all Triton MLIR ops. Checks that the layouts of tensors are +// valid. +template +class VerifyTensorLayoutsTrait + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorLayouts(op); + } +}; + +template +class SameOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding(op); + } +}; + +template +class SameOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op); + } +}; + +template +class SameLoadStoreOperandsShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsShape(op); + } +}; + +template +class SameLoadStoreOperandsAndResultShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsAndResultShape(op); + } +}; + +template +class SameLoadStoreOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op, + /*allowTensorPointerType=*/true); + } +}; + +template +class SameLoadStoreOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding( + op, /*allowTensorPointerType=*/true); + } +}; + +} // namespace OpTrait +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonAttrDefs.td new file mode 100644 index 000000000..adfeaff6f --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -0,0 +1,121 @@ +#ifndef TRITON_ATTR_DEFS +#define TRITON_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Attributes for LoadOp and StoreOp +def TT_CacheModifierAttr : I32EnumAttr< + "CacheModifier", "", + [ + I32EnumAttrCase<"NONE", 1, "none">, + I32EnumAttrCase<"CA", 2, "ca">, + I32EnumAttrCase<"CG", 3, "cg">, + I32EnumAttrCase<"WB", 4, "wb">, + I32EnumAttrCase<"CS", 5, "cs">, + I32EnumAttrCase<"WT", 6, "wt">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_EvictionPolicyAttr : I32EnumAttr< + "EvictionPolicy", "", + [ + I32EnumAttrCase<"NORMAL", 1, "evict_normal">, + I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, + I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_PaddingOptionAttr : I32EnumAttr< + "PaddingOption", "", + [ + I32EnumAttrCase<"PAD_ZERO", 1, "zero">, + // We can not set the string value to "NAN" because it is a keyword in C++ + I32EnumAttrCase<"PAD_NAN", 2, "nan"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +// atomic +def TT_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"FADD", 5, "fadd">, + I32EnumAttrCase<"MAX", 6, "max">, + I32EnumAttrCase<"MIN", 7, "min">, + I32EnumAttrCase<"UMAX", 8, "umax">, + I32EnumAttrCase<"UMIN", 9, "umin">, + I32EnumAttrCase<"XCHG", 10, "exch"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Program ID dimensions. +def TT_ProgramDim : I32EnumAttr< + "ProgramIDDim", "", + [ + I32EnumAttrCase<"X", 0, "x">, + I32EnumAttrCase<"Y", 1, "y">, + I32EnumAttrCase<"Z", 2, "z">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Rounding mode. +def TT_RoundingModeAttr : I32EnumAttr< + "RoundingMode", "", + [ + I32EnumAttrCase<"RTZ", 0, "rtz">, + I32EnumAttrCase<"RTNE", 1, "rtne">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// PropagateNan. +def TT_PropagateNanAttr : I32EnumAttr< + "PropagateNan", "", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"ALL", 0xFFFF, "all">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// InputPrecision +def TT_InputPrecisionAttr : I32EnumAttr< + "InputPrecision", "", + [ + I32EnumAttrCase<"TF32", 0, "tf32">, + I32EnumAttrCase<"TF32x3", 1, "tf32x3">, + I32EnumAttrCase<"IEEE", 2, "ieee"> + ]>{ + let cppNamespace = "::mlir::triton"; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonDialect.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonDialect.td new file mode 100644 index 000000000..c917538c7 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -0,0 +1,46 @@ +#ifndef TRITON_DIALECT +#define TRITON_DIALECT + +include "mlir/IR/OpBase.td" + +def Triton_Dialect : Dialect { + let name = "tt"; + + let cppNamespace = "::mlir::triton"; + + let summary = "The Triton IR in MLIR"; + + let description = [{ + Triton Dialect. + + Dependent Dialects: + * Arith: + * addf, addi, andi, cmpf, cmpi, divf, fptosi, ... + * Math: + * exp, sin, cos, log, ... + * StructuredControlFlow: + * for, if, while, yield, condition + * ControlFlow: + * br, cond_br + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "math::MathDialect", + "scf::SCFDialect", + "cf::ControlFlowDialect" + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +include "triton/Dialect/Triton/IR/TritonTypes.td" + + +#endif // TRITON_DIALECT diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonInterfaces.td new file mode 100644 index 000000000..cfc7d0032 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -0,0 +1,15 @@ +#ifndef TRITON_INTERFACES +#define TRITON_INTERFACES + +include "mlir/IR/OpBase.td" + +def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; +def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">; +def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">; +def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">; +def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">; +def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAndResultShape">; +def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">; +def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">; + +#endif // TRITON_INTERFACES diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td new file mode 100644 index 000000000..a8ab6caa2 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td @@ -0,0 +1,1149 @@ +#ifndef TRITON_OPS +#define TRITON_OPS + +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" + + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +// +// Op Base +// +class TT_Op traits = []> : + Op { +} + +// +// Cast Ops +// +// Use cast ops in arith: +// bitcast +// fptoui, fptosi, uitofp, sitofp, +// extf, tructf, +// extui, extsi, tructi +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast int64 to pointer"; + + let arguments = (ins TT_I64Like:$src); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast pointer to int64"; + + let arguments = (ins TT_PtrLike:$src); + + let results = (outs TT_I64Like:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +// arith.bitcast doesn't support pointers +def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast between types of the same bitwidth"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + // TODO: Add verifier +} + +def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Floating point casting for custom types"; + + let description = [{ + Floating point casting for custom types (F8), and non-default rounding modes. + + F8 <-> FP16, BF16, FP32, FP64 + }]; + + let arguments = ( + ins TT_FloatTensor:$src, + OptionalAttr:$rounding + ); + + let results = (outs TT_FloatTensor:$result); + + let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)"; + + let hasVerifier = 1; +} + +// +// Arithmetic Ops +// + +def TT_ClampFOp : TT_Op<"clampf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Clamp operation for floating point types"; + + let description = [{ + Clamp operation for floating point types. + + The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. + }]; + + let arguments = ( + ins + TT_FloatLike:$x, + TT_FloatLike:$min, + TT_FloatLike:$max, + TT_PropagateNanAttr:$propagateNan + ); + + let results = (outs TT_FloatLike:$result); + + // List $propagateNan explicitly rather than relying on attr-dict to pick it + // up, because if it's inside attr-dict, its value will be printed as a + // number rather than as a meaningful string. + let assemblyFormat = "$x `,` $min `,` $max `,` `propagateNan` `=` $propagateNan attr-dict `:` type($result)"; +} + +// +// Math Ops +// + +def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise sqrt for floating point types"; + + let description = [{ + Precise sqrt for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x attr-dict `:` type($x)"; +} + +def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise div for floating point types"; + + let description = [{ + Precise div for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$y); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Most significant N bits of the 2N-bit product of two integers"; + + let description = [{ + Most significant N bits of the 2N-bit product of two integers. + }]; + + let arguments = (ins TT_IntLike:$x, TT_IntLike:$y); + + let results = (outs TT_IntLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +// +// Pointer Arith Ops +// +def TT_AddPtrOp : TT_Op<"addptr", + [Pure, + Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; +} + +def TT_AdvanceOp : TT_Op<"advance", + [Pure, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let summary = "Advance a tensor pointer by offsets"; + + let arguments = (ins TT_TensorPtr:$ptr, Variadic:$offsets); + + let results = (outs TT_TensorPtr:$result); + + let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)"; +} + +// +// Load/Store Ops +// +def TT_LoadOp : TT_Op<"load", [ + SameLoadStoreOperandsAndResultShape, + SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">, + TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Load from a tensor of pointers or from a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + Optional:$mask, + Optional:$other, + + DefaultValuedAttr{}">:$boundaryCheck, + OptionalAttr:$padding, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let results = (outs TT_Type:$result); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor pointer with boundary check and padding + OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask and other + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A utility function to build the operation with all attributes + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + ]; + + // Specify `cacheModifier` and `evictionPolicy` explicitly in the + // assemblyFormat instead of as part of attr-dict so that they get printed + // as strings rather than opaque integers. + // + // Note there's no comma between `other` and `cacheModifier` and between + // `cacheModifier` and `evictionPolicy`. This is due to an apparent + // limitation in the MLIR custom-format parser. In oilist, the initial + // keywords of each clause have to be unique, so they can't be `,`. + // + // Even if we gave up on order-independence and used vanilla optional + // clauses, the format (`,` `foo` `=` $foo^)? (`,` `bar` `=` $bar^)? will + // not match the string ", bar = 0" because after the initial comma (first + // token of the first optional clause) we expect to see "foo". + let assemblyFormat = [{ + $ptr (`,` $mask^)? (`,` $other^)? + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +def TT_StoreOp : TT_Op<"store", [ + SameLoadStoreOperandsShape, + SameLoadStoreOperandsEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"value type matches ptr type", "ptr", "value", + "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", + "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Store by a tensor of pointers or by a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + TT_Type:$value, + Optional:$mask, + DefaultValuedAttr{}">:$boundaryCheck, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)>, + // A tensor pointer with boundary check + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$boundaryCheck, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)> + ]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between mask, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $ptr `,` $value (`,` $mask^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +// +// Atomic Ops +// +def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"mask type matches value type", + "val", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "atomic rmw"; + + let description = [{ + load data at $ptr, do $rmw_op with $val, and store result to $ptr. + + return old value at $ptr + }]; + + let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr, + TT_Type:$val, Optional:$mask, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $atomic_rmw_op, $sem, and $scope rather than relying on + // attr-dict so they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` + functional-type(operands, $result) + }]; +} + +def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { + let summary = "atomic cas"; + + let description = [{ + compare $cmp with data $old at location $ptr, + + if $old == $cmp, store $val to $ptr, + + else store $old to $ptr, + + return $old + }]; + + let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $sem and $scope rather than relying on attr-dict so + // they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` + functional-type(operands, $result) + }]; +} + +// +// Shape Manipulation Ops +// +def TT_SplatOp : TT_Op<"splat", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "splat"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; +} + +def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "expand_dims"; + + let arguments = (ins TT_Tensor:$src, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +def TT_ReshapeOp : TT_Op<"reshape", [Pure, + SameOperandsAndResultElementType]> { + let summary = "reinterpret a tensor to a different shape. It may change elements order if the attribute is set."; + let description = [{ + reinterpret a tensor to a different shape. + + If allow_reorder is set the compiler is free to change the order of + elements to generate more efficient code. + + If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. + The compiler is still free to change it for better performance. + }]; + let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr:$efficient_layout); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + let hasCanonicalizeMethod = 1; + let hasFolder = 1; + let hasVerifier = 1; + let builders = [ + OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder), + [{ + build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr()); + }]>]; +} + +def TT_BroadcastOp : TT_Op<"broadcast", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "broadcast a tensor"; + + let description = [{ + For a given tensor, broadcast changes one or more dimensions with size 1 + to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot + change the size of a non-1 dimension. + }]; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +// cat is not `pure` because it may reorder elements +def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, + SameTypeOperands, + SameOperandsAndResultElementType]> { + let summary = "concatenate 2 tensors"; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_JoinOp : TT_Op<"join", [ + NoMemoryEffect, SameTypeOperands, + DeclareOpInterfaceMethods, +]> { + let summary = "join two tensors along a new, minor dimension"; + let description = [{ + For example, if the two input tensors are 4x8xf32, returns a tensor of + shape 4x8x2xf32. + + Because Triton tensors always have a power-of-two number of elements, + the two input tensors must have the same shape. + }]; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_SplitOp : TT_Op<"split", [ + NoMemoryEffect, + DeclareOpInterfaceMethods, + TypesMatchWith<"outLHS and outRHS types match", + "outLHS", "outRHS", "$_self">, +]> { + let summary = "splits a tensor into two, along its last dimension"; + let description = [{ + The input must be a tensor whose last dimension has size 2. Returns two + tensors, src[..., 0] and src[..., 1]. + + For example, if the input shape is 4x8x2xf32, returns two tensors of + shape 4x8xf32. + }]; + + let arguments = (ins TT_Tensor:$src); + let results = (outs TT_Tensor:$outLHS, TT_Tensor:$outRHS); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($outLHS)"; +} + +def TT_TransOp : TT_Op<"trans", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + + let summary = "rearrange the dimensions of a tensor"; + let description = [{ + For example, given a tensor x with shape [1,2,4], transpose(x) with + order=[2,0,1] rearranges the tensor to have shape [4,1,2]. + + Although this op is called "trans", it implements both tl.trans() and + tl.permute(). ("permute" might be a better name, but it's called "trans" + because originally it only supported 2D tensors.) + + ## Implementation note on encodings: + + In the TritonGPU dialect (and probably others), an encoding is chosen for + this op's output so it's a nop from the perspective of code generation. + + For example, suppose tensor x has an encoding such that GPU thread [i,j,k] + has a register containing element [i,j,k] of the tensor. Now we transpose + x with order [2,1,0], i.e. we reverse the order of its dimensions. In + TritonGPU, we will choose a layout for the output of the transpose so that + GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the + same element it had before! All we've done is "rename" the element that + thread [i,j,k] has. + + The "real" transpose -- i.e. moving data between GPU threads -- occurs in + convertLayout ops that appear before and/or after the operation. + + We do this so that you can chain multiple data-movement ops (e.g. + transpose+reshape+concat) without going to shared memory after each one. + }]; + + let arguments = ( + ins TT_TensorOrMemDesc:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorOrMemDesc:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// SPMD Ops +// +def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +// +// Dot Op +// +def TT_DotOp : TT_Op<"dot", [Pure, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC + when the inputs are f32. It can be one of: tf32, tf32x3, ieee. + tf32: use TC with tf32 ops. + tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + ieee: don't use TC, implement dot in software. + If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. + }]; + + let arguments = ( + ins + TT_TensorOrMemDesc:$a, + TT_TensorOrMemDesc:$b, + TT_FpIntTensor:$c, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs TT_FpIntTensor:$d); + + // attr-dict prints enums as integers. To get inputPrecision printed as a + // string, we need to specify it explicitly. + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; + let hasVerifier = 1; +} + +// +// Reduce Op +// +def TT_ReduceOp: TT_Op<"reduce", + [Pure, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ReduceReturnOp: TT_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +// +// Scan Op +// +def TT_ScanOp: TT_Op<"scan", + [Pure, + SameOperandsAndResultEncoding, + SameOperandsAndResultShape, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Associative scan using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis, BoolAttr:$reverse); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "bool":$reverse)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ScanReturnOp: TT_Op<"scan.return", + [HasParent<"ScanOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for scan operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + + +// +// External Elementwise op +// +def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; +} + +// +// Make Range Op +// +def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { + let summary = "make range"; + + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + + // WARNING: MLIR generates getStart()/getEnd() functions which return + // uint32_t, even though these arguments are to be interpreted as *signed* + // int32 values. If this matters, use get{Start,End}Attr().getInt(), which + // return int64_t. + let arguments = (ins I32Attr:$start, I32Attr:$end); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// ElementwiseInlineAsm Op +// +def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ + Elementwise, + SameOperandsAndResultEncoding, + DeclareOpInterfaceMethods +]> { + let summary = "inline assembly applying an elementwise operation to a group of packed elements."; + let description = [{ + Runs an inline asm block to generate one or more tensors. + + The asm block is given `packed_element` elements at a time. Exactly which + elems it receives is unspecified. + }]; + + let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic>:$args); + let results = (outs Variadic:$result); + + let assemblyFormat = [{ + $asm_string attr-dict ($args^ `:` type($args))? `->` type($result) + }]; + + let hasVerifier = 1; +} + +// +// Histogram Op +// +def TT_HistogramOp : TT_Op<"histogram", [Pure]> { + let summary = "return a histgram of the inputs."; + let description = [{ + Return the histogram of the input tensor. The number of bins is equal to + the dimension of the output tensor. Each bins has a width of 1 and bins + start at 0. + }]; + + let arguments = (ins TT_IntTensor:$src); + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = [{ + $src attr-dict `:` type($src) `->` type($result) + }]; +} + +// +// Print Op +// +def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>, + Arguments<(ins StrAttr:$prefix, BoolAttr:$hex, Variadic>:$args)> { + let summary = "Device-side print, as in CUDA for debugging"; + let description = [{ + `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. + format are generated automatically from the arguments. + }]; + let assemblyFormat = [{ + $prefix attr-dict (`:` $args^ `:` type($args))? + }]; +} + +// +// Assert Op +// +def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "Device-side assert, as in CUDA for correctness checking"; + let description = [{ + `tt.assert` takes a condition tensor, a message string, a file string, a function string, and a line number. + If the condition is false, the message is printed, and the program is aborted. + }]; + let arguments = (ins TT_Tensor:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line); + let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)"; +} + +// +// Make Tensor Pointer Op +// +def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", + [Pure, + SameVariadicOperandSize, + TypesMatchWith<"infer pointer type from the result type", + "result", "base", + "getPointerType(getElementTypeOfTensorPointerType($_self))">]> { + let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified"; + + let description = [{ + `tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a + pointer to the block tensor, e.g. returns a type of `tt.ptr>`. + }]; + + // TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints. + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorPtr:$result); + + // TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly + // Add additional `[]` to increase readability and split variadic lists + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins + "Value":$base, + "ValueRange":$shape, + "ValueRange":$strides, + "ValueRange":$offsets, + "ArrayRef":$tensorShape, + "ArrayRef":$order + )> + ]; +} + +// The following ops, including `call`, `func`, and `return` are copied and modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +// We could revert it back once MLIR has a better inliner interface. +// +// Function Ops +// +def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `tt.call` operation represents a direct call to a function that is + within the same symbol scope as the call. The operands and result types of + the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); + } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + + // Required by CallOpInterface. + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). An external function declaration (used when + referring to a function declared in some other module) has no body. While + the MLIR textual form provides a nice inline syntax for function arguments, + they are internally represented as “block arguments” to the first block in + the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // External function definitions. + tt.func @abort() + tt.func @scribble(i32, i64, memref) -> f64 + + // A function that returns its argument twice: + tt.func @count(%x: i64) -> (i64, i64) + attributes {fruit: "banana"} { + return %x, %x: i64, i64 + } + + // A function with an argument attribute + tt.func @example_fn_arg(%x: i32 {swift.self = unit}) + + // A function with a result attribute + tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + + // A function with an attribute + tt.func @example_fn_attr() attributes {dialectName.attrName = false} + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getFunctionType().getResults(); } + + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; +} + +def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `tt.return` operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + The operand number and types must match the signature of the function + that contains the operation. + + Example: + + ```mlir + tt.func @foo() : (i32, f8) { + ... + tt.return %0, %1 : i32, f8 + } + ``` + }]; + + let arguments = (ins Variadic:$srcs); + + let builders = [OpBuilder<(ins), [{ + build($_builder, $_state, std::nullopt); + }]>]; + + let assemblyFormat = "attr-dict ($srcs^ `:` type($srcs))?"; + let hasVerifier = 1; +} + + +def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ + MemoryEffects<[MemRead]>]> { + let summary = "Load from descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA load operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The destination tensor type and shape must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + Variadic:$indices, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` qualified(type($desc_ptr)) `->` type($result) + }]; +} + +def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ + MemoryEffects<[MemWrite]>]> { + let summary = "store value based on descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The shape and types of `src` must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_Tensor:$src, + Variadic:$indices + ); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc_ptr)) `,` type($src) + }]; +} + +#endif // Triton_OPS diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td new file mode 100644 index 000000000..e3aed2262 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td @@ -0,0 +1,24 @@ +#ifndef TRITON_TYPE_INTERFACES +#define TRITON_TYPE_INTERFACES + +include "mlir/IR/OpBase.td" + +// Interface dynamically attached to RankedTensorType and MemDescType. +def TT_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"Returns the encoding of the tensor or memory descriptor", + "mlir::Attribute", "getEncoding", (ins)>, + InterfaceMethod<"Returns element type", + "mlir::Type", "getElementType", (ins)>, + InterfaceMethod<"Returns the type shape", + "llvm::ArrayRef", "getShape", (ins)>, + InterfaceMethod<"Returns the tensor or buffer rank", + "int64_t", "getRank", (ins)>, + InterfaceMethod<"Returns the element type bit width", + "int64_t", "getElementTypeBitWidth", (ins)>, + + ]; +} + +#endif // TRITON_TYPE_INTERFACES diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypes.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypes.td new file mode 100644 index 000000000..fd5af9cc8 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -0,0 +1,140 @@ +#ifndef TRITON_TYPES +#define TRITON_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" + +// +// Types +// +class TritonTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def TT_FloatTensor : RankedTensorOf<[TT_Float]>; +def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def TT_BoolTensor : RankedTensorOf<[I1]>; +def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; + +// Integer Type +def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def TT_IntTensor : RankedTensorOf<[TT_Int]>; +def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>; + +// I32 Type +// TT_I32 -> I32 +// TT_I32Tensor -> I32Tensor +def TT_I32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// TT_I64 -> I64 +// TT_I64Tensor -> I64Tensor +def TT_I64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class TT_PtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `TT_PtrOf`) +def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def TT_Ptr : TT_PtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def TT_PtrTensor : RankedTensorOf<[TT_Ptr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>; + +// Tensor Type +def TT_FpIntTensor : RankedTensorOf<[TT_Float, TT_Int]>; +def TT_Tensor : RankedTensorOf<[TT_Float, TT_Int, TT_Ptr]>; + +// Pointer Type to Tensor Type: `ptr>` +def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>; + +// Any Type in Triton IR +def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>; + +// Memory descriptor type. +def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { + let summary = "memory descriptor type (`::mlir::triton::MemDescType`) in Triton IR type system"; + + let description = [{ + Memory descriptor contains a base pointer (scalar) and a descriptor of the memory. + If mutable memory is false that means the memory is constant and can only be allocated and stored once. + A constant memory allocation is different than a tensor as it can have multiple views and the descriptor + can be changed without changing the underlying memory. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType, + "Attribute":$encoding, + "bool":$mutable_memory + ); + let extraClassDeclaration = [{ + MemDescType cloneWith(std::optional> shape, + Type elementType) const { + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding()); + } + + bool hasRank() const { return true; } + }]; + let builders = [ + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, /*mutableMemory=*/false); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "bool":$mutableMemory + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, mutableMemory); + }]> + ]; + let hasCustomAssemblyFormat = 1; +} + + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Types.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Types.h new file mode 100644 index 000000000..bf1967f1b --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Types.h @@ -0,0 +1,39 @@ +#ifndef TRITON_IR_TYPES_H_ +#define TRITON_IR_TYPES_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.h.inc" + +#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc" + +namespace mlir { + +namespace triton { + +bool isTensorPointerType(Type type); + +bool isTensorOrTensorPointerType(Type type); + +unsigned getPointeeBitWidth(Type type); + +Type getPointeeType(Type type); + +Type getPointerType(Type type); + +Type getElementTypeOfTensorPointerType(Type type); + +Type getI1SameShape(Type type); + +Type getI32SameShape(Type type); + +Type getPointerTypeSameShape(Type type); + +} // namespace triton + +} // namespace mlir + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Utility.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Utility.h new file mode 100644 index 000000000..0ef597147 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Utility.h @@ -0,0 +1,190 @@ +#ifndef TRITON_IR_UTILITY_H_ +#define TRITON_IR_UTILITY_H_ + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include +#include + +namespace mlir { + +template SmallVector convertType(ArrayRef in) { + SmallVector out; + for (const auto &i : in) + out.push_back(T(i)); + return out; +} + +template +SmallVector convertType(const VecU &in) { + return convertType(ArrayRef(in)); +} + +template Int product(llvm::ArrayRef arr) { + return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{}); +} +template auto product(const VecT &vec) { + return product(llvm::ArrayRef(vec)); +} + +// TODO(jlebar): Rename to ceilOfRatio. +template Int ceil(Int m, Int n) { return (m + n - 1) / n; } + +/// Get the highest power of 2 divisor of an integer. +template T highestPowOf2Divisor(T n) { + if (n == 0) { + return (static_cast(1) << (sizeof(T) * 8 - 2)); + } + return (n & (~(n - 1))); +} + +/// Get the next power of 2 for an integer (or the integer itself if it is a +/// power of 2). +template T nextPowOf2(T n) { + if (n == 0) { + return 1; + } + n--; + for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) { + n |= n >> i; + } + return n + 1; +} + +namespace triton { + +// Many functions here have two overloads, fn(ArrayRef) and fn(const VecT&). +// This is helpful because C++ won't both convert a vector to ArrayRef *and* +// infer the proper type T in one step. So without the second overload, we +// would have to explicitly convert most arguments to ArrayRef at the callsite. + +template +SmallVector applyPermutation(ArrayRef vec, ArrayRef permutation) { + static_assert(std::is_integral_v); + assert(vec.size() == permutation.size()); + + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (U i = 0; i < static_cast(sortedPerm.size()); i++) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret; + ret.reserve(vec.size()); + for (const U &i : permutation) { + ret.push_back(vec[i]); + } + return ret; +} + +template +auto applyPermutation(const VecT &vec, const PermT &permutation) { + return applyPermutation(ArrayRef(vec), ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector inversePermutation(ArrayRef permutation) { + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (int i = 0; i < sortedPerm.size(); ++i) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret(permutation.size()); + for (int i = 0; i < permutation.size(); ++i) { + ret[permutation[i]] = i; + } + return ret; +} + +template +[[nodiscard]] auto inversePermutation(const VecT &permutation) { + return inversePermutation(ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector gather(ArrayRef elems, ArrayRef indices) { + SmallVector ret; + ret.reserve(indices.size()); + for (const U &i : indices) { + ret.push_back(elems[i]); + } + return ret; +} + +template +[[nodiscard]] auto gather(const VecT &elems, const IdxT &indices) { + return gather(ArrayRef(elems), ArrayRef(indices)); +} + +// Is `vec` [0, 1, ..., n]? Returns true on empty list. +template bool isIota(ArrayRef vec) { + static_assert(std::is_integral_v); + for (T i = 0; i < vec.size(); ++i) { + if (vec[i] != i) { + return false; + } + } + return true; +} + +template bool isIota(const VecT &vec) { + return isIota(ArrayRef(vec)); +} + +// Is `vals` some permutation of the numbers 0..(vals.size()-1)? +template bool isPermutationOfIota(ArrayRef vals) { + SmallVector sorted(vals); + llvm::sort(sorted); + return isIota(sorted); +} + +template bool IsPermutationOfIota(const VecT &vec) { + return isPermutationOfIota(ArrayRef(vec)); +} + +// Is `vec` [i, i+1, ..., i+n]? Returns true on empty list. +template bool isConsecutive(ArrayRef vec) { + static_assert(std::is_integral_v); + for (int i = 1; i < vec.size(); i++) { + if (vec[i] != vec[i - 1] + 1) { + return false; + } + } + return true; +} + +template bool isConsecutive(const VecT &vec) { + return isConsecutive(ArrayRef(vec)); +} + +// LLVM's STLExtras.h provides a bunch of functions that work over ranges, but +// it's missing min/max_element until +// https://github.com/llvm/llvm-project/commit/fab2bb8b makes it into Triton. +// TODO(jlebar): Remove this once we have the LLVM helpers. +template auto min_element(R &&Range) { + return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range)); +} +template +auto min_element(R &&Range, Compare &&C) { + return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range), + std::forward(C)); +} +template auto max_element(R &&Range) { + return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range)); +} +template +auto max_element(R &&Range, Compare &&C) { + return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range), + std::forward(C)); +} + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..372a9ec11 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton) +add_public_tablegen_target(TritonTransformsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.h b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.h new file mode 100644 index 000000000..fde54fe17 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.h @@ -0,0 +1,21 @@ +#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr createCombineOpsPass(); + +std::unique_ptr createReorderBroadcastPass(); +std::unique_ptr createRewriteTensorPointerPass(); + +} // namespace triton + +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.td b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.td new file mode 100644 index 000000000..4ebff63fa --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/Triton/Transforms/Passes.td @@ -0,0 +1,44 @@ +#ifndef TRITON_PASSES +#define TRITON_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonCombineOps : Pass { + let summary = "combine ops"; + let description = [{ + dot(a, b, 0) + c => dot(a, b, c) + + addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1)) + + select(cond, load(ptrs, broadcast(cond), ???), other) => + load(ptrs, broadcast(cond), other) + }]; + + let constructor = "mlir::triton::createCombineOpsPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect"]; +} + +def TritonReorderBroadcast : Pass { + let summary = "Moves broadcast and splat after elementwise operations"; + let description = [{ + elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + }]; + let constructor = "mlir::triton::createReorderBroadcastPass()"; + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonRewriteTensorPointer : Pass { + let summary = "Rewrite load/stores with tensor pointers into legacy load/stores"; + let description = [{ + This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy + semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute + the pointer/mask/other for each load/store. + }]; + + let constructor = "mlir::triton::createRewriteTensorPointerPass()"; + + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Attributes.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Attributes.h new file mode 100644 index 000000000..a99ddfc17 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Attributes.h @@ -0,0 +1,10 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ + +#include "mlir/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" + +#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..73c9401c1 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,21 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu) +add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) +mlir_tablegen(TritonGPUAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(TritonGPUAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonGPUAttrDefsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Dialect.h new file mode 100644 index 000000000..5ae7848a0 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -0,0 +1,127 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonGPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace gpu { + +struct SharedMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +unsigned getTotalElemsPerThread(Type type); + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, + Type eltTy); + +SmallVector getElemsPerThread(Type type); + +// Returns the number of threads per warp that may have access to replicated +// elements. If you want non-replicated threads, use +// getThreadsPerWarpWithUniqueData. +SmallVector getThreadsPerWarp(Attribute layout); + +unsigned getWarpSize(Attribute layout); + +// Returns the number of warps per CTA that may have access to replicated +// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData. +SmallVector getWarpsPerCTA(Attribute layout); + +SmallVector getSizePerThread(Attribute layout); + +// Returns the number of contiguous elements that each thread +// has access to, on each dimension of the tensor. E.g. +// for a blocked layout with sizePerThread = [1, 4], returns [1, 4], +// regardless of the shape of the tensor. +SmallVector getContigPerThread(Attribute layout); + +// Returns the number of non-replicated contiguous elements that each thread +// has access to, on each dimension of the tensor. For a blocked layout +// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements +// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1, +// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be +// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4]. +SmallVector getUniqueContigPerThread(Attribute layout, + ArrayRef tensorShape); + +// Returns the number of threads per warp that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17 +// have access to the full tensor, whereas the other threads have access to +// replicated elements, so this function returns [2, 2]. +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape); + +// Returns the number of warps per CTA that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2], +// returns [1, 1], since the first warp has access to the full tensor, whereas +// the other warps have access to replicated elements. +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); + +SmallVector getWarpOrder(Attribute layout); + +SmallVector getOrder(Attribute layout); + +CTALayoutAttr getCTALayout(Attribute layout); + +SmallVector getCTAsPerCGA(Attribute layout); + +SmallVector getCTASplitNum(Attribute layout); + +SmallVector getCTAOrder(Attribute layout); + +/* The difference between ShapePerCTATile and ShapePerCTA: + * (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp * + * WarpsPerCTA in each dimension and is independent from the tensor shape. + * (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension. + * (3) In the implementation of emitIndices, ShapePerCTATile will + * be replicated or wrapped to fit ShapePerCTA. + */ +SmallVector +getShapePerCTATile(Attribute layout, + ArrayRef tensorShape = ArrayRef()); + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape); +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape); +SmallVector getShapePerCTA(Type type); + +unsigned getNumWarpsPerCTA(Attribute layout); + +unsigned getNumCTAs(Attribute layout); + +bool isaDistributedLayout(Attribute layout); + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding); + +// Return true if a view between the two types cannot be implemented as a no-op. +bool isExpensiveView(Type srcType, Type dstType); + +// Return a blocked encoding where the shape is distributed contiguously amongst +// the threads, warps, CTAs with 1 element per threads. +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs); + +} // namespace gpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h new file mode 100644 index 000000000..d4f274742 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -0,0 +1,37 @@ +// Conversions from TritonGPU layouts (e.g. BlockedEncodingAttr) to +// LinearLayout. + +#include + +#include "triton/Tools/LinearLayout.h" + +namespace mlir::triton::gpu { + +// - BlockedEncodingAttrs have the following input dimensions. +// +// "register": elements in one thread +// "lane": threads in a warp +// "warp": warps in a block/CTA +// "block": blocks in a cluster +// +// - An n-dimensional SharedEncodingAttr has the following input dimensions. +// +// "offset": the n'th element in the allocation, within a particular block +// "block": blocks in a cluster +// +// All layouts have the following output dimensions. +// +// "dimi" for i in 0..n-1: the location in the n'th logical dimension of the +// output tensor. These also are not reordered according to the layout's +// `order`. +// +// You can flatten the input or output dimensions into a single dimension using +// LinearLayout::flattenIns/Outs(). +// +// Returns std::nullopt if the given layout can't be converted to an LL. +// TODO(jlebar): Remove the std::optional once all layouts are supported. +// +std::optional toLinearLayout(ArrayRef shape, + Attribute layout); + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td new file mode 100644 index 000000000..ae23f9d13 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -0,0 +1,1301 @@ +#ifndef TRITONGPU_ATTRDEFS +#define TRITONGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +//===----------------------------------------------------------------------===// +// TritonGPU Attribute Definitions +//===----------------------------------------------------------------------===// +def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let methods = [ + InterfaceMethod<"Return total element size per thread.", + "unsigned", + "getTotalElemsPerThread", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy)>, + + InterfaceMethod<"Return element size per thread in each dimension.", + "SmallVector", + "getElemsPerThread", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy)>, + ]; +} + +class TritonGPU_Attr traits = [], + Dialect dialect = TritonGPU_Dialect, + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + + let description = [{ +TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines +how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function +\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding +to the indices of the CUDA threads allowed to access some data at index $i$. + +For example, let us consider the layout function: +\mathcal{L}(0, 0) = {0, 4} +\mathcal{L}(0, 1) = {1, 5} +\mathcal{L}(1, 0) = {2, 6} +\mathcal{L}(1, 1) = {3, 7} + +Then, attaching $\mathcal{L} to a tensor $T$ would mean that: +- T[0,0] is owned by both cuda thread 0 and 4 +- T[0,1] is owned by both cuda thread 1 and 5 +- T[1,0] is owned by both cuda thread 2 and 6 +- T[1,1] is owned by both cuda thread 3 and 7 + +Right now, Triton implements two main classes of layouts: shared, and distributed. + }]; + let attrName = "triton.gpu." # attrMnemonic; + + code extraBaseClassDeclaration = [{ + unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const; + SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const; + ::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const; + }]; +} + +//===----------------------------------------------------------------------===// +// CTA Layout +//===----------------------------------------------------------------------===// + +def CTALayoutAttr : TritonGPU_Attr<"CTALayout", "cta_layout"> { + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$CTAsPerCGA, + ArrayRefParameter<"unsigned">:$CTASplitNum, + ArrayRefParameter<"unsigned">:$CTAOrder + ); + + let description = [{ +Describes how blocks are distributed among the cooperate thread arrays (aka +CTAs, aka thread blocks) in a cooperate thread group (aka CTG, aka thread group +cluster). CGAs were introduced in Hopper (sm90). + +The tensor is divided up into CTASplitNum pieces, which are distributed among +the CTAsPerCGA thread blocks. Each CTA processes a subtensor of shape +`tensor_shape / CTASplitNum`. + +Example 0: The tensor shape is [64, 128] and, there are two CTAs, each +processing half the tensor [64, 64]. Then CTAsPerCGA = [1, 2] and +CTASplitNum = [1, 2]. + +Example 1: The tensor shape is [64, 128] and, there are two CTAs, both +processing the complete tensor [64, 128]. This happens when multicast is +enabled. In this case, CTAsPerCTA = [1, 2] but CTASplitNum = [1, 1]. + +Example 2: Consider a matmul AxB=C, where A=[M,K], B=[K,N], C=[M,N]. The +CTAsPerCGA for A, B, C are the same, [SplitM, SplitN], but the CTASplitNum are +different. CTASplitNum_A = [SplitM, 1], which means multicast on dim1, +CTASplitNum_B = [1, SplitN], which means multicast on dim0, CTASplitNum_C = +[SplitM, SplitN] which means no multicast. + +Currently programs with multiple CTAs per CGA are an experimental feature in +Triton, not enabled by default. + +You can leave off the CTALayout properties in the textual IR and Triton will +fill in the "default" CTALayout of CTAsPerCGA = CTASplitNum = [1...1]. In +addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to +[n-1,...,0] (it doesn't matter in this case). + }]; + + // CTALayout::get canonicalizes CTAOrder to [n,n-1,...,0] if CTAsPerCGA is + // [1...1]. The CTAOrder doesn't matter in this case. + // + // This is a little weird because if you write textual IR with a one order and + // then print it back out, you might get a different order. But it seems this + // is the best way to canonicalize an attribute in MLIR. + let builders = [ + AttrBuilder<(ins "ArrayRef":$CTAsPerCGA, + "ArrayRef":$CTASplitNum, + "ArrayRef":$CTAOrder), [{ + if (llvm::all_of(CTAsPerCGA, [](unsigned x) { return x == 1; })) { + SmallVector order; + for (int i = CTAsPerCGA.size() - 1; i >= 0; --i) + order.push_back(i); + return $_get(context, CTAsPerCGA, CTASplitNum, order); + } + return $_get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + }]>, + ]; + + let extraClassDeclaration = [{ + SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const { + llvm::report_fatal_error( + "Unsupported getElemsPerThread in CTALayoutAttr."); + } + unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { + llvm::report_fatal_error( + "Unsupported getTotalElemsPerThread in CTALayoutAttr."); + } + + static CTALayoutAttr getDefault(MLIRContext *context, int rank) { + SmallVector CTAsPerCGA(rank, 1); + SmallVector CTASplitNum(rank, 1); + SmallVector CTAOrder; + for (int i = rank - 1; i >= 0; --i) + CTAOrder.push_back(i); + return get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + } + }]; + + let genVerifyDecl = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// Shared Layout Encoding +//===----------------------------------------------------------------------===// + +def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> { + let mnemonic = "shared"; + + let description = [{ +An encoding for tensors whose elements may be simultaneously accessed by +different cuda threads in the programs, via shared memory. In other words, +for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. + +In order to avoid shared memory bank conflicts, elements may be swizzled. +Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1]. + +1. Basic swizzling + + #shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // xor with 0 + [ 5, 4, 7, 6], // xor with 1 + [10, 11, 8, 9], // xor with 2 + [15, 14, 13, 12] // xor with 3 + +Here elements of row r are xor'ed with r (or more properly, in[r][c] -> +out[r][c^r]). + +2. Multiple rows per phase + + #shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14] + +Elements of row r are xor'ed with r/2. In other words, perPhase=2 +means that pairs of 2 rows get the same swizzling. + +3. Max-phase applied + + $shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 5, 4, 7, 6], // phase 1 (xor with 1) + [ 8, 9, 10, 11], // phase 0 + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // ... + [21, 20, 23, 22], + [24, 25, 26, 27], + [29, 28, 31, 30] + +Elements of row r are xor'ed with (r/2) % 2. In other words, maxPhase=m has the +effect of limiting the maximum value of the xor to m-1. + +4. Max-phase and per-phase + + #shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], // phase 0 + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // phase 0 + [20, 21, 22, 23], // phase 0 + [25, 24, 27, 26], // phase 1 + [29, 28, 31, 30]] // phase 1 + +Here the xor value (the "phase", I guess?) changes every perPhase rows, up to a +maximum value of maxPhase-1. In other words, elements of row r are xor'ed with +(r/2) % 2. + +5. Adding vec + + #shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3, 4, 5, 6, 7], + [10, 11, 8, 9, 14, 15, 12, 13], + [20, 21, 22, 23, 16, 17, 18, 19], + [30, 31, 28, 29, 26, 27, 24, 25] + +When vec=2, elements are swizzled in pairs of 2. In other words, the element at +(r,c) has value + + ((c / 2) ^ r) * 2 + (c % 2). + +For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case, +when the matrix is stored in shared memory, there will be an offset not +only in the stride dimension, but also in the leading dimension. For example, +a matrix of size 16x128 and data type I8 is stored in the shared memory with +64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64, +compared to 1*64 when the hasLeadingOffset is false. + }]; + + // swizzle info: vec, perPhase, maxPhase + // order: the fastest-changing axis first + let parameters = ( + ins + "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + ArrayRefParameter<"unsigned">:$order, + "CTALayoutAttr":$CTALayout, + "bool":$hasLeadingOffset + ); + + let builders = [ + AttrBuilder<(ins "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout), [{ + bool hasLeadingOffset = false; // default value + return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit), [{ + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); + }]>, + + // TODO(jlebar): This should not be an overload of + // SharedEncodingAttr::get(). It's misleading, because it does a bunch of + // nontrivial work based on the given dotOpEnc. + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ + + // ---- begin GFX908/GFX90A ---- + if (auto mfmaEnc = mlir::dyn_cast(dotOpEnc.getParent())) { + int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + if (needTrans) + kDimNum = 1 - kDimNum; + bool isKDimInner = (order[0] == kDimNum); + if (isKDimInner) { + const int numBanks = 32; + const int bankBitWidth = 32; + const int SIMDWidth = 16; + + // number of inner dimension rows per one pattern repeat + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + // vecSize is set to kWidth of the dotop layout + int vecSize = dotOpEnc.getKWidth(); + int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); + + // TODO (zhanglx): figure out better parameters for mfma4 + if (mfmaEnc.getMDim() == 4) + maxPhase = 4; + + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return get(context, 1, 1, 1, order, CTALayout); + } + } + + // ---- begin GFX11 ---- + if (mlir::isa(dotOpEnc.getParent())) { + if (dotOpEnc.getOpIdx() == 0) { + const int numBanks = 32; + const int bankBitWidth = 32; + + // number of inner dimension rows per one pattern repeat + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit; + int maxPhase = 16 / perPhase; + + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return get(context, 1, 1, 1, order, CTALayout); + } + } + + + auto mmaEnc = mlir::dyn_cast(dotOpEnc.getParent()); + + if(!mmaEnc) + return get(context, 1, 1, 1, order, CTALayout); + + int opIdx = dotOpEnc.getOpIdx(); + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + // number of rows per phase + + // index of the inner dimension in `order` + unsigned inner = (opIdx == 0) ? 0 : 1; + + // ---- begin Volta ---- + if (mmaEnc.isVolta()) { + int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8)); + perPhase = std::max(perPhase, 1); + bool is_row = order[0] != 0; + bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) : + is_row && (shapePerCTA[order[0]] <= 16); + int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) : + ((is_row && !is_vec4) ? 2 : 1); + int rep = 2 * pack_size; + int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; + int vec = 2 * rep; + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + // ---- begin Ampere ---- + if (mmaEnc.isAmpere()) { + int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); + perPhase = std::max(perPhase, 1); + std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; + int vecWidth = 32 / typeWidthInBit; + if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) { + perPhase = std::max(perPhase, 2 * vecWidth); + } + int rank = order.size(); + // --- handle A operand --- + if (opIdx == 0) { // compute swizzling for A operand + int m = (needTrans) ? matShape[2] : matShape[0]; + int k = (needTrans) ? matShape[0] : matShape[2]; + int vec = (order[0] == rank-1) ? k : m; + int mmaStride = (order[0] == rank-1) ? m : k; + int maxPhase = mmaStride / perPhase; + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + // --- handle B operand --- + if (opIdx == 1) { + // we compute vec and maxPhase m, n and k size of the mma + // instruction. when matmul operands is transposed, we should + // consider that to get m, n and k. + int n = needTrans ? matShape[2] : matShape[1]; + int k = needTrans ? matShape[1] : matShape[2]; + int vec = (order[0] == rank-1) ? n : k; + int mmaStride = (order[0] == rank-1) ? k : n; + int maxPhase = mmaStride / perPhase; + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + llvm_unreachable("invalid operand index"); + } + + // ---- begin version 3 ---- + if (mmaEnc.isHopper()) { + llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr" + " is Hopper has not been implemented yet"); + return $_get(context, 1, 1, 1, order, CTALayout, true); + } + + // ---- not implemented ---- + llvm_unreachable("unsupported swizzling for provided MMA version"); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool":$needTrans), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy), [{ + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth(); + int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1; + + // get proper shared memory swizzling mode from the contiguous dimension + // size of the origin blocked layout. + auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8; + if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) { + perPhase = 1; + maxPhase = 8; + } else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) { + perPhase = 2; + maxPhase = 4; + } else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) { + perPhase = 4; + maxPhase = 2; + } else { + llvm_unreachable("unsupported shared memory layout for MMAv3"); + } + + return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true); + }]> + ]; + + let extraClassDeclaration = extraBaseClassDeclaration; + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// Distributed Layout Encoding +//===----------------------------------------------------------------------===// +def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ +The Distributed encoding describes the layout L with the 4-level compute hierarchy on GPU. +It is abstracted from the top to the bottom as CTAs Per CGA->Warps Per CTA->Threads Per Warp->Values Per Thread. + +For CTAs Per CGA and Warps Per CTA level, the linear id is distributed contiguously with the shape and order. +For example, for a shape/order pair defines a distribution layout +shape = [4, 4] +order = [0, 1] // The fastest-changing axis first +-> +layout = [0 4 8 12] + [1 5 9 13] + [2 6 10 14] + [3 7 11 15] + +For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + }]; + + let methods = [ + // Interface for the meta information about the multiple thread hierarchy. + InterfaceMethod<"Get the shape of the CTAs per CGA.", + "SmallVector", + "getCTAsPerCGA">, + + InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", + "SmallVector", + "getCTAOrder">, + + InterfaceMethod<"Get the shape of the warps per CTA.", + "SmallVector", + "getWarpsPerCTA">, + + InterfaceMethod<"Get the order of the warps per CTA. The fastest-changing axis first", + "SmallVector", + "getWarpOrder">, + + InterfaceMethod<"Get the shape of the threads per warp", + "SmallVector", + "getThreadsPerWarp">, + + InterfaceMethod<"Get the order of the threads per warp. The fastest-changing axis first", + "SmallVector", + "getThreadOrder">, + + InterfaceMethod<"Get the shape of the values per thread.", + "SmallVector", + "getSizePerThread">, + + InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", + "SmallVector", + "getCTASplitNum">, + + InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA", + "SmallVector", + "getShapePerCTATile", + (ins "ArrayRef":$tensorShape)>, + + InterfaceMethod<"Gets the number of contiguous elements per thread.", + "SmallVector", + "getContigPerThread">, + ]; +} + +class DistributedEncoding traits = [], + Dialect dialect = TritonGPU_Dialect> + : TritonGPU_Attr { + + let description = [{ +Distributed encodings have a layout function L that is entirely characterized +by a d-dimensional tensor T. Note that L doesn't need to have the same shape +(or even the same rank) as the tensor it is encoding. + +The layout function \mathcal{L} of this layout is then defined, for an +index `i` \in Z^d, as follows: + +\mathcal{L}(T)[i_d] = L[(i_d + k_d*T.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*T.shape[d] < L.shape[d] + +Intuitively, when the tensor dim size T.shape[d] is larger than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "wrapped around" manner, with +each thread owning multiple values. + +OTOH, when the tensor dim size T.shape[d] is smaller than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "broadcasted" manner, with +each value owned by multiple threads. + +For example, for a tensor/layout pair +T = [x x x x x x x x] + [x x x x x x x x] +L = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] + +Then the data of T would be distributed as follow between the 16 CUDA threads: +L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, + {4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ] + }]; + + code extraDistributedDeclaration = extraBaseClassDeclaration # [{ + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + SmallVector getWarpsPerCTA() const; + SmallVector getWarpOrder() const; + SmallVector getThreadsPerWarp() const; + SmallVector getThreadOrder() const; + + SmallVector getSizePerThread() const; + SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; + }]; +} + +//===----------------------------------------------------------------------===// +// Blocked Layout Encoding +//===----------------------------------------------------------------------===// + +def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding", "blocked_encoding"> { + let mnemonic = "blocked"; + + let description = [{ +An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout +used to promote memory coalescing in LoadInst and StoreInst. +It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which +specify the amount of elements owned by each CUDA thread, warp and CTA respectively. + +Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and +4 CTAs (taking 2x2 for example) as follows: + +CTA [0,0] CTA [0,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +CTA [1,0] CTA [1,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {2, 2} + CTASplitNum = {2, 2} +}> +}]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$sizePerThread__, + ArrayRefParameter<"unsigned">:$threadsPerWarp__, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first + + // CTALayout is optional in the textual IR. If omitted, we infer it to be a + // single CTA (so CTAsPerCGA = [1,...,1], CTASplitNum = [1,...,1], + // CTAOrder=[n,n-1,...,0]). + "CTALayoutAttr":$CTALayout + ); + let genVerifyDecl = 1; + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "CTALayoutAttr":$CTALayout), [{ + unsigned rank = sizePerThread.size(); + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + SmallVector shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + unsigned remainingLanes = numThreadsPerWarp; + unsigned remainingThreads = numWarps * numThreadsPerWarp; + unsigned remainingWarps = numWarps; + unsigned prevLanes = 1; + unsigned prevWarps = 1; + + // starting from the contiguous dimension + for (unsigned d = 0; d < rank - 1; ++d) { + unsigned i = order[d]; + unsigned threadsPerCTA = std::clamp(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]); + threadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); + warpsPerCTA[i] = std::clamp(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps); + remainingWarps /= warpsPerCTA[i]; + remainingLanes /= threadsPerWarp[i]; + remainingThreads /= threadsPerCTA; + prevLanes *= threadsPerWarp[i]; + prevWarps *= warpsPerCTA[i]; + } + + // Expand the last dimension to fill the remaining lanes and warps + threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes; + warpsPerCTA[order[rank - 1]] = numWarps / prevWarps; + + return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "unsigned":$numCTAs), [{ + unsigned rank = sizePerThread.size(); + SmallVector CTAsPerCGA(rank); + SmallVector CTASplitNum(rank); + ArrayRef CTAOrder = order; + + unsigned remainingCTAs = numCTAs; + + // starting from the most strided dimension + for (int d = rank - 1; d >= 0; --d) { + unsigned i = order[d]; + CTAsPerCGA[i] = std::clamp(remainingCTAs, 1, shape[i] / sizePerThread[i]); + CTASplitNum[i] = CTAsPerCGA[i]; + remainingCTAs /= CTAsPerCGA[i]; + } + + CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level + + CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout); + }]> + ]; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SliceEncodingAttr squeeze(int axis); + + SmallVector getContigPerThread() { + // Block encoding is dense stride layout. The elements per thread are contiguous. + return getSizePerThread(); + }; + }]; + + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// MMA Layout Encoding +//===----------------------------------------------------------------------===// +// TODO: MMAv1 and MMAv2 should be two instances of the same class +def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let methods = [ + + InterfaceMethod<"Return whether the layout support reduction op.", + "bool", + "supportReduction">, + + InterfaceMethod<"Return shape per CTA.", + "SmallVector", + "getShapePerCTATileForDotOperands", + (ins "ArrayRef":$tensorShape, + "unsigned":$opIdx)>, + + InterfaceMethod<"Return total element size per thread for dot operands.", + "unsigned", + "getTotalElemsPerThreadForOperands", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy, + "unsigned":$kWidth, + "unsigned":$opIdx)>, + + InterfaceMethod<"Return size per thread for dot operands.", + "SmallVector", + "getSizePerThreadForOperands", + (ins "unsigned":$opIdx)>, + ]; +} + +def AMDMfmaEncodingAttr : DistributedEncoding<"AMDMfmaEncoding", "amd_mfma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_mfma"; + + let description = [{ +An encoding for tensors that have been produced by MFMA matrix core instructions, +available on AMD Instinct GPUs of CDNA architectures. + +It is characterized by the following parameters: +- `versionMajor` and `versionMinor` indicates the GPU architecture: + - 1.0: gfx908, i.e. MI100 + - 2.0: gfx90a: i.e. MI200, MI210, MI250 + - 3.0: gfx940, gfx941, gfx942: MI300 +- `warpsPerCTA` indicates the wave layout in the workgroup. +- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction. +- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout +without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel). + +Example 1: +Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32. +The data will be distributed between threads as follows: + + wave 0 wave 1 +-----------------/\-------------- -----------------/\-------------- +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] + +Example 2: +Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and MDim=NDim=16. +The data will be distributed between threads as follows: + + wave 0 wave 1 +-----------------/\------------- ------------------/\--------------- +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] + +Example 3: +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4. +The data will be distributed between threads as follows(note that each element is duploicated in 16 threads): +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4. +The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): + +M N -> wave 0 wave 2 +| --------------------------/\-------------------------- ------------------------------/\------------------------------ +V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + wave 1 wave 3 + --------------------------/\-------------------------- ------------------------------/\------------------------------ + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] +}]; + + let parameters = ( + ins + "unsigned": $versionMajor, + "unsigned": $versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "unsigned":$MDim, + "unsigned":$NDim, + "bool":$isTransposed, + "CTALayoutAttr":$CTALayout + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool supportReduction() const { + return true; + } + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getMFMAInstrShapeForOperands(int kWidth, int opIdx) const; + SmallVector getMFMARepForOperands(ArrayRef operandShape, int kWidth, int opIdx) const; + + SmallVector getContigPerThread() { + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + if (getIsTransposed()) + contigPerThread[rank - 1] = 4; + else + contigPerThread[rank - 2] = 4; + return contigPerThread; + }; + + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + +def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_wmma"; + + let description = [{ +An important limitation of WMMA for layout is a shape for tiles proccessed +by a single wave. It is [16, 16]. +This encoding assumes specific access to matrix elements by threads. + +Example: +Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3]. + + wave 0 [16, 16] wave 1 [16, 16] wave 2 [16, 16] +-----------/\---------- -----------/\---------- -----------/\---------- +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +... ... ... +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + + wave 3 [16, 16] wave 4 [16, 16] wave 5 [16, 16] +-----------/\---------- -----------/\---------- -----------/\---------- +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +... ... ... +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + }]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "CTALayoutAttr":$CTALayout + ); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool supportReduction() const { + return true; + } + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getWMMAElemsPerInstrForOperands() const; + SmallVector getWMMARepForOperands(ArrayRef operandShape, + Type elemType, int kWidth, int opIdx) const; + static SmallVector getMNKDimPerWMMAInstr(); + + SmallVector getContigPerThread() { + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + return contigPerThread; + }; + }]; +} + +def NvidiaMmaEncodingAttr : DistributedEncoding<"NvidiaMmaEncoding", "nvidia_mma_encoding", [MmaEncodingTrait]> { + let mnemonic = "nvidia_mma"; + + let description = [{ +An encoding for tensors that have been produced by tensor cores. + +It is characterized by two parameters: +- A 'versionMajor' which specifies the generation the tensor cores + whose output is being partitioned: + - 1 for first-gen tensor cores (Volta), and + - 2 for second-gen tensor cores (Turing/Ampere). +- A 'versionMinor' which indicates the specific layout of a tensor core + generation, e.g. for Volta, there might be multiple kinds of layouts + annotated by 0,1,2 and so on. +- A `blockTileSize` to indicate how data should be partitioned between warps. + +// -------------------------------- version = 1 --------------------------- // + +For first-gen tensor cores, the implicit warpTileSize is [16, 16]. +Note: the layout is different from the recommended in PTX ISA +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.884 section, FP32 accumulator). + +For example, when versionMinor=1, the matrix L corresponding to +blockTileSize=[32,16] is: + + warp 0 +--------------------------------/\------------------------------- +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] + + warp 1 = warp0 + 32 +--------------------------------/\------------------------------- +[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ] +[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ] +[ ............................................................... ] + + +// -------------------------------- version = 2 --------------------------- // + +For second-gen tensor cores, the implicit warpTileSize is [16, 8]. +Information about this layout can be found in the official PTX documentation +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.16816 section, FP32 accumulator). + +For example, the matrix L corresponding to blockTileSize=[32,16] is: + warp 0 warp 2 +-----------------/\------------- ----------------/\------------- +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 + + warp 1 warp 3 +----------------/\------------- ----------------/\------------- +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 + +}]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "CTALayoutAttr":$CTALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + let builders = [ + // Specially for MMAV1(Volta) + AttrBuilder<(ins "int":$versionMajor, + "int":$numWarps, + "CTALayoutAttr":$CTALayout, + "ArrayRef":$instrShape, + "ArrayRef":$shapeC, + "bool":$isARow, + "bool":$isBRow, + "bool":$isAVec4, + "bool":$isBVec4, + "int":$id), [{ + assert(versionMajor == 1 && "This builder is specially for versionMajor==1"); + // 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4] + int versionMinor = (isARow * (1<<0)) |\ + (isBRow * (1<<1)) |\ + (isAVec4 * (1<<2)) |\ + (isBVec4 * (1<<3)); + + // TODO: Share code with + // DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the + // rep,spw and fpw. + SmallVector wpt({1, 1}); + SmallVector wpt_nm1; + + SmallVector rep(2), spw(2); + std::array fpw{{2, 2, 1}}; + int packSize0 = (isARow || isAVec4) ? 1 : 2; + rep[0] = 2 * packSize0; + spw[0] = fpw[0] * 4 * rep[0]; + + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; + rep[1] = 2 * packSize1; + spw[1] = fpw[1] * 4 * rep[1]; + + do { + wpt_nm1 = wpt; + if (wpt[0] * wpt[1] < numWarps) + wpt[0] = std::clamp(wpt[0] * 2, 1, shapeC[0] / spw[0]); + if (wpt[0] * wpt[1] < numWarps) + wpt[1] = std::clamp(wpt[1] * 2, 1, shapeC[1] / spw[1]); + } while (wpt_nm1 != wpt); + + return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape); + }]>, + + + AttrBuilder<(ins "int":$versionMajor, + "int":$numWarps, + "CTALayoutAttr":$CTALayout, + "ArrayRef":$instrShape, + "ArrayRef":$shapeA, + "ArrayRef":$shapeB, + "ArrayRef":$shapeC, + "bool":$isARow, + "bool":$isBRow, + "int":$id), [{ + assert(versionMajor == 1 && "This builder is specially for versionMajor==1"); + bool isAVec4 = !isARow && (shapeA[isARow] <= 16); + bool isBVec4 = isBRow && (shapeB[isBRow] <= 16); + return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id); + }]> + ]; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool isVolta() const; + bool isTuring() const; + bool isAmpere() const; + bool isHopper() const; + + unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef shape) const; + + // Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor + std::tuple decodeVoltaLayoutStates() const; + + // Number of bits in versionMinor to hold the ID of the MMA encoding instance. + // Here 5 bits can hold 32 IDs in a single module. + static constexpr int numBitsToHoldMmaV1ID{5}; + + // For MMA v1, method `getMMAv1IsRow` returns whether e.g. the a operand is used + // in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation + // section 9.7.13.4.1 for more details. + bool getMMAv1IsRow(int opIdx) const; + bool getMMAv1IsVec4(int opIdx) const; + int getMMAv1NumOuter(ArrayRef shape, int opIdx) const; + SmallVector getMMAv1Rep(int opIdx) const; + SmallVector getMMAv1ShapePerWarp(int opIdx) const; + int getMMAv1Vec(int opIdx) const; + SmallVector getMMAv2Rep(ArrayRef shape, + int bitwidth, int opIdx) const; + + bool supportReduction() const { + if (isAmpere() || isHopper()) { + return true; + } + return false; + }; + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + + SmallVector getContigPerThread() { + assert(isVolta() || isAmpere() || isHopper()); + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + contigPerThread[rank - 1] = 2; + return contigPerThread; + }; + + }]; + + let hasCustomAssemblyFormat = 1; +} + +def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> { + let mnemonic = "slice"; + + let description = [{ + Given a `parent` layout and a `dim`, squeezes the given `dim` in the `parent` + layout and distributes values in a tensor T according to the new layout. + + For example, given + + T = [x x x x x x x x] + L_parent = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] (with 16 CUDA threads) + + With dim = 0, squeezing out dim 0, we have + L = [{0,4,8,12}, {1,5,9,13}, {2,6,10,14}, {3,7,11,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L(T) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ] + + With dim = 1, squeezing out dim 1, we have + L = [ {0,1,2,3}, {4,5,6,7}, {8,9,10,11}, {12,13,14,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L = [ {0,1,2,3}, {4,5,6,7}, ..., {12,13,14,15}, {0,1,2,3}, ..., {12,13,14,15} ] + + This is useful for constructing the inverse layout of an expand_dims operation + during some optimization passes. + }]; + + let parameters = ( + ins + "unsigned":$dim, + // TODO: constraint here to only take distributed encodings + "Attribute":$parent + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + template + SmallVector paddedShape(ArrayRef shape) const; + + SmallVector getContigPerThread() { + auto parentLayout = mlir::cast(getParent()); + auto parentContigPerThread = parentLayout.getContigPerThread(); + parentContigPerThread.erase(parentContigPerThread.begin() + getDim()); + return parentContigPerThread; + }; + }]; + + let hasCustomAssemblyFormat = 1; +} + +def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> { + let mnemonic = "dot_op"; + + let description = [{ +In the TritonGPU dialect, given `d = tt.dot a, b, c` tt.dot's operands a and b +must be of DotOperandEncodingAttr layout, if the dot is MMA v1 or v2 (i.e. +pre-Hopper). For MMA v3, the operands are *almost always* in a regular shared +encoding, but sometimes the LHS is also a dot-operand encoding. + +a's opIdx is 0, b's opIdx is 1. + +The parent field is the layout of d. + +kWidth defines number of consecutive elements stored by one thread along k dimension. +Some layouts do not use this parameter, either because they have a fixed number of +elements along the K dim, or they use all elements of the tensor along the K dim. + }]; + + let parameters = ( + ins + "unsigned":$opIdx, + "Attribute":$parent, + DefaultValuedParameter<"unsigned", "0">:$kWidth + ); + + let builders = [ + // Specially for MMAV1(Volta) + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "Type":$eltTy), [{ + NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); + if (!parentAttr || !parentAttr.isAmpere()) + return $_get(context, opIdx, parent, 0); + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + unsigned MMAv2kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, MMAv2kWidth); + }]> + ]; + + let assemblyFormat = "`<` `{` struct(params) `}` `>`"; + let genVerifyDecl = 1; + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getContigPerThread() { + return getSizePerThread(); + }; + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td new file mode 100644 index 000000000..10f2c8c68 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -0,0 +1,54 @@ +#ifndef TRITONGPU_DIALECT +#define TRITONGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonGPU_Dialect : Dialect { + let name = "triton_gpu"; + + let cppNamespace = "::mlir::triton::gpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton GPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "mlir::gpu::GPUDialect", + "tensor::TensorDialect", + ]; + + let extraClassDeclaration = [{ + static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static int getNumWarps(ModuleOp mod) { + if (!mod->hasAttr("triton_gpu.num-warps")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-warps attribute"); + return cast(mod->getAttr("triton_gpu.num-warps")).getInt(); + } + static int getNumCTAs(ModuleOp mod) { + if (!mod->hasAttr("triton_gpu.num-ctas")) + return 1; + return cast(mod->getAttr("triton_gpu.num-ctas")).getInt(); + } + void registerTypes(); + + static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; } + + static int getThreadsPerWarp(ModuleOp mod) { + Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp"); + if(!threadsPerWarp) { + return 32; + } + return cast(threadsPerWarp).getInt(); + } + }]; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h new file mode 100644 index 000000000..0ee2cfeca --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h @@ -0,0 +1,6 @@ +#ifndef TRITON_GPU_DIALECT_INTERFACES_H +#define TRITON_GPU_DIALECT_INTERFACES_H + +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc" + +#endif // TRITON_GPU_DIALECT_INTERFACES_H diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td new file mode 100644 index 000000000..2530009cb --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -0,0 +1,235 @@ +#ifndef TRITONGPU_OPS +#define TRITONGPU_OPS + +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +class TTG_Op traits = []> : + Op { +} + +def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", + [SameOperandsAndResultShape, + SameOperandsAndResultElementType, + Pure]> { + let summary = "convert layout"; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { + let summary = "async wait"; + + let arguments = (ins Variadic:$asyncToken, I32Attr:$num); + + let results = (outs TTG_AsyncToken:$retToken); + + let assemblyFormat = "$asyncToken attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { + let summary = "async commit group"; + + let results = (outs TTG_AsyncToken:$asyncToken); + let arguments = (ins Variadic:$inputTokens); + + let assemblyFormat = [{ + $inputTokens attr-dict + }]; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + TypesMatchWith<"infer mask type from src type", + "src", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, + TypesMatchWith<"infer other type from src type", + "src", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 4) || std::equal_to<>()"> +]> { + let summary = "copy data from global memory to local memory asynchronously"; + + let description = [{ + This operation copies data from global memory to local memory asynchronously. + This is analogue to tt.load except the data are copied to local memory pointed + by by the memory descriptor instread of a distributed tensor. The rest of the + operands are the same as tt.load. + }]; + + let arguments = ( + ins TT_PtrTensor:$src, + TT_MemDescType:$result, + Optional:$mask, + Optional:$other, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let builders = [ + OpBuilder<(ins "Value":$src, "Value":$result, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + ]; + + let results = (outs TTG_AsyncToken:$token); + + let extraClassDeclaration = [{ + static DenseSet getEligibleLoadByteWidth(int computeCapability) { + DenseSet validLoadBytes; + if (computeCapability >= 80) { + validLoadBytes = {4, 8, 16}; + } + return validLoadBytes; + } + }]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between other, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $src `,` $result (`mask` $mask^)? (`other` $other^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($src) `->` type($result) + }]; +} + + +// Allocate shared memory +def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods]> { + let summary = "allocate tensor"; + let description = [{ + This operation allocates buffer in shared memory and return a descriptor + containing the address and a view of the buffer. + + Explicitly deallocating a buffer is optional; see local_dealloc. + }]; + let arguments = (ins Optional:$src); + + let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}]; + + let results = (outs TT_MemDescType:$result); +} + +// Deallocate shared memory +def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree]>]> { + let summary = "dealloc buffer"; + + let description = [{ + This operation deallocates a buffer explicitly. Using the buffer after this + operation is undefined. + + This operation is optional. If you don't explicitly dealloc a buffer, the + compiler assumes it's deallocated at the first point that post-dominates all + uses of the alloc. + + Because we assume a memdesc is dead at the first point that post-dominates + its uses, ops that wait for an async operation on a memdesc to complete + (such as triton_nvidia_gpu.dot_wait) should also take the memdesc as an + operand. + }]; + + let arguments = (ins TT_MemDescType:$src); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}]; +} + +def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> { + let summary = "take a subview of the descriptor."; + + let description = [{ + This operation returns a new descriptor representing a subview of the buffer. + It doesn't affect the underlying memory. The subview can be rank-reduced. + + For example, suppose that + - the input shape is 2x4x16xf16, + - the output shape is 4x4xf16, and + - offsets = [1, 0, 4]. + + Then in Python syntax, the subview covers input[1][0:4][4:8]. + }]; + let arguments = ( + ins TT_MemDescType:$src, Variadic:$offsets); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}]; + + let results = (outs TT_MemDescType:$result); + + let hasVerifier = 1; +} + +def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods]> { + let summary = "Load a buffer from local memory into a distributed tensor"; + + let description = [{ + Load a tensor from the local memory descriptor into a distributed tensor. + }]; + let arguments = (ins TT_MemDescType:$src, Optional :$token); + + let builders = [ + OpBuilder<(ins "Type":$retType, "Value":$src), + [{ + build($_builder, $_state, retType, src, /*token=*/static_cast(nullptr)); + }]>]; + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; + + let results = (outs TT_Tensor:$result); +} + +def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods]> { + let summary = "Store a distributed tensor into a buffer in local memory"; + + let description = [{ + Store a distributed tensor into a buffer in local memory. + }]; + let arguments = (ins TT_Tensor:$src, TT_MemDescType:$dst); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{ + $src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst)) + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td new file mode 100644 index 000000000..6765ac40c --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td @@ -0,0 +1,36 @@ +#ifndef TRITONGPU_TYPES +#define TRITONGPU_TYPES + +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class TTG_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTG_TokenType : TTG_TypeDef<"Token", "token"> { + let parameters = (ins "int32_t":$type); + + let builders = [ + TypeBuilder<(ins "unsigned":$type), [{ + return $_get($_ctxt, type); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +def TTG_AsyncToken : TTG_TypeDef<"AsyncToken", + "async.token", []> { + let summary = "async token type"; + let description = [{ + `ttg.async.token` is a type returned by an asynchronous operation. + It is used to establish an SSA-based link between async operations + and operations that group or synchronize the async operations. + }]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Types.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Types.h new file mode 100644 index 000000000..edf37fef6 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/Types.h @@ -0,0 +1,10 @@ +#ifndef TRITONGPU_IR_TYPES_H_ +#define TRITONGPU_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..6be94d1a8 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU) +add_public_tablegen_target(TritonGPUTransformsIncGen) diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.h new file mode 100644 index 000000000..98c137d86 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -0,0 +1,21 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { +namespace gpu { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +} // namespace gpu +} // namespace triton +} // namespace mlir +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.td new file mode 100644 index 000000000..9f41727de --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -0,0 +1,71 @@ +#ifndef TRITONGPU_PASSES +#define TRITONGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { + let summary = "coalesce"; + + let description = [{ + TODO + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + + +def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> { + let summary = "remove superfluous layout conversions"; + + let description = [{ + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + +} + +def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality", "mlir::ModuleOp"> { + let summary = "Reduce the cost of synchronization between threads in an SM"; + + let description = [{ + Today, this optimizes reduction yielded by loop to be thread-local until after the loop completes. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> { + let summary = "Reorder instructions"; + + let description = "This pass reorder instructions so as to (1) decrease register pressure (e.g., by moving " + "conversions from shared memory before their first use) and (2) promote LLVM instruction " + "order more friendly to `ptxas`."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReduceDataDuplication: Pass<"tritongpu-reduce-data-duplication", "mlir::ModuleOp"> { + let summary = "Reduce data duplication in register by decomposing convert[distributed -> dotOperand] " + "into convert[distributed -> shared -> dotOperand]"; + + let description = "Decomposing conversions this way makes it possible to use CSE and reuse #shared tensors"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and-if", "mlir::ModuleOp"> { + let summary = "Combine tensor select and if"; + + let description = "For select instruction that uses the same condidtion as the if instruction in the same block " + "this pass combines the select into the if instruction, making the select operands returned by the " + "then/else yields."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h new file mode 100644 index 000000000..fbfa235fc --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while converting to the TritonGPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class TritonGPUTypeConverter : public TypeConverter { +public: + TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp, + int numCTAs); + int getNumWarps() const { return numWarps; } + int getThreadsPerWarp() const { return threadsPerWarp; } + int getNumCTAs() const { return numCTAs; } + +private: + MLIRContext *context; + int numWarps; + int threadsPerWarp; + int numCTAs; +}; + +class TritonGPUConversionTarget : public ConversionTarget { + +public: + explicit TritonGPUConversionTarget(MLIRContext &ctx, + TritonGPUTypeConverter &typeConverter); +}; + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Utility.h new file mode 100644 index 000000000..114c18142 --- /dev/null +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -0,0 +1,177 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include + +namespace mlir { + +namespace triton { +class ModuleAxisInfoAnalysis; +class LoadOp; +class StoreOp; +class FuncOp; +namespace gpu { +class SharedEncodingAttr; +} +} // namespace triton + +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + TensorOrMemDesc type, + int numWarps); + +/// Returns true if the Load uses block pointer. +bool isLoadFromTensorPtr(triton::LoadOp op); + +// Return an array of indices enumerating the elements of 'arr' in descending +// order (so that result[i] is the index of the i-th largest element of 'arr') +SmallVector argSort(const SmallVector &arr); + +// Return the operand used to access the memory in the operation +Value getMemAccessPtr(Operation *op); + +// Return bitwidth of tensor element +unsigned getElementBitWidth(RankedTensorType type); + +// Calculate the optimal number of elements per thread for a given operation +// along an axis with greatest continuity. +unsigned +getNumElementsPerThread(Operation *op, SmallVector order, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis); + +/* Dump Triton IR in graphviz dot format. + * + * You can override `onValue` and `onOperation` in a subclass to mark + * specific Values and Operations. The below subclass + * GraphLayoutMarker is an example. + * + * Default NodeInfo for Value nodes: + * {{"shape": "box"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", shapeStr}} + * + * Default NodeInfo for Operation nodes: + * {{"shape": "ellipse"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", operationName}} + * + * If the key "label" is not set by `onValue` or `onOperation`, default labels + * will be generated. For Value node, the default label is the shape string and + * for Operation node, it is the operation name. + * + * Reference: + * https://graphviz.org/doc/info/shapes.html + * https://graphviz.org/doc/info/colors.html + * + * Usage: + * C++: GraphDumper().dumpToFile(func, "func.dot"); + * Shell: dot -Tjpg func.dot -o func.jpg + */ +class GraphDumper { +public: + using NodeInfo = std::map; + + // Override this function to mark specific Values + virtual NodeInfo onValue(Value value) const; + // Override this function to mark specific Operations + virtual NodeInfo onOperation(Operation *op) const; + + std::string dump(triton::FuncOp func) const; + void dumpToFile(triton::FuncOp func, const std::string &filename) const; + +protected: + std::string getShapeStr(const Type &type) const; + + std::string getUniqueId(Value value) const; + std::string getUniqueId(Operation *op) const; + + std::string emitNode(const std::string &id, const NodeInfo style) const; + std::string emitEdge(const std::string &srcId, + const std::string &destId) const; + + std::string emitValueNode(Value value) const; + std::string emitOperationNode(Operation *op) const; +}; + +/* A subclass of GraphDumper that marks different layout kinds in different + * colors.*/ +class GraphLayoutMarker : public GraphDumper { +public: + NodeInfo onValue(Value value) const override; + +protected: + std::string getColor(const Type &type) const; +}; + +// Infers the encoding of the result of op given the source encoding. +std::optional inferDstEncoding(Operation *op, Attribute encoding); + +// Infers the encoding of the source of op given the result encoding. +std::optional inferSrcEncoding(Operation *op, Attribute encoding); + +bool isExpensiveLoadOrStore(Operation *op); + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); + +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separately for the loop to be correct. +scf::ForOp replaceForOpWithNewSignature( + RewriterBase &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements); +scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, + ValueRange newIterOperands); + +// Replace IfOp with a new IfOp with extra results operands. The YieldOp is not +// updated and needs to be updated separately for the bodies to be correct. +scf::IfOp replaceIfOpWithNewSignature( + RewriterBase &rewriter, scf::IfOp loop, TypeRange newResultTypes, + SmallVectorImpl> &replacements); + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping); + +// Get backward slice of tensor values starting from the root node along with +// encoding propagation. +LogicalResult getConvertBackwardSlice( + Value root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation = nullptr); + +// Populate pattern to remove dead cycles in ForOp. +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(OpBuilder &b, Location loc, unsigned linear, + ArrayRef shape); + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape); +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape); + +// Return true if the op is a pure elementwise_inline_asm op with a single +// operand and single result. +bool isPureUnaryInlineAsm(Operation *op); + +// read the compute capability from the module attributes +int getNVIDIAComputeCapability(Operation *module); + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/third_party/mthreads/include/triton/Target/CMakeLists.txt b/third_party/mthreads/include/triton/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/mthreads/include/triton/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/mthreads/include/triton/Target/LLVMIR/CMakeLists.txt b/third_party/mthreads/include/triton/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..1f6c1b351 --- /dev/null +++ b/third_party/mthreads/include/triton/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVMIR) +add_public_tablegen_target(LLVMIRIncGen) diff --git a/third_party/mthreads/include/triton/Target/LLVMIR/Passes.h b/third_party/mthreads/include/triton/Target/LLVMIR/Passes.h new file mode 100644 index 000000000..27ecb5c3d --- /dev/null +++ b/third_party/mthreads/include/triton/Target/LLVMIR/Passes.h @@ -0,0 +1,17 @@ +#ifndef TRITON_TARGET_LLVM_IR_PASSES_H +#define TRITON_TARGET_LLVM_IR_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Create a pass to add DIScope +std::unique_ptr createLLVMDIScopePass(); + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "triton/Target/LLVMIR/Passes.h.inc" + +} // namespace mlir + +#endif // TRITON_TARGET_LLVM_IR_PASSES_H diff --git a/third_party/mthreads/include/triton/Target/LLVMIR/Passes.td b/third_party/mthreads/include/triton/Target/LLVMIR/Passes.td new file mode 100644 index 000000000..999b0b889 --- /dev/null +++ b/third_party/mthreads/include/triton/Target/LLVMIR/Passes.td @@ -0,0 +1,15 @@ +#ifndef TRITON_TARGET_LLVMIR_PASSES +#define TRITON_TARGET_LLVMIR_PASSES + +include "mlir/Pass/PassBase.td" + +def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> { + let summary = "Materialize LLVM line info"; + let description = [{ + This pass materializes line mapping information for LLVM IR dialect operations. + }]; + + let constructor = "mlir::createLLVMDIScopePass()"; +} + +#endif diff --git a/third_party/mthreads/include/triton/Tools/LinearLayout.h b/third_party/mthreads/include/triton/Tools/LinearLayout.h new file mode 100644 index 000000000..fb2680241 --- /dev/null +++ b/third_party/mthreads/include/triton/Tools/LinearLayout.h @@ -0,0 +1,532 @@ +#ifndef TRITON_TOOLS_LINEARLAYOUT_H +#define TRITON_TOOLS_LINEARLAYOUT_H + +#include +#include +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::triton { + +// # High-level overview of linear layouts +// +// The idea for linear layouts is due to Adam P. Goucher. +// +// In Triton, a linear layout (LL) is a function that maps from a "hardware +// location" to a "logical tensor index". +// +// For example, suppose we have a 2D tensor T stored in GPU registers. T's +// layout is the function that, given a "hardware location" tuple of (thread-id, +// warp-id), returns an index (x,y) into T. In other words, if L(t,w) = (x,y) +// is our linear layout func, then a register in thread t in warp w contains the +// value T[x,y]. +// +// The key fact about LLs is, the mapping from (t,w) to (x,y) is not arbitrary. +// We only need to specify the value of L(t,w) at certain special points +// (namely, the values L(t,0) and L(0,w) where t and w are powers of 2), and +// from those we can compute all the other values of L. +// +// Here's an example LL where we have 4 warps and 4 threads per warp, and the +// tensor T has shape 4x4. We define the function L by choosing the values of +// L(0,1), L(0,2), L(1,0), and L(2,0). Our choices are shown below. +// +// t/w 0 1 2 3 +// 0 ? (0,1) (0,2) ? +// L(t,w) = 1 (1,1) ? ? ? +// 2 (2,2) ? ? ? +// 3 ? ? ? ? +// +// You only need to specify these four values to define the whole linear layout. +// These special values are called the "basis vectors" or "bases" of the layout. +// We complete the table by xor'ing together the bases, according to the +// following rule. (I write "⊕" for xor.) +// +// L(t1 ⊕ t2, w1 ⊕ w2) = L(t1, w1) ⊕ L(t2, w2) (linearity rule). +// +// The linearity rule plus our four choices allows us to fill in the whole +// table. Here's how we might compute some of the values. +// +// L(0,0) = L(1 ⊕ 1, 0 ⊕ 0) = L(1,0) ⊕ L(1,0) = (1,1) ⊕ (1,1) = (0,0) +// L(0,3) = L(0 ⊕ 0, 2 ⊕ 1) = L(0,2) ⊕ L(0,1) = (0,2) ⊕ (0,1) = (0,3) +// L(3,0) = L(2 ⊕ 1, 0 ⊕ 0) = L(2,0) ⊕ L(1,0) = (2,2) ⊕ (1,1) = (3,3) +// L(3,3) = L(3 ⊕ 0, 0 ⊕ 3) = L(3,0) ⊕ L(0,3) = (3,3) ⊕ (0,3) = (3,0). +// +// (Notice it's a consequence of the linearity rule that L(0,0) = (0,0), no +// matter what values we chose for the table.) +// +// The whole table looks like this. +// +// t/w 0 1 2 3 +// 0 (0,0) (0,1) (0,2) (0,3) +// L(t,w) = 1 (1,1) (1,0) (1,3) (1,2) +// 2 (2,2) (2,3) (2,0) (2,1) +// 3 (3,3) (3,2) (3,1) (3,0). +// +// Careful readers will recognize this as a classic "swizzled" layout where +// (t, w) -> (t, w ⊕ t). To go from this formula to an LL, you only need to +// compute the results at input points (0,1), (0,2), (1,0), and (2,0). + +// Indeed the whole point of LLs is that they allow us to specify transposed and +// swizzled layouts as a "general case". Instead of a layout class for +// registers in a thread, and another layout for registers in a thread but in +// MMAv2 order, and so on, all of these can be represented by different LLs. +// This gets rid of special cases and lets us write more general code. +// +// In this example, L was a 2D -> 2D function, but LLs are general MD -> ND +// functions. In practice, a GPU register layout usually has input dims (reg, +// thread-id, warp-id, block-id), where reg represents the fact that one thread +// may store values for the tensor in multiple registers. +// +// To summarize, a linear layout is a function from tuples of integers to tuples +// of integers. We specify some key values of the function, and then we can +// compute all the other values using the linearity rule. +// +// Here are the key things you can do with linear layout objects. +// +// 1. Given an LL, construct a new LL by modifying it or combining it with +// another LL. +// +// 2. "Apply" an LL, i.e. use it to map an input index to an output index. +// A function for this that uses LLVM-dialect MLIR as its input and output +// lives in TritonGPUToLLVM.h. +// +// 3. Convert an existing Triton layout (e.g. BlockedLayoutAttr) to an LL. +// These functions live in TritonGPU/LinearLayoutConversions.h. During +// TTGIR -> LLVM codegen, we convert Triton layouts to linear layouts and +// then apply them. In the future, we intend to remove the Triton layouts +// entirely. +// +// # Examples of linear layouts +// +// 1. The 1D identity layout. This maps L(x) = x. +// +// Recall that our bases are the values of L(x) where x is a power of two. +// So for e.g. an 8-element layout, we have L(1) = 1, L(2) = 2, L(4) = 4, and +// therefore our bases are [1, 2, 4]. +// +// 2. The 1D zeros layout. This maps L(x) = 0. +// +// For an 8-element layout, we have L(1) = L(2) = L(4) = 0, so our bases are +// [0, 0, 0]. +// +// 3. A 2D -> 2D identity layout. Our basis vectors are the values of L(x,0) +// and L(0,y) where x and y are powers of two. The bases are +// +// - L(0,1) = (0,1) +// - L(0,2) = (0,2) +// - L(1,0) = (1,0) +// - L(2,0) = (2,0). +// +// 4. A 2D -> 2D transpose layout. For a 4x4 layout, we have: +// +// - L(0,1) = (1,0) +// - L(0,2) = (2,0) +// - L(1,0) = (0,1) +// - L(2,0) = (0,2). +// +// 5. A 1D -> 1D "transpose" layout. Consider the 16-element layout that maps +// +// x = 0 1 2 3 4 5 6 7 8 9 A B C D E F +// L(x) = 0 4 8 C 1 5 9 D 2 6 A E 3 7 B F. +// +// The bases are [L(1), L(2), L(4), L(8)] = [4, 8, 1, 2]. You can also think +// of this as a rearrangement of the 1D identity layout [1, 2, 4, 8]. +// +// 6. A 2D -> 1D broadcasted layout. L(x,y) = x. For a 4x4 -> 4 layout, our +// bases are +// +// - L(0,1) = 0 +// - L(0,2) = 0 +// - L(1,0) = 1 +// - L(2,0) = 2. +// +// # Implementation notes +// +// ## Dimension order +// +// An LL's input and output dimensions have an order. This order only affects +// the reshapeIns/Outs operations, where the layout is logically flattened +// according to the dimension order and then chopped up again. +// +// ## Surjectivity +// +// We require that all output values are covered by some input value, i.e. the +// function L is surjective. But multiple input values can map to the same +// output value. This represents the idea that the same logical tensor element +// can be stored in multiple places in the hardware. +// +// ## Why map hardware loc -> tensor index and not the other way around? +// +// In Triton, a linear layout usually tells us which logical tensor value is +// stored at a particular place in the hardware. For example, an LL might map +// the tuple (thread-id, warp-id, block-id) to a 2D index into a tensor, (x,y), +// meaning that the register at (t,w,b) has value tensor[x,y]. Or it might map +// from a shared memory (offset, block) to a tensor index. +// +// It might seem more natural to go the other way around, from tensor index to +// place in the hardware. But a particular tensor[x,y] value might be stored in +// more than one place in the hardware, so if we went in this direction, the +// layout would no longer be a proper function. This would complicate +// everything else. +// +// # Optional mathematical background: Linear functions over GF(2) +// +// (You shouldn't need to understand this math to use linear layouts, but it +// helps with the implementation.) +// +// One way to define a linear function is to say it's any function F that can be +// written as +// +// L(a) = a1 * B1 + a2 * B2 + ... + aM * BM, +// +// where +// +// - a is a vector [a1...aM], and ai is a scalar in some field 𝔽 (for +// example, ai might be a real number), and +// - each Bj is a vector [b1j, b1j, ..., bNj] of N scalars in 𝔽. +// +// We can also write this as a matrix-vector product Ba, where +// +// - a is the column vector [a1, ..., aM] and +// +// - B is the matrix formed by concatenating the column vectors B1, ..., BM: +// +// | ↑ ↑ ↑ | +// B = | B1, B2, ..., BM| +// | ↓ ↓ ↓ | +// +// |b11, b12, ..., b1M| +// |b21, b22, ..., b2M| +// = | ↓ ↓ ↓ | +// |bN1, bN2, ..., bNM|. +// +// Usually when we do linear algebra, the field 𝔽 from which `ai` and `bij` are +// drawn is the real or complex numbers. But in linear layouts, we let 𝔽 be a +// different field: GF(2). +// +// GF(2) is the two-element field of bits. To define a field, I need to give +// you the set of elements and also addition and multiplication operations. For +// GF(2) the elements are simply {0,1}. We define addition as xor, and +// multiplication as binary `and`. +// +// Here's an example of a 4x4 matrix-vector multiply where the elements are in +// GF(2). I'm using ⊕ to represent GF(2)'s addition operation (i.e xor) and × +// to represent multiplication (i.e. binary `and`). +// +// | 1 0 0 0 | | 0 | | 1 | | 0 | | 0 | | 0 | +// | 0 1 1 0 | | 1 | = | 0 | × 0 ⊕ | 1 | × 1 ⊕ | 1 | × 1 ⊕ | 0 | × 0 +// | 0 0 1 1 | | 1 | | 0 | | 0 | | 1 | | 1 | +// | 0 0 1 1 | | 0 | | 0 | | 0 | | 1 | | 1 | +// +// | 0 | | 0 | +// = | 1 | ⊕ | 1 | +// | 0 | | 1 | +// | 0 | | 1 | +// +// | 0 | +// = | 0 |. +// | 1 | +// | 1 | +// +// This works, but it's cumbersome. It's more compact to think of the vector +// `a` as an M-bit integer, and each column Bi of the matrix B as an N-bit +// integer. Here's the same matrix-vector product written this way. +// +// = | 1 2 14 12 | × 6 +// = | 1 2 14 12 | × 0b0110 +// = (1 × 0) ⊕ (2 × 1) ⊕ (14 × 1) ⊕ (12 × 0) +// = 2 ⊕ 14 +// = 12. +// +// And we confirm that our answer of 12 is equal to the binary value 0b1100 we +// got before. +// +// Notice that the function F(a) is fully specified by the matrix B, and that +// the four columns of B tell us the values of F at power-of-two values for `a`, +// namely F(1), F(2), F(4), and F(8). In other words, we specify four results +// of F(x) (we call these the function's "basis vectors" or its "bases") and we +// can then compute any other value by xor'ing together subsets of the bases. +// +// In the case of a 1D -> 1D layout, the implementation of an LL is +// straightforward from the mathematical description. If the LL is +// higher-dimensional, we can "stack" the bit vectors to create 1D vectors. +// For example, if we have a 2D LL and we're given input tuple (0b0011, 0b1100), +// we can treat this like a 1D input 0b0011'1100 and then do the regular 1D LL +// computation. Similarly we can "unstack" the output from 1D to ND. +// +// The linearity rule presented earlier is perhaps misleading at this point. In +// the 1D view of things, we really only need +// +// L(x ⊕ y) = L(x) ⊕ L(y) (1D linearity rule), +// +// which is part of the definition of L being a linear function. The new 1D +// linearity rule plus stacking/unstacking is equivalent to the earlier +// N-dimensional linearity rule. +// +// That's all we need in order to define linear layouts mathematically! +// +// # Comaprison to Nvidia CuTe +// +// (Note, I'm not an expert on CuTe; this is my best understanding.) +// +// CuTe is a programmatic layout system that's part of Nvidia CUTLASS; see +// https://github.com/NVIDIA/cutlass/blob/629f465/media/docs/cute/00_quickstart.md +// +// LLs and CuTe solve similar problems. Before CuTe, CUTLASS v2 had many +// handcrafted layouts, "RowMajor", "VoltaTensorOpMultiplicandCongruous", etc, +// see https://www.youtube.com/watch?v=QLdUML5MCfE&t=574s. Each of these was a +// special case. CUTLASS v3 introduced CuTe layouts, which are programmable and +// subsume all of these special cases. The CUTLASS folks say this simplified +// CUTLASS, in the same way that we hope LLs will simplify Triton. +// +// Like CuTe layouts, LLs are also programmable and composible. But there are +// also some differences. +// +// - Dimensions in LLs are named; CuTe dimensions are numbered. +// - CuTe layouts can be nested; LLs cannot be. (Nesting doesn't give CuTe +// layouts additional power; any nested layout can be flattened.) +// - CuTe layouts support non-power-of-two shapes; LLs do not. In particular +// this means that LLs cannot represent padded layouts. +// - In CuTe, swizzling is a separate step applied after specifying a layout. +// In LLs, swizzling is part of the layout itself. +// - The structure of LLs allows us to programmatically search for layouts that +// satisfy certain requirements, for example a shared layout that doesn't +// have bank conflicts when read into a particular register layout. CuTe +// expects a human to choose the layout using their brain. +// - CuTe emits code that is in the critical path of your CPU and GPU programs, +// therefore it needs to be fast. It uses C++ template magic to specialize +// on known-sized dimensions, and so on. LLs themselves do not need to be +// fast; only the emitted `apply` code is on the critical path. +// - CuTe requires a CUDA compiler such as nvcc; LLs do not. +// +class LinearLayout { +private: + // bases[inDim][i] = L(0, ..., inDim=2^i, ..., 0). All other values of L are + // computed by xor'ing bases together, using the linearity rule. In addition: + // + // - Each inDim has the same set of outDims, in the same order. + // - The order of dims is minor-to-major, although this only affects reshape. + llvm::MapVector /*size=getNumOutDims()*/> + /*size=getInDimSizeLog2(inDim)*/> + bases; + + llvm::SetVector outDimNames; + +public: + using BasesT = decltype(bases); + + // The 0-dimensional layout that maps everything to 0. This is useful as a + // starting point when doing something like + // + // LinearLayout ret = LinearLayout::empty(); + // for (...) ret *= ...; + // return ret; + static LinearLayout empty() { return LinearLayout(BasesT{}, {}); } + + // Creates a 1D -> 1D layout that's the identity function, i.e. L(x) = x + // for x in [0, size). + static LinearLayout identity1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0 + // for x in [0, size). + static LinearLayout zeros1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a LinearLayout from a list of bases. These are interpreted + // according to the rules written for the member variable `bases`. + explicit LinearLayout(BasesT bases, ArrayRef outDimNames); + + // Construct a LinearLayout from an explicit list of bases. (This constructor + // is needed because llvm::MapVector does not have a constructor that accepts + // an initializer_list.) + // + // For example, given these bases + // + // L(in1=1, in2=0) = (out1=0, out2=1) + // L(in1=2, in2=0) = (out1=0, out2=2) + // L(in1=0, in2=1) = (out1=0, out2=4) + // L(in1=0, in2=2) = (out1=0, out2=8) + // L(in1=0, in2=4) = (out1=1, out2=1) + // + // we can use this constructor to build an equivalent LL: + // + // LinearLayout({ + // {"in1", {/*L(in1=1)=*/{0,1}, /*L(in1=2)=*/{0,2}}}, + // {"in2", {/*L(in2=1)=*/{0,4}, /*L(in2=2)=*/{0,8}, /*L(in2=4)=*/{1,1}}}, + // }, + // {"out1", "out2"}) + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames); + + const BasesT &getBases() const { return bases; } + + // Get the pos'th basis vector for the inDim -> outDim mapping. + // getBasis(inDim, pos) = L(0, ..., inDim = 2^pos, ..., 0). + ArrayRef getBasis(StringAttr inDim, int32_t pos) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + assert(pos < it->second.size()); + return it->second[pos]; + } + + int32_t getBasis(StringAttr inDim, int32_t pos, StringAttr outDim) const { + return getBasis(inDim, pos)[getOutDimIndex(outDim)]; + ; + } + + // These are in minor-to-major order, although if you don't flatten the dims + // (e.g. by reshaping) then the order doesn't really affect anything. + auto getInDimNames() const { return llvm::make_first_range(bases); } + ArrayRef getOutDimNames() const { + return outDimNames.getArrayRef(); + } + + // Gets the position that this outDim occupies in getOutDimNames(). Asserts + // if the dim is not present. + int32_t getOutDimIndex(StringAttr outDim) const; + + bool hasInDim(StringAttr inDim) const { return bases.contains(inDim); } + bool hasOutDim(StringAttr outDim) const { + return outDimNames.contains(outDim); + } + + int32_t getNumInDims() const { return bases.size(); } + int32_t getNumOutDims() const { return outDimNames.size(); } + + // Asserts if the dimension is not present. + int32_t getInDimSizeLog2(StringAttr inDim) const; + int32_t getInDimSize(StringAttr inDim) const { + return 1 << getInDimSizeLog2(inDim); + } + + // getOutDimSize(dim) == s means that there exists an input value that will + // produce each output value in [0,s). + // + // For example, if our bases are + // + // L(in0=1) = 1 + // L(in0=2) = 4 + // L(in1=1) = 2 + // L(in1=2) = 8 + // + // then the largest value we can produce is L(3,3) = 1 ⊕ 4 ⊕ 2 ⊕ 8 = 15 (and + // indeed we can produce all values in [0,16) by xor'ing subsets of the bases + // 1,2,4,8), so getOutDimSize(out_dim0) == 16. + // + // Asserts if the dimension is not present. + int32_t getOutDimSizeLog2(StringAttr outDim) const; + int32_t getOutDimSize(StringAttr outDim) const { + return 1 << getOutDimSizeLog2(outDim); + } + + // Reorders the in/out dimensions of the layout. This is mostly cosmetic + // (affecting e.g. the order of getIn/OutDimNames), but it also affects the + // behavior of reshape. + [[nodiscard]] LinearLayout + transposeIns(ArrayRef newInDimOrder) const; + [[nodiscard]] LinearLayout + transposeOuts(ArrayRef newOutDimOrder) const; + + // Creates a new layout which, roughly speaking, is equivalent to one where + // every element of the `outer` layout is replaced by a full instance of the + // `inner` layout. + // + // Examples: + // + // - empty() is the multiplicative identity: + // + // L * empty() == empty() * L == L. + // + // - Multiplying two identity1D layouts with disjoint in/out dimensions gives + // a 2D identity layout: + // + // identity1D(4, "i1", "o1") * identity1D(8, "i2", "o2") => + // L(i1,i2) = (i1,i2), + // + // with in-dims ("i1", "i2") and out-dims ("o1", "o2"), in that order. + // + // - If out-dims overlap, they are combined, as in the following examples. + // + // - identity1D(4, "i", "o") * identity1D(2, "i", "o") == + // identity1D(8, "i", "o") + // + // - identity1D(4, "i", "o") * zeros1D(2, "i", "o") => L(x) = x % 4 + // for x in [0,8). + // + // - zeros1D(2, "i", "o") * identity1D(4, "i", "o") => L(x) = x / 2 + // for x in [0,8). + // + // - identity1D(4, "i", "o1") * identity1D(8, "i", "o2") => + // L(x) = (x % 4, x / 4) for x in [0,32). + // + // Notice that this operation is not commutative. It's also not associative. + // TODO(jlebar): Can I modify the definition to make it associative? Pretty + // confusing if not. If I can't, add an example. + // + // Requires: Any in/out dimensions which are in both outer and inner appear in + // the same relative order. + friend LinearLayout operator*(LinearLayout inner, LinearLayout outer); + LinearLayout &operator*=(LinearLayout outer) { + *this = *this * outer; + return *this; + } + + // Computes and returns L(x, y, z). + // + // If you want to apply the layout to mlir Values instead of integers, that + // function lives in TritonGPUToLLVM/Utility.h. + SmallVector> + apply(ArrayRef> ins) const; + + // Creates a new layout which is equivalent to running this layout, then + // running `outer`. That is, + // + // - let this layout be L(x), and + // - let `outer` be O(x). + // - Then compose(outer) returns the layout (O∘L)(x), aka O(L(x)). + // + // Requires: The output dimensions of this layout equal the input dimensions + // of outer (order doesn't matter). + [[nodiscard]] LinearLayout compose(const LinearLayout &outer) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeIns( + // std::vector> + // newInDims) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeOuts( + // std::vector> + // newOutDims) const; + + std::string toString() const; + + friend bool operator==(LinearLayout lhs, LinearLayout rhs); + friend bool operator!=(LinearLayout lhs, LinearLayout rhs) { + return !(lhs == rhs); + } +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +} // namespace mlir::triton + +#endif diff --git a/third_party/mthreads/include/triton/Tools/StrUtil.h b/third_party/mthreads/include/triton/Tools/StrUtil.h new file mode 100644 index 000000000..8b59f7d2b --- /dev/null +++ b/third_party/mthreads/include/triton/Tools/StrUtil.h @@ -0,0 +1,54 @@ +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::triton { + +// Better version of llvm::join. This one works when T is an integer or any +// other type which defines operator<<(raw_ostream). +template +std::string join(C &&container, llvm::StringRef sep = ", ") { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + s << elem; + } + return ret; +} + +// Joins a container of elements into a string, using `sep` as a separator. +// +// fn is called to transform each element of the container before it's added to +// the string. fn must have one of the following two signatures. +// +// - void fn(llvm::raw_ostream&, E), where E is the element type of the +// container, or +// - T fn(E), where T is a type which can be passed to +// raw_ostream::operator<<. +// +template +std::string join(C &&container, llvm::StringRef sep, Fn &&fn) { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + + if constexpr (std::is_invocable_v) { + static_assert( + std::is_void_v< + std::invoke_result_t>); + fn(s, elem); + } else { + s << fn(elem); + } + } + return ret; +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/include/triton/Tools/Sys/GetEnv.hpp b/third_party/mthreads/include/triton/Tools/Sys/GetEnv.hpp new file mode 100644 index 000000000..d272cf97f --- /dev/null +++ b/third_party/mthreads/include/triton/Tools/Sys/GetEnv.hpp @@ -0,0 +1,87 @@ +#ifndef TRITON_TOOLS_SYS_GETENV_HPP +#define TRITON_TOOLS_SYS_GETENV_HPP + +#include +#include +#include +#include +#include +#include + +namespace mlir::triton { + +inline const std::set CACHE_INVALIDATING_ENV_VARS = { + // clang-format off + "AMDGCN_ENABLE_DUMP", + "DISABLE_FAST_REDUCTION", + "DISABLE_LLVM_OPT", + "DISABLE_MMA_V3", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", + "LLVM_ENABLE_TIMING", + "MLIR_ENABLE_DIAGNOSTICS", + "MLIR_ENABLE_DUMP", + "MLIR_ENABLE_TIMING", + "TRITON_DISABLE_LINE_INFO", + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", + "TRITON_ENABLE_LLVM_DEBUG", + "TRITON_LLVM_DEBUG_ONLY", + "USE_TTGIR_LOC", + "NVPTX_ENABLE_DUMP", + // musa backend + "MUSA_LLVMIR_ENABLE_DUMP", + "MUASM_ENABLE_DUMP", + "MTCC_ENABLE_ASM_BIN_PATH", + "MTCC_BIN_PATH", + "MUSA_ENABLE_FP8_BURST2", + // clang-format on +}; + +inline const std::set CACHE_NEUTRAL_ENV_VARS = { + "TRITON_REPRODUCER_PATH", +}; + +namespace tools { + +inline void assertIsRecognized(const std::string &env) { + bool is_invalidating = CACHE_INVALIDATING_ENV_VARS.find(env.c_str()) != + CACHE_INVALIDATING_ENV_VARS.end(); + bool is_neutral = + CACHE_NEUTRAL_ENV_VARS.find(env.c_str()) != CACHE_NEUTRAL_ENV_VARS.end(); + std::string errmsg = env + "is not recognized. " + "Please add it to triton/tools/sys/getenv.hpp"; + assert((is_invalidating || is_neutral) && errmsg.c_str()); +} + +inline std::string getStrEnv(const std::string &env) { + assertIsRecognized(env); + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return ""; + std::string result(cstr); + return result; +} + +// return value of a cache-invalidating boolean environment variable +inline bool getBoolEnv(const std::string &env) { + assertIsRecognized(env); + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return str == "on" || str == "true" || str == "1"; +} + +inline std::optional isEnvValueBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (str == "on" || str == "true" || str == "1") + return true; + if (str == "off" || str == "false" || str == "0") + return false; + return std::nullopt; +} +} // namespace tools +} // namespace mlir::triton + +#endif diff --git a/third_party/mthreads/lib/Analysis/Alias.cpp b/third_party/mthreads/lib/Analysis/Alias.cpp new file mode 100644 index 000000000..dde554319 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/Alias.cpp @@ -0,0 +1,64 @@ +#include "triton/Analysis/Alias.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir { + +AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { + if (lhs == rhs) + return lhs; + AliasInfo ret; + for (auto value : lhs.allocs) { + ret.insert(value); + } + for (auto value : rhs.allocs) { + ret.insert(value); + } + return ret; +} + +void SharedMemoryAliasAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + AliasInfo aliasInfo; + bool pessimistic = true; + // These ops may allocate a new shared memory buffer. + auto result = op->getResult(0); + + // Only LocalAllocOp creates a new buffer. + if (isa(op)) { + aliasInfo.insert(result); + pessimistic = false; + } else if (isa(op)) { + // extract_slice %src + // trans %src + aliasInfo = AliasInfo(operands[0]->getValue()); + pessimistic = false; + } else { + assert(!isa(result.getType()) && + "unknown operation creating memory descriptor"); + } + + if (pessimistic) { + return setAllToEntryStates(results); + } + // Join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(aliasInfo)); +} + +AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { + // TODO: implement + return AliasResult::MayAlias; +} + +ModRefResult SharedMemoryAliasAnalysis::getModRef(Operation *op, + Value location) { + // TODO: implement + return ModRefResult::getModAndRef(); +} + +} // namespace mlir diff --git a/third_party/mthreads/lib/Analysis/Allocation.cpp b/third_party/mthreads/lib/Analysis/Allocation.cpp new file mode 100644 index 000000000..a129cb194 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/Allocation.cpp @@ -0,0 +1,645 @@ +#include "triton/Analysis/Allocation.h" + +#include +#include +#include + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" + +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getUniqueContigPerThread; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +namespace mlir { + +//===----------------------------------------------------------------------===// +// Shared Memory Allocation Analysis +//===----------------------------------------------------------------------===// +namespace triton { + +// Bitwidth of pointers +constexpr int kPtrBitWidth = 64; + +static std::pair, SmallVector> +getCvtOrder(Attribute srcLayout, Attribute dstLayout) { + auto srcMmaLayout = mlir::dyn_cast(srcLayout); + auto srcDotLayout = mlir::dyn_cast(srcLayout); + auto dstMmaLayout = mlir::dyn_cast(dstLayout); + auto dstDotLayout = mlir::dyn_cast(dstLayout); + + assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && + "mma -> mma layout conversion is only supported on Ampere"); + + // mma or dot layout does not have an order, so the order depends on the + // layout of the other operand. + auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout) + : getOrder(srcLayout); + auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout) + : getOrder(dstLayout); + + return {inOrd, outOrd}; +} + +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (shouldUseDistSmem(srcLayout, dstLayout)) { + // TODO: padding to avoid bank conflicts + return convertType(getShapePerCTA(srcTy)); + } + + if (isMfmaToDotShortcut(srcTy, dstTy)) + return {}; + + // MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem + if (auto srcMmaLayout = mlir::dyn_cast(srcLayout)) { + if (mlir::isa(dstLayout)) { + if (isMmaToDotShortcut(srcTy, dstTy)) { + return {}; + } + } else if (auto dstMmaLayout = + mlir::dyn_cast(dstLayout)) { + if (isMmaToMmaShortcut(srcTy, dstTy)) { + return {}; + } + } + } + + assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()"); + + auto srcShapePerCTA = getShapePerCTA(srcTy); + auto dstShapePerCTA = getShapePerCTA(dstTy); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); + + unsigned rank = dstTy.getRank(); + SmallVector repShape(rank); + for (unsigned d = 0; d < rank; ++d) { + repShape[d] = + std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), + std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); + } + return repShape; +} + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec) { + auto repShape = getRepShapeForCvtLayout(op); + if (repShape.empty()) + return repShape; + auto rank = repShape.size(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + assert(!isMfmaToDotShortcut(srcTy, dstTy)); + + auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); + unsigned srcContigPerThread = + getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; + unsigned dstContigPerThread = + getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means + // that we cannot do vectorization. + unsigned innerDim = rank - 1; + inVec = outOrd[0] != innerDim ? 1 + : inOrd[0] != innerDim ? 1 + : srcContigPerThread; + outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; + + // For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the + // codegen. + if (auto mma = mlir::dyn_cast(srcLayout)) { + if (mma.getVersionMajor() == 1) { + inVec = srcContigPerThread; + } else if (mlir::isa(dstLayout)) { + // when storing from mma layout and loading in blocked layout vectorizing + // the load back gives better performance even if there is a + // transposition. + outVec = dstContigPerThread; + } + } + + if (rank <= 1) + return repShape; + // pad the last dimension + unsigned paddedDim = rank - 1; + if (auto dstBlockedLayout = mlir::dyn_cast(dstLayout)) { + paddedDim = dstBlockedLayout.getOrder()[0]; + } + unsigned pad = std::max(inVec, outVec); + repShape[paddedDim] += pad; + return repShape; +} + +// TODO: extend beyond scalars +SmallVector getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { + SmallVector smemShape; + if (isa(op.getPtr().getType())) { + // do nothing or just assert because shared memory is not used in tensor up + // to now + } else { + // need only bytes for scalar + // always vec = 1 and elemsPerThread = 1 for scalar? + smemShape.push_back(1); + } + return smemShape; +} + +SmallVector getScratchConfigForAtomicCAS(triton::AtomicCASOp op) { + return SmallVector{1}; +} + +class AllocationAnalysis { +public: + AllocationAnalysis(Operation *operation, + Allocation::FuncAllocMapT *funcAllocMap, + Allocation *allocation) + : operation(operation), funcAllocMap(funcAllocMap), + allocation(allocation) { + run(); + } + +private: + using BufferT = Allocation::BufferT; + + /// Value -> Liveness Range + /// Use MapVector to ensure determinism. + using BufferRangeMapT = llvm::MapVector>; + /// Nodes -> Nodes + using GraphT = DenseMap>; + + void run() { + getValuesAndSizes(); + resolveLiveness(); + computeOffsets(); + } + + /// Initializes explicitly defined shared memory values for a given operation. + void getExplicitValueSize(Operation *op) { + // Values returned from scf.yield will not be allocated even though they + // have the shared encoding. + // For example: %a = scf.if -> yield + // %a must be allocated elsewhere by other operations. + // FIXME(Keren): extract and insert are always alias for now + if (!maybeSharedAllocationOp(op)) + return; + + // XXX(Keren): Why this hard-coded alignment? + size_t kAlignment = 8; + for (Value result : op->getResults()) { + if (auto alloc = result.getDefiningOp()) { + // Bytes could be a different value once we support padding or other + // allocation policies. + auto allocType = alloc.getType(); + auto shapePerCTA = triton::gpu::getShapePerCTA(allocType); + auto bytes = product(shapePerCTA) * + allocType.getElementTypeBitWidth() / 8; + + // XXX(Keren): magic numbers 256 and 1024 + // benzh@maybe alignment should be passed in. + // Software swizzling calculates phase based on offset, while hardware + // swizzling do that based on physical address. Thus only by setting the + // alignment to 1024 can ensure the correctness.  + if (bytes > 256) + kAlignment = 1024; + allocation->addBuffer(result, bytes, + kAlignment); + } + } + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes, + unsigned alignment) { + if (bytes > 0) + allocation->addBuffer(op, bytes, alignment); + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes) { + if (bytes > 0) + allocation->addBuffer(op, bytes); + } + + /// Initializes temporary shared memory for a given operation. + void getScratchValueSize(Operation *op) { + const size_t scratchAlignment = 128; + if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + unsigned bytes = helper.getScratchSizeInBytes(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto scanOp = dyn_cast(op)) { + ScanLoweringHelper helper(scanOp); + unsigned bytes = helper.getScratchSizeInBytes(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto histogram = dyn_cast(op)) { + auto dstTy = histogram.getType(); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + auto bytes = std::max(dstTy.getNumElements(), threadsPerWarp) * + std::max(8, dstTy.getElementTypeBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.getSrc().getType(); + auto dstTy = cvtLayout.getType(); + auto srcEncoding = srcTy.getEncoding(); + auto dstEncoding = dstTy.getEncoding(); + if (mlir::isa(srcEncoding) || + mlir::isa(dstEncoding)) { + // Conversions from/to shared memory do not need scratch memory. + return; + } + // ConvertLayoutOp with both input/output non-shared_layout + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's + // also possible to realize it with other approaches in restricted + // conditions, such as warp-shuffle + unsigned inVec = 0; + unsigned outVec = 0; + auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto bytes = + isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto atomicRMWOp = dyn_cast(op)) { + auto value = op->getOperand(0); + // only scalar requires scratch memory + // make it explicit for readability + if (dyn_cast(value.getType())) { + // nothing to do + } else { + auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto elemTy = + cast(value.getType()).getPointeeType(); + auto bytes = + isa(elemTy) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } else if (auto atomicCASOp = dyn_cast(op)) { + // only scalar requires scratch memory + // make it explicit for readability + auto value = op->getOperand(0); + if (dyn_cast(value.getType())) { + // nothing to do + } else { + auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto elemTy = + cast(value.getType()).getPointeeType(); + auto bytes = isa(elemTy) + ? elems * kPtrBitWidth / 8 + : elems * elemTy.getIntOrFloatBitWidth() / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } else if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto funcOp = dyn_cast(callable); + auto *funcAlloc = &(*funcAllocMap)[funcOp]; + auto bytes = funcAlloc->getSharedMemorySize(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } + + void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { + dataflow::Lattice *latticeElement = + analysis.getLatticeElement(value); + if (latticeElement) { + AliasInfo &info = latticeElement->getValue(); + if (!info.getAllocs().empty()) { + for (auto alloc : info.getAllocs()) { + allocation->addAlias(value, alloc); + } + } + } + } + + /// Extract all shared memory values and their sizes + void getValuesAndSizes() { + // Get the alloc values + operation->walk([&](Operation *op) { + getExplicitValueSize(op); + getScratchValueSize(op); + }); + // Get the alias values + std::unique_ptr solver = createDataFlowSolver(); + SharedMemoryAliasAnalysis *aliasAnalysis = + solver->load(); + if (failed(solver->initializeAndRun(operation))) { + // TODO: return error instead of bailing out.. + llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); + } + operation->walk([&](Operation *op) { + for (auto operand : op->getOperands()) { + getValueAlias(operand, *aliasAnalysis); + } + for (auto value : op->getResults()) { + getValueAlias(value, *aliasAnalysis); + } + }); + } + + /// Computes the liveness range of the allocated value. + /// Each buffer is allocated only once. + void resolveExplicitBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto valueBufferIter : allocation->valueBuffer) { + auto value = valueBufferIter.first; + auto *buffer = valueBufferIter.second; + bufferRange[buffer] = getLiveness(value); + } + } + + /// Extends the liveness range by unionizing the liveness range of the aliased + /// values because each allocated buffer could be an alias of others, if block + /// arguments are involved. + void resolveAliasBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto aliasBufferIter : allocation->aliasBuffer) { + auto value = aliasBufferIter.first; + auto buffers = aliasBufferIter.second; + auto range = getLiveness(value); + for (auto *buffer : buffers) { + auto minId = range.start(); + auto maxId = range.end(); + if (bufferRange.count(buffer)) { + // Extend the allocated buffer's range + minId = std::min(minId, bufferRange[buffer].start()); + maxId = std::max(maxId, bufferRange[buffer].end()); + } + bufferRange[buffer] = Interval(minId, maxId); + } + } + } + + /// Computes the liveness range of scratched buffers. + /// Some operations may have a temporary buffer that is not explicitly + /// allocated, but is used to store intermediate results. + void resolveScratchBufferLiveness( + const DenseMap &operationId) { + // Analyze liveness of scratch buffers and virtual buffers. + auto processScratchMemory = [&](const auto &container) { + for (auto opScratchIter : container) { + // Any scratch memory's live range is the current operation's live + // range. + auto *op = opScratchIter.first; + auto *buffer = opScratchIter.second; + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } + }; + processScratchMemory(allocation->opScratch); + processScratchMemory(allocation->opVirtual); + } + + /// Resolves liveness of all values involved under the root operation. + void resolveLiveness() { + // Assign an ID to each operation using post-order traversal. + // To achieve the correct liveness range, the parent operation's ID + // should be greater than each of its child operation's ID . + // Example: + // ... + // %5 = triton.convert_layout %4 + // %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) { + // %2 = triton.convert_layout %5 + // ... + // scf.yield %arg0 + // } + // For example, %5 is defined in the parent region and used in + // the child region, and is not passed as a block argument. + // %6 should should have an ID greater than its child operations, + // otherwise %5 liveness range ends before the child operation's liveness + // range ends. + DenseMap operationId; + operation->walk( + [&](Operation *op) { operationId[op] = operationId.size(); }); + + // Analyze liveness of explicit buffers + Liveness liveness(operation); + auto getValueLivenessRange = [&](Value value) { + auto liveOperations = liveness.resolveLiveness(value); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); + return Interval(minId, maxId); + }; + + resolveExplicitBufferLiveness(getValueLivenessRange); + resolveAliasBufferLiveness(getValueLivenessRange); + resolveScratchBufferLiveness(operationId); + } + + /// Computes the shared memory offsets for all related values. + /// Paper: Algorithms for Compile-Time Memory Optimization + /// (https://dl.acm.org/doi/pdf/10.5555/314500.315082) + void computeOffsets() { + SmallVector buffers; + for (auto bufferIter : bufferRange) { + buffers.emplace_back(bufferIter.first); + } + + calculateStarts(buffers); + + // NOTE: The original paper doesn't consider interference between + // the bumped ranges. Buffers that previously do not interfere with + // could interfere after offset bumping if their liveness ranges overlap. + // Therefore, we rerun the interference graph algorithm after bumping so + // that we regroup the buffers and color them again. Since we always + // increase the buffer offset and keep reducing conflicts, we will + // eventually reach a fixed point. + GraphT interference; + buildInterferenceGraph(buffers, interference); + do { + allocate(buffers, interference); + buildInterferenceGraph(buffers, interference); + } while (!interference.empty()); + } + + /// Computes the initial shared memory offsets. + void calculateStarts(const SmallVector &buffers) { + // v = values in shared memory + // t = triplet of (size, start, end) + // shared memory space + // - + // | *******t4 + // | /|\ v2 inserts t4, t5, and t6 + // | | + // | ******t5 ************t6 + // | ^^^^^v2^^^^^^ + // | | *********************t2 + // | \|/ v2 erases t1 + // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 + // |---------------------------------------------| liveness range + // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... + // If the available triple's range is less than a given buffer range, + // we won't know if there has been an overlap without using graph coloring. + // Start -> Liveness Range + using TripleMapT = std::multimap>; + TripleMapT tripleMap; + tripleMap.insert(std::make_pair(0, Interval())); + SmallVector xBuffers = buffers; + while (!xBuffers.empty()) { + auto tripleIt = tripleMap.begin(); + auto offset = tripleIt->first; + auto range = tripleIt->second; + tripleMap.erase(tripleIt); + auto bufferIt = + std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { + auto xRange = bufferRange[buffer]; + bool res = xRange.intersects(range); + for (auto val : tripleMap) + res = res && + !val.second.intersects(xRange); // only one buffer intersect + return res; + }); + if (bufferIt != xBuffers.end()) { + auto buffer = *bufferIt; + auto xSize = buffer->size; + auto xRange = bufferRange.lookup(buffer); + // TODO(Keren): A buffer's size shouldn't be determined here, have to + // clean it up + size_t alignOffset = buffer->setOffsetAligned(offset); + tripleMap.insert({alignOffset + xSize, + Interval{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); + // We could either insert (range.start, xRange.start) or (range.start, + // xRange.end), both are correct and determine the potential buffer + // offset, and the graph coloring algorithm will solve the interference, + // if any + if (range.start() < xRange.start()) + tripleMap.insert({offset, Interval{range.start(), xRange.end()}}); + if (xRange.end() < range.end()) + tripleMap.insert({offset, Interval{xRange.start(), range.end()}}); + xBuffers.erase(bufferIt); + } + } + } + + /// Builds a graph of all shared memory values. Edges are created between + /// shared memory values that are overlapping. + void buildInterferenceGraph(const SmallVector &buffers, + GraphT &interference) { + // Reset interference graph + interference.clear(); + for (auto x : buffers) { + for (auto y : buffers) { + if (x == y) + continue; + auto xStart = x->offset; + auto yStart = y->offset; + auto xSize = x->size; + auto ySize = y->size; + Interval xSizeRange = {xStart, xStart + xSize}; + Interval ySizeRange = {yStart, yStart + ySize}; + auto xOpRange = bufferRange.lookup(x); + auto yOpRange = bufferRange.lookup(y); + if (xOpRange.intersects(yOpRange) && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + } + } + } + + /// Finalizes shared memory offsets considering interference. + void allocate(const SmallVector &buffers, + const GraphT &interference) { + // Reset shared memory size + allocation->sharedMemorySize = 0; + // First-fit graph coloring + // Neighbors are nodes that interfere with each other. + // We color a node by finding the index of the first available + // non-neighboring node or the first neighboring node without any color. + // Nodes with the same color do not interfere with each other. + DenseMap colors; + for (auto value : buffers) { + colors[value] = (value == buffers[0]) ? 0 : -1; + } + SmallVector available(buffers.size()); + for (auto x : buffers) { + std::fill(available.begin(), available.end(), true); + for (auto y : interference.lookup(x)) { + int color = colors[y]; + if (color >= 0) { + available[color] = false; + } + } + auto it = std::find(available.begin(), available.end(), true); + colors[x] = std::distance(available.begin(), it); + } + // Finalize allocation + // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) + // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) + // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) + // TODO(Keren): We are wasting memory here. + // Nodes with color2 can actually start with 24. + for (auto x : buffers) { + size_t newOffset = 0; + for (auto y : interference.lookup(x)) { + newOffset = std::max(newOffset, y->offset + y->size); + } + if (colors.lookup(x) != 0) + x->setOffsetAligned(newOffset); + allocation->sharedMemorySize = + std::max(allocation->sharedMemorySize, x->offset + x->size); + } + } + +private: + Operation *operation; + Allocation::FuncAllocMapT *funcAllocMap; + Allocation *allocation; + BufferRangeMapT bufferRange; +}; + +} // namespace triton + +void Allocation::run(FuncAllocMapT &funcAllocMap) { + triton::AllocationAnalysis(getOperation(), &funcAllocMap, this); +} + +} // namespace mlir diff --git a/third_party/mthreads/lib/Analysis/AxisInfo.cpp b/third_party/mthreads/lib/Analysis/AxisInfo.cpp new file mode 100644 index 000000000..49d559618 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/AxisInfo.cpp @@ -0,0 +1,1316 @@ +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "axis-info" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton { +namespace { + +int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) { + // Base Case + if (a == 0) { + *x = 0; + *y = 1; + return b; + } + int64_t x1, y1; // To store results of recursive call + int64_t gcd = gcdImpl(b % a, a, &x1, &y1); + // Update x and y using results of + // recursive call + *x = y1 - (b / a) * x1; + *y = x1; + return gcd; +} + +int64_t gcd(int64_t a, int64_t b) { + if (a == 0) + return b; + if (b == 0) + return a; + int64_t x, y; + return gcdImpl(a, b, &x, &y); +} + +constexpr int log2Int(int64_t num) { + return (num > 1) ? 1 + log2Int(num / 2) : 0; +} + +// If lhs * rhs overflows, return max value possible value for the type +int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { + int64_t maxDivisor = highestPowOf2Divisor(0); + if (lhs > maxDivisor / rhs) + return maxDivisor; + return lhs * rhs; +} + +class AxisInfoVisitor { +public: + AxisInfoVisitor() = default; + virtual ~AxisInfoVisitor() = default; + + static bool isContiguousDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getContiguity(dim) == shape[dim]; + } + + static bool isConstantDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getConstancy(dim) == shape[dim]; + } + + virtual AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) = 0; + + virtual bool match(Operation *op) = 0; +}; + +// Base class for all operations +template class AxisInfoVisitorImpl : public AxisInfoVisitor { +public: + using AxisInfoVisitor::AxisInfoVisitor; + + AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) final { + return getAxisInfo(cast(op), operands); + } + + bool match(Operation *op) final { return isa(op); } + + virtual AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) = 0; +}; + +// Binary operations +template +class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + assert(operands.size() == 2 && "Expected two operands"); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); + for (auto d = 0; d < rank; ++d) { + if (constantValue.has_value()) { + contiguity.push_back(1); + constancy.push_back( + std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + highestPowOf2Divisor(constantValue.value())); + } else { + contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); + constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); + divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); + } + } + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +protected: + virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getDivisibility(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) { + return {}; + } +}; + +class AxisInfoVisitorList { +public: + template > + void append() { + (visitors.emplace_back(std::make_unique()), ...); + } + + AxisInfo apply(Operation *op, + ArrayRef *> operands) { + for (auto &visitor : visitors) + if (visitor->match(op)) + return visitor->getAxisInfo(op, operands); + return AxisInfo(); + } + +private: + std::vector> visitors; +}; + +class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + AxisInfoVisitorList visitors; + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, + lattice->join(AxisInfo::getPessimisticValueState(lattice->getPoint()))); + } + + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef *> argLattices, + unsigned firstIndex) override { + if (auto forOp = dyn_cast(op)) { + visitForOpInductionVar(forOp, argLattices); + } else { + setAllToEntryStates(argLattices.take_front(firstIndex)); + setAllToEntryStates(argLattices.drop_front( + firstIndex + successor.getSuccessorInputs().size())); + } + } + +public: + AxisInfoAnalysis(DataFlowSolver &solver); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + using FuncAxisInfoMapT = DenseMap; + + void visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; + void + visitForOpInductionVar(scf::ForOp op, + ArrayRef *> argLattices); +}; + +template +class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + return operands[0]->getValue(); + } +}; + +class MakeRangeOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::MakeRangeOp op, + ArrayRef *> operands) override { + auto start = op.getStart(); + auto end = op.getEnd(); + return AxisInfo(/*contiguity=*/{end - start}, + /*divisibility=*/{highestPowOf2Divisor(start)}, + /*constancy=*/{1}); + } +}; + +template +class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto intAttr = dyn_cast(op.getValue()); + auto boolAttr = dyn_cast(op.getValue()); + if (intAttr || boolAttr) { + int64_t value{}; + if (intAttr) + value = intAttr.getValue().getZExtValue(); + else + value = boolAttr.getValue() ? 1 : 0; + return AxisInfo(/*contiguity=*/{1}, + /*divisibility=*/{highestPowOf2Divisor(value)}, + /*constancy=*/{1}, + /*knownConstantValue=*/{value}); + } + // TODO: generalize to dense attr + auto splatAttr = dyn_cast(op.getValue()); + if (splatAttr && splatAttr.getElementType().isIntOrIndex()) { + int64_t value = splatAttr.template getSplatValue().getZExtValue(); + TensorType ty = cast(splatAttr.getType()); + return AxisInfo( + /*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1), + /*divisibility=*/ + AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)), + /*constancy=*/ + AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()), + /*knownConstantValue=*/{value}); + } + return AxisInfo(); + } +}; + +template +class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)), + gcd(lhs.getContiguity(dim), rhs.getConstancy(dim))); + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs) + // rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs) + // lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) * + // gcd(d_lhs, d_rhs) + auto rhsDivisibility = rhs.getDivisibility(dim); + if constexpr (std::is_same_v) { + // %ptr = addptr %lhs, %rhs + // is equivalent to + // %0 = mul %rhs, %elemSize + // %ptr = add %lhs, %0 + // The result will still be contiguous in terms of elements but not bytes + // For example: + // addptr [16] : !ptr, [0, 1, 2, 3] : i32 -> !ptr + // returns: + // [16, 20, 24, 28] : !ptr + // with element locations: + // [4, 5, 6, 7] + // It is "strided contiguous" with a divisilibity of 16 bytes + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + rhsDivisibility = multiplyDivisor(rhs.getDivisibility(dim), elemSize); + } + return gcd(lhs.getDivisibility(dim), rhsDivisibility); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + return {lhs.getConstantValue().value() + + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() - + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + auto rhsValue = rhs.getConstantValue().value() * elemSize; + return {lhs.getConstantValue().value() + rhsValue}; + } + } + return {}; + } +}; + +class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + // lhs * 1 = lhs + auto lhsContiguity = + rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1 + ? lhs.getContiguity(dim) + : 1; + // 1 * rhs = rhs + auto rhsContiguity = + lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1 + ? rhs.getContiguity(dim) + : 1; + return std::max(lhsContiguity, rhsContiguity); + } + + int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && + !(rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto rhsDivisibility = rhs.getDivisibility(dim); + if (rhs.getContiguity(dim) > 1 && + !(lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + rhsDivisibility = 1; + } + return multiplyDivisor(lhsDivisibility, rhsDivisibility); + } + + std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() * rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs / 1 = lhs + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? lhs.getContiguity(dim) + : 1; + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // Case 1: both lhs and rhs are constants. + auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + // Case 2: lhs contiguous, rhs constant. + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p), + // ..., (d_lhs * k + n) / (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // the minimal constancy is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual constancy. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + constancy = std::max(constancy, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return constancy; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // Case 1: lhs is 0 + if (lhs.getConstantValue().has_value() && + lhs.getConstantValue().value() == 0) + return lhs.getDivisibility(dim); + // Case 2: rhs is 1 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return lhs.getDivisibility(dim); + // otherwise: return 1 + return 1; + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() / rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getContiguity(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + int64_t contiguity = 1; + // lhs contiguous, rhs constant + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p), + // ..., (d_lhs * k + n) % (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // The minimal contiguity is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual contiguity. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + contiguity = std::max(contiguity, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return contiguity; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k'' + // rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p'' + // lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r + // r must be divisible by gcd(d_lhs, d_rhs) + return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)); + }; + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // lhs % 1 = 0 + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? shape[dim] + : gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() % rhs.getConstantValue().value()}; + else if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return {0}; + return {}; + } +}; + +class SplatOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::SplatOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + TensorType retTy = cast(_retTy); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(opInfo.getDivisibility(0)); + constancy.push_back(retTy.getShape()[d]); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::LoadOp op, + ArrayRef *> operands) override { + // If pointers and mask both have constancy properties, those properties + // will also extend to output. + AxisInfo ptrInfo = operands[0]->getValue(); + std::optional maskInfo; + if (operands.size() > 1) { + maskInfo = operands[1]->getValue(); + } + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + + for (int d = 0; d < ptrInfo.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(1); + constancy.push_back( + gcd(ptrInfo.getConstancy(d), + maskInfo.has_value() ? maskInfo->getConstancy(d) : 0)); + } + + return AxisInfo(contiguity, divisibility, constancy); + } +}; + +class ExpandDimsOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::ExpandDimsOp op, + ArrayRef *> operands) override { + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); + AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); + AxisInfo::DimVectorT constancy = opInfo.getConstancy(); + int64_t newDivisibility = 1; + if (opInfo.getConstantValue().has_value()) { + // The tensor is constant, same as ConstantOpAxisInfoVisitor + newDivisibility = highestPowOf2Divisor(opInfo.getConstantValue().value()); + } else if (opInfo.getRank()) { + // Otherwise, calculate the GCD as the new divisibility + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + newDivisibility = + opInfo.getContiguity(0) > 1 ? 1 : opInfo.getDivisibility(0); + for (int d = 1; d < opInfo.getRank(); ++d) { + newDivisibility = + gcd(newDivisibility, + opInfo.getContiguity(d) > 1 ? 1 : opInfo.getDivisibility(d)); + } + } + contiguity.insert(contiguity.begin() + op.getAxis(), 1); + divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility); + constancy.insert(constancy.begin() + op.getAxis(), 1); + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class BroadcastOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::BroadcastOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = cast(_retTy); + TensorType opTy = cast(_opTy); + ArrayRef retShape = retTy.getShape(); + ArrayRef opShape = opTy.getShape(); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); + divisibility.push_back(opInfo.getDivisibility(d)); + constancy.push_back(opShape[d] == 1 ? retShape[d] + : opInfo.getConstancy(d)); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +template +class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return AxisInfo(); + auto shape = resTy.getShape(); + short rank = resTy.getRank(); + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + for (short d = 0; d < rank; ++d) { + int64_t constHint = 1; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + constHint = lhsInfo.getConstancy(d); + constantValue = + compare(getPredicate(op), lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value()) + ? 1 + : 0; + } else { + // Case 1: lhs and rhs are both partial constants + constHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)); + if ((gtPredicate(getPredicate(op)) || lePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(lhsInfo, shape, d)) { + // Case 2: lhs all constant, rhs all contiguous + // NOTE: + // lhs: 4 4 4 4 + // rhs: 4 5 6 7 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs lt rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 1, 1, 1 + // lhs ge rhs: 1, 0, 0, 0 + // lhs gt rhs: 0, 0, 0, 0 + constHint = std::max(constHint, gcd(rhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } else if ((ltPredicate(getPredicate(op)) || + gePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(rhsInfo, shape, d)) { + // Case 3: lhs all contiguous, rhs all constant + // NOTE + // lhs: 4 5 6 7 + // rhs: 4 4 4 4 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 0, 0, 0 + // lhs lt rhs: 0, 0, 0, 0 + // lhs gt rhs: 0, 1, 1, 1 + // lhs ge rhs: 1, 1, 1, 1 + constHint = std::max(constHint, gcd(lhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } + } + + constancy.push_back(constHint); + divisibility.push_back(1); + contiguity.push_back(1); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +private: + static arith::CmpIPredicate getPredicate(arith::CmpIOp op) { + return op.getPredicate(); + } + + static bool gtPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sgt || + predicate == arith::CmpIPredicate::ugt; + } + + static bool gePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sge || + predicate == arith::CmpIPredicate::uge; + } + + static bool ltPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::slt || + predicate == arith::CmpIPredicate::ult; + } + + static bool lePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sle || + predicate == arith::CmpIPredicate::ule; + } + + static bool compare(arith::CmpIPredicate predicate, int64_t lhs, + int64_t rhs) { + switch (predicate) { + case arith::CmpIPredicate::eq: + return lhs == rhs; + case arith::CmpIPredicate::ne: + return lhs != rhs; + case arith::CmpIPredicate::slt: + return lhs < rhs; + case arith::CmpIPredicate::sle: + return lhs <= rhs; + case arith::CmpIPredicate::sgt: + return lhs > rhs; + case arith::CmpIPredicate::sge: + return lhs >= rhs; + case arith::CmpIPredicate::ult: + return (uint64_t)lhs < (uint64_t)rhs; + case arith::CmpIPredicate::ule: + return (uint64_t)lhs <= (uint64_t)rhs; + case arith::CmpIPredicate::ugt: + return (uint64_t)lhs > (uint64_t)rhs; + case arith::CmpIPredicate::uge: + return (uint64_t)lhs >= (uint64_t)rhs; + default: + break; + } + llvm_unreachable("unknown comparison predicate"); + } +}; + +template +class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto condConstancy = operands[0]->getValue().getConstancy(); + auto lhsInfo = operands[1]->getValue(); + auto rhsInfo = operands[2]->getValue(); + auto rank = lhsInfo.getRank(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + if (operands[0]->getValue().getConstantValue().has_value()) { + if (operands[0]->getValue().getConstantValue() == 0) { + contiguity = rhsInfo.getContiguity(); + divisibility = rhsInfo.getDivisibility(); + constancy = rhsInfo.getConstancy(); + constantValue = rhsInfo.getConstantValue(); + } else { + contiguity = lhsInfo.getContiguity(); + divisibility = lhsInfo.getDivisibility(); + constancy = lhsInfo.getConstancy(); + constantValue = lhsInfo.getConstantValue(); + } + } else { + // The condition can be either a tensor or i1. + // If i1 is used as the condition, the entire tensor of either + // lhs or rhs is selected. + bool i1Cond = isa(op.getOperand(0).getType()); + for (auto d = 0; d < rank; ++d) { + if (i1Cond) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } else { + constancy.push_back( + std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]), + gcd(rhsInfo.getConstancy(d), condConstancy[d]))); + contiguity.push_back( + std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]), + gcd(rhsInfo.getContiguity(d), condConstancy[d]))); + if (contiguity.back() == lhsInfo.getContiguity(d) && + contiguity.back() == rhsInfo.getContiguity(d)) { + // Contiguity not changed + divisibility.push_back( + gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + } else { + // Contiguity changed, we cannot use only divisibility. + // For example, the following example should have contiguity 2 and + // divisibility 2 + // [[0, 1], [4, 5]] + // [[16, 17, 18, 19]] + divisibility.push_back( + std::min(gcd(lhsInfo.getDivisibility(d), contiguity.back()), + gcd(rhsInfo.getDivisibility(d), contiguity.back()))); + } + } + } + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value() && + lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) + constantValue = lhsInfo.getConstantValue(); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } +}; + +template +class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() & + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() | + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() ^ + rhs.getConstantValue().value()}; + } + } + return {}; + } +}; + +class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto shift = rhs.getConstantValue().has_value() + ? rhs.getConstantValue().value() + : rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto numBits = log2Int(lhsDivisibility); + return multiplyDivisor(lhsDivisibility, 1 << shift); + } + + int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() << rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto shift = rhs.getConstantValue().has_value() + ? rhs.getConstantValue().value() + : rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + return std::max(1, lhsDivisibility / (1 << shift)); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + std::optional constantValue; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::max(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } else if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::min(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } + return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1), + /*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1), + /*knownConstancy=*/AxisInfo::DimVectorT(rank, 1), + /*constantValue=*/constantValue); + } else { + AxisInfo::DimVectorT contiguity, divisibility, constancy; + for (auto d = 0; d < rank; ++d) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } + return AxisInfo(contiguity, divisibility, constancy, std::nullopt); + } + } +}; + +//===----------------------------------------------------------------------===// +// AxisInfoAnalysis +//===----------------------------------------------------------------------===// + +AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) + : dataflow::SparseForwardDataFlowAnalysis>( + solver) { + // UnrealizedConversionCast: + // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is + // in the process of a PartialConversion, where UnrealizedConversionCast + // may exist + visitors.append, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor>(); + // TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp + // when scf.for supports integer induction variables + visitors.append(); + visitors.append, + ConstantOpAxisInfoVisitor>(); + visitors.append, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor>(); + visitors.append(); + visitors.append, + DivOpAxisInfoVisitor>(); + visitors.append, + RemOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append>(); + visitors.append, + LogicalOpAxisInfoVisitor, + LogicalOpAxisInfoVisitor>(); + visitors.append>(); + visitors.append, + ShROpAxisInfoVisitor>(); + visitors.append, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor>(); + visitors.append(); +} + +void AxisInfoAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + // TODO: For sure not the right way to do this + // but why is scf.if not initialized otherwise? + for (auto op : operands) + if (op->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op); + AxisInfo curr = visitors.apply(op, operands); + if (curr.getRank() == 0) + return setAllToEntryStates(results); + // override with hint + auto newContiguity = curr.getContiguity(); + auto newDivisibility = curr.getDivisibility(); + auto newConstancy = curr.getConstancy(); + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + curr = AxisInfo(newContiguity, newDivisibility, newConstancy, + curr.getConstantValue()); + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(curr)); +} + +void AxisInfoAnalysis::visitForOpInductionVar( + scf::ForOp op, ArrayRef *> argLattices) { + auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue(); + auto step = getLatticeElementFor(op, op.getStep())->getValue(); + + AxisInfo::DimVectorT knownContiguity(1, 1); + AxisInfo::DimVectorT knownDivisibility(1, 1); + AxisInfo::DimVectorT knownConstancy(1, 1); + knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0)); + auto inductionVar = + AxisInfo(knownContiguity, knownDivisibility, knownConstancy); + (void)argLattices[0]->join(inductionVar); +} + +} // anonymous namespace + +template +void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy) { + // liast of attributes that we care about + SmallVector> retVecs; + retVecs.push_back({contiguity, "tt.contiguity"}); + retVecs.push_back({divisibility, "tt.divisibility"}); + retVecs.push_back({constancy, "tt.constancy"}); + // initialize attributes one by one + for (auto [vec, attrName] : retVecs) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto int_attr = dyn_cast_or_null(attr)) + *vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue()); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + *vec = DimVectorT(vals.begin(), vals.end()); + } + } +} + +/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { + auto rank = 1; + if (TensorType ty = dyn_cast(value.getType())) + rank = ty.getRank(); + if (triton::PointerType ty = dyn_cast(value.getType())) + if (TensorType elemTy = dyn_cast(ty.getPointeeType())) + rank = elemTy.getRank(); + + DimVectorT knownContiguity(rank, 1); + DimVectorT knownDivisibility(rank, 1); + DimVectorT knownConstancy(rank, 1); + + BlockArgument blockArg = dyn_cast(value); + + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + // llvm codegen check alignment to generate vector load/store + // would be nice if this wasn't the case + else if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + } else if (Operation *op = value.getDefiningOp()) { + if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. + knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); + knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); + knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); + } + // Other operations are conservatively initialized with the lowest possible + // divisibility, contiguity, and constancy unless they have specified. + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + knownDivisibility = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + knownContiguity = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + knownConstancy = DimVectorT(vals.begin(), vals.end()); + } + } + + return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); +} + +/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { + // If one argument is not initialized, return the other. + if (lhs.getRank() == 0) + return rhs; + if (rhs.getRank() == 0) + return lhs; + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + for (auto d = 0; d < lhs.getRank(); ++d) { + contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); + divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); + } + std::optional constantValue; + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value() && + lhs.getConstantValue() == rhs.getConstantValue()) + constantValue = lhs.getConstantValue(); + return AxisInfo(contiguity, divisibility, constancy, constantValue); +} + +unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto layout = tensorTy.getEncoding(); + + // Here order should be ordered by contiguous first, so the first element + // should have the largest contiguous. + auto order = triton::gpu::getOrder(layout); + unsigned align = getPtrAlignment(ptr); + + auto uniqueContigPerThread = + triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + LDBG("getPtrContiguity uniqueContigPerThread = " << contiguity); + contiguity = std::min(align, contiguity); + + return contiguity; +} + +unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(ptr); + if (!axisInfo) + return 1; + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); + auto maxContig = axisInfo->getContiguity(order[0]); + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); + unsigned alignment = std::min(maxMultiple, maxContig); + LDBG("getPtrAlignment order[0] " + << order[0] << " maxMultipleBytes = " << maxMultipleBytes + << " maxContig = " << maxContig << " elemNumBits = " << elemNumBits + << " maxMultiple = " << maxMultiple << " alignment " << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { + auto tensorTy = dyn_cast(mask.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(mask); + if (!axisInfo) + return 1; + auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); + auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); + LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " + << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(funcOp))) + return; + auto *axisInfoMap = getFuncData(funcOp); + auto updateAxisInfoMap = [&](Value value) { + auto axisInfo = analysis->getLatticeElement(value)->getValue(); + AxisInfo curAxisInfo; + if (axisInfoMap->count(value)) { + curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value)); + } else { + curAxisInfo = axisInfo; + } + (*axisInfoMap)[value] = curAxisInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateAxisInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateAxisInfoMap(value); + } + }); +} + +void ModuleAxisInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *axisInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, int64_t prevValue) { + auto curValue = highestPowOf2Divisor(0); + if (callee.getArgAttrOfType(index, attrName)) { + curValue = + callee.getArgAttrOfType(index, attrName).getInt(); + } + auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64), + gcd(prevValue, curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto axisInfo = axisInfoMap->lookup(value); + assert(axisInfo.getRank() == 1 && "only scalar arguments are supported"); + setAttrFn("tt.contiguity", axisInfo.getContiguity(0)); + setAttrFn("tt.divisibility", axisInfo.getDivisibility(0)); + setAttrFn("tt.constancy", axisInfo.getConstancy(0)); + } +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Analysis/CMakeLists.txt b/third_party/mthreads/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..fd9dccead --- /dev/null +++ b/third_party/mthreads/lib/Analysis/CMakeLists.txt @@ -0,0 +1,17 @@ +add_triton_library(TritonAnalysis + AxisInfo.cpp + Allocation.cpp + Membar.cpp + Alias.cpp + Utility.cpp + + DEPENDS + TritonTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/mthreads/lib/Analysis/Membar.cpp b/third_party/mthreads/lib/Analysis/Membar.cpp new file mode 100644 index 000000000..407a5ae15 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/Membar.cpp @@ -0,0 +1,178 @@ +#include "triton/Analysis/Membar.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include + +namespace mlir { + +void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) { + FunctionOpInterface funcOp = + dyn_cast(allocation->getOperation()); + OpBuilder builder(funcOp.getContext()); + resolve(funcOp, &funcBlockInfoMap, &builder); +} + +void MembarAnalysis::resolve(FunctionOpInterface funcOp, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + // Initialize the blockList + DenseMap inputBlockInfoMap; + DenseMap outputBlockInfoMap; + std::deque blockList; + funcOp.walk([&](Block *block) { + for (auto &op : block->getOperations()) { + // Check if the operation belongs to scf dialect, if so, we need to + // throw an error + if (op.getDialect()->getNamespace() == "scf") { + llvm::report_fatal_error( + "scf dialect is not supported in membar. Please lower it " + "to cf dialect first."); + return; + } + } + if (block->isEntryBlock()) + blockList.emplace_back(block); + }); + + // A fixed point algorithm + while (!blockList.empty()) { + auto *block = blockList.front(); + blockList.pop_front(); + // Make a copy of the inputblockInfo but not update + auto inputBlockInfo = inputBlockInfoMap[block]; + SmallVector successors; + for (auto &op : block->getOperations()) { + if (op.hasTrait()) { + visitTerminator(&op, successors); + } else { + update(&op, &inputBlockInfo, funcBlockInfoMap, builder); + } + } + // Get the reference because we want to update if it changed + if (outputBlockInfoMap.count(block) && + inputBlockInfo == outputBlockInfoMap[block]) { + // If we have seen the block before and the inputBlockInfo is the same as + // the outputBlockInfo, we skip the successors + continue; + } + // Update the current block + outputBlockInfoMap[block].join(inputBlockInfo); + // Update the successors + for (auto *successor : successors) { + inputBlockInfoMap[successor].join(outputBlockInfoMap[block]); + blockList.emplace_back(successor); + } + } + + // Update the final dangling buffers that haven't been synced + auto &funcBlockInfo = (*funcBlockInfoMap)[funcOp]; + funcOp.walk([&](Block *block) { + block->walk([&](triton::ReturnOp returnOp) { + funcBlockInfo.join(outputBlockInfoMap[block]); + }); + }); +} + +void MembarAnalysis::visitTerminator(Operation *op, + SmallVector &successors) { + if (auto branchInterface = dyn_cast(op)) { + Block *parentBlock = branchInterface->getBlock(); + successors.append(std::begin(parentBlock->getSuccessors()), + std::end(parentBlock->getSuccessors())); + return; + } + // Otherwise, it could be a return op + if (op->hasTrait()) + return; + llvm_unreachable("Unknown terminator encountered in membar analysis"); +} + +void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { + OpBuilder::InsertionGuard g(*builder); + auto barrierOp = builder->create(op->getLoc()); +} + +void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + if (isa(op)) { + // If the current op is a barrier, we sync previous reads and writes + blockInfo->sync(); + return; + } + + if (isa(op) && + !isa(op->getNextNode())) { + // If the current op is an async wait and the next op is not a barrier we + // insert a barrier op and sync + builder->setInsertionPointAfter(op); + insertBarrier(op, builder); + blockInfo->sync(); + return; + } + + BlockInfo curBlockInfo; + if (isa(op)) { + // Inter-function dependencies + auto callOpInterface = dyn_cast(op); + if (auto callee = + dyn_cast(callOpInterface.resolveCallable())) + curBlockInfo = funcBlockInfoMap->lookup(callee); + } else { + // Intra-function dependencies + if (auto memoryEffectOpInterface = dyn_cast(op)) { + // Explicit buffer + SmallVector> + effectInstances; + memoryEffectOpInterface.getEffects(effectInstances); + for (auto effectInstance : effectInstances) { + if (auto value = effectInstance.getValue()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) { + if (isa(effectInstance.getEffect())) + curBlockInfo.syncWriteIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + else if (isa(effectInstance.getEffect())) + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + } + } + } + // XXX(Keren): This is a hack as we cannot set side effects for dot ops, but + // on hopper they do have side effects. Need to clean it up + if (auto dotOp = dyn_cast(op)) { + for (auto value : dotOp.getOperands()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + } + // Scratch buffer is considered as both shared memory write & read + auto bufferId = allocation->getBufferId(op); + if (bufferId != Allocation::InvalidBufferId) { + curBlockInfo.syncWriteIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + + if (blockInfo->isIntersected(curBlockInfo)) { + builder->setInsertionPoint(op); + insertBarrier(op, builder); + blockInfo->sync(); + } + // Update the region info, even if barrier is inserted, we have to maintain + // the current op's read/write buffers. + blockInfo->join(curBlockInfo); +} +} // namespace mlir diff --git a/third_party/mthreads/lib/Analysis/Utility.cpp b/third_party/mthreads/lib/Analysis/Utility.cpp new file mode 100644 index 000000000..739ed6a00 --- /dev/null +++ b/third_party/mthreads/lib/Analysis/Utility.cpp @@ -0,0 +1,936 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace mlir { +namespace { + +using namespace triton; +using namespace triton::gpu; + +int getParentAxis(Attribute layout, int axis) { + if (auto sliceEncoding = dyn_cast(layout)) { + axis = axis < sliceEncoding.getDim() ? axis : axis + 1; + return getParentAxis(sliceEncoding.getParent(), axis); + } + return axis; +} + +SmallVector getParentOrder(Attribute layout) { + if (auto sliceEncoding = mlir::dyn_cast(layout)) { + return getParentOrder(sliceEncoding.getParent()); + } + return getOrder(layout); +} + +} // namespace + +// TODO(jlebar): Move this class into namespace triton. +bool ReduceOpHelper::isReductionOnLayoutFastAxis() { + return getParentAxis(getSrcLayout(), axis) == + getParentOrder(getSrcLayout())[0]; +} + +SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { + auto srcLayout = getSrcLayout(); + auto order = getOrder(srcLayout); + auto it = std::find(order.begin(), order.end(), axis); + // delete the axis from order + order.erase(it); + // insert axis at the beginning of order + order.insert(order.begin(), axis); + return order; +} + +// Thread offset is the thread index offset of two adjacent threads on the +// reduction axis within the warp. +unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { + auto srcLayout = getSrcLayout(); + + // If the reduction axis is the fast axis of the parent layout + if (isReductionOnLayoutFastAxis()) { + return 1; + } + + unsigned threadOffset = 1; + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto parentLayout = sliceLayout.getParent(); + auto threadsPerWarp = getThreadsPerWarp(parentLayout); + threadOffset = threadsPerWarp[sliceLayout.getDim()]; + } else { + auto threadsPerWarp = getThreadsPerWarp(srcLayout); + auto order = getOrder(srcLayout); + for (unsigned i = 0; i < order.size(); i++) { + if (order[i] == axis) + break; + threadOffset *= threadsPerWarp[order[i]]; + } + } + return threadOffset; +} + +// Cases where distributed shared memory is not required in ConvertLayout: +// (1) numCTAs == 1 +// (2) numCTAs > 1 but srcCTALayout == dstCTALayout +// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented +// in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { + unsigned numCTAs = getNumCTAs(srcLayout); + assert(numCTAs == getNumCTAs(dstLayout) && + "Invalid layout conversion: the numbers of CTAs of src and dst " + "layouts are different"); + + // Case (1): Never use dsmem when numCTAs == 1 + if (numCTAs == 1) + return false; + + // Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not + // implemented yet + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + llvm::report_fatal_error("Layout conversion to be implemented"); + } + + // Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported + if (auto sliceLayout = mlir::dyn_cast(dstLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + return true; + } + + // The above two branches make sure that it is legal to call getCTALayout of + // srcLayout and dstLayout + + // Case (2): Do not use dsmem when srcCTALayout == dstCTALayout + auto srcCTALayout = getCTALayout(srcLayout); + auto dstCTALayout = getCTALayout(dstLayout); + if (srcCTALayout == dstCTALayout) + return false; + + // Dsmem access is required when srcCTALayout != dstCTALayout + return true; +} + +unsigned ReduceOpHelper::getInterWarpSize() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned sizeIntraWarps = getIntraWarpSize(); + return std::min(srcReduceDimSize / sizeIntraWarps, + getWarpsPerCTA(getSrcLayout())[axis]); +} + +unsigned ReduceOpHelper::getIntraWarpSize() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + return std::min(srcReduceDimSize, getThreadsPerWarp(getSrcLayout())[axis]); +} + +unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData(); + return std::min( + srcReduceDimSize / sizeIntraWarps, + getWarpsPerCTAWithUniqueData(getSrcLayout(), getSrcShape())[axis]); +} + +unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned elementPerThreads = + getUniqueContigPerThread(getSrcLayout(), getSrcShape())[axis]; + return std::min( + srcReduceDimSize / elementPerThreads, + getThreadsPerWarpWithUniqueData(getSrcLayout(), getSrcShape())[axis]); +} + +unsigned ReduceOpHelper::getThreadsReductionAxis() { + auto srcLayout = getSrcLayout(); + auto srcShape = getSrcShape(); + return getThreadsPerWarpWithUniqueData(srcLayout, srcShape)[axis] * + getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; +} + +bool ReduceOpHelper::isWarpSynchronous() { + auto srcLayout = getSrcLayout(); + auto srcShape = getSrcShape(); + return getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == 1; +} + +SmallVector ReduceOpHelper::getScratchConfig() { + SmallVector smemShape; + // that case doesn't need inter-warp communication + if (isWarpSynchronous()) + return {0, 0}; + + smemShape = convertType(getSrcShape()); + smemShape[axis] = getInterWarpSizeWithUniqueData(); + + return smemShape; +} + +unsigned ReduceOpHelper::getScratchSizeInBytes() { + auto smemShape = getScratchConfig(); + auto elems = product(smemShape); + + unsigned bytesPerElem = 0; + for (const auto &ty : srcElementTypes) { + bytesPerElem += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return bytesPerElem * elems; +} + +bool ReduceOpHelper::isReduceWithinCTA() { + auto axis = getAxis(); + auto srcLayout = getSrcLayout(); + auto CTASplitNum = getCTASplitNum(srcLayout); + assert(axis < CTASplitNum.size()); + return CTASplitNum[axis] == 1; +} + +bool ReduceOpHelper::isSupportedLayout() { + // Layout optimization passes such as PlanCTAPass and + // RemoveLayoutConversionPass should avoid cross-CTA reduction + if (!isReduceWithinCTA()) { + return false; + } + + auto srcLayout = getSrcLayout(); + if (isa(srcLayout)) { + return true; + } + if (auto mmaLayout = dyn_cast(srcLayout)) { + return mmaLayout.supportReduction(); + } + if (auto sliceLayout = dyn_cast(srcLayout)) { + return true; + } + return false; +} + +unsigned ScanLoweringHelper::getAxisNumElementsPerThread() { + return getEncoding().getSizePerThread()[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { + SmallVector sizePerThreads = getContigPerThread(getEncoding()); + sizePerThreads[getAxis()] = 1; + return product(sizePerThreads); +} + +Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); } + +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() { + return getThreadsPerWarp(getEncoding())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() { + return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + threadsPerWarp[getAxis()] = 1; + return product(threadsPerWarp); +} + +// Return the flat numbers of threads computing independent scan results. +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { + unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp(); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + warpsPerCTA[getAxis()] = 1; + unsigned numParallelWarpsPerCTA = product(warpsPerCTA); + return numParallelThreadsPerWarp * numParallelWarpsPerCTA; +} + +unsigned ScanLoweringHelper::getAxisNumWarps() { + return getWarpsPerCTA(getEncoding())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() { + return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumBlocks() { + auto sizePerThreads = getSizePerThread(getEncoding()); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + unsigned axis = getAxis(); + return ceil( + getShape()[axis], + (sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis])); +} + +unsigned ScanLoweringHelper::getNonAxisNumBlocks() { + auto sizePerThreads = getSizePerThread(getEncoding()); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + unsigned axis = getAxis(); + unsigned numBlocks = 1; + for (unsigned i = 0; i < sizePerThreads.size(); i++) { + if (i == axis) + continue; + numBlocks *= + ceil(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] * + warpsPerCTA[i])); + } + return numBlocks; +} + +bool ScanLoweringHelper::isSupported() { + // TODO: Support the following cases: + // 1. Scan on non-blocking encodings + if (!isa(getEncoding())) + return false; + return true; +} + +unsigned ScanLoweringHelper::getScratchSizeInElems() { + auto mod = scanOp->getParentOfType(); + unsigned numWarps = TritonGPUDialect::getNumWarps(mod); + unsigned numNonAxisElementsPerWarp = + getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread(); + unsigned numElements = numWarps * numNonAxisElementsPerWarp * + getAxisNumBlocks() * getNonAxisNumBlocks(); + return numElements; +} + +unsigned ScanLoweringHelper::getScratchSizeInBytes() { + unsigned axisNumWarps = getAxisNumWarpsWithUniqueData(); + if (axisNumWarps == 1) + return 0; + unsigned elementSizeInBytes = 0; + for (const auto &ty : srcElementTypes) { + elementSizeInBytes += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return elementSizeInBytes * getScratchSizeInElems(); +} + +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, + ArrayRef dstShape) { + SmallVector, SmallVector>> ret; + + if (srcShape.empty()) { + assert(dstShape.empty()); + return ret; + } + ret.push_back({}); + + int srcIdx = 0; + int dstIdx = 0; + int srcNElems = 1; + int dstNElems = 1; + while (srcIdx < srcShape.size() || dstIdx < dstShape.size()) { + if (srcNElems < dstNElems || // + (srcIdx < srcShape.size() && srcNElems == 1) || + (srcIdx < srcShape.size() && srcShape[srcIdx] == 1)) { + assert(srcIdx < srcShape.size()); + srcNElems *= srcShape[srcIdx]; + ret.back().first.push_back(srcIdx); + srcIdx++; + } else if (dstNElems < srcNElems || + (dstIdx < dstShape.size() && dstShape[dstIdx] == 1)) { + assert(dstIdx < dstShape.size()); + dstNElems *= dstShape[dstIdx]; + ret.back().second.push_back(dstIdx); + dstIdx++; + } else { + ret.push_back({}); + srcNElems = 1; + dstNElems = 1; + } + } + return ret; +} + +BlockedEncodingAttr ScanLoweringHelper::getEncoding() { + return cast(srcEncoding); +} + +unsigned ScanLoweringHelper::getAxisElementStride() { + auto order = getOrder(getEncoding()); + unsigned stride = 1; + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= getContigPerThread(getEncoding())[dim]; + } + llvm_unreachable("Axis not found in order"); +} + +unsigned ScanLoweringHelper::getAxisThreadStride() { + auto order = getOrder(getEncoding()); + unsigned stride = 1; + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= getEncoding().getThreadsPerWarp()[dim]; + } + llvm_unreachable("Axis not found in order"); +} + +unsigned ScanLoweringHelper::getAxisBlockStride() { + auto order = getOrder(getEncoding()); + unsigned stride = 1; + auto sizePerThreads = getSizePerThread(getEncoding()); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= ceil(getShape()[dim], sizePerThreads[dim] * + threadsPerWarp[dim] * + warpsPerCTA[dim]); + } + llvm_unreachable("Axis not found in order"); +} + +bool maybeSharedAllocationOp(Operation *op) { + // TODO(Keren): This function can be replaced by adding + // MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to + // query the memory effects of the op. + auto *dialect = op->getDialect(); + return dialect && + (dialect->getTypeID() == TypeID::get() || + dialect->getTypeID() == TypeID::get() || + dialect->getTypeID() == TypeID::get() || + dialect->getTypeID() == TypeID::get()); +} + +static bool supportMFMAGranularity(int m, int n, int k) { + // these limitations are dtype dependent, in future we may relax them + const static std::pair mfmaTypes[2] = {{32, 8}, {16, 16}}; + for (const auto &mfmaType : mfmaTypes) { + auto [granularityMN, granularityK] = mfmaType; + if (m % granularityMN != 0 || n % granularityMN != 0) + continue; + if (k % granularityK != 0) + continue; + return true; + } + return false; +} + +bool supportMFMATypes(Type a, Type b) { + if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth()) + return false; + + auto F8E4M3FNUZ = TypeID::get(); + auto F8E5M2FNUZ = TypeID::get(); + auto F16 = TypeID::get(); + auto BF16 = TypeID::get(); + auto F32 = TypeID::get(); + auto Int = TypeID::get(); + DenseSet> supportedTypes = { + {F32, F32}, + {F16, F16}, + {BF16, BF16}, + {F8E4M3FNUZ, F8E4M3FNUZ}, + {F8E4M3FNUZ, F8E5M2FNUZ}, + {F8E5M2FNUZ, F8E4M3FNUZ}, + {F8E5M2FNUZ, F8E5M2FNUZ}, + {Int, Int}}; + + if (!supportedTypes.contains({a.getTypeID(), b.getTypeID()})) + return false; + + if (a.isIntOrIndex() && a.getIntOrFloatBitWidth() != 8) + return false; + return true; +} + +bool supportMFMA(triton::DotOp op) { + auto aTy = cast(op.getA().getType()); + auto bTy = cast(op.getB().getType()); + + auto aElemTy = aTy.getElementType(); + auto bElemTy = bTy.getElementType(); + + if (!supportMFMATypes(aElemTy, bElemTy)) + return false; + + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + + auto rank = aShape.size(); + assert(bShape.size() == rank); + auto M = aShape[rank - 2]; + auto N = bShape[rank - 1]; + auto K = aShape[rank - 1]; + assert(K == bShape[rank - 2]); + if (!supportMFMAGranularity(M, N, K)) + return false; + + return true; +} + +static bool supportWMMAGranularity(int m, int n, int k) { + return m % 16 == 0 && n % 16 == 0 && k % 16 == 0; +} + +static bool supportWMMATypes(Type a, Type b, Type c, Type d) { + if (a != b || c != d) + return false; + auto aWidth = a.getIntOrFloatBitWidth(); + auto cWidth = c.getIntOrFloatBitWidth(); + if (a.isIntOrIndex()) { + if (!c.isIntOrIndex()) + return false; + bool aValid = aWidth <= 8; + bool cValid = cWidth <= 32; + return aValid && cValid; + } else if (isa(a) && isa(c)) { + if (a.isBF16()) + return c.isBF16() || c.isF32(); + if (a.isF16()) + return c.isF16() || c.isF32(); + return aWidth <= cWidth && aWidth <= 16; + } + return false; +} + +bool supportWMMA(triton::DotOp op) { + auto aTy = cast(op.getA().getType()); + auto bTy = cast(op.getB().getType()); + auto cTy = cast(op.getC().getType()); + auto dTy = cast(op.getResult().getType()); + + auto aElemTy = aTy.getElementType(); + auto bElemTy = bTy.getElementType(); + auto cElemTy = cTy.getElementType(); + auto dElemTy = dTy.getElementType(); + + if (!supportWMMATypes(aElemTy, bElemTy, cElemTy, dElemTy)) + return false; + + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + + auto rank = aShape.size(); + assert(bShape.size() == rank); + assert(aShape[rank - 1] == bShape[rank - 2]); + if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1], + aShape[rank - 1])) + return false; + + return true; +} + +bool supportMMA(triton::DotOp op, int version) { + // Refer to mma section for the data type supported by Volta and Hopper + // Tensor Core in + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + auto aElemTy = op.getA().getType().getElementType(); + auto bElemTy = op.getB().getType().getElementType(); + if (version == 3) { + if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) + return false; + auto retType = op.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + auto mod = op->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && + retShapePerCTA[rank - 1] % 8 == 0 && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || + aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || + aElemTy.isF32()))) { + return false; + } + // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. + if (op.getMaxNumImpreciseAcc() < 32 && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && + cast(op.getType()).getElementType().isF32()) { + return false; + } + } + if (aElemTy.isF32() && bElemTy.isF32()) { + return op.getInputPrecision() == InputPrecision::TF32 && version >= 2; + } + return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); +} + +bool supportMMA(Value value, int version) { + // Tell whether a DotOp support MMA by the operand type(either $a or $b). + // We cannot get both the operand types(in TypeConverter), here we assume the + // types of both the operands are identical here. + assert((version == 1 || version == 2 || version == 3) && + "Unexpected MMA layout version found"); + auto elemTy = cast(value.getType()).getElementType(); + // FP8 is not natively supported on all mma versions but it can always be + // promoted to fp16 therefore we can always support it. + bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || + elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + return isFP8 || elemTy.isF16() || elemTy.isBF16() || + (elemTy.isF32() && version >= 2) || + (elemTy.isInteger(8) && version >= 2); +} + +bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mfmaLayout = dyn_cast(srcLayout); + auto dotOperandLayout = dyn_cast(dstLayout); + if (mfmaLayout == nullptr || dotOperandLayout == nullptr) + return false; + // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is + // improved. In addition, we can enable this shortcut for regular MFMA + // layout when opIdx == 1. + return mfmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && + dotOperandLayout.getParent() == mfmaLayout && + (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && + (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); +} + +static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) { + auto src = dyn_cast(srcEncoding); + auto dst = dyn_cast(dstEncoding); + if (!src || !dst) + return false; + // when #mma = MmaEncoding + return src && dst && src.getVersionMajor() == 3 && + src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 && + dst.getWarpsPerCTA()[1] == 1; +} + +bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding()); +} + +// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mmaLayout = cast(srcLayout); + auto dotOperandLayout = cast(dstLayout); + int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); + auto ans = mmaLayout.getVersionMajor() == 3 && + dotOperandLayout.getOpIdx() == 0 && + isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) && + (elementTypeSize == 16 || elementTypeSize == 8); + return ans; +} + +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) + return true; + // dot_op = #mma + // when #mma = MmaEncoding + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mmaLayout = mlir::cast(srcLayout); + auto dotOperandLayout = mlir::cast(dstLayout); + return mmaLayout.getVersionMajor() == 2 && + mmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && + dotOperandLayout.getParent() == mmaLayout && + !srcTy.getElementType().isF32(); +} + +namespace { + +/// A data structure similar to SetVector but maintains +/// a deque instead of a vector to allow for efficient +/// push_back and pop_front operations. +/// Using SetVector doesn't suffice our needs because +/// it only pushes and pops from the back. +/// For example, if we have a queue like this: +/// 0->4 1->2->3 +/// ^-------- +/// where 3 depends on 4, once we pop 3, we found +/// 4 is not ready, so we check 2 and push 3 back +/// to the queue. +struct DFSSubgraphState { + DFSSubgraphState() : set(), deque() {} + DenseSet set; + std::deque deque; + + bool push_back(Operation *op) { + if (set.insert(op).second) { + deque.push_back(op); + return true; + } + return false; + } + + Operation *pop_front() { + Operation *op = deque.front(); + deque.pop_front(); + set.erase(op); + return op; + } + + bool empty() { return deque.empty(); } +}; + +/// DFS post-order implementation that maintains a global count to work across +/// multiple invocations, to help implement topological sort on multi-root DAGs. +/// We traverse all operations but only record the ones that appear in +/// `toSort` for the final result. +struct DFSState { + DFSState(const SetVector &set) : toSort(set), seen() {} + const SetVector &toSort; + SmallVector topologicalCounts; + DenseSet seen; + + /// We mark each op as ready if all its operands and parents ops are seen. If + /// an op is ready, we add it to the queue. Otherwise, we keep adding its + /// operands to the ancestors set. + /// We always want an op to be scheduled after all its parents to handle + /// correctly cases with scf operations. + void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph, + SmallVector &readyQueue) { + bool ready = true; + for (Value operand : op->getOperands()) { + auto def = operand.getDefiningOp(); + if (def && !seen.count(def)) { + subGraph.push_back(def); + ready = false; + } + } + Operation *parent = op->getParentOp(); + while (parent) { + if (!seen.count(parent)) { + subGraph.push_back(parent); + ready = false; + } + parent = parent->getParentOp(); + } + if (ready) + readyQueue.push_back(op); + } +}; + +void dfsPostorder(Operation *root, DFSState *state) { + DFSSubgraphState subGraph; + subGraph.push_back(root); + SmallVector ops; + while (!subGraph.empty()) { + // Nodes in the ready queue are ready to be processed. + // Meaning that either their operands are all seen or they have null + // operands. + SmallVector readyQueue; + auto *current = subGraph.pop_front(); + state->addToReadyQueue(current, subGraph, readyQueue); + while (!readyQueue.empty()) { + Operation *current = readyQueue.pop_back_val(); + if (!state->seen.insert(current).second) + continue; + ops.push_back(current); + for (Value result : current->getResults()) { + for (Operation *op : result.getUsers()) + state->addToReadyQueue(op, subGraph, readyQueue); + } + for (Region ®ion : current->getRegions()) { + for (Operation &op : region.getOps()) + state->addToReadyQueue(&op, subGraph, readyQueue); + } + } + } + + for (Operation *op : llvm::reverse(ops)) { + if (state->toSort.count(op) > 0) + state->topologicalCounts.push_back(op); + } +} + +} // namespace + +SetVector +multiRootTopologicalSort(const SetVector &toSort) { + if (toSort.empty()) { + return toSort; + } + + // Run from each root with global count and `seen` set. + DFSState state(toSort); + for (auto *s : toSort) { + assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); + dfsPostorder(s, &state); + } + + // Reorder and return. + SetVector res; + for (auto it = state.topologicalCounts.rbegin(), + eit = state.topologicalCounts.rend(); + it != eit; ++it) { + res.insert(*it); + } + return res; +} + +SetVector multiRootGetSlice(Operation *op, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter) { + SetVector slice; + slice.insert(op); + + unsigned currentIndex = 0; + SetVector backwardSlice; + SetVector forwardSlice; + while (currentIndex != slice.size()) { + auto *currentOp = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentOp. + backwardSlice.clear(); + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = backwardFilter; + getBackwardSlice(currentOp, &backwardSlice, opt); + slice.insert(backwardSlice.begin(), backwardSlice.end()); + + // Compute and insert the forwardSlice starting from currentOp. + forwardSlice.clear(); + getForwardSlice(currentOp, &forwardSlice, forwardFilter); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + ++currentIndex; + } + return multiRootTopologicalSort(slice); +} + +namespace { +// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis +// interacts with constant propagation, but SparseConstantPropagation +// doesn't seem to be sufficient. +class ConstantAnalysis : public DataFlowAnalysis { +public: + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override { + WalkResult result = top->walk([&](Operation *op) { + if (failed(visit(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + + LogicalResult visit(ProgramPoint point) override { + Operation *op = point.get(); + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + auto *constant = getOrCreate>( + op->getResult(0)); + propagateIfChanged(constant, constant->join(dataflow::ConstantValue( + value, op->getDialect()))); + return success(); + } + // Dead code analysis requires every operands has initialized ConstantValue + // state before it is visited. + // https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322 + // That's why we need to set all operands to unknown constants. + setAllToUnknownConstants(op->getResults()); + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) + setAllToUnknownConstants(block.getArguments()); + } + return success(); + } + +private: + /// Set all given values as not constants. + void setAllToUnknownConstants(ValueRange values) { + dataflow::ConstantValue unknownConstant(nullptr, nullptr); + for (Value value : values) { + auto *constant = + getOrCreate>(value); + propagateIfChanged(constant, constant->join(unknownConstant)); + } + } +}; +} // namespace + +std::unique_ptr createDataFlowSolver() { + auto solver = std::make_unique(); + solver->load(); + solver->load(); + return solver; +} + +static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { + + if (auto makeTensorPtrOp = dyn_cast(op)) { + return makeTensorPtrOp; + } + + if (auto advanceOp = dyn_cast(op)) { + return getMakeTensorPtrOp(advanceOp.getPtr()); + } + + if (auto branch = dyn_cast(op)) { + auto idx = cast(v).getResultNumber(); + llvm::SmallVector yieldOps; + op->walk([&](Operation *op) { + if (auto yieldOp = dyn_cast(op)) + yieldOps.push_back(yieldOp); + }); + + // benzh@ if multi yields, all yields operand should come from same arg. + Value newValue = yieldOps[0].getOperands()[idx]; + return getMakeTensorPtrOp(newValue); + } + + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +MakeTensorPtrOp getMakeTensorPtrOp(Value v) { + using BranchOps = llvm::SetVector>; + llvm::DenseMap blockToCFOps; + auto moduleOp = + v.getParentBlock()->getParentOp()->getParentOfType(); + + moduleOp.walk([&](Operation *op) { + if (auto br = dyn_cast(op)) { + Block *block = br.getDest(); + blockToCFOps[block].insert({op, -1}); + } + if (auto condBr = dyn_cast(op)) { + Block *blockT = condBr.getTrueDest(); + Block *blockF = condBr.getFalseDest(); + blockToCFOps[blockT].insert({condBr, 1}); + blockToCFOps[blockF].insert({condBr, 0}); + } + }); + + if (Operation *definingOp = v.getDefiningOp()) + return getMakeTensorPtrOpImpl(definingOp, v); + + // If there is no defining op, v must be a BlockArgument. + BlockArgument arg = cast(v); + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + + if (auto forOp = dyn_cast(argOwner)) + return getMakeTensorPtrOp( + forOp.getOperand(argNum + forOp.getNumControlOperands() - 1)); + if (auto funcOp = dyn_cast(argOwner)) { + Block *block = arg.getOwner(); + Operation *op; + int tOrF; + std::tie(op, tOrF) = blockToCFOps[block][0]; + if (auto br = dyn_cast(op)) + return getMakeTensorPtrOp(br.getDestOperands()[argNum]); + if (auto condBr = dyn_cast(op)) + return getMakeTensorPtrOp(tOrF ? condBr.getTrueDestOperands()[argNum] + : condBr.getFalseDestOperands()[argNum]); + return getMakeTensorPtrOp(argOwner->getOperand(argNum)); + } + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +} // namespace mlir diff --git a/third_party/mthreads/lib/CMakeLists.txt b/third_party/mthreads/lib/CMakeLists.txt new file mode 100644 index 000000000..c58b7fa0a --- /dev/null +++ b/third_party/mthreads/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(Analysis) +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) +add_subdirectory(Tools) diff --git a/third_party/mthreads/lib/Conversion/CMakeLists.txt b/third_party/mthreads/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..143a4375a --- /dev/null +++ b/third_party/mthreads/lib/Conversion/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonToTritonGPU) +add_subdirectory(TritonGPUToLLVM) diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp new file mode 100644 index 000000000..aae9faf0e --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -0,0 +1,69 @@ +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_ALLOCATESHAREDMEMORY +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +struct AllocateSharedMemory + : public mlir::triton::impl::AllocateSharedMemoryBase< + AllocateSharedMemory> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + ModuleAllocation allocation(mod); + + mod.walk([&](FunctionOpInterface funcOp) { + funcOp.walk([&](Operation *op) { + auto *funcAllocation = allocation.getFuncData(funcOp); + auto oBufferId = funcAllocation->getBufferId(op); + int offset = -1; + if (oBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(oBufferId); + else if (op->getNumResults() == 1) { + Value value = op->getResult(0); + auto vBufferId = funcAllocation->getBufferId(value); + if (vBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(vBufferId); + } + if (offset == -1) + return; + op->setAttr("allocation.offset", + IntegerAttr::get(IntegerType::get(ctx, 32), offset)); + }); + }); + mod->setAttr("triton_gpu.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + allocation.getSharedMemorySize())); + } +}; + +} // namespace + +namespace mlir { + +namespace triton { + +namespace gpu { + +std::unique_ptr> createAllocateSharedMemoryPass() { + return std::make_unique(); +} + +} // namespace gpu + +} // namespace triton + +} // namespace mlir diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp new file mode 100644 index 000000000..a3f55f1e7 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -0,0 +1,80 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; + +struct AssertOpConversion : public ConvertOpToLLVMPattern { + explicit AssertOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter); + auto elemTy = elems[0].getType(); + Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0); + for (auto elem : elems) { + if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) { + condition = + or_(condition, + icmp_eq(elem, rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)))); + } else { + assert(false && "Unsupported type for assert"); + return failure(); + } + } + llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(), + adaptor.getFunc(), adaptor.getLine(), rewriter); + rewriter.eraseOp(op); + return success(); + } + // op: the op at which the assert is inserted. Unlike printf, we need to + // know about the op to split the block. + void llAssert(Operation *op, Value condition, StringRef message, + StringRef file, StringRef func, int line, + ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + auto ctx = rewriter.getContext(); + auto loc = op->getLoc(); + // #block1 + // if (condition) { + // #block2 + // __assertfail(message); + // } + // #block3 + Block *prevBlock = op->getBlock(); + + Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToStart(ifBlock); + targetInfo.assertFail(rewriter, loc, message, file, func, line); + + // Split a block after the call. + Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(ifBlock); + rewriter.create(loc, thenBlock); + rewriter.setInsertionPointToEnd(prevBlock); + rewriter.create(loc, condition, ifBlock, thenBlock); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateAssertOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..8820d2147 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,43 @@ +if (FLAGTREE_BACKEND) + set(NVGPUIR "") +else() + set(NVGPUIR "NVGPUIR") +endif() + +add_triton_library(TritonGPUToLLVM + ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp + DotOpToLLVM/FMA.cpp + TypeConverter.cpp + Utility.cpp + ElementwiseOpToLLVM.cpp + MemoryOpToLLVM.cpp + AssertOpToLLVM.cpp + ViewOpToLLVM.cpp + MakeRangeOpToLLVM.cpp + HistogramOpToLLVM.cpp + AllocateSharedMemory.cpp + ReduceOpToLLVM.cpp + ScanOpToLLVM.cpp + ConvertLayoutOpToLLVM.cpp + ControlFlowOpToLLVM.cpp + FuncOpToLLVM.cpp + SPMDOpToLLVM.cpp + DecomposeUnsupportedConversions.cpp + PrintOpToLLVM.cpp + + DEPENDS + TritonGPUConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRGPUDialect + MLIRGPUToNVVMTransforms + MLIRGPUToROCDLTransforms + MLIRGPUTransforms + TritonAnalysis + TritonIR + TritonGPUIR + TritonGPUTransforms + ${NVGPUIR} +) diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp new file mode 100644 index 000000000..147b2736f --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -0,0 +1,141 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + if (funcOp->hasAttr("nvvm.kernel") || funcOp->hasAttr("mtgpu.kernel")) { + // A GPU kernel + if (op.getNumOperands() > 0) { + return rewriter.notifyMatchFailure( + op, "Kernel functions do not support return with operands"); + } + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + } else { + // A device function + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + auto loc = op.getLoc(); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = insert_val(packedResultsTy, packedResults, it.value(), + it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the last argument of the caller, which is the current stack pointer + // of shared memory and append it to the operands of the callOp. + auto loc = callOp.getLoc(); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + if (!caller->hasAttr("allocation.offset")) { + auto base = LLVM::getStackPointer(rewriter, caller); + promotedOperands.push_back(base); + return promotedOperands; + } + promotedOperands.push_back( + LLVM::getSharedMemoryBase(callOp->getLoc(), rewriter, callOp)); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } +}; + +} // namespace + +void mlir::triton::populateControlFlowOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 000000000..94894ceb1 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,324 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +using mlir::isLayoutMmaV1; +using mlir::LLVM::getMultiDimOffset; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getStridesFromShapeAndOrder; +using mlir::LLVM::getWrappedMultiDimOffset; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::isaDistributedLayout; +using ::mlir::triton::gpu::SharedEncodingAttr; + +namespace { + +struct LocalLoadOpConversion + : public ConvertOpToLLVMPattern { +public: + LocalLoadOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + // TODO: do we need to check if src is shared ? + if (isa(srcLayout) && isaDistributedLayout(dstLayout)) { + return lowerSharedToDistributed(op, adaptor, getTypeConverter(), + rewriter); + } + if (isa(dstLayout) && + isa( + cast(dstLayout).getParent())) { + return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter); + } + return failure(); + } + +private: + LogicalResult + lowerSharedToDotOpFMA(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + RankedTensorType dstTy = op.getType(); + Attribute dstLayout = dstTy.getEncoding(); + auto dotLayout = cast(dstLayout); + auto blockedLayout = cast( + cast(dstLayout).getParent()); + auto thread = getThreadId(rewriter, loc); + Value res = SharedToDotOperandFMA::convertLayout( + dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, + thread, loc, getTypeConverter(), rewriter); + rewriter.replaceOp(op, res); + return success(); + } + LogicalResult + lowerSharedToDistributed(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + auto dstShape = dstTy.getShape(); + assert(dstShape.size() <= 2 && + "Unexpected rank of ConvertLayout(shared->blocked)"); + auto srcSharedLayout = cast(srcTy.getEncoding()); + auto dstLayout = dstTy.getEncoding(); + auto inOrd = getOrder(srcSharedLayout); + + auto smemObj = getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), + typeConverter->convertType(srcTy.getElementType()), rewriter); + auto elemTy = typeConverter->convertType(dstTy.getElementType()); + + auto srcStrides = + getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter); + + SmallVector outVals = + loadSharedToDistributed(op.getResult(), op.getSrc(), smemObj, elemTy, + loc, rewriter, targetInfo); + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct ConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { +public: + ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isSupported(srcLayout, dstLayout)) { + return lowerDistributedToDistributed(op, adaptor, rewriter); + } + return failure(); + } + +private: + bool isSupported(Attribute srcLayout, Attribute dstLayout) const { + return isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout) && + !isLayoutMmaV1(srcLayout) && !isLayoutMmaV1(dstLayout); + } + // shared memory rd/st for blocked or mma layout with data padding + void processReplica(Location loc, ConversionPatternRewriter &rewriter, + bool stNotRd, RankedTensorType type, + ArrayRef numCTAsEachRep, + ArrayRef multiDimRepId, unsigned vec, + ArrayRef paddedRepShape, + ArrayRef origRepShape, + ArrayRef outOrd, SmallVector &vals, + Value smemBase) const { + auto accumNumCTAsEachRep = product(numCTAsEachRep); + auto layout = type.getEncoding(); + auto rank = type.getRank(); + auto sizePerThread = getSizePerThread(layout); + auto accumSizePerThread = product(sizePerThread); + SmallVector numCTATiles(rank); + auto shapePerCTATile = getShapePerCTATile(layout); + auto shapePerCTA = getShapePerCTA(layout, type.getShape()); + auto order = getOrder(layout); + for (unsigned d = 0; d < rank; ++d) { + numCTATiles[d] = ceil(shapePerCTA[d], shapePerCTATile[d]); + } + auto elemTy = type.getElementType(); + bool isInt1 = elemTy.isInteger(1); + bool isPtr = isa(elemTy); + auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy); + if (isInt1) + elemTy = IntegerType::get(elemTy.getContext(), 8); + else if (isPtr) + elemTy = IntegerType::get(elemTy.getContext(), 64); + + auto llvmElemTy = getTypeConverter()->convertType(elemTy); + + for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { + auto multiDimCTAInRepId = + getMultiDimIndex(ctaId, numCTAsEachRep, order); + SmallVector multiDimCTAId(rank); + for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) { + auto d = it.index(); + multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); + } + + auto linearCTAId = + getLinearIndex(multiDimCTAId, numCTATiles, order); + // TODO: This is actually redundant index calculation, we should + // consider of caching the index calculation result in case + // of performance issue observed. + for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { + SmallVector multiDimOffset = + getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, + multiDimCTAInRepId, shapePerCTATile); + SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( + rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, + shapePerCTA); + Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, + paddedRepShape, outOrd); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); + auto vecTy = vec_ty(llvmElemTy, vec); + ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); + if (stNotRd) { + Value valVec = undef(vecTy); + for (unsigned v = 0; v < vec; ++v) { + auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v]; + if (isInt1) + currVal = zext(llvmElemTy, currVal); + else if (isPtr) + currVal = ptrtoint(llvmElemTy, currVal); + valVec = insert_element(vecTy, valVec, currVal, i32_val(v)); + } + store(valVec, ptr); + } else { + Value valVec = load(vecTy, ptr); + for (unsigned v = 0; v < vec; ++v) { + Value currVal = extract_element(llvmElemTy, valVec, i32_val(v)); + if (isInt1) + currVal = icmp_ne(currVal, + rewriter.create( + loc, i8_ty, rewriter.getI8IntegerAttr(0))); + else if (isPtr) + currVal = inttoptr(llvmElemTyOrig, currVal); + vals[elemId + linearCTAId * accumSizePerThread + v] = currVal; + } + } + } + } + } + // blocked/mma -> blocked/mma. + // Data padding in shared memory to avoid bank conflict. + LogicalResult + lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto typeConverter = getTypeConverter(); + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (product(srcTy.getShape()) == 1) { + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals(getTotalElemsPerThread(dstTy), inVals[0]); + Value result = + packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + return success(); + } + + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + smemBase = bitcast(smemBase, elemPtrTy); + auto shape = dstTy.getShape(); + unsigned rank = dstTy.getRank(); + SmallVector numReplicates(rank); + SmallVector inNumCTAsEachRep(rank); + SmallVector outNumCTAsEachRep(rank); + SmallVector inNumCTAs(rank); + SmallVector outNumCTAs(rank); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); + auto shapePerCTA = getShapePerCTA(srcLayout, shape); + + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = + std::min(shapePerCTA[d], srcShapePerCTATile[d]); + unsigned outPerCTA = + std::min(shapePerCTA[d], dstShapePerCTATile[d]); + unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); + numReplicates[d] = ceil(shapePerCTA[d], maxPerCTA); + inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; + outNumCTAsEachRep[d] = maxPerCTA / outPerCTA; + assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0); + inNumCTAs[d] = ceil(shapePerCTA[d], inPerCTA); + outNumCTAs[d] = ceil(shapePerCTA[d], outPerCTA); + } + // Potentially we need to store for multiple CTAs in this replication + auto accumNumReplicates = product(numReplicates); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + unsigned inVec = 0; + unsigned outVec = 0; + auto origRepShape = getRepShapeForCvtLayout(op); + auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); + + unsigned outElems = getTotalElemsPerThread(dstTy); + auto outOrd = getOrder(dstLayout); + SmallVector outVals(outElems); + + for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { + auto multiDimRepId = + getMultiDimIndex(repId, numReplicates, outOrd); + if (repId != 0) { + barrier(); + } + auto successful = targetInfo.processReplicaUsingStMatrix( + rewriter, loc, smemBase, vals, srcTy, + getTypeConverter()->convertType(srcTy.getElementType()), + paddedRepShape, origRepShape, outOrd, accumNumReplicates); + if (!successful) { + processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, + multiDimRepId, inVec, paddedRepShape, origRepShape, + outOrd, vals, smemBase); + } + barrier(); + processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, + multiDimRepId, outVec, paddedRepShape, origRepShape, + outOrd, outVals, smemBase); + } + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp new file mode 100644 index 000000000..b7bd5fbc3 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -0,0 +1,234 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using ValueTable = std::map, Value>; +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getStridesFromShapeAndOrder; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::isaDistributedLayout; +using ::mlir::triton::gpu::SharedEncodingAttr; + +SmallVector +getThreadIds(Value threadId, ArrayRef shapePerCTATile, + ArrayRef sizePerThread, ArrayRef order, + ConversionPatternRewriter &rewriter, Location loc) { + int dim = order.size(); + SmallVector threadIds(dim); + for (unsigned k = 0; k < dim - 1; k++) { + Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]); + Value rem = urem(threadId, dimK); + threadId = udiv(threadId, dimK); + threadIds[order[k]] = rem; + } + Value dimK = i32_val(shapePerCTATile[order[dim - 1]]); + threadIds[order[dim - 1]] = urem(threadId, dimK); + return threadIds; +} + +// Get shapePerCTATile for M or N axis. +int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto shapePerCTATile = getShapePerCTATile(layout); + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + return isM ? mShapePerCTATile : nShapePerCTATile; +} + +// Get sizePerThread for M or N axis. +int getSizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto sizePerThread = getSizePerThread(layout); + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + return isM ? mSizePerThread : nSizePerThread; +} + +Value getStructFromValueTable(ArrayRef vals, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, + Type elemTy) { + SmallVector elemTypes(vals.size(), elemTy); + SmallVector elems; + elems.reserve(vals.size()); + for (auto &val : vals) { + elems.push_back(val); + } + MLIRContext *ctx = elemTy.getContext(); + Type structTy = struct_ty(elemTypes); + return packLLElements(loc, typeConverter, elems, rewriter, structTy); +} + +ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, + int sizePerThread, + ConversionPatternRewriter &rewriter, + Location loc, + const LLVMTypeConverter *typeConverter, + Type type) { + ValueTable res; + auto elems = unpackLLElements(loc, val, rewriter); + int index = 0; + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTA) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + res[{m + mm, k}] = elems[index++]; + } + } + return res; +} + +Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, + Location loc, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto aTensorTy = cast(A.getType()); + auto aLayout = cast(aTensorTy.getEncoding()); + auto aShapePerCTA = getShapePerCTA(aTensorTy); + + auto aOrder = aLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isARow = aOrder[0] == 1; + + auto aSmem = getSharedMemoryObjectFromStruct( + loc, llA, typeConverter->convertType(aTensorTy.getElementType()), + rewriter); + Value strideAM = aSmem.strides[0]; + Value strideAK = aSmem.strides[1]; + Value strideA0 = isARow ? strideAK : strideAM; + Value strideA1 = isARow ? strideAM : strideAK; + int aNumPtr = 8; + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; + + auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value mContig = i32_val(sizePerThread[order[1]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); + Value threadIdM = threadIds[0]; + + Value offA0 = isARow ? _0 : mul(threadIdM, mContig); + Value offA1 = isARow ? mul(threadIdM, mContig) : _0; + SmallVector aOff(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) { + aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1)); + } + auto elemTy = typeConverter->convertType(aTensorTy.getElementType()); + + Type ptrTy = ptr_ty(rewriter.getContext(), 3); + SmallVector aPtrs(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) + aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); + + SmallVector vas; + + int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/); + int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/); + + for (unsigned k = 0; k < K; ++k) + for (unsigned m = 0; m < M; m += mShapePerCTATile) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) { + Value offset = + add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); + Value pa = gep(ptrTy, elemTy, aPtrs[0], offset); + Value va = load(elemTy, pa); + vas.emplace_back(va); + } + + return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy); +} + +Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, + Location loc, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto bTensorTy = cast(B.getType()); + auto bLayout = cast(bTensorTy.getEncoding()); + auto bShapePerCTA = getShapePerCTA(bTensorTy); + + auto bOrder = bLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isBRow = bOrder[0] == 1; + + auto bSmem = getSharedMemoryObjectFromStruct( + loc, llB, typeConverter->convertType(bTensorTy.getElementType()), + rewriter); + Value strideBN = bSmem.strides[1]; + Value strideBK = bSmem.strides[0]; + Value strideB0 = isBRow ? strideBN : strideBK; + Value strideB1 = isBRow ? strideBK : strideBN; + int bNumPtr = 8; + int K = bShapePerCTA[0]; + int N = bShapePerCTA[1]; + + auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value nContig = i32_val(sizePerThread[order[0]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); + Value threadIdN = threadIds[1]; + + Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; + Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); + SmallVector bOff(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) { + bOff[i] = add(mul(offB0, strideB0), mul(offB1, strideB1)); + } + auto elemTy = typeConverter->convertType(bTensorTy.getElementType()); + + Type ptrTy = ptr_ty(rewriter.getContext(), 3); + SmallVector bPtrs(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) + bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); + + SmallVector vbs; + + int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/); + int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/); + + for (unsigned k = 0; k < K; ++k) + for (unsigned n = 0; n < N; n += nShapePerCTATile) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + Value offset = + add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); + Value pb = gep(ptrTy, elemTy, bPtrs[0], offset); + Value vb = load(elemTy, pb); + vbs.emplace_back(vb); + } + + return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy); +} + +namespace SharedToDotOperandFMA { +Value convertLayout(int opIdx, Value val, Value llVal, + BlockedEncodingAttr dLayout, Value thread, Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + if (opIdx == 0) + return loadAFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); + else + return loadBFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); +} +} // namespace SharedToDotOperandFMA diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp new file mode 100644 index 000000000..690155ee5 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -0,0 +1,116 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +static void addAttrs(Operation *op, ArrayRef attrs) { + for (const NamedAttribute attr : attrs) + op->setAttr(attr.getName(), attr.getValue()); +} + +} // namespace + +namespace mlir::triton::gpu { + +void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) { + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + module.walk([&](triton::SplatOp splatOp) -> void { + auto dstType = cast(splatOp.getType()); + auto shared = + dyn_cast(dstType.getEncoding()); + if (shared) { + OpBuilder builder(splatOp); + SmallVector sizePerThread(dstType.getRank(), 1); + auto newType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::BlockedEncodingAttr::get( + module.getContext(), dstType.getShape(), sizePerThread, + getOrder(shared), numWarps, threadsPerWarp, numCTAs)); + auto newSplat = builder.create(splatOp.getLoc(), newType, + splatOp.getSrc()); + auto newConvert = builder.create( + splatOp.getLoc(), dstType, newSplat.getResult()); + splatOp.replaceAllUsesWith(newConvert.getResult()); + splatOp.erase(); + } + }); +} + +template +void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, + ShortcutFn shortcutFn) { + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + + module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcMma = dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (srcMma && dstDotOp && !shortcutFn(srcType, dstType)) { + auto tmpType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::BlockedEncodingAttr::get( + module.getContext(), srcType.getShape(), getSizePerThread(srcMma), + getOrder(srcMma), numWarps, threadsPerWarp, numCTAs)); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + addAttrs(tmp, cvtOp->getAttrs()); + auto newConvert = builder.create( + cvtOp.getLoc(), dstType, tmp); + addAttrs(newConvert, cvtOp->getAttrs()); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); +} + +template void decomposeTensorCoreToDotLayoutConversion< + triton::gpu::NvidiaMmaEncodingAttr>(ModuleOp, ShortcutFn); +template void + decomposeTensorCoreToDotLayoutConversion( + ModuleOp, ShortcutFn); + +void decomposeBlockedToDotLayoutConversion(ModuleOp module) { + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcBlocked = + dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (srcBlocked && dstDotOp) { + auto tmpType = MemDescType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get( + module.getContext(), dstDotOp, srcType.getShape(), + srcBlocked.getOrder(), srcBlocked.getCTALayout(), + srcType.getElementType())); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + addAttrs(tmp, cvtOp->getAttrs()); + auto newConvert = builder.create(cvtOp.getLoc(), + dstType, tmp); + addAttrs(newConvert, cvtOp->getAttrs()); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 000000000..114974b3c --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,121 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; + +using ValueTableFMA = std::map, Value>; + +static ValueTableFMA +getValueTableFromStructFMA(Value val, int K, int n0, int shapePerCTATile, + int sizePerThread, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Type type) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + int index = 0; + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTATile) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + res[{m + mm, k}] = elems[index++]; + } + } + return res; +} + +static Value extendfp16Andbf16(Location loc, Value v, + ConversionPatternRewriter &rewriter) { + if (v.getType() != f16_ty && v.getType() != bf16_ty) + return v; + if (v.getType() == f16_ty) + return rewriter.create(loc, f32_ty, v); + if (v.getType() == bf16_ty) { + auto as_int16 = bitcast(v, i16_ty); + auto as_int32 = zext(i32_ty, as_int16); + auto shifted = shl(i32_ty, as_int32, i32_val(16)); + return bitcast(shifted, f32_ty); + } + llvm_unreachable("unreachable"); + return nullptr; +} + +LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto B = op.getB(); + auto C = op.getC(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto bTensorTy = cast(B.getType()); + auto dTensorTy = cast(D.getType()); + + auto aShapePerCTA = getShapePerCTA(aTensorTy); + auto bShapePerCTA = getShapePerCTA(bTensorTy); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + auto order = dLayout.getOrder(); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = getSizePerThread(dLayout); + auto shapePerCTATile = getShapePerCTATile(dLayout); + + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; + int N = bShapePerCTA[1]; + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + + auto has = + getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread, + rewriter, loc, typeConverter, aTensorTy); + auto hbs = + getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread, + rewriter, loc, typeConverter, bTensorTy); + + SmallVector ret = cc; + bool isCRow = order[0] == 1; + + for (unsigned k = 0; k < K; k++) { + for (unsigned m = 0; m < M; m += mShapePerCTATile) + for (unsigned n = 0; n < N; n += nShapePerCTATile) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + int mIdx = m / mShapePerCTATile * mSizePerThread + mm; + int nIdx = n / nShapePerCTATile * nSizePerThread + nn; + + int z = isCRow + ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx + : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; + // FIXME: ph1 support fp16 and bf16 fma, no need to convert + Value a = extendfp16Andbf16(loc, has[{m + mm, k}], rewriter); + Value b = extendfp16Andbf16(loc, hbs[{n + nn, k}], rewriter); + Value c = extendfp16Andbf16(loc, ret[z], rewriter); + ret[z] = rewriter.create(loc, a, b, c); + } + } + + auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 000000000..0287207be --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,839 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir::triton::gpu; + +namespace mlir::triton::gpu { + +Type getElementType(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) + return tensorType.getElementType(); + return type; +} +// MMA encoding has a different order depending on the element's bit width; +// reorder if we're in this case. +SmallVector reorderValues(const SmallVector &values, Type inType, + Type ouType) { + auto inTensorTy = dyn_cast(inType); + auto ouTensorTy = dyn_cast(ouType); + if (!inTensorTy || !ouTensorTy) + return values; + auto inEncoding = dyn_cast(inTensorTy.getEncoding()); + auto ouEncoding = dyn_cast(ouTensorTy.getEncoding()); + assert(inEncoding == ouEncoding); + if (!inEncoding) + return values; + // If the parent of the dot operand is in block encoding, we don't need to + // reorder elements + auto parentEncoding = dyn_cast(ouEncoding.getParent()); + if (!parentEncoding) + return values; + size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); + size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); + auto ouEltTy = ouTensorTy.getElementType(); + if (inBitWidth == ouBitWidth) + return values; + if (inBitWidth == 16 && ouBitWidth == 32) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 8) { + ret.push_back(values[i]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + } + return ret; + } + if (inBitWidth == 8 && ouBitWidth == 16) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 16) { + ret.push_back(values[i + 0]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 8]); + ret.push_back(values[i + 9]); + ret.push_back(values[i + 10]); + ret.push_back(values[i + 11]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + ret.push_back(values[i + 12]); + ret.push_back(values[i + 13]); + ret.push_back(values[i + 14]); + ret.push_back(values[i + 15]); + } + return ret; + } + llvm_unreachable("unimplemented code path"); +} + +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); + auto vecType = vec_ty(eltType, vecWidth); + for (int i = 0; i < inValues.size(); i += vecWidth) { + Value vec = undef(vecType); + for (int j = 0; j < vecWidth; j++) { + vec = insert_element(vec, inValues[i + j], i32_val(j)); + } + outValues.push_back(bitcast(vec, i32_ty)); + } + return outValues; +} + +int getNumElementsPerThreads(Type type, + const LLVMTypeConverter *typeConverter) { + int numElemsPerThread = 1; + auto tensorTy = dyn_cast(type); + if (!tensorTy) + return numElemsPerThread; + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (structType) { + numElemsPerThread = structType.getBody().size(); + } + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return numElemsPerThread; + auto eltType = tensorTy.getElementType(); + assert(eltType.getIntOrFloatBitWidth() <= 32 && + "Only support element type with bit width <= 32 in dot operand mma " + "layout"); + // dot operand data are packed into i32 elements so use the following formula + // to get the number of elements per thread. + return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread; +} + +} // namespace mlir::triton::gpu + +namespace { +struct AddPtrOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = op.getType(); + auto typeConverter = getTypeConverter(); + auto resultTensorTy = dyn_cast(resultTy); + if (resultTensorTy) { + unsigned elems = getTotalElemsPerThread(resultTy); + Type elemTy = typeConverter->convertType( + cast(resultTensorTy.getElementType()).getPointeeType()); + Type ptrTy = typeConverter->convertType(resultTensorTy.getElementType()); + auto ptrs = unpackLLElements(loc, adaptor.getPtr(), rewriter); + auto offsets = unpackLLElements(loc, adaptor.getOffset(), rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = gep(ptrTy, elemTy, ptrs[i], offsets[i]); + } + Value view = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, view); + } else { + assert(isa(resultTy)); + auto resultPtrTy = typeConverter->convertType(resultTy); + auto resultElemTy = typeConverter->convertType( + cast(resultTy).getPointeeType()); + Value result = + gep(resultPtrTy, resultElemTy, adaptor.getPtr(), adaptor.getOffset()); + rewriter.replaceOp(op, result); + } + return success(); + } +}; + +struct CmpIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, + MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create( + loc, elemTy, ArithCmpIPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::ICmpPredicate + ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__) \ + case arith::CmpIPredicate::item__: \ + return LLVM::ICmpPredicate::item__ + + __PRED_ENUM(eq); + __PRED_ENUM(ne); + __PRED_ENUM(sgt); + __PRED_ENUM(sge); + __PRED_ENUM(slt); + __PRED_ENUM(sle); + __PRED_ENUM(ugt); + __PRED_ENUM(uge); + __PRED_ENUM(ult); + __PRED_ENUM(ule); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpIPredicate"); + } +}; + +struct CmpFOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + static SmallVector + createDestOps(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + MultipleOperandsRange operands, Location loc) { + return {rewriter.create( + loc, elemTy, ArithCmpFPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::FCmpPredicate + ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__, item1__) \ + case arith::CmpFPredicate::item__: \ + return LLVM::FCmpPredicate::item1__ + + __PRED_ENUM(OEQ, oeq); + __PRED_ENUM(ONE, one); + __PRED_ENUM(OGT, ogt); + __PRED_ENUM(OGE, oge); + __PRED_ENUM(OLT, olt); + __PRED_ENUM(OLE, ole); + __PRED_ENUM(ORD, ord); + __PRED_ENUM(UEQ, ueq); + __PRED_ENUM(UGT, ugt); + __PRED_ENUM(UGE, uge); + __PRED_ENUM(ULT, ult); + __PRED_ENUM(ULE, ule); + __PRED_ENUM(UNE, une); + __PRED_ENUM(UNO, uno); + __PRED_ENUM(AlwaysTrue, _true); + __PRED_ENUM(AlwaysFalse, _false); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpFPredicate"); + } +}; + +struct MulhiUIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + explicit MulhiUIOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(MulhiUIOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + + Type resultElementTy = getElementTypeOrSelf(op.getResult().getType()); + assert(resultElementTy.isInteger(32) || resultElementTy.isInteger(64)); + + auto funcName = targetInfo.getMulhiFuncName(resultElementTy); + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +struct ExternElementwiseOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + SmallVector createDestOps(ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + StringRef funcName = op.getSymbol(); + if (funcName.empty()) + llvm::errs() << "ExternElementwiseOpConversion"; + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath()); + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } +}; + +template +struct ElementwiseOpConversion + : public ElementwiseOpConversionBase< + SourceOp, ElementwiseOpConversion> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using OpAdaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0], + adaptor.getAttributes().getValue())}; + } +}; + +struct ElementwiseInlineAsmOpConversion + : public ConvertOpToLLVMPattern { + using Base = ConvertOpToLLVMPattern; + + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + // If operand size is smaller than 32 bits, pack in groups of 32 bits. + SmallVector packOperands(ElementwiseInlineAsmOp op, + MultipleOperandsRange operands, + ConversionPatternRewriter &rewriter, + Location loc) const { + SmallVector packedOperands; + unsigned numPackedElements = op.getPackedElement(); + for (int i = 0, e = op.getNumOperands(); i < e; i++) { + Type elemTy = getElementType(op.getOperand(i)); + unsigned bitWidth = + elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 64; + unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1; + numElementPerReg = std::min(numElementPerReg, numPackedElements); + for (int j = 0; j < numPackedElements; j += numElementPerReg) { + if (numElementPerReg == 1) { + packedOperands.push_back(operands[j][i]); + continue; + } + Type t = + vec_ty(getTypeConverter()->convertType(elemTy), numElementPerReg); + Value packed = undef(t); + for (int k = 0; k < numElementPerReg; k++) { + packed = insert_element(packed, operands[j + k][i], i32_val(k)); + } + packedOperands.push_back(packed); + } + } + return packedOperands; + } + + SmallVector> + createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + MultipleOperandsRange operands, Location loc) const { + auto ctx = op->getContext(); + + if (operands.size() % op.getPackedElement() != 0) + llvm::report_fatal_error("Inline asm op has more packed elements than " + "number of elements per thread."); + + // Pack elems smaller than 32 bits into 32-bit registers. + SmallVector packedOperands = + packOperands(op, operands, rewriter, loc); + + // Types returned by the LLVM asm op. If there's more than one, they'll be + // wrapped in a struct. + SmallVector asmRetTypes; + for (auto result : op.getResult()) { + auto ty = getTypeConverter()->convertType(getElementType(result)); + + // Pack return elements into 32-bits. + unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64; + unsigned numElemsPerReg = + std::min(bitWidth < 32 ? 32 / bitWidth : 1, op.getPackedElement()); + assert(op.getPackedElement() % numElemsPerReg == 0); + if (numElemsPerReg > 1) { + ty = vec_ty(ty, numElemsPerReg); + } + for (unsigned i = 0; i < op.getPackedElement() / numElemsPerReg; i++) { + asmRetTypes.push_back(ty); + } + } + Type asmRetType = + asmRetTypes.size() > 1 ? struct_ty(asmRetTypes) : asmRetTypes[0]; + + Value asmResults = + rewriter + .create( + loc, asmRetType, + /*operands=*/packedOperands, + /*asm_string=*/op.getAsmString(), + /*constraints=*/op.getConstraints(), + /*has_side_effects=*/!op.getPure(), + /*is_align_stack=*/false, + /*asm_dialect=*/ + LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT), + /*operand_attrs=*/ArrayAttr()) + ->getResult(0); + + // asmResults is a flat struct; pack its values into + // [return_value][op.getPackedElement()]. + SmallVector> ret(op->getNumResults()); + for (int i = 0; i < op->getNumResults(); i++) { + int structIdx = 0; + for (int j = 0; j < op.getPackedElement(); j++) { + Value val; + if (asmRetTypes.size() > 1) { + val = + extract_val(asmResults, i * op.getPackedElement() + structIdx++); + } else { + val = asmResults; + } + if (auto vectorTy = dyn_cast(val.getType())) { + for (int k = 0; k < vectorTy.getNumElements(); k++) { + ret[i].push_back(extract_element(val, i32_val(k))); + } + j += vectorTy.getNumElements() - 1; + } else { + ret[i].push_back(val); + } + } + } + return ret; + } + + LogicalResult + matchAndRewrite(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // Layout is unpackedOperands[operand][elem]. + SmallVector> unpackedOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + unpackedOperands.push_back( + unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter())); + } + + int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), + getTypeConverter()); + + // These are checked by the verifier, so we don't need to raise a nice + // error. + assert(all_of(unpackedOperands, [&](auto &operands) { + return operands.size() == numElemsPerThread; + })); + if (numElemsPerThread % op.getPackedElement() != 0) { + // Pad with the undef for each operand to have a multiple of + // op.getPackedElement() elements. + int numPaddedValue = + op.getPackedElement() - numElemsPerThread % op.getPackedElement(); + for (auto &operands : unpackedOperands) { + for (int i = 0; i < numPaddedValue; i++) { + operands.push_back(undef(operands[0].getType())); + } + } + } + + // Run the inline asm op on each block of elements. + // + // Layout is unpackedResults[result_idx][elem]. + // + // This loop always runs at least once, even when the asm has no input + // elements. + SmallVector> unpackedResults(op->getNumResults()); + for (unsigned i = 0; i < numElemsPerThread; i += op.getPackedElement()) { + // Block of elements to process with one call to the inline asm. This is + // ordered opposite `unpackedResults`: The outer dim is + // op.getPackedElement(), and the inner dim is the operand. + SmallVector> block(op.getPackedElement()); + for (auto &os : unpackedOperands) { + for (int j = 0; j < op.getPackedElement(); j++) { + block[j].push_back(os[i + j]); + } + } + auto cur = createDestOps(op, adaptor, rewriter, block, loc); + assert(cur.size() == unpackedResults.size()); + for (unsigned j = 0; j < cur.size(); j++) { + unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), + cur[j].end()); + } + } + for (auto &results : unpackedResults) { + results.resize(numElemsPerThread); + } + // Reorder and pack the results. + SmallVector outs; + for (int i = 0; i < unpackedResults.size(); i++) { + // We reordered all the inputs so they match operand 0. Reorder the + // outputs accordingly. + if (op->getNumOperands() > 0) { + unpackedResults[i] = reorderValues( + unpackedResults[i], /*inType=*/op->getOperand(0).getType(), + /*ouType=*/op->getResult(i).getType()); + } + auto packed = packI32(unpackedResults[i], op->getResult(i).getType(), + rewriter, loc, getTypeConverter()); + outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter, + op->getResult(i).getType())); + } + + rewriter.replaceOp(op, outs); + return success(); + } +}; + +struct AbsIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0][0], + /*is_int_min_poison=*/false)}; + } +}; + +struct AbsFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (llvm::isa(elemTy)) { + // Mask out the sign bit + auto num_bits = + getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); + assert(num_bits <= 16); + auto mask = (1u << (num_bits - 1u)) - 1u; + auto maskAttr = rewriter.getIntegerAttr(elemTy, mask); + auto maskConst = rewriter.create(loc, maskAttr); + return {and_(operands[0][0], maskConst)}; + } + + return {rewriter.create(loc, elemTy, operands[0][0])}; + } +}; +/// The lowering of index_cast becomes an integer conversion since index +/// becomes an integer. If the bit width of the source and target integer +/// types is the same, just erase the cast. If the target type is wider, +/// sign-extend the value, otherwise truncate it. +struct IndexCastOpLowering + : public ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto inElemTy = + this->getTypeConverter()->convertType(getElementType(op.getIn())); + unsigned targetBits = elemTy.getIntOrFloatBitWidth(); + unsigned sourceBits = inElemTy.getIntOrFloatBitWidth(); + + if (targetBits == sourceBits) + return {operands[0][0]}; + if (targetBits < sourceBits) + return {rewriter.replaceOpWithNewOp(op, elemTy, + operands[0][0])}; + return { + rewriter.replaceOpWithNewOp(op, elemTy, operands[0][0])}; + } +}; + +struct SelectOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + std::array llvmOperands; + if (operands[0].size() == 2) { + // Case of scalar condition with tensor operands. + assert(op.getCondition().getType().isInteger(1)); + llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]}; + } else { + llvmOperands = {operands[0][0], operands[0][1], operands[0][2]}; + } + return {rewriter.create( + loc, llvmOperands[1].getType(), llvmOperands, + adaptor.getAttributes().getValue())}; + } +}; +template +struct MinMaxFOpConversion + : ElementwiseOpConversionBase> { + using Base = ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static_assert(std::is_same::value || + std::is_same::value, + "OpTy must be arith::MinimumFOp or arith::MaximumFOp"); + + // Choose the destination op based on the OpTy. + using DestOpNanProp = + typename std::conditional::value, + LLVM::MinimumOp, LLVM::MaximumOp>::type; + using DestOpNoNanProp = + typename std::conditional::value, + LLVM::MinNumOp, LLVM::MaxNumOp>::type; + + explicit MinMaxFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + bool hwNanPropagationSupported, + PatternBenefit benefit = 1) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + hwNanPropagationSupported(hwNanPropagationSupported) {} + + SmallVector createDestOps(OpTy op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (hwNanPropagationSupported) { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + // Handle workaround for NaN propagation, i.e. software emulation of NaN + // propagation. If any of the operands is NaN, return NaN. + auto lhs = operands[0][0]; + auto rhs = operands[0][1]; + auto lhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, lhs, lhs); + auto rhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, rhs, rhs); + auto isNan = rewriter.create(loc, lhsIsNan, rhsIsNan); + auto nonNanRes = rewriter.create(loc, elemTy, lhs, rhs); + + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + +private: + bool hwNanPropagationSupported; +}; + +struct ClampFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit ClampFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // Clip pattern not found, use min/max. + if (op.getPropagateNan() == PropagateNan::ALL) { + if (targetInfo.supportMaximumMinimum()) { + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + // On pre-80 compute capability, we need to handle NaN propagation + // manually. We need to check only the first operand for clamp. + auto lhs = operands[0][0]; + auto isNan = rewriter.create(loc, LLVM::FCmpPredicate::une, + lhs, lhs); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + auto nonNanRes = rewriter.create(loc, v, operands[0][2]); + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + + // No NaN propagation. + assert(op.getPropagateNan() == PropagateNan::NONE); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMinMaxFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool hwNanPropagationSupported, + PatternBenefit benefit) { + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); +} + +void mlir::triton::populateClampFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); +} + +void mlir::triton::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { +#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) + POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) + POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) + POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) + POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) + POPULATE_UNARY_OP(math::FloorOp, math::FloorOp) + POPULATE_UNARY_OP(math::CeilOp, math::CeilOp) + POPULATE_UNARY_OP(math::LogOp, math::LogOp) + POPULATE_UNARY_OP(math::Log2Op, math::Log2Op) + POPULATE_UNARY_OP(math::CosOp, math::CosOp) + POPULATE_UNARY_OP(math::SinOp, math::SinOp) + POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) + POPULATE_UNARY_OP(math::RsqrtOp, math::RsqrtOp) + POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) + POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op) + POPULATE_UNARY_OP(math::ErfOp, math::ErfOp) + POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) + POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) + POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) +#undef POPULATE_UNARY_OP + +#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - + POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * + POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp) + POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp) + POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % + POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) + POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + // fmin (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MinNumFOp, LLVM::MinNumOp) + // fmax (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MaxNumFOp, LLVM::MaxNumOp) + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax + POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin + POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax +#undef POPULATE_BINARY_OP + + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); + patterns.add(typeConverter, axisInfoAnalysis, + benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000..47f40ebec --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,118 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, int numWarps, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + // 1. Modify the function type to add the new argument. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(ptrTy); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add the new argument. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + // 3. Add a new argument to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(ptrTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = funcOp; + if (!LLVM::isKernel(funcOp)) + amendedFuncOp = amendFuncOp(funcOp, rewriter); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + amendedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) { + return failure(); + } + + auto ctx = funcOp->getContext(); + + if (LLVM::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr("nvvm.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + newFuncOp.setLinkage(LLVM::Linkage::External); + } else { + // The noinline attribute will be used by the LLVM codegen to prevent + // inlining. + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + newFuncOp.setPassthroughAttr( + ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + rewriter.eraseOp(amendedFuncOp); + newFuncOp.setLinkage(LLVM::Linkage::Internal); + } + // Set an attribute for maxntidx, it could be used in latter LLVM codegen + // for `nvvm.annotation` metadata. + newFuncOp->setAttr("nvvm.maxntid", + rewriter.getDenseI32ArrayAttr(32 * numWarps)); + rewriter.eraseOp(funcOp); + return success(); + } + +private: + int numWarps{0}; +}; + +} // namespace + +void mlir::triton::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit) { + patterns.add(typeConverter, numWarps, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp new file mode 100644 index 000000000..acf940b3e --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp @@ -0,0 +1,212 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +static int log2Int(int64_t num) { return (num > 1) ? 1 + log2Int(num / 2) : 0; } + +// Compute a histogram within a warp. This uses an algorithm by @apgoucher +// that does the following: +// Create a ballot for each bit of the bin index (there +// are only log2(num_bins) of these) and then apply bitwise operations to get +// the indicator functions for the bins owned by this particular thread, and +// only popcount those. +static SmallVector computeWarpLevelHistogram( + Location loc, RankedTensorType srcType, SmallVector &srcValues, + int numBins, int numThreadPerWarp, Value threadId, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) { + assert(numBins % numThreadPerWarp == 0 && + "numBins must be divisible by numThreadPerWarp"); + Value zero = i32_val(0); + int numBits = log2Int(numBins); + int numBitsLaneId = log2Int(numThreadPerWarp); + unsigned numElementsPerThreads = triton::gpu::getTotalElemsPerThread(srcType); + unsigned numThreadWithUniqueData = + triton::gpu::getThreadsPerWarpWithUniqueData(srcType.getEncoding(), + srcType.getShape())[0]; + // The histogram is distributed across threads, each thread owns `numBins / + // numThreadPerWarp` bins. + SmallVector warpLevelHistogram(numBins / numThreadPerWarp, zero); + for (int i = 0; i < numElementsPerThreads; ++i) { + Value value = srcValues[i]; + SmallVector ballotBits; + for (int j = 0; j < numBits; ++j) { + Value bitSet = and_(value, i32_val(1 << j)); + Value cmp = icmp_ne(bitSet, zero); + Value bit = + targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), cmp); + ballotBits.push_back(bit); + } + uint64_t fullMaskValue = + numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF; + Value fullMask = int_val(numThreadPerWarp, fullMaskValue); + Value mask = fullMask; + // If not all threads have unique data, mask out the redundant ones. + if (numThreadWithUniqueData < numThreadPerWarp) { + mask = int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1); + } + for (int i = 0; i < numBitsLaneId; i++) { + Value updateMask = select(icmp_ne(and_(threadId, i32_val(1 << i)), zero), + int_val(numThreadPerWarp, 0), fullMask); + mask = + and_(mask, xor_(ballotBits[i + numBits - numBitsLaneId], updateMask)); + } + // at this point, 'mask' tells you which elements are in a bin owned by this + // thread. + for (int k = 0; k < warpLevelHistogram.size(); k++) { + Value binMask = mask; + for (int j = 0; j < numBits - numBitsLaneId; j++) { + Value updateMask = + int_val(numThreadPerWarp, ((k & (1 << j)) ? 0 : fullMaskValue)); + binMask = and_(binMask, xor_(ballotBits[j], updateMask)); + } + // at this point, 'bin_mask' tells you which elements are in the kth bin + // owned by this thread. + Value bitCount = rewriter.create( + loc, int_ty(numThreadPerWarp), binMask); + if (numThreadPerWarp > 32) + bitCount = trunc(i32_ty, bitCount); + warpLevelHistogram[k] = add(warpLevelHistogram[k], bitCount); + } + } + return warpLevelHistogram; +} + +static void atomicAdd(Value ptr, Value val, Location loc, + ConversionPatternRewriter &rewriter) { + rewriter.create(loc, LLVM::AtomicBinOp::add, ptr, val, + LLVM::AtomicOrdering::monotonic); +} + +static SmallVector computeCrossWarpHistogram( + Location loc, ConversionPatternRewriter &rewriter, RankedTensorType srcType, + Value baseSharedMemPtr, const SmallVector &warpLevelHistogram, + int numBins, int numThreadPerWarp, const SmallVector &indices, + Value threadId, int numWarps) { + SmallVector histogramValues; + unsigned numWarpsWithUniqueData = + mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(), + srcType.getShape())[0]; + Value laneId = and_(threadId, i32_val(numThreadPerWarp - 1)); + // Initialize the shared memory with zeros. + int64_t numElementPerThread = + ceil(numBins, numThreadPerWarp * numWarps); + for (int i = 0; i < numElementPerThread; ++i) { + Value offset = add(threadId, i32_val((i * numWarps * numThreadPerWarp))); + offset = urem(offset, i32_val(numBins)); + Value sharedMemPtr = + gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + store(i32_val(0), sharedMemPtr); + } + barrier(); + Block *afterAtomics = nullptr; + // If some warps have replicated data we need to skip those warps when + // accumulating. + if (numWarpsWithUniqueData < numWarps) { + Block *currentBlock = rewriter.getInsertionBlock(); + afterAtomics = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *atomicBlock = rewriter.createBlock(afterAtomics); + rewriter.setInsertionPointToEnd(currentBlock); + Value cond = + icmp_ult(threadId, i32_val(numWarpsWithUniqueData * numThreadPerWarp)); + rewriter.create(loc, cond, atomicBlock, afterAtomics); + rewriter.setInsertionPointToStart(atomicBlock); + } + // Apply atomic add to update the histogram in shared memory. + for (int i = 0; i < warpLevelHistogram.size(); ++i) { + Value warpLevelHistogramValue = warpLevelHistogram[i]; + Value offset = + add(mul(laneId, i32_val(warpLevelHistogram.size())), i32_val(i)); + Value sharedMemPtr = + gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + atomicAdd(sharedMemPtr, warpLevelHistogramValue, loc, rewriter); + } + if (afterAtomics) { + rewriter.create(loc, afterAtomics); + rewriter.setInsertionPointToStart(afterAtomics); + } + barrier(); + // load the histogram to register with the right layout. + for (Value index : indices) { + Value sharedMemPtr = + gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index); + Value val = load(i32_ty, sharedMemPtr); + histogramValues.push_back(val); + } + return histogramValues; +} + +namespace { +struct HistogramOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + explicit HistogramOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + SmallVector srcValues = unpackLLElements(loc, input, rewriter); + int numBins = op.getType().getDimSize(0); + auto mod = op->getParentOfType(); + int numThreadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + assert(numThreadsPerWarp == 32 || + numThreadsPerWarp == 64 && + "Only supports 32 or 64 threads per warp"); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + // Pad out the bins so that we have at least one bin per thread within a + // warp. + numBins = std::max(numBins, numThreadsPerWarp); + Value threadId = getThreadId(rewriter, loc); + auto srcType = op.getSrc().getType(); + // First compute a warp local histogram based on values owned by each warps. + SmallVector warpLevelHistogram = computeWarpLevelHistogram( + loc, srcType, srcValues, numBins, numThreadsPerWarp, threadId, rewriter, + targetInfo); + + // Then use atomic to update the histogram in shared memory. + // TODO: we could skip this for cases with num_warps=1 as long as we can + // generate the right layout. Currently the warp level histogram generates + // data in the default blocked layout. + Value baseSharedMemPtr = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto dstType = op.getType(); + Attribute dstEncoding = dstType.getEncoding(); + auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding, + dstType, true); + SmallVector innerDimIndices; + for (int i = 0; i < indices.size(); ++i) + innerDimIndices.push_back(indices[i][0]); + SmallVector histogramValue = computeCrossWarpHistogram( + loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins, + numThreadsPerWarp, innerDimIndices, threadId, numWarps); + + Value results = packLLElements(loc, typeConverter, histogramValue, rewriter, + op.getType()); + rewriter.replaceOp(op, results); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateHistogramOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp new file mode 100644 index 000000000..43120c791 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp @@ -0,0 +1,53 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +struct MakeRangeOpConversion + : public ConvertOpToLLVMPattern { + MakeRangeOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + RankedTensorType ty = op.getType(); + auto shape = ty.getShape(); + auto layout = ty.getEncoding(); + auto elemTy = ty.getElementType(); + assert(elemTy.isInteger(32)); + Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart()); + auto idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, true); + unsigned elems = idxs.size(); + SmallVector retVals(elems); + // TODO: slice layout has more elements than expected. + // Unexpected behavior for make range, but generally OK when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. + for (const auto &multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = add(multiDim.value()[0], start); + } + auto typeConverter = getTypeConverter(); + Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); + rewriter.replaceOp(op, result); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMakeRangeOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000..12ab6684c --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,145 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// blocked -> shared. +// Swizzling in shared memory to avoid bank conflict. Normally used for +// A/B operands of dots. +void lowerDistributedToShared(Location loc, Value src, Value dst, + Value adaptorSrc, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) { + auto srcTy = cast(src.getType()); + auto dstTy = cast(dst.getType()); + auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy); + auto srcLayout = srcTy.getEncoding(); + auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); + assert(srcTy.getShape().size() <= 2 || + (srcTy.getShape().size() == 3 && outOrd[2] == 0) && + "Unexpected rank of ConvertLayout(blocked->shared)"); + auto elemTy = typeConverter->convertType(srcTy.getElementType()); + + auto smemBase = smemObj.getBase(); + int32_t elemSize = elemTy.getIntOrFloatBitWidth(); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + auto dstStrides = smemObj.getStrides(); + auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); + storeDistributedToShared(src, inVals, dstStrides, dst, smemBase, elemTy, loc, + rewriter, targetInfo); +} + +struct LocalAllocOpConversion + : public ConvertOpToLLVMPattern { + LocalAllocOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + auto sharedLayout = + cast(resultTy.getEncoding()); + auto order = sharedLayout.getOrder(); + // Workaround for 3D tensors + // TODO: we need to modify the pipeline pass to give a proper shared + // encoding to 3D tensors + SmallVector newOrder; + if (resultTy.getShape().size() != order.size()) { + for (auto i = 0; i < order.size(); ++i) + newOrder.push_back(order[i] + 1); + newOrder.push_back(0); + } else { + newOrder = SmallVector(order.begin(), order.end()); + } + + auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); + auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); + auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, + newOrder, loc, rewriter); + // If there is an initial tensor, store it into the shared memory. + if (op.getSrc()) { + lowerDistributedToShared(loc, op.getSrc(), op.getResult(), + adaptor.getSrc(), smemObj, typeConverter, + rewriter, targetInfo); + } + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::LocalDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct LocalStoreOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + + LocalStoreOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value memDescVal = op.getDst(); + auto llvmElemTy = + getTypeConverter()->convertType(op.getDst().getType().getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), + adaptor.getSrc(), smemObj, getTypeConverter(), + rewriter, targetInfo); + rewriter.eraseOp(op); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp new file mode 100644 index 000000000..32c7835c2 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp @@ -0,0 +1,243 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace { + +// The input print op contains: +// - a "prefix" (string) specified by the user, and +// - one or more "operands" (tensors). +// +// For each operand, we print all of the values contained in this GPU thread, +// one per line, along with the index of the value in its tensor. +struct PrintOpConversion : public ConvertOpToLLVMPattern { + explicit PrintOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + auto getPid = [&](int axis) { + return targetInfo.programId(rewriter, loc, + op->getParentOfType(), axis); + }; + std::array pid = {getPid(0), getPid(1), getPid(2)}; + + // Simple printf of a string without any tensors. + if (op.getNumOperands() == 0) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(pid[0]) << ", " + << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" + << op.getPrefix(); + llPrintf(formatStr, {pid[0], pid[1], pid[2]}, rewriter); + rewriter.eraseOp(op); + return success(); + } + + for (size_t i = 0; i < op.getNumOperands(); i++) { + // Elements of the tensor that are resident in this GPU thread. + auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); + + // Get the indices of `elems` within the tensor. Note that if `elems` + // has an "interesting" layout, then these will not be in any + // particularly nice order. + + // Extract the shape of the tensor being printed and use it to figure + // out how many digits we need for each of the dimensions. + SmallVector dimWidths; + SmallVector> indices; + if (auto rankedTy = + dyn_cast(op.getOperand(i).getType())) { + indices = emitIndices(loc, rewriter, targetInfo, rankedTy.getEncoding(), + rankedTy, true); + for (int64_t dim : rankedTy.getShape()) { + if (dim > 0) { + dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); + } else { + dimWidths.push_back(0); + } + } + } else { + // We're printing a scalar. + assert(elems.size() == 1); + indices.push_back({}); + } + + if (!elems.empty()) { + printTensor(op.getPrefix(), /*operand=*/i, + /*numOperands=*/op.getNumOperands(), elems, pid, indices, + dimWidths, op.getHex(), rewriter); + } + } + rewriter.eraseOp(op); + return success(); + } + + void printTensor(StringRef prefixStr, size_t operand, size_t numOperands, + ArrayRef elems, std::array pid, + ArrayRef> indices, + ArrayRef dimWidths, bool hex, + ConversionPatternRewriter &rewriter) const { + assert(!elems.empty()); + assert(elems.size() == indices.size()); + assert(dimWidths.size() == indices.front().size()); + + size_t rank = dimWidths.size(); + + // Format is: + // pid (, , ) idx (, , ...) (operand ) + // where we leave off "(operand )" if there's only one operand. + // + // The Python wrapper munges `prefix` so that it prints nicely (e.g. starts + // with " " and ends with ": "). + + Value formatStrValue; + int formatStrByteCount = 0; + for (int i = 0; i < elems.size(); i++) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + + // nvptx printf can only accept 32 args; if we pass more than that, it + // will print garbage for the trailing args. + constexpr int kMaxPrintfOperands = 32; + SmallVector printfOperands; + + // TODO(jlebar): We really should pad the pid, but because the max pid is + // not known at compile-time, this would require nontrivial device-side + // work. + os << "pid ("; + for (int j = 0; j < pid.size(); j++) { + if (j != 0) { + os << ", "; + } + os << getFormatSubstr(pid[j]); + printfOperands.push_back(pid[j]); + } + os << ") "; + + // If `rank` is large enough, we could end up exceeding + // kMaxPrintfOperands. In that case, just truncate the index. + // (Subtract 2 because we're going to add two operands after the index.) + int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2; + + os << "idx ("; + const auto &index = indices[i]; + for (size_t dim = 0; dim < index.size(); dim++) { + if (dim != 0) { + os << ", "; + } + if (dim == maxAllowedRank) { + os << "... (truncated)"; + break; + } + os << getFormatSubstr(index[dim], /*hex=*/false, + /*width=*/dimWidths[dim]); + printfOperands.push_back(index[dim]); + } + os << ")" << prefixStr; + + if (numOperands > 1) { + os << "(operand " << operand << ") "; + } + + auto elem = elems[i]; + os << getFormatSubstr(elem, hex); + printfOperands.push_back(elem); + + // It's the same format string each iteration, but it's a lot easier if we + // construct the format string at the same time as we populate + // printfOperands. But we don't want to create BLOCK_SIZE duplicate + // strings, so we cache the Value. + if (i == 0) { + formatStrValue = + llPrintf(formatStr, printfOperands, rewriter, &formatStrByteCount); + } else { + targetInfo.printf(rewriter, formatStrValue, formatStrByteCount, + printfOperands); + } + } + } + + std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt) const { + Type type = value.getType(); + if (isa(type)) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = + "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } else if (hex) { + prefix += "0"; + prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isSignedInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "lli"; + else + return prefix + "i"; + } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "llu"; + else + return prefix + "u"; + } + assert(false && "not supported type"); + return ""; + } + + // Returns a Value for the format string, which you can reuse. Writes the byte + // count for the string to |formatStrByteCount| if not null. + Value llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter, + int *formatStrByteCount = nullptr) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); + if (formatStrByteCount) + *formatStrByteCount = msgNewline.size_in_bytes(); + return msgValue; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populatePrintOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp new file mode 100644 index 000000000..4d036c21a --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -0,0 +1,436 @@ +#include "ReduceScanCommon.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +namespace { +struct ReduceOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + ReduceOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ReduceOpHelper helper(op); + assert(helper.isSupportedLayout() && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + // Then reduce across threads within a warp. + reduceWithinWarps(helper, accs, rewriter); + + if (helper.isWarpSynchronous()) { + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + + // Compute a shared memory base per operand. + auto smemShape = helper.getScratchConfig(); + + SmallVector smemBases = + getSmemBases(op, product(smemShape), rewriter); + + storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); + + sync(rewriter, loc, op); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + accumulatePartialReductions(helper, smemBases, rewriter); + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier in case the layouts are accepted. + sync(rewriter, loc, op); + + // set output values + loadReductionAndPackResult(helper, smemShape, smemBases, rewriter); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; + + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + SmallVector &acc, ValueRange cur, bool isFirst) const { + if (isFirst) { + acc = SmallVector(cur.begin(), cur.end()); + return; + } + + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(combineOp, &parent.front()); + auto &newReduce = parent.front(); + auto returnOp = dyn_cast(newReduce.getTerminator()); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), + combineArgs); + + auto results = returnOp.getResult(); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; + } + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + } + + SmallVector> + unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; + } + + void sync(ConversionPatternRewriter &rewriter, Location loc, + triton::ReduceOp op) const { + barrier(); + } + + // Reduce along op axis for elements that are in the same thread. The + // accumulated value is stored in accs. + void reduceWithinThreads( + ReduceOpHelper &helper, SmallVector> &srcValues, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + RankedTensorType operandType = op.getInputTypes()[0]; + // Assumes offsets don't actually depend on type + SmallVector> offsets = + emitOffsetForLayout(helper.getSrcLayout(), operandType); + + // Thread X might hold the same input value in two registers. Get the + // indices in `offsets` that hold unique values, and only accumualte over + // those. + llvm::MapVector, int> uniqueOffsets; + for (int i = 0; i < offsets.size(); ++i) { + uniqueOffsets.insert({offsets[i], i}); + } + + unsigned srcElems = getTotalElemsPerThread(operandType); + auto *combineOp = &op.getCombineOp(); + auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo, + helper.getSrcLayout(), operandType, true); + // reduce within threads + for (const auto &[_, i] : uniqueOffsets) { + SmallVector key = offsets[i]; + key[op.getAxis()] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); + if (isFirst) + indices[key] = srcIndices[i]; + } + } + + // Apply warp reduction across the given number of contiguous lanes using op + // region and the accumulator values as source. + void warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, unsigned interleave) const { + auto success = + targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce); + if (success) + return; + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { + SmallVector shfl(acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); + } + accumulate(rewriter, op.getCombineOp(), acc, shfl, false); + } + } + + // Reduce across threads within each warp. + void + reduceWithinWarps(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + unsigned threadOffsetOnReductionAxis = + helper.getThreadOffsetOnReductionAxis(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = accs[key]; + warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps, + threadOffsetOnReductionAxis); + } + } + + // Pack the accumulator values and replace the reduce op with the result. + void packResults(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + unsigned axis = op.getAxis(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + SmallVector> resultOffset = + emitOffsetForLayout(resultLayout, resultTy); + SmallVector resultVals; + for (int j = 0; j < resultElems; j++) { + auto key = resultOffset[j]; + key.insert(key.begin() + axis, 0); + resultVals.push_back(accs[key][i]); + } + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else + results[i] = accs.begin()->second[i]; + } + rewriter.replaceOp(op, results); + } + + SmallVector + getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc, + ConversionPatternRewriter &rewriter) const { + auto srcLayout = helper.getSrcLayout(); + auto srcShape = helper.getSrcShape(); + auto order = triton::gpu::getWarpOrder(srcLayout); + SmallVector multiDimWarpId; + + // 2x2 warps with slice dim = 0, warpId = 2 ends up writing at the same + // address as warpId = 0 since the warpsPerCTA is [1, 2], need to figure out + // a way to properly delinearize warpId in the slice case + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout); + auto parentOrder = triton::gpu::getWarpOrder(parentLayout); + multiDimWarpId = + delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder); + multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim()); + } else { + SmallVector warpsPerCTA = + triton::gpu::getWarpsPerCTA(srcLayout); + warpsPerCTA[helper.getAxis()] = triton::gpu::getWarpsPerCTAWithUniqueData( + srcLayout, srcShape)[helper.getAxis()]; + multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); + } + return multiDimWarpId; + } + + void storeWarpReduceToSharedMemory( + ReduceOpHelper &helper, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + Value threadId = getThreadId(rewriter, loc); + auto srcLayout = helper.getSrcLayout(); + Value warpSize = i32_val(triton::gpu::getWarpSize(srcLayout)); + Value warpId = udiv(threadId, warpSize); + Value laneId = urem(threadId, warpSize); + auto srcShape = helper.getSrcShape(); + unsigned axis = op.getAxis(); + auto smemShape = helper.getScratchConfig(); + + auto threadsPerWarp = + triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); + auto order = getOrder(srcLayout); + SmallVector multiDimLaneId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + Value laneIdAxis = multiDimLaneId[axis]; + Value zero = i32_val(0); + Value laneZero = icmp_eq(laneIdAxis, zero); + + SmallVector multiDimWarpId = + getMultiDimWarpId(helper, warpId, loc, rewriter); + Value warpIdAxis = multiDimWarpId[axis]; + + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = it.second; + + SmallVector writeIdx = indices[key]; + writeIdx[axis] = warpIdAxis; + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShape, smemOrder); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], writeOffset); + targetInfo.storeShared(rewriter, loc, writePtr, acc[i], laneZero); + } + } + } + + // Load the reduction of each warp and accumulate them to a final value and + // store back to shared memory. + void accumulatePartialReductions(ReduceOpHelper &helper, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + auto srcLayout = helper.getSrcLayout(); + auto smemShape = helper.getScratchConfig(); + unsigned elems = product(smemShape); + unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); + Location loc = op.getLoc(); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(srcLayout)); + Value laneId = urem(threadId, warpSize); + Value zero = i32_val(0); + + auto mod = op.getOperation()->getParentOfType(); + unsigned numThreads = + product(triton::gpu::getWarpsPerCTA(srcLayout)) * + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); + Value readOffset = threadId; + for (unsigned round = 0; round < elemsPerThread; ++round) { + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], readOffset); + acc[i] = targetInfo.loadShared(rewriter, loc, getTypeConverter(), + readPtr, elemTy, threadIsNeeded); + } + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); + // only the first thread in each sizeInterWarps is writing + Value writeOffset = readOffset; + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + writePtrs[i] = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], writeOffset); + } + + Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); + Value laneIdModSizeInterWarpsIsZero = + icmp_eq(laneIdModSizeInterWarps, zero); + Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + targetInfo.storeShared(rewriter, loc, writePtrs[i], acc[i], pred); + } + + if (round != elemsPerThread - 1) { + readOffset = add(readOffset, i32_val(numThreads)); + } + } + } + + // Load the final reduction from shared memory and replace the reduce result + // with it. + void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector smemShape, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto srcLayout = helper.getSrcLayout(); + auto axis = op.getAxis(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + // nd-tensor where n >= 1 + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, targetInfo, + resultLayout, resultTy, true); + auto resultShape = resultTy.getShape(); + auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); + for (size_t resultIdx = 0, resultDim = resultShape.size(); + resultIdx < resultDim; ++resultIdx) { + auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1; + if (resultCTATile[resultIdx] > smemShape[smemIdx] || + resultShape[resultIdx] > smemShape[smemIdx]) { + // When srcShape smaller then src sizePerThread, only srcShape + // elements is accumulated in smem. Modulo smemShape effectively + // replicates srcShape elements to src sizePerThread. + readIdx[smemIdx] = + urem(readIdx[smemIdx], i32_val(smemShape[smemIdx])); + } + } + Value readOffset = + linearize(rewriter, loc, readIdx, smemShape, smemOrder); + Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], readOffset); + resultVals[j] = load(elemTy, readPtr); + } + + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(elemTy, smemBases[i]); + } + } + rewriter.replaceOp(op, results); + } +}; +} // namespace + +void mlir::triton::populateReduceOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h new file mode 100644 index 000000000..5604c8b0b --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -0,0 +1,83 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H + +// TODO: refactor so that it doesn't fail if Allocation.h +// is included after utility.h (due to conflict in `store` macro +// and +#include "triton/Analysis/Allocation.h" + +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +// +#include "mlir/IR/TypeUtilities.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include +#include + +#define DEBUG_TYPE "ttgpu_to_llvm" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::SharedMemoryObject; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::CTALayoutAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +namespace mlir::triton { +class ReduceOp; +class ScanOp; +} // namespace mlir::triton + +template +class ConvertTritonGPUReduceScanToLLVMPattern + : public ConvertOpToLLVMPattern { +public: + // Make sure the class is only instantiated with Reduce and Scan + static_assert(std::is_same_v || + std::is_same_v); + + using ConvertOpToLLVMPattern::getTypeConverter; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + // Return the pointee type of the shared memory pointer for operand i. + Type getElementType(SourceOp op, int i) const { + auto ty = op.getInputTypes()[i].getElementType(); + return getTypeConverter()->convertType(ty); + } + + // Helper to compute the smem bases in both reductions and scans + SmallVector getSmemBases(SourceOp op, unsigned elems, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + // indices will store the index of the op operands in descending order + // of their bitwidths + std::vector indices(op.getNumOperands()); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) { + return op.getElementTypes()[i].getIntOrFloatBitWidth() > + op.getElementTypes()[j].getIntOrFloatBitWidth(); + }); + // Assign base index to each operand in their order in indices + std::map indexToBase; + indexToBase[indices[0]] = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + indexToBase[indices[i]] = gep( + ptr_ty(rewriter.getContext(), 3), getElementType(op, indices[i - 1]), + indexToBase[indices[i - 1]], i32_val(elems)); + } + // smemBases[k] is the base pointer for the k-th operand + SmallVector smemBases(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemBases[i] = indexToBase[i]; + } + return smemBases; + } +}; + +#endif diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000..972fc5592 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,38 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetProgramIdOpConversion + : public ConvertOpToLLVMPattern { + explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value programId = targetInfo.programId(rewriter, op->getLoc(), + op->getParentOfType(), + op.getAxisAsInt()); + rewriter.replaceOp(op, programId); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp new file mode 100644 index 000000000..675bf5a34 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -0,0 +1,589 @@ +#include + +#include "ReduceScanCommon.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +// apply combine region to acc and cur and accumulate it into acc +// TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce. +// Deduplicate +static SmallVector accumulate(ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur) { + // Allows for passing an unitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(combineOp, &parent.front()); + auto &newScan = parent.front(); + auto returnOp = dyn_cast(newScan.getTerminator()); + + SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(), + combineArgs); + SmallVector results; + llvm::transform(returnOp.getResult(), std::back_inserter(results), + [&](Value res) { return rewriter.getRemappedValue(res); }); + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +// Scan a contiguous elements within a thread and update `srcValues` in place. +static void +scanThreadContiguousElements(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper) { + // Depending on layout contiguous elements along axis dim may not be + // contiguous in srcValues. Keep track of what elements belong to the same + // chunk of contiguous elements. + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned numChunks = srcValues.size() / scanElementsPerThreads; + unsigned stride = helper.getAxisElementStride(); + SmallVector> accs(numChunks); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + // Change this into emitOffsetForLayout? + unsigned accIndex = (srcIndex % stride) + + ((srcIndex / stride) / scanElementsPerThreads) * stride; + + accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex], + srcValues[srcIndex]); + srcValues[srcIndex] = accs[accIndex]; + } +} + +// Apply a scan across threads of the warp for the last element of each +// contiguous group of elements. +static void warpScan(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value laneIdAxis) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Reduce within warps. + SmallVector acc = srcValues[srcIndex]; + for (unsigned i = 1; i <= scanDim / 2; i <<= 1) { + SmallVector shfl(acc.size()); + for (unsigned j = 0; j < acc.size(); ++j) { + shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); + } + SmallVector tempAcc = + accumulate(rewriter, helper.getCombineOp(), shfl, acc); + Value mask = icmp_slt(laneIdAxis, i32_val(i)); + for (unsigned j = 0; j < acc.size(); ++j) { + acc[j] = select(mask, acc[j], tempAcc[j]); + } + } + srcValues[srcIndex] = acc; + } +} + +// For each set of contiguous elements within a thread we store the partial +// reduction into shared memory. Each parallel scan and each warp will store its +// own partial reductions. The shared memory is organized as follow: +// ----------------------------------------------------------------- +// chunk 0: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +// chunk 1: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +static void storeWarpAccumulator(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId, SmallVector smemBases, + SmallVector smemTypes, + Value parallelLaneId, + const TargetInfoBase &targetInfo) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned chunkId = 0; + unsigned elementStride = helper.getAxisElementStride(); + + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + auto lastElement = srcValues[srcIndex]; + Value mask = icmp_eq(laneId, i32_val(scanDim - 1)); + Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane))); + index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps)); + for (unsigned i = 0; i < lastElement.size(); ++i) { + Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), smemTypes[i], + smemBases[i], index); + targetInfo.storeShared(rewriter, loc, writePtr, lastElement[i], mask); + } + chunkId++; + } +} + +// Read the partial reductions from shared memory from each chunk of contiguous +// elements for each warp and parallel scan. Then combine the partial reduction +// with the right elements. Within a given contiguous element chunk we update +// all the elements by accumulating the value from the last element of the +// reduced value from the previous lane. +static void AddPartialReduce(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, + SmallVector smemBases, + SmallVector smemTypes, Value warpId, + Value laneIdAxis, Value parallelLaneId) { + Location loc = helper.getLoc(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); + Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); + Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + struct Accumulator { + SmallVector acc; + SmallVector maskedAcc; + }; + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Accumulate the partial reduction from shared memory. Decide which + // accumulator to combine based on whether the elements belong to the same + // dimension along axis. + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + Accumulator &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + for (unsigned i = 0; i < axisNumWarps; ++i) { + Value index = add(parallelLaneId, i32_val(numParallelLane * + (i + chunkId * axisNumWarps))); + SmallVector partialReduce(helper.getNumOperands()); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + auto elemTy = smemTypes[j]; + Value ptr = + gep(ptr_ty(rewriter.getContext(), 3), elemTy, smemBases[j], index); + partialReduce[j] = load(elemTy, ptr); + } + + if (accumulator.acc.size() == 0) { + accumulator.acc = partialReduce; + accumulator.maskedAcc = partialReduce; + continue; + } + accumulator.acc = accumulate(rewriter, helper.getCombineOp(), + accumulator.acc, partialReduce); + Value mask = icmp_slt(warpId, i32_val(i + 1)); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + accumulator.maskedAcc[j] = + select(mask, accumulator.maskedAcc[j], accumulator.acc[j]); + } + } + auto temp = accumulate(rewriter, helper.getCombineOp(), + accumulator.maskedAcc, srcValues[srcIndex]); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + auto val = srcValues[srcIndex]; + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + temp[i] = select(maskFirstWarp, val[i], temp[i]); + } + } + srcValues[srcIndex] = temp; + // Update the rest of the contiguous elements. + SmallVector lastElement(helper.getNumOperands()); + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); + lastElement[i] = select(maskFirstLane, accumulator.maskedAcc[i], elem); + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + laneValue[j] = + select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], laneValue[j]); + } + } + srcValues[srcIndex - i * elementStride] = laneValue; + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + accumulator.maskedAcc = accumulator.acc; + chunkId++; + } +} + +static void AddPartialReduceOneWarp(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value warpId, + Value laneIdAxis, Value laneIdLast) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); + Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); + Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector> accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + auto &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + if (axisBlockId == 0) // First chunk and first block + accumulator = srcValues[srcIndex]; + else + srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(), + accumulator, srcValues[srcIndex]); + // Update the rest of the contiguous elements. + auto lastElement = srcValues[srcIndex]; + if (scanDim > 1) { + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + lastElement[i] = targetInfo.shuffleUp( + rewriter, loc, srcValues[srcIndex][i], threadStride); + lastElement[i] = select(maskFirstLane, accumulator[i], lastElement[i]); + if (numScanBlocks > 1) + // Update accumulator with the value from the last lane. + accumulator[i] = targetInfo.shuffleIdx( + rewriter, loc, srcValues[srcIndex][i], laneIdLast); + } + } else if (numScanBlocks > 1) { + accumulator = srcValues[srcIndex]; + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + if (axisBlockId == 0) { + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + // For the first warp and first chunk we don't have anything to + // accumulate. + laneValue[j] = + select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], laneValue[j]); + } + } + srcValues[srcIndex - i * elementStride] = laneValue; + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + chunkId++; + } +} + +namespace { +struct ScanOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + using ConvertTritonGPUReduceScanToLLVMPattern< + triton::ScanOp>::ConvertTritonGPUReduceScanToLLVMPattern; + explicit ScanOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (succeeded(emitFastScan(op, adaptor, rewriter))) + return success(); + return failure(); + } + +private: + const TargetInfoBase &targetInfo; + SmallVector getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const; + SmallVector getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const; + std::tuple + getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const; + LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; +}; + +SmallVector +ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + return delinearize(rewriter, loc, laneId, threadsPerWarp, order); +} + +SmallVector +ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); + return delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); +} + +// Break up the threadId into lane and warp id along the scan dimension and +// compute a flat id for the parallel dimensions. +std::tuple +ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); + SmallVector multiDimLaneId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); + + Value laneIdAxis = multiDimLaneId[axis]; + Value warpIdAxis = multiDimWarpId[axis]; + + multiDimLaneId[axis] = i32_val(0); + threadsPerWarp[axis] = 1; + Value laneIdParallel = + linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order); + multiDimWarpId[axis] = i32_val(0); + warpsPerCTA[axis] = 1; + Value warpIdParallel = + linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, warpOrder); + Value flatIdParallel = + add(laneIdParallel, + mul(warpIdParallel, i32_val(helper.getNonAxisNumThreadsPerWarp()))); + return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel); +} + +SmallVector> +unpackInputs(Location loc, triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter) { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; +} + +// Flip the srcValues. Both reverses the chunks and reverses the lanes. +// Lane reversal is done with a butterfly shuffle flip (divide and flip). +SmallVector> +flipSrcValues(Location loc, triton::ScanOp op, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + SmallVector> srcValues, int iWarpSize) { + SmallVector> values(srcValues.size()); + for (int i = 0; i < srcValues.size(); ++i) { + int revIndex = srcValues.size() - i - 1; + for (unsigned j = 0; j < op.getNumOperands(); ++j) { + for (unsigned k = iWarpSize / 2; k >= 1; k = k / 2) { + srcValues[revIndex][j] = + targetInfo.shuffleXor(rewriter, loc, srcValues[revIndex][j], k); + } + values[i].push_back(srcValues[revIndex][j]); + } + } + return values; +} + +// Lowering using warp shuffle operations to do warp level scan. +LogicalResult +ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ScanLoweringHelper helper(op); + auto loc = helper.getLoc(); + if (!helper.isSupported()) + return failure(); + + Value threadId = getThreadId(rewriter, loc); + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = i32_val(iWarpSize); + Value warpId = udiv(threadId, warpSize); + Value laneId = urem(threadId, warpSize); + + auto [laneIdAxis, warpIdAxis, flatIdParallel] = + getDelinearizedIds(rewriter, helper, laneId, warpId); + auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps)); + auto srcValues = + unpackInputs(loc, op, adaptor, rewriter, *getTypeConverter()); + + // For the reverse option we apply flip(scan(flip()) in + // order to avoid having a separate code path in the reverse direction. + // We do this by 1) reversing chunks, 2) reversing lanes, 3) reversing + // warp ids and then undoing this below. + // (Note: Tried pretty hard to get shflDownSync to work but I ended up + // having to add a lot of the complex cross warp code (if rev switch + // first/last etc). Reverse first seems more maintainable.) + if (op.getReverse()) { + warpIdAxis = sub(i32_val(axisNumWarps - 1), warpIdAxis); + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + // Scan contiguous elements in a thread and update `srcValues`. + scanThreadContiguousElements(srcValues, rewriter, helper); + // Apply warp level scan to the last element of each chunk of contiguous + // elements. + warpScan(srcValues, rewriter, targetInfo, helper, laneIdAxis); + + if (axisNumWarps > 1) { + // Slow path for the case where there are multiple warps with unique data on + // the axis. + auto elems = helper.getScratchSizeInElems(); + SmallVector smemBases = getSmemBases(op, elems, rewriter); + SmallVector smemTypes(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemTypes[i] = getElementType(op, i); + } + + // Store the partial reducing for each warp into shared memory. + storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, + smemBases, smemTypes, flatIdParallel, targetInfo); + barrier(); + // Read back the partial reduction of each warp and accumulate them based on + // warpId. Then update each chunk of contiguous elements by adding the + // accumulated value from the previous lane. + AddPartialReduce(srcValues, rewriter, targetInfo, helper, smemBases, + smemTypes, warpIdAxis, laneIdAxis, flatIdParallel); + } else if (srcValues.size() > 1) { + // Fast path for the case where there is only one warp with unique data on + // the axis. + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId); + multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding()); + auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, + triton::gpu::getOrder(helper.getEncoding())); + AddPartialReduceOneWarp(srcValues, rewriter, targetInfo, helper, warpIdAxis, + laneIdAxis, laneIdLast); + } // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do. + + auto transpose = [](const SmallVector> &v) { + assert(v.size() > 0 && v[0].size() > 0); + auto ret = SmallVector>(v[0].size(), + SmallVector(v.size())); + for (int i = 0; i < v.size(); ++i) { + for (int j = 0; j < v[0].size(); ++j) { + ret[j][i] = v[i][j]; + } + } + return ret; + }; + + SmallVector results(op.getNumOperands()); + if (op.getReverse()) { + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + auto valuesTransposed = transpose(srcValues); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto resultTy = dyn_cast(op.getResult()[i].getType()); + results[i] = packLLElements(loc, getTypeConverter(), valuesTransposed[i], + rewriter, resultTy); + } + rewriter.replaceOp(op, results); + return success(); +} +} // namespace + +void mlir::triton::populateScanOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000..53705c3b7 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -0,0 +1,133 @@ +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + addConversion([&](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type); + }); + addConversion([&](MemDescType type) -> std::optional { + return convertMemDescType(type); + }); + addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional { + return convertAsyncToken(type); + }); + addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); + addConversion([&](mlir::Float8E5M2Type type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); + addConversion([&](mlir::Float8E5M2FNUZType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); +} + +Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (isa(pointeeType)) { + auto rankedTensorType = cast(pointeeType); + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto eleType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + SmallVector types; + // offsets + for (size_t i = 0; i < shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 32)); + // shapes, strides + for (size_t i = 0; i < 2 * shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 64)); + + types.push_back(LLVM::LLVMPointerType::get(ctx, type.getAddressSpace())); + + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); +} + +Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( + TensorOrMemDesc type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + Type elemTy = convertType(type.getElementType()); + auto dotOpLayout = mlir::dyn_cast(layout); + if (!dotOpLayout) + return elemTy; + auto mmaParent = + mlir::dyn_cast(dotOpLayout.getParent()); + if (!mmaParent || mmaParent.isHopper()) + return elemTy; + int bitwidth = elemTy.getIntOrFloatBitWidth(); + assert(bitwidth <= 32); + return IntegerType::get(ctx, 32); +} + +Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + SmallVector shape(type.getShape().begin(), type.getShape().end()); + Type eltType = getElementTypeForStruct(cast(type)); + + if (auto shared_layout = mlir::dyn_cast(layout)) { + SmallVector types; + // base ptr + auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); + types.push_back(ptrType); + // shape dims + auto rank = type.getRank(); + // offsets + strides + for (auto i = 0; i < rank * 2; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + + unsigned numElementsPerThread = getTotalElemsPerThread(type); + SmallVector types(numElementsPerThread, eltType); + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertMemDescType(MemDescType type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + SmallVector shape(type.getShape().begin(), type.getShape().end()); + SmallVector types; + // base ptr + auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); + types.push_back(ptrType); + // shape dims + auto rank = type.getShape().size(); + // offsets + strides + for (auto i = 0; i < rank * 2; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertAsyncToken( + triton::gpu::AsyncTokenType type) { + return IntegerType::get(type.getContext(), 32); +} diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/Utility.cpp new file mode 100644 index 000000000..a80158a46 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -0,0 +1,619 @@ +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "llvm/ADT/STLExtras.h" + +namespace SharedToDotOperandMMAv1 { +using CoordTy = SmallVector; +using ValueTable = std::map, std::pair>; + +static SmallVector +getMNCoords(Value thread, Location loc, ConversionPatternRewriter &rewriter, + ArrayRef wpt, const NvidiaMmaEncodingAttr &mmaLayout, + ArrayRef shape, bool isARow, bool isBRow, bool isAVec4, + bool isBVec4) { + static constexpr std::array fpw{{2, 2, 1}}; + + auto *ctx = thread.getContext(); + Value _1 = i32_val(1); + Value _2 = i32_val(2); + Value _4 = i32_val(4); + Value _16 = i32_val(16); + Value _32 = i32_val(32); + Value _fpw0 = i32_val(fpw[0]); + Value _fpw1 = i32_val(fpw[1]); + + // A info + auto aRep = mmaLayout.getMMAv1Rep(0); + auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); + // B info + auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); + auto bRep = mmaLayout.getMMAv1Rep(1); + + SmallVector rep({aRep[0], bRep[1]}); + SmallVector spw({aSpw[0], bSpw[1]}); + SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); + + Value lane = urem(thread, _32); + Value warp = udiv(thread, _32); + + Value warp0 = urem(warp, i32_val(wpt[0])); + Value warp12 = udiv(warp, i32_val(wpt[0])); + Value warp1 = urem(warp12, i32_val(wpt[1])); + + // warp offset + Value offWarpM = mul(warp0, i32_val(spw[0])); + Value offWarpN = mul(warp1, i32_val(spw[1])); + // quad offset + Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0); + Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1); + // pair offset + Value offPairM = udiv(urem(lane, _16), _4); + offPairM = urem(offPairM, _fpw0); + offPairM = mul(offPairM, _4); + Value offPairN = udiv(urem(lane, _16), _4); + offPairN = udiv(offPairN, _fpw0); + offPairN = urem(offPairN, _fpw1); + offPairN = mul(offPairN, _4); + + // sclare + offPairM = mul(offPairM, i32_val(rep[0] / 2)); + offQuadM = mul(offQuadM, i32_val(rep[0] / 2)); + offPairN = mul(offPairN, i32_val(rep[1] / 2)); + offQuadN = mul(offQuadN, i32_val(rep[1] / 2)); + + // quad pair offset + Value offLaneM = add(offPairM, offQuadM); + Value offLaneN = add(offPairN, offQuadN); + // a, b offset + Value offsetAM = add(offWarpM, offLaneM); + Value offsetBN = add(offWarpN, offLaneN); + // m indices + Value offsetCM = add(and_(lane, _1), offsetAM); + SmallVector idxM; + for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0]) + for (unsigned mm = 0; mm < rep[0]; ++mm) + idxM.push_back(add(offsetCM, i32_val(m + mm * 2))); + + // n indices + Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN))); + SmallVector idxN; + for (int n = 0; n < shape[1]; n += shapePerCTA[1]) { + for (int nn = 0; nn < rep[1]; ++nn) { + idxN.push_back(add( + offsetCN, i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]))); + idxN.push_back( + add(offsetCN, + i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1))); + } + } + + SmallVector> axes({idxM, idxN}); + + // product the axis M and axis N to get coords, ported from + // generator::init_idx method from triton2.0 + + // TODO[Superjomn]: check the order. + SmallVector coords; + for (Value x1 : axes[1]) { // N + for (Value x0 : axes[0]) { // M + SmallVector idx(2); + idx[0] = x0; // M + idx[1] = x1; // N + coords.push_back(std::move(idx)); + } + } + + return coords; // {M,N} in row-major +} +} // namespace SharedToDotOperandMMAv1 +namespace mlir { + +namespace triton::gpu { +Type getFunctionType(Type resultType, ValueRange operands) { + SmallVector operandTypes(operands.getTypes()); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); +} + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, + Operation *op, StringRef funcName, + Type funcType, + StringRef libname /*= ""*/, + StringRef libpath /*= ""*/) { + using LLVM::LLVMFuncOp; + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return cast(*funcOp); + + Operation *parent = op; + if (!isa(op)) + parent = op->getParentOfType(); + OpBuilder b(parent); + auto ret = b.create(op->getLoc(), funcName, funcType); + ret.getOperation()->setAttr("libname", + StringAttr::get(op->getContext(), libname)); + ret.getOperation()->setAttr("libpath", + StringAttr::get(op->getContext(), libpath)); + return ret; +} +} // namespace triton::gpu + +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices) { + assert(layout.getNumInDims() == indices.size()); + for (auto [inDimName, idx] : indices) { + assert(layout.hasInDim(inDimName) && "Invalid inDimName"); + } + + // This function can emit a lot of MLIR code, which ultimately makes + // compilation slow. (We think this shouldn't be the case -- it's not *that* + // much code -- but we're not clear on how to fix the slowness, which happens + // in the bowels of MLIR.) + // + // As a result we go through some contortions to avoid emitting code where + // possible. + + // Manually constant-fold the layout where possible. + SmallVector> constantIns; + for (auto [inDimName, idx] : indices) { + if (auto constant = dyn_cast(idx.getDefiningOp())) { + constantIns.push_back( + {inDimName, cast(constant.getValue()).getInt()}); + } else { + constantIns.push_back({inDimName, 0}); + } + } + SmallVector constantComponent = + llvm::to_vector(llvm::make_second_range(layout.apply(constantIns))); + + Value zero = i32_val(0); + SmallVector> outIndices; + for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) { + if (constantComponent[i] == 0) + outIndices.push_back({outDimName, zero}); + else + outIndices.push_back({outDimName, i32_val(constantComponent[i])}); + } + + for (auto [inDimName, idx] : indices) { + if (isa(idx.getDefiningOp())) { + continue; + } + + int nBits = layout.getInDimSizeLog2(inDimName); + for (int i = 0; i < nBits; i++) { + Value bit = and_(idx, i32_val(1 << i)); + Value bit_is_zero = icmp_eq(bit, zero); + for (auto &[outDimName, outIdx] : outIndices) { + int32_t basis = layout.getBasis(inDimName, i, outDimName); + if (basis == 0) + continue; + outIdx = xor_(outIdx, select(bit_is_zero, zero, i32_val(basis))); + } + } + } + + return outIndices; +} + +std::optional>> +emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { + MLIRContext *ctx = rewriter.getContext(); + auto shape = type.getShape(); + + std::optional ll = triton::gpu::toLinearLayout(shape, layout); + if (!ll.has_value()) { + return std::nullopt; + } + + // TODO(jlebar): We could add strong typing if we wanted; for now this is + // "stringly typed". + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(ll->getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + unsigned rank = shape.size(); + SmallVector> ret; + for (unsigned reg = 0; reg < ll->getInDimSize(str_attr("register")); reg++) { + auto idxs = applyLinearLayout(loc, rewriter, *ll, + {{kRegister, i32_val(reg)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}}); + assert(idxs.size() == rank); + for (unsigned k = 0; k < rank; ++k) { + assert(idxs[k].first == str_attr("dim" + std::to_string(k))); + } + ret.push_back(llvm::to_vector(llvm::make_second_range(idxs))); + } + + return ret; +} + +namespace LLVM { +using namespace mlir::triton; +using mlir::triton::gpu::getOrder; +using mlir::triton::gpu::getSizePerThread; + +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) { + auto i32ty = rewriter.getIntegerType(32); + return rewriter.create(loc, i32ty, + IntegerAttr::get(i32ty, v)); +} + +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v) { + auto i64ty = rewriter.getIntegerType(64); + return rewriter.create(loc, i64ty, + IntegerAttr::get(i64ty, v)); +} + +Value createConstantF16(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f16Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF16FloatAttr(v)); +} + +Value createConstantF32(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f32Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF32FloatAttr(v)); +} + +Value createConstantF64(Location loc, OpBuilder &rewriter, double v) { + auto type = type::f64Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF64FloatAttr(v)); +} + +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type) { + if (!isa(type)) { + llvm::report_fatal_error("Creating NaN constant for non-float type!"); + } + return rewriter.create( + loc, type, APFloat::getNaN(cast(type).getFloatSemantics())); +} + +// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + const TypeConverter *converter, int64_t value) { + Type ty = converter->convertType(builder.getIndexType()); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value) { + Type ty = builder.getIntegerType(width); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +SharedMemoryObject +getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, + ConversionPatternRewriter &rewriter) { + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector elems(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + elems[i] = extract_val(type, llvmStruct, i); + } + + auto rank = (elems.size() - 1) / 2; + return {/*base=*/elems[0], + /*baseElemType=*/elemTy, + /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, + /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; +} + +SmallVector getStridesFromShapeAndOrder(ArrayRef shape, + ArrayRef order, + Location loc, + RewriterBase &rewriter) { + auto rank = shape.size(); + SmallVector strides(rank); + int64_t stride = 1; + for (auto idx : order) { + strides[idx] = i32_val(stride); + stride *= shape[idx]; + } + return strides; +} + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + SmallVector reorderedMultiDim(rank); + if (auto constantOp = linear.getDefiningOp()) { + unsigned intVal = mlir::cast(constantOp.getValue()) + .getValue() + .getSExtValue(); + reorderedMultiDim = delinearize(rewriter, loc, intVal, reordered); + } else { + reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); + } + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + unsigned remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + unsigned dimSize = en.value(); + multiDim[en.index()] = i32_val(remained % dimSize); + remained = remained / dimSize; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + Value remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + Value dimSize = i32_val(en.value()); + multiDim[en.index()] = urem(remained, dimSize); + remained = udiv(remained, dimSize); + } + return multiDim; +} + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape, + ArrayRef order) { + return linearize(rewriter, loc, applyPermutation(multiDim, order), + applyPermutation(shape, order)); +} + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = i32_val(0); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = i32_val(dimShape); + linear = add(mul(linear, dimSize), dim); + } + } + return linear; +} + +Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, + StringRef key, StringRef content) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto ctx = moduleOp.getContext(); + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (key + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<64> contentStr(content); + size_t contentSize = contentStr.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize); + + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + UnknownLoc::get(ctx), globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(contentStr)); + } + + Value zero = i32_val(0); + Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); + Value globalPtr = rewriter.create( + UnknownLoc::get(ctx), globalPtrType, global.getSymName()); + Value stringStart = + gep(ptr_ty(ctx), i8_ty, globalPtr, SmallVector({zero})); + return stringStart; +} + +SmallVector getMultiDimOffset(Attribute layout, Location loc, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + unsigned elemId, RankedTensorType type, + ArrayRef multiDimCTAInRepId, + ArrayRef shapePerCTATile) { + auto shape = type.getShape(); + unsigned rank = shape.size(); + if (auto blockedLayout = dyn_cast(layout)) { + auto multiDimOffsetFirstElem = emitBaseIndexForLayout( + loc, rewriter, targetInfo, blockedLayout, type, false); + SmallVector multiDimOffset(rank); + SmallVector multiDimElemId = getMultiDimIndex( + elemId, getSizePerThread(layout), getOrder(layout)); + for (unsigned d = 0; d < rank; ++d) { + multiDimOffset[d] = + add(multiDimOffsetFirstElem[d], + i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] + + multiDimElemId[d])); + } + return multiDimOffset; + } + if (auto sliceLayout = mlir::dyn_cast(layout)) { + unsigned dim = sliceLayout.getDim(); + auto parentEncoding = sliceLayout.getParent(); + auto parentSizePerThread = getSizePerThread(parentEncoding); + auto parentShape = sliceLayout.paddedShape(shape); + auto parentTy = RankedTensorType::get(parentShape, type.getElementType(), + parentEncoding); + auto offsets = emitOffsetForLayout(layout, type); + auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy); + SmallVector idxs; + for (SmallVector off : offsets) { + off.insert(off.begin() + dim, 0); + auto it = std::find(parentOffset.begin(), parentOffset.end(), off); + idxs.push_back(std::distance(parentOffset.begin(), it)); + } + auto multiDimOffsetParent = getMultiDimOffset( + parentEncoding, loc, rewriter, targetInfo, idxs[elemId], parentTy, + sliceLayout.paddedShape(multiDimCTAInRepId), + sliceLayout.paddedShape(shapePerCTATile)); + SmallVector multiDimOffset(rank); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d == dim) + continue; + unsigned slicedD = d < dim ? d : (d - 1); + multiDimOffset[slicedD] = multiDimOffsetParent[d]; + } + return multiDimOffset; + } + if (auto mmaLayout = mlir::dyn_cast(layout)) { + assert(rank == 2 || + (rank == 3 && mmaLayout.isAmpere()) && "Unexpected rank"); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + auto instrShape = mmaLayout.getInstrShape(); + SmallVector mmaColIdx(2); + SmallVector mmaRowIdx(2); + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + // TODO: fix the bug in MMAEncodingAttr document + SmallVector multiDimWarpId(2); + auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); + auto warpOrder = triton::gpu::getWarpOrder(mmaLayout); + multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); + Value _1 = i32_val(1); + Value _2 = i32_val(2); + Value _4 = i32_val(4); + Value _8 = i32_val(8); + Value _16 = i32_val(16); + if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { + multiDimWarpId[rank - 1] = urem( + multiDimWarpId[rank - 1], + i32_val(ceil(shapePerCTA[rank - 1], instrShape[rank - 1]))); + multiDimWarpId[rank - 2] = urem( + multiDimWarpId[rank - 2], + i32_val(ceil(shapePerCTA[rank - 2], instrShape[rank - 2]))); + + Value mmaGrpId = udiv(laneId, _4); + Value mmaGrpIdP8 = add(mmaGrpId, _8); + Value mmaThreadIdInGrp = urem(laneId, _4); + Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2); + Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1); + Value rowWarpOffset = + mul(multiDimWarpId[rank - 2], i32_val(instrShape[rank - 2])); + mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset); + mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset); + Value colWarpOffset = + mul(multiDimWarpId[rank - 1], i32_val(instrShape[rank - 1])); + mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset); + mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset); + } else if (mmaLayout.isVolta()) { + // Volta doesn't follow the pattern here. + } else { + llvm_unreachable("Unexpected MMALayout version"); + } + + SmallVector multiDimOffset(rank); + if (mmaLayout.isHopper()) { + unsigned elemIdRem4 = elemId % 4; + unsigned nGrpId = elemId / 4; + multiDimOffset[0] = elemIdRem4 < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[1] = elemIdRem4 % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[1] = add(multiDimOffset[1], i32_val(8 * nGrpId)); + multiDimOffset[0] = add(multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * + shapePerCTATile[0])); + multiDimOffset[1] = add(multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * + shapePerCTATile[1])); + } else if (mmaLayout.isAmpere()) { + if (rank == 3) + multiDimOffset[0] = + add(multiDimWarpId[0], + i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0])); + multiDimOffset[rank - 2] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[rank - 1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[rank - 2] = + add(multiDimOffset[rank - 2], i32_val(multiDimCTAInRepId[rank - 2] * + shapePerCTATile[rank - 2])); + multiDimOffset[rank - 1] = + add(multiDimOffset[rank - 1], i32_val(multiDimCTAInRepId[rank - 1] * + shapePerCTATile[rank - 1])); + } else if (mmaLayout.isVolta()) { + auto [isARow, isBRow, isAVec4, isBVec4, _] = + mmaLayout.decodeVoltaLayoutStates(); + auto coords = SharedToDotOperandMMAv1::getMNCoords( + threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape, + isARow, isBRow, isAVec4, isBVec4); + return coords[elemId]; + } else { + llvm_unreachable("Unexpected MMALayout version"); + } + return multiDimOffset; + } + if (isa(layout)) { + auto multiDimBase = + emitBaseIndexForLayout(loc, rewriter, targetInfo, layout, type, false); + SmallVector> offsets; + assert(rank == 2); + SmallVector multiDimOffset(rank); + if (auto mfmaLayout = dyn_cast(layout)) { + emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0], + multiDimCTAInRepId[1]); + } else if (auto wmmaLayout = dyn_cast(layout)) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0], + multiDimCTAInRepId[1]); + } + multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0])); + multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1])); + return multiDimOffset; + } + llvm_unreachable("unexpected layout in getMultiDimOffset"); +} + +SmallVector getWrappedMultiDimOffset( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDimOffset, ArrayRef shape, + SmallVector shapePerCTATile, SmallVector shapePerCTA) { + unsigned rank = shape.size(); + SmallVector multiDimOffsetWrapped(rank); + for (unsigned d = 0; d < rank; ++d) { + if (shapePerCTATile[d] > shapePerCTA[d]) + multiDimOffsetWrapped[d] = urem(multiDimOffset[d], i32_val(shape[d])); + else + multiDimOffsetWrapped[d] = multiDimOffset[d]; + } + return multiDimOffsetWrapped; +} + +} // namespace LLVM +} // namespace mlir diff --git a/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp new file mode 100644 index 000000000..e0f6e9377 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -0,0 +1,398 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +namespace { +struct SplatOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a + // LLVM::StructType value. + // + // @elemType: the element type in operand. + // @resType: the return type of the Splat-like op. + // @constVal: a LLVM::ConstantOp or other scalar value. + static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Location loc) { + auto tensorTy = cast(resType); + // Check the converted type for the tensor as depending on the encoding the + // converter may pick different element types. + auto srcType = typeConverter->convertType(tensorTy); + if (auto structTy = dyn_cast(srcType)) + srcType = structTy.getBody()[0]; + // If the type sizes don't match we need to pack constants. + if (srcType.isIntOrFloat() && constVal.getType().getIntOrFloatBitWidth() != + srcType.getIntOrFloatBitWidth()) { + unsigned cstBitWidth = constVal.getType().getIntOrFloatBitWidth(); + unsigned srcBitWidth = srcType.getIntOrFloatBitWidth(); + assert(cstBitWidth <= srcBitWidth && srcBitWidth % cstBitWidth == 0); + unsigned ratio = srcBitWidth / cstBitWidth; + Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth); + VectorType vecType = VectorType::get(ratio, intTy); + Value intCst = bitcast(constVal, intTy); + Value vec = undef(vecType); + for (unsigned i = 0; i < ratio; ++i) + vec = insert_element(vecType, vec, intCst, int_val(32, i)); + constVal = vec; + } + auto llSrc = bitcast(constVal, srcType); + size_t elemsPerThread = getTotalElemsPerThread(tensorTy); + llvm::SmallVector elems(elemsPerThread, llSrc); + return packLLElements(loc, typeConverter, elems, rewriter, resType); + } + LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto src = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src, + typeConverter, rewriter, loc); + rewriter.replaceOp(op, {llStruct}); + return success(); + } +}; +// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), +// the logic is the same as triton::SplatOp, so the underlying implementation +// is reused. +struct ArithConstantSplatOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto value = op.getValue(); + if (!mlir::dyn_cast(value)) + return failure(); + auto loc = op->getLoc(); + LLVM::ConstantOp arithConstantOp; + auto values = mlir::dyn_cast(op.getValue()); + auto elemType = values.getElementType(); + Attribute val; + if (type::isFloat(elemType)) { + val = values.getValues()[0]; + } else if (type::isInt(elemType)) { + val = values.getValues()[0]; + } else { + llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: " + << value.getType() << "\n"; + return failure(); + } + auto constOp = rewriter.create(loc, elemType, val); + auto typeConverter = getTypeConverter(); + auto llStruct = SplatOpConversion::convertSplatLikeOp( + elemType, op.getType(), constOp, typeConverter, rewriter, loc); + rewriter.replaceOp(op, llStruct); + return success(); + } +}; +struct CatOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename CatOp::Adaptor; + explicit CatOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + unsigned elems = getTotalElemsPerThread(resultTy); + auto typeConverter = getTypeConverter(); + Type elemTy = typeConverter->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + // unpack input values + auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter); + auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter); + // concatenate (and potentially reorder) values + SmallVector retVals; + for (Value v : lhsVals) + retVals.push_back(v); + for (Value v : rhsVals) + retVals.push_back(v); + // pack and replace + Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct JoinOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename JoinOp::Adaptor; + explicit JoinOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The op has a blocked encoding. + // - The last dimension (the one we're joining) is also the most minor + // dimension. + // - The input and output encodings are the same, except the output has + // 2 elements per thread in the last dim. + // + // With these invariants, join is trivial: We just return the i'th element + // from lhs, followed by the i'th elem from rhs. + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + SmallVector lhsVals = + unpackLLElements(loc, adaptor.getLhs(), rewriter); + SmallVector rhsVals = + unpackLLElements(loc, adaptor.getRhs(), rewriter); + assert(lhsVals.size() == rhsVals.size()); + SmallVector joinedVals; + for (int i = 0; i < lhsVals.size(); i++) { + joinedVals.push_back(lhsVals[i]); + joinedVals.push_back(rhsVals[i]); + } + Value ret = + packLLElements(loc, typeConverter, joinedVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct SplitOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename SplitOp::Adaptor; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The op has a blocked encoding. + // - The last dimension (the one we're spliting) is also the most minor + // dimension, and has sizePerThread=2. + // + // With these invariants, split is trivial: Every other value goes into + // return value 0, and every other goes into return value 1. + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + SmallVector srcVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(srcVals.size() % 2 == 0); + SmallVector outLhsVals; + SmallVector outRhsVals; + for (int i = 0; i < srcVals.size(); i += 2) { + outLhsVals.push_back(srcVals[i]); + outRhsVals.push_back(srcVals[i + 1]); + } + auto resultTy = cast(op.getResult(0).getType()); + Value retLhs = + packLLElements(loc, typeConverter, outLhsVals, rewriter, resultTy); + Value retRhs = + packLLElements(loc, typeConverter, outRhsVals, rewriter, resultTy); + rewriter.replaceOp(op, {retLhs, retRhs}); + return success(); + } +}; +struct ReshapeOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ReshapeOp::Adaptor; + explicit ReshapeOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) { + return emitOptionalError(loc, + "expensive view not supported on reshape op"); + } + auto resultTy = cast(op.getType()); + auto srcTy = cast(op.getSrc().getType()); + auto typeConverter = getTypeConverter(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, typeConverter, vals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ExpandDimsOp::Adaptor; + explicit ExpandDimsOpConversion( + LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(op.getType()); + auto srcLayout = dyn_cast(srcTy.getEncoding()); + if (!srcLayout) { + return emitOptionalError( + loc, "ExpandDimsOp only supports SliceEncodingAttr as its input"); + } + auto resultLayout = resultTy.getEncoding(); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + offset.erase(offset.begin() + srcLayout.getDim()); + resultVals.push_back(srcValues.at(offset)); + } + Value ret = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct TransOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + if (auto enc = dyn_cast(resultTy.getEncoding())) { + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.base, srcSmemObj.baseElemType, + /*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()), + /*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder())); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } else if (auto enc = mlir::dyn_cast( + resultTy.getEncoding())) { + // If the dst encoding is blocked, then TransOp::inferReturnTypes + // ensures that: + // - the src encoding is also blocked, and + // - the translation from src to dst is just a "renaming" of the + // registers, i.e. each thread has exactly the same values. + // Thus the transpose op simply returns the same values it got. + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, this->getTypeConverter(), vals, rewriter, + resultTy); + rewriter.replaceOp(op, ret); + return success(); + } + return emitOptionalError(loc, "unsupported encoding for TransOp"); + } +}; +struct BroadcastOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Following the order of indices in the legacy code, a broadcast of: + // [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)] + // => + // [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)] + // + // logically maps to a broadcast within a thread's scope: + // [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1), + // 1,spt(k+1)..spt(n-1)] + // => + // [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)] + // + // regardless of the order of the layout + // + Location loc = op->getLoc(); + Value src = adaptor.getSrc(); + Value result = op.getResult(); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(result.getType()); + auto srcLayout = srcTy.getEncoding(); + auto resultLayout = resultTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultShape = resultTy.getShape(); + unsigned rank = srcTy.getRank(); + auto typeConverter = getTypeConverter(); + assert(rank == resultTy.getRank()); + auto order = triton::gpu::getOrder(srcLayout); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + SmallVector srcVals = unpackLLElements(loc, src, rewriter); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + for (size_t j = 0; j < srcShape.size(); j++) + if (srcShape[j] == 1) + offset[j] = 0; + resultVals.push_back(srcValues.at(offset)); + } + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct MemDescSubviewOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // %dst = extract_slice %src[%offsets] + Location loc = op->getLoc(); + auto srcTy = op.getSrc().getType(); + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + + // newBase = base + offset + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + SmallVector opOffsetVals = op.getOffsets(); + size_t destRank = op.getResult().getType().getRank(); + SmallVector offsetVals; + SmallVector strides; + int rankReduced = srcTy.getRank() - destRank; + for (int i = rankReduced; i < opOffsetVals.size(); i++) { + strides.push_back(smemObj.strides[i]); + offsetVals.push_back(opOffsetVals[i]); + } + // Compute the offset based on the original strides of the shared memory + // object + auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + smemObj = + SharedMemoryObject(gep(elemPtrTy, llvmElemTy, smemObj.base, offset), + llvmElemTy, strides, offsetVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; +} // namespace + +void mlir::triton::populateViewOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..1b629ba16 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonToTritonGPU + TritonGPUConversion.cpp + TritonToTritonGPUPass.cpp + + DEPENDS + TritonConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + TritonIR + TritonGPUIR + TritonGPUTransforms +) diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp new file mode 100644 index 000000000..34fb89954 --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -0,0 +1,123 @@ +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +#include +#include + +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton::gpu; + +// +// TypeConverter +// +TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + int numWarps, int threadsPerWarp, + int numCTAs) + : context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp), + numCTAs(numCTAs) { + addConversion([](Type type) { return type; }); + + // Add encoding for tensor + addConversion([this](RankedTensorType tensorType) -> RankedTensorType { + // types with encoding are already in the right format + // TODO: check for layout encodings more specifically + if (tensorType.getEncoding()) + return tensorType; + ArrayRef shape = tensorType.getShape(); + triton::gpu::BlockedEncodingAttr encoding = + getDefaultBlockedEncoding(this->context, shape, this->numWarps, + this->threadsPerWarp, this->numCTAs); + return RankedTensorType::get(shape, tensorType.getElementType(), encoding); + }); + + // Add encoding for tensor pointer + addConversion([this](triton::PointerType ptrType) -> triton::PointerType { + // Check whether tensor pointer `tt.ptr>` + auto pointeeTensorType = + dyn_cast(ptrType.getPointeeType()); + if (pointeeTensorType == nullptr) + return ptrType; + + // Add layout into the tensor + auto convertedTensorType = convertType(pointeeTensorType); + return triton::PointerType::get(convertedTensorType, + ptrType.getAddressSpace()); + }); + + // + // Materializations + // + // This will be called when (newArgType != origArgType) + // This will create newArg, and map(origArg, newArg) + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); + return std::nullopt; + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonGPU Conversion"); + return std::nullopt; + }); + + // This will be called when (desiredType != newOperandType) + // where, desiredType = typeConverter->convertType(origType) + // NOTE: only for remapped values. + addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + auto cast = + builder.create(loc, tensorType, inputs); + return std::optional(cast.getResult()); + }); +} + +// +// TritonGPUConversion +// +TritonGPUConversionTarget::TritonGPUConversionTarget( + MLIRContext &context, TritonGPUTypeConverter &typeConverter) + : ConversionTarget(context) { + // TODO: we should also verify ops of TritonGPUDialect + addLegalDialect(); + + // Some ops from SCF are illegal + addIllegalOp(); + + addDynamicallyLegalDialect([&](Operation *op) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; + }); + + // We have requirements for the data layouts + addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { + Attribute aEncoding = + cast(dotOp.getA().getType()).getEncoding(); + Attribute bEncoding = + cast(dotOp.getB().getType()).getEncoding(); + if (aEncoding && isa(aEncoding) && + bEncoding && isa(bEncoding)) + return true; + return false; + }); +} diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp new file mode 100644 index 000000000..4aa2712ec --- /dev/null +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -0,0 +1,821 @@ +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "llvm/ADT/APSInt.h" +#include + +#define GEN_PASS_CLASSES +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// pass named attrs (e.g., tt.contiguity) from Triton to Triton +static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { + for (const NamedAttribute attr : dictAttrs.getValue()) + if (!op->hasAttr(attr.getName())) + op->setAttr(attr.getName(), attr.getValue()); +} + +template struct GenericOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + + return success(); + } +}; + +class ArithConstantPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = getTypeConverter()->convertType(op.getType()); + auto retShapedType = cast(retType); + auto value = dyn_cast(adaptor.getValue()); + if (dyn_cast(retShapedType)) { + assert(value); + if (value.getElementType().isInteger(1) && value.isSplat()) + // Workaround until https://reviews.llvm.org/D133743 is included. + value = + DenseElementsAttr::get(retShapedType, value.getSplatValue()); + else + // This is a hack. We just want to add encoding + value = value.reshape(retShapedType); + } + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retShapedType, value), + adaptor.getAttributes()); + return success(); + } +}; + +void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + // -------------- + // Add legality and rewrite pattern rules for operations + // from the Arith dialect. The basic premise is that + // Arith operations require both inputs to have the same + // non-null encoding + // -------------- + MLIRContext *context = patterns.getContext(); + // TODO: there's probably a better way to avoid adding all ops one-by-one + patterns.add< + ArithConstantPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, // NegFOp + // Floating point + GenericOpPattern, GenericOpPattern, + // MaxMin + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + // Floating point + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + // Cmp + GenericOpPattern, GenericOpPattern, + // Select + GenericOpPattern, + // Cast Ops + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern>(typeConverter, context); +} + +void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern>( + typeConverter, context); +} + +// +// Triton patterns +// +struct TritonExpandDimsPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Type retType = op.getType()); + RankedTensorType argType = + cast(adaptor.getSrc().getType()); + Attribute _argEncoding = argType.getEncoding(); + if (!_argEncoding) + return failure(); + auto argEncoding = cast(_argEncoding); + // return shape + auto retShape = argType.getShape().vec(); + retShape.insert(retShape.begin() + op.getAxis(), 1); + // return encoding + auto retSizePerThread = argEncoding.getSizePerThread(); + retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1); + auto retThreadsPerWarp = argEncoding.getThreadsPerWarp(); + retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1); + auto retWarpsPerCTA = argEncoding.getWarpsPerCTA(); + retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1); + SmallVector retOrder(retShape.size()); + std::iota(retOrder.begin(), retOrder.end(), 0); + + auto argCTALayout = argEncoding.getCTALayout(); + auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis()); + auto retCTASplitNum = + insertOne(argCTALayout.getCTASplitNum(), op.getAxis()); + auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis()); + auto retCTALayout = triton::gpu::CTALayoutAttr::get( + getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder); + + triton::gpu::BlockedEncodingAttr retEncoding = + triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, + retThreadsPerWarp, retWarpsPerCTA, + retOrder, retCTALayout); + // convert operand to slice of return type + Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( + getContext(), op.getAxis(), retEncoding); + RankedTensorType newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), newArgEncoding); + // construct new op + auto newSrc = rewriter.create( + op.getLoc(), newArgType, adaptor.getSrc()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newSrc, adaptor.getAxis()), + adaptor.getAttributes()); + return success(); + } + +private: + template + SmallVector insertOne(ArrayRef vec, unsigned axis) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + axis, 1); + return res; + } + + // Example: order = [ 0, 2, 1, 3], dim = 2 + // resOrder = [2, 0, 3, 1, 4] + SmallVector insertOrder(ArrayRef order, + unsigned axis) const { + SmallVector resOrder(order.begin(), order.end()); + for (unsigned i = 0; i < resOrder.size(); ++i) + if (resOrder[i] >= axis) + ++resOrder[i]; + resOrder.insert(resOrder.begin(), axis); + return resOrder; + } +}; + +struct TritonDotPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType origType = op.getType(); + auto origShape = origType.getShape(); + auto typeConverter = getTypeConverter(); + int numWarps = typeConverter->getNumWarps(); + int threadsPerWarp = typeConverter->getThreadsPerWarp(); + int numCTAs = typeConverter->getNumCTAs(); + auto rank = origShape.size(); + SmallVector retSizePerThread(rank, 1); + auto numElements = product(origShape); + if (numElements / (numWarps * threadsPerWarp) >= 4) { + retSizePerThread[rank - 1] = 2; + retSizePerThread[rank - 2] = 2; + } + if (numElements / (numWarps * threadsPerWarp) >= 16) { + retSizePerThread[rank - 1] = 4; + retSizePerThread[rank - 2] = 4; + } + SmallVector retOrder(rank); + for (unsigned i = 0; i < rank; ++i) + retOrder[i] = rank - 1 - i; + Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( + getContext(), origShape, retSizePerThread, retOrder, numWarps, + threadsPerWarp, numCTAs); + RankedTensorType retType = + RankedTensorType::get(origShape, origType.getElementType(), dEncoding); + // a & b must be of smem layout + auto aType = cast(adaptor.getA().getType()); + auto bType = cast(adaptor.getB().getType()); + Type aEltType = aType.getElementType(); + Type bEltType = bType.getElementType(); + Attribute aEncoding = aType.getEncoding(); + Attribute bEncoding = bType.getEncoding(); + if (!aEncoding || !bEncoding) + return failure(); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + Value c = adaptor.getC(); + if (!mlir::isa(aEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 0, dEncoding, aEltType); + auto dstType = + RankedTensorType::get(aType.getShape(), aEltType, encoding); + a = rewriter.create(a.getLoc(), dstType, a); + } + if (!mlir::isa(bEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 1, dEncoding, bEltType); + auto dstType = + RankedTensorType::get(bType.getShape(), bEltType, encoding); + b = rewriter.create(b.getLoc(), dstType, b); + } + c = rewriter.create(c.getLoc(), retType, c); + + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, a, b, c, adaptor.getInputPrecision(), + adaptor.getMaxNumImpreciseAcc()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The cat op satisfy two conditions: + // 1. output.numel = lhs.numel + rhs.numel + // 2. output.total_elems_per_thread = + // next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread) + // For now, this behaves like generic, but this + // will evolve when we add support for `can_reorder=False`. + auto retType = cast( + this->getTypeConverter()->convertType(op.getType())); + auto retEncoding = + cast(retType.getEncoding()); + auto lhsType = adaptor.getLhs().getType(); + auto rhsType = adaptor.getRhs().getType(); + auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType); + auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType); + auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType); + auto retShape = retType.getShape(); + auto retOrder = retEncoding.getOrder(); + auto retSizePerThread = retEncoding.getSizePerThread(); + auto retThreadsPerWarp = retEncoding.getThreadsPerWarp(); + auto retWarpsPerCTA = retEncoding.getWarpsPerCTA(); + // Get new retSizePerThread if ret elems per thread is not enough. + // We have to round it up to the next power of 2 due to triton's tensor size + // constraint. + auto newRetTotalElemsPerThread = + nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread); + auto newRetSizePerThread = retSizePerThread; + newRetSizePerThread[retOrder[0]] *= + newRetTotalElemsPerThread / retTotalElemsPerThread; + triton::gpu::BlockedEncodingAttr newRetEncoding = + triton::gpu::BlockedEncodingAttr::get( + getContext(), newRetSizePerThread, retThreadsPerWarp, + retWarpsPerCTA, retOrder, retEncoding.getCTALayout()); + auto newRetType = RankedTensorType::get(retShape, retType.getElementType(), + newRetEncoding); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newRetType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonJoinOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Simply rely on type inference for this op. (Notably, GenericOpPattern + // does not do this, instead it assigns the default layout to the ins and + // outs.) + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, adaptor.getLhs(), adaptor.getRhs()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonSplitOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = dyn_cast(srcTy.getEncoding()); + int rank = srcEnc.getOrder().size(); + auto typeConverter = getTypeConverter(); + + // The operand to split must have: + // - a blocked layout, with + // - sizePerThread = 2 in the last dimension, + // - threadsPerWarp, warpsPerCTA, and CTAsPerCGA = 1 in the last dim, and + // - the last dimension minor. + // If that's not the case, add a convert before the split. + if (!srcEnc || srcEnc.getSizePerThread().back() != 2 || + srcEnc.getOrder().front() != rank - 1) { + // If we take the default encoding for the op's result (i.e. post-split) + // and add 1 to the end of each dim, that gives us what we want. Other + // than making a legal src encoding, our choice of layout doesn't matter; + // it'll get fixed by RemoveLayoutConversions. + auto defaultEnc = getDefaultBlockedEncoding( + getContext(), + cast(op.getResult(0).getType()).getShape(), + typeConverter->getNumWarps(), typeConverter->getThreadsPerWarp(), + typeConverter->getNumCTAs()); + + auto append = [&](ArrayRef vals, unsigned val) { + SmallVector res(vals); + res.push_back(val); + return res; + }; + auto prepend = [&](ArrayRef vals, unsigned val) { + SmallVector res; + res.push_back(val); + res.append(vals.begin(), vals.end()); + return res; + }; + + srcEnc = BlockedEncodingAttr::get( + getContext(), append(defaultEnc.getSizePerThread(), 2), + append(defaultEnc.getThreadsPerWarp(), 1), + append(defaultEnc.getWarpsPerCTA(), 1), + prepend(defaultEnc.getOrder(), rank - 1), + CTALayoutAttr::get(getContext(), + append(defaultEnc.getCTAsPerCGA(), 1), + append(defaultEnc.getCTASplitNum(), 1), + prepend(defaultEnc.getCTAOrder(), rank - 1))); + srcTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), + srcEnc); + src = rewriter.create(op.getLoc(), srcTy, src); + } + + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonTransPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = srcTy.getEncoding(); + if (!srcEnc) + return failure(); + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src, op.getOrder()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonBroadcastPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // This creates a tensor with the new shape but the argument's layout + LogicalResult + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(adaptor.getSrc().getType()); + auto srcEncoding = srcType.getEncoding(); + if (!srcEncoding) + return failure(); + Type retType = RankedTensorType::get( + op.getType().getShape(), op.getType().getElementType(), srcEncoding); + // Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReduce = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); + addNamedAttrs(newReduce, adaptor.getAttributes()); + + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newReduce.getResult()); + return success(); + } +}; + +struct TritonScanPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newScan = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis(), op.getReverse()); + addNamedAttrs(newScan, adaptor.getAttributes()); + + auto &newCombineOp = newScan.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newScan.getResult()); + return success(); + } +}; + +class TritonFuncOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getName(), op.getFunctionType()); + addNamedAttrs(newOp, adaptor.getAttributes()); + rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), + newOp.getBody().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter))) + return failure(); + + return success(); + } +}; + +class TritonCallOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getCallee(), op.getResultTypes(), adaptor.getOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + return success(); + } +}; + +class TritonReturnOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, unsigned numCTAs) { + MLIRContext *context = patterns.getContext(); + patterns.insert< // TODO: view should have custom pattern that views the + // layout + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + TritonBroadcastPattern, GenericOpPattern, + TritonCatPattern, TritonJoinOpPattern, TritonSplitOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, TritonReducePattern, + GenericOpPattern, TritonScanPattern, + GenericOpPattern, + GenericOpPattern, TritonExpandDimsPattern, + TritonTransPattern, TritonDotPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, TritonFuncOpPattern>(typeConverter, + context); +} + +// +// SCF patterns +// +// This is borrowed from ConvertForOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +struct SCFForPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + // Ref: ConvertForOpTypes + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Now, update all the types. + + // Convert the types of block arguments within the given region. This + // replaces each block with a new block containing the updated signature. + // The entry block may have a special conversion if `entryConversion` is + // provided. On success, the new entry block to the region is returned for + // convenience. Otherwise, failure is returned. + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *getTypeConverter()))) { + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + // Change the clone to use the updated operands. We could have cloned with + // a IRMapping, but this seems a bit more direct. + newOp->setOperands(adaptor.getOperands()); + // Update the result types to the new converted types. + SmallVector newResultTypes; + for (Type type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + + rewriter.replaceOp(op, newOp.getResults()); + + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFWhilePattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + auto newOp = rewriter.create(op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +class SCFConditionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add, SCFForPattern, SCFIfPattern, + SCFWhilePattern, SCFConditionPattern>(typeConverter, context); +} + +// CF + +class CFBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getSuccessor(), adaptor.getOperands()); + if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +class CFCondBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + + if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(), + *converter))) + return failure(); + if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +void populateCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(typeConverter, context); +} +// + +class ConvertTritonToTritonGPU + : public ConvertTritonToTritonGPUBase { +public: + ConvertTritonToTritonGPU() = default; + // constructor with some parameters set explicitly. + ConvertTritonToTritonGPU(const std::string &target, int numWarps, + int threadsPerWarp, int numCTAs) { + this->numWarps = numWarps; + this->threadsPerWarp = threadsPerWarp; + this->numCTAs = numCTAs; + this->target = target; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + // type converter + TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp, + numCTAs); + TritonGPUConversionTarget target(*context, typeConverter); + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + populateArithPatternsAndLegality(typeConverter, patterns, target); + populateMathPatternsAndLegality(typeConverter, patterns, target); + populateTritonPatterns(typeConverter, patterns, numCTAs); + // TODO: can we use + // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? + populateSCFPatterns(typeConverter, patterns); + populateCFPatterns(typeConverter, patterns); + + auto inti = llvm::APSInt(32, false); + auto i32_ty = IntegerType::get(mod->getContext(), 32); + + mod->setAttr( + AttrNumWarpsName, + IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue()))); + mod->setAttr( + AttrNumThreadsPerWarp, + IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue()))); + + mod->setAttr(AttrNumCTAsName, + IntegerAttr::get(i32_ty, llvm::APInt(32, numCTAs.getValue()))); + + if (this->target.getValue().empty()) { + mod.emitError("expected target specification to attach to the module op"); + return signalPassFailure(); + } + mod->setAttr(AttrTargetName, + StringAttr::get(context, this->target.getValue())); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + + // update layouts + // broadcast src => multicast, dst => broadcasted + // if (failed(target.refineLayouts(mod, numWarps))) + // return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass(const std::string &target, + int numWarps, + int threadsPerWarp, + int numCTAs) { + return std::make_unique<::ConvertTritonToTritonGPU>(target, numWarps, + threadsPerWarp, numCTAs); +} + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass() { + return std::make_unique<::ConvertTritonToTritonGPU>(); +} diff --git a/third_party/mthreads/lib/Dialect/CMakeLists.txt b/third_party/mthreads/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..27cb65ce5 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) diff --git a/third_party/mthreads/lib/Dialect/Triton/CMakeLists.txt b/third_party/mthreads/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/mthreads/lib/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..752daa7ff --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonIR + Dialect.cpp + Ops.cpp + Traits.cpp + Types.cpp + + DEPENDS + TritonTableGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/Dialect.cpp new file mode 100644 index 000000000..8f46e8ca8 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Dialect.cpp @@ -0,0 +1,138 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/DialectImplementation.h" + +#include "mlir/Transforms/InliningUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" +#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; + +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TritonInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + auto funcOp = dyn_cast(callable); + if (!funcOp) + return true; + if (funcOp->hasAttr("noinline")) + return !funcOp->getAttrOfType("noinline").getValue(); + return true; + } + + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + IRMapping &) const final { + return true; + } + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create(op->getLoc(), newDest, + returnOp.getOperands()); + op->erase(); + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only return needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } +}; + +struct TensorModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getRank(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementTypeBitWidth(); + } +}; + +struct MemDescModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getShape().size(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementType().getIntOrFloatBitWidth(); + } +}; + +} // namespace + +void TritonDialect::initialize() { + registerTypes(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + >(); + + // We can also add interface here. + addInterfaces(); + + RankedTensorType::attachInterface(*getContext()); + MemDescType::attachInterface(*getContext()); +} + +Operation *TritonDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Ops.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/Ops.cpp new file mode 100644 index 000000000..ce4f97336 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Ops.cpp @@ -0,0 +1,982 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +namespace mlir { +namespace triton { + +void LoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getPtr(), + triton::GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace triton +} // namespace mlir + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { + +//-- LoadOp -- +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + auto paddingAttr = + padding.has_value() + ? PaddingOptionAttr::get(builder.getContext(), padding.value()) + : PaddingOptionAttr(); + LoadOp::build(builder, state, ptr, mask, other, + builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, + evict, isVolatile); +} + +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { + CanonicalizeMaskedLoadPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return success(); + } +}; + +void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- StoreOp -- +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + /*boundaryCheck=*/{}, cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, Value mask, CacheModifier cache, + EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, + cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, ArrayRef boundaryCheck, + CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + builder.getDenseI32ArrayAttr(boundaryCheck), cache, + evict); +} + +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern : public OpRewritePattern { + CanonicalizeMaskedStorePattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return success(); + } +}; + +void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- TransOp -- +OpFoldResult TransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + return getSrc(); + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + return {}; +} + +LogicalResult TransOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input + auto argTy = cast(operands[0].getType()); + auto order = properties.as()->order.asArrayRef(); + SmallVector retShape = applyPermutation(argTy.getShape(), order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, order, retEncoding) + .failed()) { + return failure(); + } + } + if (isa(argTy)) { + inferredReturnTypes.push_back( + MemDescType::get(retShape, retEltTy, retEncoding)); + } else { + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +LogicalResult TransOp::verify() { + // Check that the op's `order` attribute is a permutation of the right length. + auto srcTy = getSrc().getType(); + + ArrayRef order = getOrder(); + if (order.size() != srcTy.getRank()) { + return emitError("order must have the same size as the rank of the " + "operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +//-- DotOp -- +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = dyn_cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult DotOp::verify() { + auto aTy = getA().getType(); + auto bTy = getB().getType(); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding && !bEncoding) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + Dialect &dialect = aEncoding.getDialect(); + auto interface = cast(&dialect); + return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, + bEncoding); +} + +//-- MakeRangeOp -- +OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { + // make_range(start, start + 1) -> constant(start) + if (adaptor.getStart() + 1 == adaptor.getEnd()) { + auto shapedType = cast(getType()); + return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); + } + return {}; +} + +LogicalResult MakeRangeOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start > end) { + return this->emitOpError() << "start must be less than or equal to end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start != ty.getShape()[0]) { + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must match size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32)) { + return this->emitOpError() << "returned tensor must have i32 elements"; + } + return success(); +} + +//-- ReduceOp -- +static LogicalResult +inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy, + int axis, SmallVectorImpl &inferredReturnTypes) { + auto retShape = argTy.getShape().vec(); + retShape.erase(retShape.begin() + axis); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(retEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = + dyn_cast(&dialect); + if (inferLayoutInterface + ->inferReduceOpEncoding(argEncoding, axis, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return failure(); + } + } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +void ReduceOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis) { + SmallVector inferredReturnTypes; + for (unsigned i = 0; i < operands.size(); ++i) { + auto argTy = cast(operands[i].getType()); + auto retEltTy = argTy.getElementType(); + (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + } + + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + +LogicalResult ReduceOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + for (auto arg : operands) { + auto argTy = cast(arg.getType()); + auto retEltTy = argTy.getElementType(); + if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +// Helpers for Reductions and Scans +template LogicalResult verifyReduceScan(Op &op) { + if (op.getOperands().empty()) { + return op.emitOpError() << "must have at least 1 operand"; + } + if (op.getNumOperands() != op.getNumResults()) { + return op.emitOpError() << "must have the same number of inputs as outputs"; + } + + auto getElementType = [](Type ty) { + if (auto tensorType = dyn_cast(ty)) { + return tensorType.getElementType(); + } + return ty; + }; + + for (auto [opElemTy, resTy] : + llvm::zip(op.getElementTypes(), op.getResultTypes())) { + if (opElemTy != getElementType(resTy)) { + return op.emitOpError() << "operand types and result types must agree"; + } + } + return success(); +} + +template +static LogicalResult verifyRegionsImpl(Op &op) { + auto argElementTypes = op.getElementTypes(); + const auto &operands = op.getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *op.getBody(); + if (block.getNumArguments() != numArgs) { + return op.emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; + if (blockArgTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return op.emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return op.emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return success(); +} + +static llvm::SmallVector +getInputTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcTys; + srcTys.reserve(operands.size()); + for (const auto &ty : operands.getTypes()) { + srcTys.push_back(cast(ty)); + } + return srcTys; +} + +static llvm::SmallVector +getElementTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(operands.size()); + for (const auto &op : operands) { + srcElemTys.push_back(cast(op.getType()).getElementType()); + } + return srcElemTys; +} + +LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ReduceOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ReduceOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + +//-- ScanOp -- +void ScanOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis, bool reverse) { + SmallVector inferredReturnTypes; + state.addAttribute("reverse", builder.getBoolAttr(reverse)); + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + +LogicalResult +ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + return success(); +} + +LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ScanOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ScanOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ScanOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } + +//-- SplatOp -- +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getSrc(); + if (!value) + return {}; + auto shapedType = cast(getType()); + auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); + return ret; +} + +//-- ExpandDimsOp -- +LogicalResult ExpandDimsOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + auto arg = operands[0]; + auto argTy = cast(arg.getType()); + auto retShape = argTy.getShape().vec(); + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + retShape.insert(retShape.begin() + axis, 1); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc) + .failed()) + return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return success(); +} + +LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + // expand_dims(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) + // + // On its own this doesn't do much, but consider + // broadcast(expand_dims(broadcast)) + // -> broadcast(broadcast(expand_dims)) + // -> broadcast(expand_dims) + if (auto broadcast = dyn_cast(definingOp)) { + auto src = broadcast.getSrc(); + auto srcTy = src.getType(); + SmallVector newExpandShape(srcTy.getShape()); + newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); + + // Infer the encoding of the new expand op, if encodings are present. + Attribute newExpandEnc; + if (auto srcEnc = srcTy.getEncoding()) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc, + op.getLoc()) + .failed()) { + return emitOptionalError(op.getLoc(), + "failed to infer layout for ExpandDimsOp"); + } + } + + auto newExpandTy = RankedTensorType::get( + newExpandShape, srcTy.getElementType(), newExpandEnc); + auto newExpand = rewriter.create(op.getLoc(), newExpandTy, + src, op.getAxis()); + auto newBroadcast = rewriter.create( + broadcast.getLoc(), op.getType(), newExpand.getResult()); + rewriter.replaceOp(op, {newBroadcast.getResult()}); + return success(); + } + + return failure(); +} + +template +static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { + if (!value) + return {}; + + auto shapedType = cast(op.getType()); + if (auto denseElemsAttr = dyn_cast(value)) { + if (denseElemsAttr.isSplat()) { + return denseElemsAttr.resizeSplat(shapedType); + } else { + return denseElemsAttr.reshape(shapedType); + } + } + return {}; +} + +OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +//-- ReshapeOp -- +template +LogicalResult canonicalizeViewOrBroadcast(OpType op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + + // view(view) -> view + if (auto parentView = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, TypeRange({op.getType()}), + parentView->getOperands(), + parentView->getAttrs()); + return success(); + } + + // view(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + + return failure(); +} + +LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { + if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + return failure(); + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +LogicalResult ReshapeOp::verify() { + auto dstTy = getType(); + auto srcTy = getSrc().getType(); + if (getType().getNumElements() != srcTy.getNumElements()) { + return emitError( + "number of src and dst elements of reshape must be the same"); + } + + Attribute srcEnc = srcTy.getEncoding(); + Attribute dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("Op requires that either (a) src and dst both have " + "encodings, or (b) neither does."); + } + + if (srcEnc && !getAllowReorder()) { + Attribute inferredDstEnc; + if (cast(&srcEnc.getDialect()) + ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, + dstTy.getShape(), inferredDstEnc, + getLoc()) + .failed()) { + return emitError("This reshape is impossible without reordering, but " + "reordering is not allowed. Try choosing a different " + "encoding for the input tensor (or allow reordering)."); + } + if (inferredDstEnc != dstEnc) { + return emitError("Expected result encoding ") + << inferredDstEnc << " but was " << dstEnc; + } + } + + return success(); +} + +//-- FpToFpOp -- +LogicalResult FpToFpOp::verify() { + auto dstType = getType().getElementType(); + auto srcType = getSrc().getType().getElementType(); + if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && + (!getRounding().has_value())) { + return emitError("Rounding mode is required for FP downcast"); + } + return success(); +} + +//-- BroadcastOp -- +LogicalResult BroadcastOp::canonicalize(BroadcastOp op, + PatternRewriter &rewriter) { + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + auto value = adaptor.getSrc(); + if (!value) + return {}; + + if (auto denseElemsAttr = dyn_cast(value)) { + auto shapedType = cast(getType()); + return denseElemsAttr.resizeSplat(shapedType); + } + return {}; +} + +//-- MakeTensorPtrOp -- +void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ValueRange offsets, ArrayRef tensorShape, + ArrayRef order) { + // Get pointer type from `base` + auto pointerType = cast(base.getType()); + assert(pointerType != nullptr); + + // Build type `tt.ptr>` + auto tensorType = RankedTensorType::get( + SmallVector(tensorShape.begin(), tensorShape.end()), + pointerType.getPointeeType()); + auto result = PointerType::get(tensorType, 1); + + return build(builder, state, result, base, shape, strides, offsets, + builder.getDenseI32ArrayAttr(order)); +} + +// The following ops, including `call`, `func`, and `return` are copied and +// modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp +// We could revert it back once MLIR has a better inliner interface. +//-- FuncOp -- +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &printer) { + function_interface_impl::printFunctionOp( + printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +// -- CallOp -- +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this).getProperties().callee; + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +// -- ReturnOp -- +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); + + return success(); +} + +// -- JoinOp -- +LogicalResult +JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 2); + assert(operands[0].getType() == operands[1].getType()); + assert(isa(operands[0].getType())); + assert(isa(operands[1].getType())); + + Value lhs = operands[0]; + Value rhs = operands[1]; + auto srcTy = cast(lhs.getType()); + + SmallVector retShape(srcTy.getShape()); + retShape.push_back(2); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferJoinOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); + return success(); +} + +// -- SplitOp -- +LogicalResult SplitOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 1); + assert(isa(operands[0].getType())); + + Value src = operands[0]; + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + + if (srcShape.empty() || srcShape.back() != 2) { + return emitOptionalError(location, + "last dimension of input tensor must be 2"); + } + ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferSplitOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); + inferredReturnTypes.push_back(retTy); + inferredReturnTypes.push_back(retTy); + return success(); +} + +// -- ElementwiseInlineAsmOp -- +void ElementwiseInlineAsmOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +LogicalResult ElementwiseInlineAsmOp::verify() { + if (getNumOperands() >= 1) { + auto tensorType = dyn_cast(getOperand(0).getType()); + size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; + if (numInputElems % this->getPackedElement() != 0) { + return emitError("number of input elements ") + << numInputElems + << " must be a multiple of the op's packed_element attribute, " + << getPackedElement(); + } + } + return success(); +} + +// -- ExternElementwiseOp -- +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Traits.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 000000000..19729aee5 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,239 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +#include + +#include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +static LogicalResult verifySameEncoding(Type typeA, Type typeB, + bool allowTensorPointerType) { + // TODO(Keren): the allowTensorPointerType argument is a hack to allow. + // The type checking code is kind of a mess with the current design. + auto getEncoding = [=](Type type) -> Attribute { + Attribute ret; + if (auto tensorType = dyn_cast(type)) { + ret = tensorType.getEncoding(); + } + if (!allowTensorPointerType) { + assert(!triton::isTensorPointerType(type)); + } + return ret; + }; + auto encodingA = getEncoding(typeA); + auto encodingB = getEncoding(typeB); + if (!encodingA || !encodingB) + return success(); + return encodingA == encodingB ? success() : failure(); +} + +LogicalResult +OpTrait::impl::verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) + return op->emitOpError() << "requires the same encoding for all operands"; + + return success(); +} + +LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( + Operation *op, bool allowTensorPointerType) { + if (op->getNumOperands() == 0) + return success(); + + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto resultType : op->getResultTypes()) + if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) + return op->emitOpError() + << "requires the same encoding for all operands and results"; + + return verifySameOperandsEncoding(op, allowTensorPointerType); +} + +LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + if ((numElements & (numElements - 1)) != 0) + return op->emitError("Number of elements must be power-of-two, but ") + << *op << " doesn't follow the rule (" << numElements << ")" + << " elements"; + } + } + return success(); +} + +// Check that the Triton layouts on op's operands and return types are valid. +// For example, we check that the number of warps per block in a Triton GPU +// blocked layout matches that of its module. +// +// It's a little weird to check these properties of a layout only when the +// layout is used in an op, since most of the properties don't actually depend +// on the op. They do depend on the *module*, though, and a layout is attached +// to a module only by virtue of being used in one of the module's ops. +LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { + auto module = op->getParentOfType(); + auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { + // Only ranked tensors can have layouts. + auto rankedTy = dyn_cast(val.getType()); + if (!rankedTy) + return success(); + + mlir::Attribute layout = rankedTy.getEncoding(); + if (!layout) + return success(); + + if (isa(layout)) + return makeErr() << "Shared layout is not allowed on tensor type."; + // TODO(jlebar): Currently this only checks blocked layouts, but other + // layouts also have invariants! + + // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. + if (auto blocked = dyn_cast(layout)) { + // A different verifier should have checked that the layout itself is + // valid, including that threads-per-warp has the same rank as + // warps-per-block etc. + auto layoutRank = blocked.getThreadsPerWarp().size(); + if (layoutRank != rankedTy.getRank()) { + return makeErr() << layout << ".\nLayout has rank " << layoutRank + << ", but the tensor it's attached to has rank " + << rankedTy.getRank() << "."; + } + + int moduleThreadsPerWarp = + ttg::TritonGPUDialect::getThreadsPerWarp(module); + int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); + if (layoutThreadsPerWarp != moduleThreadsPerWarp) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutThreadsPerWarp + << " threads per warp, but the module specifies " + << moduleThreadsPerWarp << " threads per warp."; + } + + int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module); + int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); + if (layoutWarpsPerCTA != moduleWarpsPerCTA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutWarpsPerCTA + << " warps per CTA, but the module specifies " + << moduleWarpsPerCTA << " warps per CTA."; + } + + if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { + int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module); + int64_t layoutCTAsPerCGA = + product(blocked.getCTALayout().getCTAsPerCGA()); + if (layoutCTAsPerCGA != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutCTAsPerCGA + << " CTAs per CGA, but the module specifies " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + } + } + + return success(); + }; + + for (size_t i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + auto err = checkLayout(operand, [&]() { + // Stringify the operand using `printAsOperand`. This prints e.g. "%42" + // rather than the full definition. + std::string operandStr; + llvm::raw_string_ostream os(operandStr); + // If we don't assume verified, dump() will recursively call this + // function! + operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); + + return op->emitError("Operand ") + << i << " (" << operand << ") has an invalid layout: "; + }); + if (!err.succeeded()) + return err; + } + + for (size_t i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + auto err = checkLayout(result, [&]() { + if (op->getNumResults() == 1) { + return op->emitError("Result has an invalid layout: "); + } else { + return op->emitError("Result ") << i << " has an invalid layout: "; + } + }); + if (!err.succeeded()) + return err; + } + + return success(); +} + +static ArrayRef getTypeShape(Type type) { + auto rankedType = dyn_cast(type); + if (auto ptrType = dyn_cast(type)) + rankedType = dyn_cast(ptrType.getPointeeType()); + return rankedType ? rankedType.getShape() : ArrayRef(); +} + +LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() << "requires the same shape for all operands"; + + return success(); +} + +LogicalResult +OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : op->getResultTypes()) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + + return verifySameLoadStoreOperandsShape(op); +} diff --git a/third_party/mthreads/lib/Dialect/Triton/IR/Types.cpp b/third_party/mthreads/lib/Dialect/Triton/IR/Types.cpp new file mode 100644 index 000000000..0e1df5b74 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/IR/Types.cpp @@ -0,0 +1,171 @@ +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void TritonDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + >(); +} + +Type PointerType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + int addressSpace = 1; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseInteger(addressSpace)) + return Type(); + } + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType, addressSpace); +} + +void PointerType::print(AsmPrinter &printer) const { + if (getAddressSpace() == 1) { + printer << "<" << getPointeeType() << ">"; + } else { + printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; + } +} + +static constexpr llvm::StringRef kMutableMemory = "mutable"; + +Type MemDescType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false)) + return Type(); + + // Parse the element type. + Type elementType; + if (parser.parseType(elementType)) + return Type(); + + Attribute encoding; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(encoding)) + return Type(); + } + bool mutableMemory = false; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOptionalKeyword(kMutableMemory)) + return Type(); + mutableMemory = true; + } + if (parser.parseGreater()) + return Type(); + + return MemDescType::get(parser.getContext(), dimensions, elementType, + encoding, mutableMemory); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + for (auto dim : getShape()) + printer << dim << "x"; + printer << getElementType(); + if (getEncoding()) + printer << ", " << getEncoding(); + if (getMutableMemory()) + printer << ", " << kMutableMemory; + printer << ">"; +} + +namespace mlir { + +namespace triton { + +unsigned getPointeeBitWidth(Type type) { + auto pointeeType = getPointeeType(type); + if (auto tensorTy = dyn_cast(pointeeType)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + return pointeeType.getIntOrFloatBitWidth(); +} + +Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i1Type, + tensorTy.getEncoding()); + return i1Type; +} + +Type getPointeeType(Type type) { + if (auto tensorTy = dyn_cast(type)) { + // Tensor of pointers + auto shape = tensorTy.getShape(); + auto ptrType = dyn_cast(tensorTy.getElementType()); + Type pointeeType = ptrType.getPointeeType(); + return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return type; +} + +Type getI32SameShape(Type type) { + auto i32Type = IntegerType::get(type.getContext(), 32); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i32Type, + tensorTy.getEncoding()); + return i32Type; +} + +Type getPointerTypeSameShape(Type type) { + if (auto tensorTy = dyn_cast(type)) { + Type elementType = tensorTy.getElementType(); + auto shape = tensorTy.getShape(); + PointerType ptrType = PointerType::get(elementType, 1); + return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding()); + } else { + return PointerType::get(type, 1); + } +} + +Type getPointerType(Type type) { return PointerType::get(type, 1); } + +bool isTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + return isa(ptrType.getPointeeType()); + return false; +} + +bool isTensorOrTensorPointerType(Type type) { + return isa(type) || isTensorPointerType(type); +} + +Type getElementTypeOfTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) + return tensorTy.getElementType(); + return {}; +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/mthreads/lib/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..298398750 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,18 @@ +set(LLVM_TARGET_DEFINITIONS Combine.td) +mlir_tablegen(TritonCombine.inc -gen-rewriters) +add_public_tablegen_target(TritonCombineIncGen) + +add_triton_library(TritonTransforms + Combine.cpp + ReorderBroadcast.cpp + RewriteTensorPointer.cpp + + DEPENDS + TritonTransformsIncGen + TritonCombineIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTransformUtils + TritonIR +) diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.cpp new file mode 100644 index 000000000..c5d638754 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.cpp @@ -0,0 +1,255 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace mlir::triton { +namespace { + +bool isZero(Value val) { + if (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())) + return true; + // broadcast(constant_0) + if (auto bc = val.getDefiningOp()) { + if (matchPattern(bc.getSrc(), m_Zero()) || + matchPattern(bc.getSrc(), m_AnyZeroFloat())) + return true; + } + return false; +} + +bool isBroadcastConstantCombinable(Attribute value) { + if (auto denseValue = dyn_cast(value)) { + return denseValue.isSplat(); + } + return isa(value); +} + +DenseElementsAttr getConstantValue(Builder &builder, Attribute value, + Value bcast_res) { + auto resType = cast(bcast_res.getType()); + DenseElementsAttr res; + if (auto denseValue = dyn_cast(value)) { + res = + DenseElementsAttr::get(resType, denseValue.getSplatValue()); + } else { + res = DenseElementsAttr::get(resType, value); + } + return res; +} + +bool isAddPtrOffsetCombinable(Value first, Value second) { + auto GetConstantIntValue = [](Value val) -> std::optional { + DenseElementsAttr constAttr; + auto defOp = val.getDefiningOp(); + if (defOp) { + if (auto splatOp = llvm::dyn_cast(defOp)) + val = splatOp.getSrc(); + else if (matchPattern(defOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto attr = constAttr.getSplatValue(); + // Check IntegerAttr + if (auto intAttr = dyn_cast_or_null(attr)) + return intAttr.getValue(); + } + } + + // Check constant value. + llvm::APInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal; + + return std::nullopt; + }; + + if (first.getType() == second.getType()) { + // Whether bitwidth of element type is equal to pointer + if (getElementTypeOrSelf(first.getType()).getIntOrFloatBitWidth() == 64) + return true; + + // first + second does not overflow + auto firstVal = GetConstantIntValue(first); + auto secondVal = GetConstantIntValue(second); + if (firstVal && secondVal) { + bool overflow = false; + auto resVal = firstVal->sadd_ov(*secondVal, overflow); + return !overflow; + } + } + return false; +} + +// TODO(csigg): remove after next LLVM integrate. +using FastMathFlags = arith::FastMathFlags; + +#include "TritonCombine.inc" + +// select(cond, load(ptrs, splat(cond), ???), other) +// => load(ptrs, splat(cond), other) +class CombineSelectMaskedLoadPattern : public RewritePattern { +public: + CombineSelectMaskedLoadPattern(MLIRContext *context) + : RewritePattern(arith::SelectOp::getOperationName(), 3, context, + {LoadOp::getOperationName()}) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto selectOp = llvm::dyn_cast(op); + if (!selectOp) + return failure(); + + Value trueValue = selectOp.getTrueValue(); + Value falseValue = selectOp.getFalseValue(); + Value condSelect = selectOp.getCondition(); + + auto *loadOpCandidate = trueValue.getDefiningOp(); + auto loadOp = llvm::dyn_cast_or_null(loadOpCandidate); + if (!loadOp) + return failure(); + + Value mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto *splatOpCandidate = mask.getDefiningOp(); + auto splatOp = llvm::dyn_cast_or_null(splatOpCandidate); + if (!splatOp) + return failure(); + + auto splatCond = splatOp.getSrc(); + if (splatCond != condSelect) + return failure(); + + rewriter.replaceOpWithNewOp( + op, loadOp.getPtr(), loadOp.getMask(), /*other=*/falseValue, + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + return success(); + } +}; + +// sum(x[:, :, None] * y[None, :, :], 1) +// -> dot(x, y) +class CombineBroadcastMulReducePattern : public RewritePattern { +private: + static bool isAddF32(const Operation *op) { + if (auto addf = dyn_cast_or_null(op)) + return addf.getType().getIntOrFloatBitWidth() <= 32; + return false; + } + + static SmallVector getEqualIndices(ArrayRef x, + ArrayRef y) { + SmallVector res; + for (int i = 0; i < x.size(); ++i) + if (x[i] == y[i]) + res.push_back(i); + return res; + } + +public: + CombineBroadcastMulReducePattern(MLIRContext *context) + : RewritePattern(ReduceOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto reduceOp = llvm::dyn_cast(op); + if (!reduceOp) + return failure(); + // only support reduce with simple addition + Region &combineOp = reduceOp.getCombineOp(); + bool isReduceAdd = combineOp.hasOneBlock() && + combineOp.front().getOperations().size() == 2 && + isAddF32(&*combineOp.front().getOperations().begin()); + if (!isReduceAdd) + return failure(); + // operand of reduce has to be mul + auto mulOp = llvm::dyn_cast_or_null( + reduceOp.getOperand(0).getDefiningOp()); + if (!mulOp) + return failure(); + // mul operand has to be broadcast + auto broadcastLhsOp = llvm::dyn_cast_or_null( + mulOp.getOperand(0).getDefiningOp()); + if (!broadcastLhsOp) + return failure(); + auto broadcastRhsOp = llvm::dyn_cast_or_null( + mulOp.getOperand(1).getDefiningOp()); + if (!broadcastRhsOp) + return failure(); + // broadcast operand is expand dims + auto expandLhsOp = llvm::dyn_cast_or_null( + broadcastLhsOp.getSrc().getDefiningOp()); + if (!expandLhsOp) + return failure(); + auto expandRhsOp = llvm::dyn_cast_or_null( + broadcastRhsOp.getSrc().getDefiningOp()); + if (!expandRhsOp) + return failure(); + // get not-broadcast dimensions + int expandLhsAxis = expandLhsOp.getAxis(); + int expandRhsAxis = expandRhsOp.getAxis(); + if (expandLhsAxis != 2 || expandRhsAxis != 0) + return failure(); + auto broadcastLhsShape = + cast(broadcastLhsOp.getType()).getShape(); + auto broadcastRhsShape = + cast(broadcastLhsOp.getType()).getShape(); + if (broadcastLhsShape[2] < 16 || broadcastRhsShape[0] < 16) + return failure(); + Type newAccType = RankedTensorType::get( + {broadcastLhsShape[0], broadcastRhsShape[2]}, + cast(broadcastLhsOp.getSrc().getType()).getElementType()); + rewriter.setInsertionPoint(op); + auto newAcc = rewriter.create( + op->getLoc(), newAccType, + rewriter.create(op->getLoc(), + rewriter.getF32FloatAttr(0))); + rewriter.replaceOpWithNewOp(op, expandLhsOp.getSrc(), + expandRhsOp.getSrc(), newAcc, + InputPrecision::TF32, 0); + return success(); + } +}; + +class CombineOpsPass : public TritonCombineOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + // Dot Add %{ + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + // %} + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // anonymous namespace + +std::unique_ptr createCombineOpsPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.td b/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.td new file mode 100644 index 000000000..5a2fcecfa --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/Combine.td @@ -0,0 +1,54 @@ +#ifndef TRITON_PATTERNS +#define TRITON_PATTERNS + +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "triton/Dialect/Triton/IR/TritonOps.td" +include "mlir/IR/PatternBase.td" + + +// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) + +// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +def CombineDotAddIPattern : Pat< + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFPattern : Pat< + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +def CombineDotAddIRevPattern : Pat< + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFRevPattern : Pat< + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1)) +// Note: leave (sub %c0, %c0) canceling to ArithDialect +// (ref: ArithCanonicalization.td) +defvar DefOverflow = ConstantEnumCase; +def CombineAddPtrPattern : Pat< + (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), + (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)), + [(Constraint> $idx0, $idx1)]>; + +// broadcast(cst) => cst +def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; +def CombineBroadcastConstantPattern : Pat< + (TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)), + (Arith_ConstantOp (getConstantValue $value, $bcast_res), (location $bcast_res)), + [(Constraint> $value)]>; + +#endif diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp new file mode 100644 index 000000000..43479a3d9 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -0,0 +1,247 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +// TODO(jlebar): Move this and all other generatede code into namespace +// mlir::triton. +#define GEN_PASS_DEF_TRITONREORDERBROADCAST +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace mlir::triton { +namespace { + +Operation *cloneWithNewArgsAndResultTypes(PatternRewriter &rewriter, + Operation *op, ValueRange newOperands, + TypeRange newTypes) { + OperationState newElementwiseState(op->getLoc(), op->getName()); + newElementwiseState.addOperands(newOperands); + newElementwiseState.addTypes(newTypes); + newElementwiseState.addAttributes(op->getAttrs()); + return rewriter.create(newElementwiseState); +} + +bool isSplat(Operation *op) { + if (auto splatOp = llvm::dyn_cast(op)) { + return true; + } + DenseElementsAttr constAttr; + return (matchPattern(op, m_Constant(&constAttr)) && constAttr.isSplat()); +} + +// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) +struct MoveSplatAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveSplatAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult match(Operation *op) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + for (auto operand : op->getOperands()) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + + if (!isSplat(definingOp)) { + return failure(); + } + } + return success(op->getNumOperands() > 0); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto operands = op->getOperands(); + + llvm::SmallVector scalarOperands(operands.size()); + for (unsigned iOp = 0; iOp < operands.size(); ++iOp) { + auto definingOp = operands[iOp].getDefiningOp(); + + DenseElementsAttr constAttr; + if (auto splatOp = llvm::dyn_cast(definingOp)) { + scalarOperands[iOp] = splatOp.getSrc(); + } else if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto value = constAttr.getSplatValue(); + scalarOperands[iOp] = arith::ConstantOp::materialize( + rewriter, value, constAttr.getElementType(), loc); + } else { + llvm_unreachable("Expected a splat"); + } + } + + auto resultTypes = op->getResultTypes(); + llvm::SmallVector scalarResultTys; + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + scalarResultTys.push_back(elemTy); + } + + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, scalarOperands, + scalarResultTys); + + for (unsigned iRes = 0; iRes < resultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + } +}; + +// elementwise(broadcast(a)) => broadcast(elementwise(a)) +// This also generalizes to multiple arguments when the rest are splat-like +// Not handled: multiple broadcasted arguments +struct MoveBroadcastAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveBroadcastAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult match(Operation *op) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + auto operands = op->getOperands(); + bool seenBroadcast = false; + ArrayRef srcShape; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) { + return failure(); + } + auto getSrcShape = [](BroadcastOp b) { + return b.getSrc().getType().getShape(); + }; + if (auto broadcastOp = llvm::dyn_cast(definingOp)) { + if (!seenBroadcast) { + seenBroadcast = true; + srcShape = getSrcShape(broadcastOp); + } else if (srcShape != getSrcShape(broadcastOp)) { + // If the broadcast have different types we cannot re-order. + return failure(); + } + } else if (!isSplat(definingOp)) { + // Not splat or broadcast + return failure(); + } + } + return success(seenBroadcast); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Find broadcast op + auto operands = op->getOperands(); + BroadcastOp broadcastOp; + for (auto operand : operands) { + broadcastOp = operand.getDefiningOp(); + if (broadcastOp) { + break; + } + } + + auto srcTy = broadcastOp.getSrc().getType(); + auto srcShape = srcTy.getShape(); + auto srcEncoding = srcTy.getEncoding(); + + // Reshape operands to match srcShape + llvm::SmallVector newOperands; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (auto broadcastSrcOp = llvm::dyn_cast(definingOp)) { + newOperands.push_back(broadcastSrcOp.getSrc()); + continue; + } + auto elemTy = + dyn_cast(operand.getType()).getElementType(); + auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding); + if (auto splatOp = llvm::dyn_cast(definingOp)) { + auto newSplat = rewriter.create(loc, newTy, splatOp.getSrc()); + newOperands.push_back(newSplat); + continue; + } + DenseElementsAttr constAttr; + if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto scalarValue = constAttr.getSplatValue(); + auto splatValue = SplatElementsAttr::get(newTy, scalarValue); + auto newConstant = + rewriter.create(loc, newTy, splatValue); + newOperands.push_back(newConstant); + continue; + } + llvm_unreachable("Expected broadcast or splat"); + } + + // Reshape results to match srcShape + llvm::SmallVector newResultTypes; + auto resultTypes = op->getResultTypes(); + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + newResultTypes.push_back( + RankedTensorType::get(srcShape, elemTy, srcEncoding)); + } + + // Create new op and broadcast results + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, newOperands, + newResultTypes); + for (unsigned iRes = 0; iRes < newResultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + } +}; + +template +class CanonicalizePattern : public OpRewritePattern { +public: + explicit CanonicalizePattern(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + return OpType::canonicalize(op, rewriter); + } +}; + +class ReorderBroadcastPass + : public ::impl::TritonReorderBroadcastBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + patterns.add>(context); + patterns.add>(context); + // elementwise(broadcast(a)) => broadcast(elementwise(a)) + patterns.add(context); + // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr createReorderBroadcastPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/third_party/mthreads/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp new file mode 100644 index 000000000..52f4ba0b3 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -0,0 +1,572 @@ +#include +#include + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +/// An additional struct to record the meta information of operations +/// with tensor pointers +struct RewritedInfo { +private: + Value base; + SmallVector shape; + SmallVector strides; + SmallVector offsets; + ArrayRef tensorShape; + + // A cache to avoid generating the same offset with range + DenseMap cachedOffsetWithRange; + +public: + RewritedInfo() = default; + + RewritedInfo(const RewritedInfo &other) = default; + + RewritedInfo(Value base, const SmallVector &shape, + const SmallVector &strides, + const SmallVector &offsets, + const ArrayRef &tensorShape) + : base(base), shape(shape), strides(strides), offsets(offsets), + tensorShape(tensorShape) { + assert(shape.size() == strides.size() && shape.size() == offsets.size() && + shape.size() == tensorShape.size()); + } + + unsigned int length() const { return shape.size(); } + + Value getOffset(unsigned i) { return offsets[i]; } + + SmallVector getOffsets() { return offsets; } + + void setOffset(unsigned i, Value newOffset) { + offsets[i] = newOffset; + cachedOffsetWithRange.clear(); + } + + void setOffsets(const SmallVector &newOffsets) { + offsets = newOffsets; + cachedOffsetWithRange.clear(); + } + + Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + unsigned i) { + if (cachedOffsetWithRange.count(i)) + return cachedOffsetWithRange[i]; + + // Add range + auto indexI32RowType = + RankedTensorType::get({tensorShape[i]}, builder.getI32Type()); + auto indexRowType = + RankedTensorType::get({tensorShape[i]}, builder.getI64Type()); + Value splatOffset = + builder.create(loc, indexRowType, offsets[i]); + Value range = builder.create(loc, indexI32RowType, 0, + tensorShape[i]); + Value i64Range = builder.create(loc, indexRowType, range); + + // Expand dimensions + Value expandedResult = + builder.create(loc, splatOffset, i64Range); + for (int j = 0; j < tensorShape.size(); ++j) { + if (j == i) + continue; + expandedResult = + builder.create(loc, expandedResult, j); + } + + return cachedOffsetWithRange[i] = expandedResult; + } + + Value generatePtr(OpBuilder &builder, const Location &loc) { + assert(tensorShape.size() == offsets.size() && + tensorShape.size() == strides.size()); + auto indexTensorType = + RankedTensorType::get(tensorShape, builder.getI64Type()); + auto ptrType = cast(base.getType()); + auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType); + + // Generate offsets per dimension + Value ptr = builder.create(loc, ptrTensorType, base); + for (unsigned i = 0; i < tensorShape.size(); ++i) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = builder.create( + loc, offsetWithRange.getType(), strides[i]); + Value offsetWithStride = + builder.create(loc, offsetWithRange, splatStride); + Value broadcasted = builder.create( + loc, indexTensorType, offsetWithStride); + + // Add to the pointer + ptr = builder.create(loc, ptrTensorType, ptr, + broadcasted); + } + + return ptr; + } + + Value generateMask(OpBuilder &builder, const Location &loc, + const std::optional> &boundaryCheck) { + if (!boundaryCheck.has_value()) + return {}; + + // Generate mask per dimension + auto maskTensorType = + RankedTensorType::get(tensorShape, builder.getI1Type()); + Value mask; + for (auto i : boundaryCheck.value()) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // Compare with lower bound + Value lowerBound = builder.create( + loc, 0, builder.getI64Type()); + Value splatLowerBound = builder.create( + loc, offsetWithRange.getType(), lowerBound); + Value cmpLower = builder.create( + loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound); + + // Compare with upper bound + Value splatUpperBound = builder.create( + loc, offsetWithRange.getType(), shape[i]); + Value cmpUpper = builder.create( + loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound); + + // And and broadcast + Value andResult = builder.create(loc, cmpLower, cmpUpper); + Value broadcasted = + builder.create(loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = builder.create(loc, mask, broadcasted); + } + } + + return mask; + } + + Value generateOther(OpBuilder &builder, const Location &loc, + const std::optional &padding) { + if (!padding.has_value()) + return Value(); + + // Create element attribute + auto elementType = + cast(base.getType()).getPointeeType(); + auto otherTensorType = RankedTensorType::get(tensorShape, elementType); + + // Set zero padding value + TypedAttr attr = + elementType.isIntOrIndex() + ? cast(builder.getIntegerAttr(elementType, 0)) + : cast(builder.getFloatAttr(elementType, 0)); + + // Float NaN padding case + if (padding.value() == triton::PaddingOption::PAD_NAN) { + assert(!elementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + cast(attr).getValue().getSemantics()); + attr = builder.getFloatAttr(elementType, apNaN); + } + + // Create tensor + Value constant = builder.create(loc, attr); + return builder.create(loc, otherTensorType, constant); + } +}; + +} // namespace + +// TODO: this pass relies on assumptions of how block pointers are created and +// on pattern matches that walks the SSA links to find the base/strides. This is +// very fragile and to solve we should expose convert Ptr of tensor to a +// structure containins all values and not only offsets. +class RewriteTensorPointerPass + : public TritonRewriteTensorPointerBase { +private: + DenseMap rewritedInfo; + +public: + static bool needRewrite(Operation *op) { + return std::any_of(op->getOperands().begin(), op->getOperands().end(), + [](Value operand) { + return triton::isTensorPointerType(operand.getType()); + }); + } + + static SmallVector + generateNewOperands(const SmallVector &oldOperands, unsigned index, + const SmallVector &newValues) { + assert(index < oldOperands.size()); + SmallVector newOperands; + for (int i = 0; i < index; ++i) + newOperands.push_back(oldOperands[i]); + for (auto value : newValues) + newOperands.push_back(value); + for (auto i = index + 1; i < oldOperands.size(); ++i) + newOperands.push_back(oldOperands[i]); + return newOperands; + } + + Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, + triton::MakeTensorPtrOp op, + std::stack &eraser) { + // Save info for later use + auto ptrType = cast(op.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + + // Cast I32 offsets into I64 + SmallVector i64Offsets; + for (auto offset : op.getOffsets()) { + auto i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), offset); + i64Offsets.push_back(i64Offset); + } + + // Save information + rewritedInfo[op.getResult()] = + RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets, + tensorType.getShape()); + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteAdvanceOp(OpBuilder &builder, triton::AdvanceOp op, + std::stack &eraser) { + // Get info from previous results + assert(rewritedInfo.count(op.getPtr())); + auto info = rewritedInfo[op.getPtr()]; + + // Calculate new offsets + assert(info.length() == op.getOffsets().size()); + SmallVector newOffsets; + for (int i = 0; i < info.length(); ++i) { + Value i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), op.getOffsets()[i]); + Value newOffset = builder.create( + op.getLoc(), info.getOffset(i), i64Offset); + newOffsets.push_back(newOffset); + } + + // Save info for later use + info.setOffsets(newOffsets); + rewritedInfo[op.getResult()] = info; + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op, + std::stack &eraser) { + assert(isa(op) || isa(op)); + + // We only have to rewrite load/stores with tensor pointers + auto ptr = op->getOperand(0); + if (!triton::isTensorPointerType(ptr.getType())) + return nullptr; + + // Get info from previous results + assert(rewritedInfo.count(ptr)); + auto info = rewritedInfo[ptr]; + + // Load/store with tensor pointers implicitly will check the bound while + // accessing memory, so we should set `mask` and `other` (according to the + // padding). Also note that load with tensor pointers do not have `mask` and + // `other` while building IR from Python AST + std::optional> boundaryCheck; + if (auto loadOp = dyn_cast(op)) { + assert(!loadOp.getMask() && !loadOp.getOther()); + boundaryCheck = loadOp.getBoundaryCheck(); + } else if (auto storeOp = dyn_cast(op)) { + assert(!storeOp.getMask()); + boundaryCheck = storeOp.getBoundaryCheck(); + } + + // Generate new `ptr`, `mask` and `other` + auto newPtr = info.generatePtr(builder, op->getLoc()); + auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); + Value newOther; + if (auto loadOp = dyn_cast(op)) + newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding()); + + // Create a new operation + if (auto loadOp = dyn_cast(op)) { + auto newResult = builder.create( + loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + op->getResult(0).replaceAllUsesWith(newResult); + } else if (auto storeOp = dyn_cast(op)) { + builder.create(storeOp.getLoc(), newPtr, + storeOp.getValue(), newMask, + storeOp.getCache(), storeOp.getEvict()); + } + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteIfOp(OpBuilder &builder, scf::IfOp op, + std::stack &eraser) { + auto thenYieldOp = op.thenYield(); + assert(op.getNumResults() == thenYieldOp.getNumOperands()); + SmallVector results = thenYieldOp.getOperands(); + + // get new result types + SmallVector newRetTypes; + bool needRewrite = false; + for (unsigned i = 0; i < results.size(); ++i) { + if (!triton::isTensorPointerType(results[i].getType())) { + newRetTypes.push_back(results[i].getType()); + continue; + } + needRewrite = true; + auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + newRetTypes.push_back(builder.getI64Type()); + } + } + if (!needRewrite) + return op; + // create and clone new IfOp + bool hasElse = !op.getElseRegion().empty(); + scf::IfOp newOp = builder.create(op.getLoc(), newRetTypes, + op.getCondition(), hasElse); + IRMapping mapping; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + mapping.map(op->getOperand(i), newOp->getOperand(i)); + } + auto rematerialize = [&](Block *block) { + for (Operation &opInIf : block->getOperations()) { + auto newOp = builder.clone(opInIf, mapping); + } + }; + builder.setInsertionPointToStart(newOp.thenBlock()); + rematerialize(op.thenBlock()); + if (hasElse) { + builder.setInsertionPointToStart(newOp.elseBlock()); + rematerialize(op.elseBlock()); + } + + // update rewritedInfo + unsigned oldResIdx = 0, newResIdx = 0; + while (oldResIdx < results.size()) { + if (!triton::isTensorPointerType(results[oldResIdx].getType())) { + oldResIdx++; + newResIdx++; + } else { + auto makeTensorPtrOp = getMakeTensorPtrOp(results[oldResIdx]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + info.setOffset(j, newOp->getResult(newResIdx++)); + } + rewritedInfo[op.getResult(oldResIdx)] = info; + oldResIdx++; + } + } + + eraser.push(op); + return newOp; + } + + Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, + std::stack &eraser) { + // Generate new iteration operands and set rewrited information + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; + ++i, ++oldI) { + if (!triton::isTensorPointerType(newIterOperands[i].getType())) + continue; + + // Expand the tensor pointer into offsets + assert(rewritedInfo.count(newIterOperands[i])); + auto info = rewritedInfo[newIterOperands[i]]; + newIterOperands = + generateNewOperands(newIterOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + + // Rebuild the loop type + auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), + op.getUpperBound(), op.getStep(), + newIterOperands); + + // Create value mapping. Note that for tensor pointers, we use identity + // mapping. It may refer to a value in the old loop, but we will rewrite it + // later + IRMapping mapping; + for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; + ++i, ++oldI) { + auto oldRegionIterArg = op.getRegionIterArg(oldI); + if (triton::isTensorPointerType(oldRegionIterArg.getType())) { + // Pass rewrited info inside + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + mapping.map(oldRegionIterArg, oldRegionIterArg); + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getRegionIterArg(i + j)); + rewritedInfo[oldRegionIterArg] = info; + i += info.length() - 1; + } else { + mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i)); + } + } + mapping.map(op.getInductionVar(), newForOp.getInductionVar()); + + // Clone body + builder.setInsertionPointToStart(newForOp.getBody()); + for (auto &opInFor : *op.getBody()) { + auto *newOp = builder.clone(opInFor, mapping); + for (unsigned i = 0; i < opInFor.getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + + // Replace later usages + assert(op.getNumResults() == op.getInitArgs().size()); + for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { + auto oldResult = op.getResult(oldI); + if (triton::isTensorPointerType(oldResult.getType())) { + // Pack new offsets into rewrited info + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getResult(i + j)); + i += info.length() - 1; + rewritedInfo[oldResult] = info; + } else { + oldResult.replaceAllUsesWith(newForOp.getResult(i)); + } + } + + // Erase later + eraser.push(op); + return newForOp; + } + + Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op, + std::stack &eraser) { + // Replace tensor pointers with offsets + SmallVector newOperands = op->getOperands(); + for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { + if (!triton::isTensorPointerType(newOperands[i].getType())) + continue; + + assert(rewritedInfo.count(newOperands[i])); + auto info = rewritedInfo[newOperands[i]]; + newOperands = generateNewOperands(newOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + op->setOperands(newOperands); + + // No need to erase + return nullptr; + } + + Operation *rewriteOp(Operation *op, std::stack &eraser) { + OpBuilder builder(op); + + // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers + // Rewriting functions return the next operation to visit, if there is no + // next one, simply return `nullptr` + std::pair rewrited; + if (auto makeTensorPtrOp = dyn_cast(op)) { + return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser); + } else if (auto advanceOp = dyn_cast(op)) { + return rewriteAdvanceOp(builder, advanceOp, eraser); + } else if (isa(op) || isa(op)) { + return rewriteLoadStoreOp(builder, op, eraser); + } else if (op->getDialect()->getNamespace() == "scf" || + op->getDialect()->getNamespace() == "cf") { + if (auto ifOp = dyn_cast(op)) { + return rewriteIfOp(builder, ifOp, eraser); + } + if (!needRewrite(op)) + return op; + + if (auto forOp = dyn_cast(op)) { + return rewriteForOp(builder, forOp, eraser); + } else if (auto yieldOp = dyn_cast(op)) { + return rewriteYieldOp(builder, yieldOp, eraser); + } else { + llvm_unreachable("Currently we only support tensor pointer usages " + "inside a `scf::ForOp` or `scf::IfOp`, others such as " + "`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` " + "are not supported yet"); + } + } + + // Otherwise return the original one + return op; + } + + void visitOperation(Operation *op, std::stack &eraser) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + // We need an extra copy because erasing operations may break the + // iterator behavior + SmallVector blockCopy; + for (auto &nestedOp : block) + blockCopy.push_back(&nestedOp); + + // Rewrite and recursively visit + for (auto &nestedOp : blockCopy) { + if (auto newOp = rewriteOp(nestedOp, eraser)) + visitOperation(newOp, eraser); + } + } + } + } + + void runOnOperation() override { + // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because + // MLIR does not support one-multiple value mapping. For example, if we use + // `ConversionPatternRewriter`, we can not make a type converter, which + // converts `ptr` into multiple types `ptr<>, int64, int64, ...` + // (containing the base/offsets/strides...). What we can do is to convert + // `ptr` into a single type `Tuple, int64, int64, ...>`. But + // in this way, we also have to define `PackTuple` and `UnpackTuple` + // operations and make a canonicalization pass to optimize, which is much + // So here we recursively build the IR, to be specific, we have to rewrite + // `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, `tt.store`, + // `scf.for` (tensor pointer usages may be in a loop fashion) + std::stack eraser; + visitOperation(getOperation(), eraser); + + // The operation could not be erased during visit, because they may have + // later usages, so we erase after visit + rewritedInfo.clear(); + while (!eraser.empty()) { + auto op = eraser.top(); + eraser.pop(); + op->erase(); + } + } +}; + +std::unique_ptr triton::createRewriteTensorPointerPass() { + return std::make_unique(); +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..b5dcdb5ea --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_triton_library(TritonGPUIR + Dialect.cpp + LinearLayoutConversions.cpp + Types.cpp + + DEPENDS + TritonGPUTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRGPUDialect + TritonIR + TritonTools +) diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/Dialect.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Dialect.cpp new file mode 100644 index 000000000..69067b706 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -0,0 +1,2977 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/TypeSwitch.h" + +// Include TableGen'erated code +#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Utility +namespace mlir { +namespace triton { + +static Type getI1SameShapeFromTensorOrTensorPtr(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorType = dyn_cast(type)) { + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + Type pointeeType = ptrType.getPointeeType(); + if (auto tensorType = dyn_cast(pointeeType)) { + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); + } + } + return Type(); +} + +namespace gpu { + +// TODO: Inheritance of layout attributes +// so that all distributed layouts implement +// these utilities + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, + Type eltTy) { + if (auto tritonGPUAttr = mlir::dyn_cast(layout)) { + return tritonGPUAttr.getTotalElemsPerThread(shape, eltTy); + } else { + llvm::report_fatal_error("getTotalElemsPerThread not implemented"); + return 0; + } +} + +SmallVector getElemsPerThread(Attribute layout, + ArrayRef shape, Type eltTy) { + if (auto tritonGPUAttr = mlir::dyn_cast(layout)) { + return tritonGPUAttr.getElemsPerThread(shape, eltTy); + } else { + llvm::report_fatal_error("getElemsPerThread not implemented"); + return SmallVector(); + } +} + +SmallVector getElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return SmallVector(1, 1); + auto tensorType = cast(type); + return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape(), + tensorType.getElementType()); +} + +unsigned getTotalElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return 1; + auto tensorType = cast(type); + return getTotalElemsPerThread(tensorType.getEncoding(), tensorType.getShape(), + tensorType.getElementType()); +} + +SmallVector getThreadsPerWarp(Attribute layout) { + if (auto distributedLayout = dyn_cast(layout)) { + return distributedLayout.getThreadsPerWarp(); + } else { + llvm::report_fatal_error("getThreadsPerWarp not implemented"); + return SmallVector(); + } +} + +unsigned getWarpSize(Attribute layout) { + unsigned size = 1; + auto threadsPerWarp = getThreadsPerWarp(layout); + for (auto e : threadsPerWarp) { + size *= e; + } + return size; +} + +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape) { + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentThreadsPerWarp = + getThreadsPerWarpWithUniqueData(parentLayout, parentShape); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + return threadsPerWarp; + } + auto threadsPerWarp = getThreadsPerWarp(layout); + assert(threadsPerWarp.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < threadsPerWarp.size(); i++) { + threadsPerWarp[i] = std::min(threadsPerWarp[i], tensorShape[i]); + } + + return threadsPerWarp; +} + +SmallVector getWarpsPerCTA(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getWarpsPerCTA(); + } + + llvm::report_fatal_error("getWarpsPerCTA not implemented"); + return SmallVector(); +} + +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape) { + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentWarpsPerCTA = + getWarpsPerCTAWithUniqueData(parentLayout, parentShape); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); + return warpsPerCTA; + } + auto warpsPerCTA = getWarpsPerCTA(layout); + assert(warpsPerCTA.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < warpsPerCTA.size(); i++) { + auto sizePerWarp = + getSizePerThread(layout)[i] * getThreadsPerWarp(layout)[i]; + auto maxWarpsPerDim = ceil(tensorShape[i], sizePerWarp); + warpsPerCTA[i] = std::min(warpsPerCTA[i], maxWarpsPerDim); + } + + return warpsPerCTA; +} + +SmallVector getSizePerThread(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getSizePerThread(); + } else { + llvm::report_fatal_error("getSizePerThread not implemented"); + return {}; + } +} + +SmallVector getContigPerThread(Attribute layout) { + if (auto distributedLayout = dyn_cast(layout)) { + return distributedLayout.getContigPerThread(); + } else { + llvm::report_fatal_error("getContigPerThread not implemented"); + return {}; + } +} + +SmallVector getUniqueContigPerThread(Attribute layout, + ArrayRef shape) { + // If slice layout, call recursively on parent layout, and drop + // sliced dim + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(shape); + auto parentUniqueContigPerThread = + getUniqueContigPerThread(parentLayout, parentShape); + parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() + + sliceLayout.getDim()); + return parentUniqueContigPerThread; + } + // Base case + auto rank = shape.size(); + SmallVector ret(rank); + auto contigPerThread = getContigPerThread(layout); + assert(contigPerThread.size() == rank && "Unexpected contigPerThread size"); + for (int d = 0; d < rank; ++d) { + ret[d] = std::min(shape[d], contigPerThread[d]); + } + return ret; +} + +SmallVector getShapePerCTATile(Attribute layout, + ArrayRef tensorShape) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getShapePerCTATile(tensorShape); + } else { + llvm::report_fatal_error("getShapePerCTATile not implemented"); + return SmallVector(); + } +} + +bool isExpensiveView(Type srcType, Type dstType) { + return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); +} + +/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. + * Erase dim and decrease all values larger than dim by 1. + * Example: order = [0, 2, 4, 3, 1], dim = 2 + * resOrder = [0, 3, 2, 1] + */ +static SmallVector eraseOrder(ArrayRef order, + unsigned dim) { + unsigned rank = order.size(); + assert(dim < rank && "Invalid dim to erase"); + SmallVector resOrder; + for (unsigned i : order) + if (i < dim) + resOrder.push_back(i); + else if (i > dim) + resOrder.push_back(i - 1); + return resOrder; +} + +SmallVector getWarpOrder(Attribute layout) { + auto order = getOrder(layout); + if (auto mmaLayout = dyn_cast(layout)) { + if (mmaLayout.isHopper()) { + // Hopper MMA instructions force a warp order of [0, 1]. See docs: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 + auto it = std::find(order.begin(), order.end(), 0); + order.erase(it); + order.insert(order.begin(), 0); + } + } + return order; +} + +SmallVector getOrder(Attribute layout) { + if (auto blockedLayout = dyn_cast(layout)) { + return SmallVector(blockedLayout.getOrder().begin(), + blockedLayout.getOrder().end()); + } else if (auto mmaLayout = dyn_cast(layout)) { + auto distributedLayout = cast(layout); + auto rank = distributedLayout.getWarpsPerCTA().size(); + SmallVector order(rank); + for (auto i = 0; i < rank; ++i) + order[i] = rank - 1 - i; + if (auto mfmaLayout = dyn_cast(layout)) { + if (mfmaLayout.getIsTransposed()) { + std::swap(order[rank - 2], order[rank - 1]); + } + } + return order; + } else if (auto dotLayout = dyn_cast(layout)) { + auto rank = getWarpsPerCTA(dotLayout.getParent()).size(); + SmallVector order(rank); + for (auto i = 0; i < rank; ++i) + order[i] = rank - 1 - i; + return order; + } else if (auto sliceLayout = dyn_cast(layout)) { + SmallVector parentOrder = getOrder(sliceLayout.getParent()); + unsigned dim = sliceLayout.getDim(); + SmallVector order; + for (unsigned d : parentOrder) { + if (d == dim) + continue; + else if (d > dim) + order.push_back(d - 1); + else + order.push_back(d); + } + return order; + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + return SmallVector(sharedLayout.getOrder().begin(), + sharedLayout.getOrder().end()); + } else { + llvm::report_fatal_error("Unimplemented usage of getOrder"); + } + return {}; +}; + +CTALayoutAttr getCTALayout(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return CTALayoutAttr::get( + layout.getContext(), getCTAsPerCGA(distributedLayout), + getCTASplitNum(distributedLayout), getCTAOrder(distributedLayout)); + } else if (auto sharedLayout = mlir::dyn_cast(layout)) + return sharedLayout.getCTALayout(); + else + llvm::report_fatal_error("Unimplemented usage of getCTALayout"); + return {}; +} + +SmallVector getCTAsPerCGA(Attribute layout) { + ArrayRef ref; + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getCTAsPerCGA(); + else if (mlir::isa(layout)) + return {1, 1}; + else if (auto sharedLayout = mlir::dyn_cast(layout)) + ref = sharedLayout.getCTALayout().getCTAsPerCGA(); + else + llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); + return SmallVector(ref.begin(), ref.end()); +} + +SmallVector getCTASplitNum(Attribute layout) { + SmallVector res; + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getCTASplitNum(); + } else if (mlir::isa(layout)) { + res.resize(2); + res[0] = res[1] = 1; + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + res.assign(sharedLayout.getCTALayout().getCTASplitNum().begin(), + sharedLayout.getCTALayout().getCTASplitNum().end()); + } else { + assert(false && "Unimplemented usage of getCTASplitNum"); + } + return res; +} + +SmallVector getCTAOrder(Attribute layout) { + SmallVector res; + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + res = distributedLayout.getCTAOrder(); + } else if (mlir::isa(layout)) { + return {0, 1}; + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + res = SmallVector(sharedLayout.getCTALayout().getCTAOrder()); + } else { + llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); + } + return res; +} + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape) { + unsigned rank = shape.size(); + SmallVector shapePerCTA(rank); + for (unsigned i = 0; i < rank; ++i) { + // This wrapping rule must be consistent with emitCTAOffsetForLayout + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + shapePerCTA[i] = shape[i] / splitNum; + } + return shapePerCTA; +} + +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { + if (auto sharedLayout = mlir::dyn_cast(layout)) { + // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. + // The first dim of shape is numStages. This is a work around, otherwise too + // many places would have to be modified in pipeline pass. Maybe we need to + // refactor this logic in the future. + auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); + if (shape.size() == CTASplitNum.size() + 1) { + auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); + res.insert(res.begin(), shape.front()); + return res; + } + } + return getShapePerCTA(getCTASplitNum(layout), shape); +} + +SmallVector getShapePerCTA(Type type) { + auto tensorType = cast(type); + return getShapePerCTA(tensorType.getEncoding(), tensorType.getShape()); +} + +unsigned getNumWarpsPerCTA(Attribute layout) { + SmallVector warpsPerCTA; + if (auto blockedLayout = dyn_cast(layout)) + warpsPerCTA = blockedLayout.getWarpsPerCTA(); + else if (auto sliceLayout = dyn_cast(layout)) + return getNumWarpsPerCTA(sliceLayout.getParent()); + else if (auto mmaLayout = dyn_cast(layout)) { + // Use the distributed layout interface to get the number of warps per CTA. + auto distributedLayout = cast(layout); + warpsPerCTA = distributedLayout.getWarpsPerCTA(); + } else if (auto mfmaLayout = dyn_cast(layout)) + warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + else if (auto wmmaLayout = dyn_cast(layout)) + warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + else if (auto dotLayout = dyn_cast(layout)) + return getNumWarpsPerCTA(dotLayout.getParent()); + else if (auto sharedLayout = dyn_cast(layout)) + llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); + else + llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA"); + return product(warpsPerCTA); +} + +unsigned getNumCTAs(Attribute layout) { + return product(getCTAsPerCGA(layout)); +} + +bool isaDistributedLayout(Attribute layout) { + return isa(layout); +} + +template bool hasEncoding(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) { + auto encoding = tensorType.getEncoding(); + return encoding && isa(encoding); + } + return false; +} + +bool hasDotOperandEncoding(Value value) { + return hasEncoding(value); +} + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { + // If the new elements per thread is less than the old one, we will need to do + // convert encoding that goes through shared memory anyway. So we consider it + // as expensive. + RankedTensorType tensorTy = cat.getType(); + auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); + auto shape = tensorTy.getShape(); + auto elemTy = tensorTy.getElementType(); + auto newTotalElemsPerThread = + gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy); + return newTotalElemsPerThread < totalElemsPerThread; +} + +LogicalResult CTALayoutAttr::verify( + function_ref emitError, ArrayRef CTAsPerCGA, + ArrayRef CTASplitNum, ArrayRef CTAOrder) { + if (CTAsPerCGA.size() != CTASplitNum.size() || + CTASplitNum.size() != CTAOrder.size()) { + return emitError() << "CTAsPerCGA, CTASplitNum, and CTAOrder must all have " + "the same rank."; + } + + if (!isPermutationOfIota(CTAOrder)) { + return emitError() + << "CTAOrder must be a permutation of 0..(rank-1), but was [" + << CTAOrder << "]"; + } + + return success(); +} + +LogicalResult +BlockedEncodingAttr::verify(function_ref emitError, + ArrayRef sizePerThread, + ArrayRef threadsPerWarp, + ArrayRef warpsPerCTA, + ArrayRef order, CTALayoutAttr CTALayout) { + if (sizePerThread.size() != threadsPerWarp.size() || + threadsPerWarp.size() != warpsPerCTA.size() || + warpsPerCTA.size() != order.size()) { + return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and " + "order must all have the same rank."; + } + + // Empty CTALayout is allowed, but if it's present its rank must match the + // BlockedEncodingAttr's rank. + if (CTALayout.getCTASplitNum().size() != 0 && + sizePerThread.size() != CTALayout.getCTASplitNum().size()) { + return emitError() << "BlockedEncodingAttr and CTALayout's fields must " + "have the same rank."; + } + if (!isPermutationOfIota(order)) { + return emitError() + << "order must be a permutation of 0..(rank-1), but was [" << order + << "]"; + } + return success(); +} + +// 1 element per thread +// order = reverse(arange(rank)) +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs) { + int rank = shape.size(); + llvm::SmallVector order(rank); + std::iota(order.begin(), order.end(), 0); + std::reverse(order.begin(), order.end()); + llvm::SmallVector sizePerThread(rank, 1); + triton::gpu::BlockedEncodingAttr encoding = + triton::gpu::BlockedEncodingAttr::get(context, shape, sizePerThread, + order, numWarps, threadsPerWarp, + numCTAs); + return encoding; +} + +} // namespace gpu +} // namespace triton +} // namespace mlir + +static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, + unsigned &value, StringRef desc) { + auto intAttr = mlir::dyn_cast(attr); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer type in ") + << desc; + return failure(); + } + if (intAttr.getType().isSignedInteger()) { + int64_t attrVal = intAttr.getSInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else if (intAttr.getType().isSignlessInteger()) { + int64_t attrVal = intAttr.getInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else { + value = intAttr.getUInt(); + } + return success(); +} + +static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, + bool &value, StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + +// parse an array of integers +static LogicalResult parseIntArrayAttr(AsmParser &parser, + const NamedAttribute &attr, + SmallVector &res, + StringRef desc) { + auto arrayAttr = mlir::dyn_cast(attr.getValue()); + if (!arrayAttr) { + parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; + return failure(); + } + for (Attribute i : arrayAttr) { + unsigned value; + if (parseIntAttrValue(parser, i, value, desc).failed()) + return failure(); + res.push_back(value); + } + return success(); +}; + +static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, + unsigned &value, StringRef desc) { + return parseIntAttrValue(parser, attr.getValue(), value, desc); +}; + +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr.getValue(), value, desc); +}; + +// Print the CTALayout if it's not equal to the default. +static void maybePrintCTALayout(mlir::MLIRContext *context, + mlir::AsmPrinter &printer, CTALayoutAttr layout, + unsigned rank) { + if (layout != CTALayoutAttr::getDefault(context, rank)) { + printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]" + << ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]" + << ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]"; + } +} + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" + +SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) { + return SliceEncodingAttr::get(getContext(), axis, *this); +} +SmallVector +BlockedEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + auto sizePerThread = getSizePerThread(); + auto warpsPerCTA = getWarpsPerCTA(); + auto threadsPerWarp = getThreadsPerWarp(); + auto shapePerCTA = getShapePerCTA(*this, shape); + assert(rank == sizePerThread.size() && + "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); + SmallVector elemsPerThread(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]; + elemsPerThread[i] = ceil(shapePerCTA[i], t) * sizePerThread[i]; + } + return elemsPerThread; +} +unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. +// But we need to have a consistent interface with e.g. SliceEncodingAttr, which +// computes some of these fields. +SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector BlockedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector BlockedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector BlockedEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector BlockedEncodingAttr::getWarpOrder() const { + return SmallVector(getOrder()); +} +SmallVector BlockedEncodingAttr::getThreadsPerWarp() const { + return SmallVector(getThreadsPerWarp__()); +} +SmallVector BlockedEncodingAttr::getThreadOrder() const { + return SmallVector(getOrder()); +} +SmallVector BlockedEncodingAttr::getSizePerThread() const { + return SmallVector(getSizePerThread__()); +} +SmallVector +BlockedEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + SmallVector shape; + for (unsigned d = 0, n = getOrder().size(); d < n; ++d) + shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] * + getWarpsPerCTA()[d]); + return shape; +} + +template +SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { + size_t rank = shape.size(); + unsigned dim = getDim(); + SmallVector retShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) + retShape[d] = shape[d]; + else if (d == dim) + retShape[d] = 1; + else + retShape[d] = shape[d - 1]; + } + return retShape; +} +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; + +SmallVector +SliceEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + auto parent = getParent(); + auto parentElemsPerThread = + ::getElemsPerThread(parent, paddedShape(shape), eltTy); + parentElemsPerThread.erase(parentElemsPerThread.begin() + getDim()); + return parentElemsPerThread; +} +unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} +SmallVector SliceEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + res.erase(res.begin() + getDim()); + return res; +} +SmallVector SliceEncodingAttr::getCTAOrder() const { + auto parentCTAOrder = ::getCTAOrder(getParent()); + return eraseOrder(parentCTAOrder, getDim()); +} +SmallVector SliceEncodingAttr::getCTAsPerCGA() const { + auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent()); + if (parentCTAsPerCGA[getDim()] == 1) { + parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim()); + return parentCTAsPerCGA; + } + /* For getCTAsPerCGA of a slice layout, we have two choices: + * (1) Return CTAsPerCGA of its parent. This is not a perfect solution + * because the rank of the returned CTAsPerCGA does not match the rank of + * tensorShape. + * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a + * perfect solution because the product of the returned CTAsPerCGA might not + * match numCTAs. + * To avoid introducing inconsistencies to the shape and + * layout system, the usage of directly getting CTAsPerCGA of a slice layout + * in which the sliced dim is not 1 is banned. You should always consider + * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) + * in the branch where layout is an instance of SliceEncodingAttr. This is + * inconvenient but safe. + */ + llvm::report_fatal_error( + "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); +} +SmallVector SliceEncodingAttr::getWarpsPerCTA() const { + auto parent = getParent(); + auto parentWarpsPerCTA = ::getWarpsPerCTA(parent); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + getDim()); + int32_t nextDim = getDim() < warpsPerCTA.size() ? getDim() : getDim() - 1; + warpsPerCTA[nextDim] *= parentWarpsPerCTA[getDim()]; + return warpsPerCTA; +} +SmallVector SliceEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector SliceEncodingAttr::getThreadsPerWarp() const { + auto parent = getParent(); + auto parentThreadsPerWarp = ::getThreadsPerWarp(parent); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + getDim()); + int32_t nextDim = getDim() < threadsPerWarp.size() ? getDim() : getDim() - 1; + threadsPerWarp[nextDim] *= parentThreadsPerWarp[getDim()]; + return threadsPerWarp; +} +SmallVector SliceEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector SliceEncodingAttr::getSizePerThread() const { + auto sizePerThread = ::getSizePerThread(getParent()); + sizePerThread.erase(sizePerThread.begin() + getDim()); + return sizePerThread; +} +SmallVector +SliceEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + SmallVector shape = ::getShapePerCTATile(getParent(), tensorShape); + shape.erase(shape.begin() + getDim()); + return shape; +} + +// + +SmallVector +AMDMfmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + assert((rank == 2 || rank == 3) && "Unexpected rank of mfma layout"); + + SmallVector elemsPerThread(rank); + auto nonKDim = getMDim(); + auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16); + if (rank == 3) + elemsPerThread[0] = ceil(shape[0], getWarpsPerCTA()[0]); + if (getIsTransposed()) { + unsigned elemsCol = + ceil(shape[rank - 1], nonKDim * getWarpsPerCTA()[rank - 1]) * + elemsPerThreadPerTile; + unsigned elemsRow = + ceil(shape[rank - 2], nonKDim * getWarpsPerCTA()[rank - 2]); + elemsPerThread[rank - 2] = elemsRow; + elemsPerThread[rank - 1] = elemsCol; + } else { + unsigned elemsCol = + ceil(shape[rank - 1], nonKDim * getWarpsPerCTA()[rank - 1]); + unsigned elemsRow = + ceil(shape[rank - 2], nonKDim * getWarpsPerCTA()[rank - 2]) * + elemsPerThreadPerTile; + elemsPerThread[rank - 2] = elemsRow; + elemsPerThread[rank - 1] = elemsCol; + } + return elemsPerThread; +} + +unsigned AMDMfmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// + +SmallVector +AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + assert((rank == 2 || rank == 3) && "Unexpected rank of wmma layout"); + + SmallVector elemsPerThread(rank); + auto mnkDim = getMNKDimPerWMMAInstr(); + auto elemsPerThreadPerTile = getSizePerThread(); + auto warpsPerCTA = getWarpsPerCTA(); + + if (rank == 3) + elemsPerThread[0] = ceil(shape[0], getWarpsPerCTA()[0]); + elemsPerThread[rank - 2] = + ceil(shape[rank - 2], mnkDim[0] * warpsPerCTA[rank - 2]) * + elemsPerThreadPerTile[rank - 2]; + elemsPerThread[rank - 1] = + ceil(shape[rank - 1], mnkDim[1] * warpsPerCTA[rank - 1]) * + elemsPerThreadPerTile[rank - 1]; + return elemsPerThread; +} + +unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// + +SmallVector +NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + assert(rank == 2 || + (rank == 3 && isAmpere()) && "Unexpected rank of mma layout"); + assert((isVolta() || isAmpere() || isHopper()) && + "For NvidiaMmaEncodingAttr only version 1~3 is supported"); + + auto shapePerCTA = getShapePerCTA(getCTALayout().getCTASplitNum(), shape); + + SmallVector elemsPerThread(rank); + if (isVolta()) { + auto [isARow, isBRow, isAVec4, isBVec4, id] = decodeVoltaLayoutStates(); + static constexpr std::array fpw{{2, 2}}; + unsigned packSize0 = (isARow || isAVec4) ? 1 : 2; + unsigned packSize1 = (isBRow && !isBVec4) ? 2 : 1; + unsigned repM = 2 * packSize0; + unsigned repN = 2 * packSize1; + unsigned spwM = fpw[0] * 4 * repM; + unsigned spwN = fpw[1] * 4 * repN; + unsigned wptM = getWarpsPerCTA()[0]; + unsigned wptN = getWarpsPerCTA()[1]; + unsigned resM = repM * std::max(1, shapePerCTA[0] / (spwM * wptM)); + unsigned resN = 2 * repN * std::max(1, shapePerCTA[1] / (spwN * wptN)); + elemsPerThread[0] = resM; + elemsPerThread[1] = resN; + } else if (isAmpere()) { + unsigned elemsRow = + ceil(shapePerCTA[rank - 2], 16 * getWarpsPerCTA()[rank - 2]) * + 2; + unsigned elemsCol = + ceil(shapePerCTA[rank - 1], 8 * getWarpsPerCTA()[rank - 1]) * + 2; + if (rank == 3) + elemsPerThread[0] = ceil(shapePerCTA[0], getWarpsPerCTA()[0]); + elemsPerThread[rank - 2] = elemsRow; + elemsPerThread[rank - 1] = elemsCol; + } else if (isHopper()) { + auto wpt = getWarpsPerCTA(); + auto instrMNK = getInstrShape(); + int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); + elemsPerThread[0] = 2 * repM; + elemsPerThread[1] = (instrMNK[1] / 4) * repN; + } else { + llvm_unreachable("Unexpected mma version"); + } + + return elemsPerThread; +} + +unsigned NvidiaMmaEncodingAttr::getElemsPerThreadOfOperand( + int opIdx, ArrayRef shape) const { + size_t rank = shape.size(); + assert(rank == 2 && "Unexpected rank of mma layout"); + auto shapePerCTA = getShapePerCTA(*this, shape); + int res = 0; + if (isVolta()) { + llvm_unreachable( + "getElemsPerThreadOfOperand() not supported for version 1"); + } else if (isAmpere()) { + llvm_unreachable( + "getElemsPerThreadOfOperand() not supported for version 2"); + } else if (isHopper()) { + auto wpt = getWarpsPerCTA(); + auto instrMNK = getInstrShape(); + if (opIdx == 0) { + int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repK = ceil(shapePerCTA[1], instrMNK[2]); + return 8 * repM * repK; + + } else if (opIdx == 1) { + int repK = ceil(shapePerCTA[0], instrMNK[2]); + int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); + // benzh@ here need more check + return 4 * std::max(instrMNK[1] / 32, 1) * repK * repN; + } + } + return res; +} + +unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// + +SmallVector +SharedEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return SmallVector(); +} +unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return 0; +} + +SmallVector +DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for dot operand"); + return SmallVector(); +} + +unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + if (auto mmaParent = mlir::dyn_cast(getParent())) { + return mmaParent.getTotalElemsPerThreadForOperands(shape, eltTy, + getKWidth(), getOpIdx()); + } + if (auto blockedLayout = mlir::dyn_cast(getParent())) { + auto shapePerCTA = getShapePerCTA(*this, shape); + auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); + auto order = blockedLayout.getOrder(); + auto sizePerThread = ::getSizePerThread(blockedLayout); + + int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; + int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; + + bool isM = getOpIdx() == 0; + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int sizePerThreadMN = isM ? mSizePerThread : nSizePerThread; + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int shapePerCTAMNTile = isM ? mShapePerCTATile : nShapePerCTATile; + + return K * std::max(otherDim / shapePerCTAMNTile, 1) * sizePerThreadMN; + } + llvm_unreachable("unknown dot operand parent layout"); + return 0; +} +SmallVector DotOperandEncodingAttr::getCTAsPerCGA() const { + return ::getCTAsPerCGA(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTAOrder() const { + return ::getCTAOrder(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + auto rank = res.size(); + assert(rank == 2 || rank == 3 && "Invalid dotLayout"); + + // Do not split CTA in K dimension + getOpIdx() == 0 ? res[rank - 1] = 1 : res[rank - 2] = 1; + return res; +} +SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto distributedLayout = + mlir::dyn_cast(parentLayout)) { + return distributedLayout.getWarpsPerCTA(); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-DistributedEncodingAttr parent not " + "supported yet"); + } +} +SmallVector DotOperandEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector DotOperandEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector DotOperandEncodingAttr::getShapePerCTATile( + ArrayRef tensorShape) const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { + return parentMmaLayout.getShapePerCTATileForDotOperands(tensorShape, + getOpIdx()); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " + "supported yet"); + } +} + +LogicalResult DotOperandEncodingAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + unsigned opIdx, Attribute parent, unsigned kWidth) { + if (opIdx != 0 && opIdx != 1) { + return emitError() + << "triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: " + << opIdx; + } + if (!parent) { + return emitError() << "triton_gpu.dot_op parent paramenter cannot be null"; + } + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0 && !parentAttr.isAmpere()) + return emitError() << "triton_gpu.dot_op kWidth parameter can only be " + "non-zero for Ampere MMA parent"; + if (kWidth == 0 && parentAttr.isAmpere()) + return emitError() + << "triton_gpu.dot_op kWidth parameter is mandatory for " + "Ampere MMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + // TODO: remove this condition if new values are supported + if (kWidth != 16) + return emitError() << "triton_gpu.dot_op kWidth parameter supports " + "only 16 for WMMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth == 0) + return emitError() + << "triton_gpu.dot_op kWidth parameter is mandatory for " + "MFMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() + << "triton_gpu.dot_op kWidth parameter is not supported " + "when the parent is a blocked layout"; + return success(); + } + + return emitError() << "triton_gpu.dot_op unexpected parent layout: " + << parent; +} + +//===----------------------------------------------------------------------===// +// Blocked Encoding +//===----------------------------------------------------------------------===// + +static std::optional getCTALayoutOrError( + AsmParser &parser, std::optional> CTAsPerCGA, + std::optional> CTASplitNum, + std::optional> CTAOrder, unsigned rank) { + if (CTAsPerCGA && CTASplitNum && CTAOrder) { + return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum, + *CTAOrder); + } + if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) { + return CTALayoutAttr::getDefault(parser.getContext(), rank); + } + parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder " + "must all be present or all be absent"); + return std::nullopt; +} + +Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector sizePerThread; + SmallVector threadsPerWarp; + SmallVector warpsPerCTA; + SmallVector order; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "sizePerThread") { + if (parseIntArrayAttr(parser, attr, sizePerThread, + "number of elements per thread") + .failed()) + return {}; + } else if (attr.getName() == "threadsPerWarp") { + if (parseIntArrayAttr(parser, attr, threadsPerWarp, + "number of threads per warp") + .failed()) + return {}; + } else if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, + "number of warps per CTA") + .failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/sizePerThread.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), + sizePerThread, threadsPerWarp, + warpsPerCTA, order, *CTALayout); +} + +void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "sizePerThread = [" << ArrayRef(getSizePerThread()) << "]" + << ", threadsPerWarp = [" << ArrayRef(getThreadsPerWarp()) << "]" + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" + << ", order = [" << getOrder() << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getSizePerThread().size()); + + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// MMA encoding +//===----------------------------------------------------------------------===// + +Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + SmallVector instrShape; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CTALayout, + instrShape); +} + +void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + + printer << ", instrShape = [" << getInstrShape() << "]}>"; +} + +//===----------------------------------------------------------------------===// +// MFMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + SmallVector instrShape; + bool isTransposed; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) + return {}; + } + if (attr.getName() == "isTransposed") { + if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, + instrShape[0], instrShape[1], isTransposed, *CTALayout); +} + +void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() // + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" // + << ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" // + << ", isTransposed = " << getIsTransposed(); + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + printer << "}>"; +} + +LogicalResult +AMDMfmaEncodingAttr::verify(function_ref emitError, + unsigned versionMajor, unsigned versionMinor, + llvm::ArrayRef warpsPerCTA, + unsigned mDim, unsigned nDim, bool isTransposed, + mlir::triton::gpu::CTALayoutAttr) { + if (!(versionMajor >= 0 && versionMajor <= 3)) { + return emitError() << "major version must be in the [0, 3] range"; + } + if (versionMinor != 0) { + return emitError() << "minor version must be 0"; + } + if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) { + return emitError() + << "(M, N) cases other than (32, 32) or (16, 16) unimplemented"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// WMMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), + warpsPerCTA, *CTALayout); +} + +void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// Sliced Encoding +//===----------------------------------------------------------------------===// + +Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + unsigned dim = mlir::cast(attrs.get("dim")).getInt(); + Attribute parent = attrs.get("parent"); + return parser.getChecked(parser.getContext(), dim, parent); +} + +void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "dim = " << getDim() << ", " + << "parent = " << getParent() << "}>"; +} + +//===----------------------------------------------------------------------===// +// Shared encoding +//===----------------------------------------------------------------------===// + +Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned vec = 0; + unsigned perPhase = 0; + unsigned maxPhase = 0; + SmallVector order; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + bool hasLeadingOffset = false; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "vec") { + if (parseUInt(parser, attr, vec, "vec").failed()) + return {}; + } else if (attr.getName() == "perPhase") { + if (parseUInt(parser, attr, perPhase, "perPhase").failed()) + return {}; + } else if (attr.getName() == "maxPhase") { + if (parseUInt(parser, attr, maxPhase, "maxPhase").failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else if (attr.getName() == "hasLeadingOffset") { + if (parseBool(parser, attr, hasLeadingOffset, "hasLeadingOffset") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/order.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), vec, + perPhase, maxPhase, order, + *CTALayout, hasLeadingOffset); +} + +void SharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "vec = " << getVec() // + << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() // + << ", order = [" << getOrder() << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getOrder().size()); + printer << ", hasLeadingOffset = " << getHasLeadingOffset() << "}>"; +} + +//===----------------------------------------------------------------------===// +// Mfma encoding +//===----------------------------------------------------------------------===// +// TODO: there is a lot of common code with MmaEncoding here + +SmallVector +AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + shapePerCTATile[rank - 1] *= getMDim(); + shapePerCTATile[rank - 2] *= getNDim(); + return shapePerCTATile; +} + +SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDMfmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDMfmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector AMDMfmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector AMDMfmaEncodingAttr::getThreadsPerWarp() const { + unsigned rows, cols; + auto rank = ::getOrder(*this).size(); + SmallVector res(rank, 1); + if (getMDim() == 32) { + cols = 2; + rows = 32; + } else { + assert(getMDim() == 16); + cols = 4; + rows = 16; + } + if (getIsTransposed()) { + res[rank - 1] = cols; + res[rank - 2] = rows; + } else { + res[rank - 1] = rows; + res[rank - 2] = cols; + } + return res; +} + +SmallVector AMDMfmaEncodingAttr::getSizePerThread() const { + unsigned rows, cols; + auto rank = ::getOrder(*this).size(); + SmallVector res(rank, 1); + if (getMDim() == 32) { + rows = 16; + cols = 1; + } else if (getMDim() == 16) { + rows = 4; + cols = 1; + } else + llvm_unreachable("Unexpected mfma non-k dim"); + + if (getIsTransposed()) { + res[rank - 1] = rows; + res[rank - 2] = cols; + } else { + res[rank - 1] = cols; + res[rank - 2] = rows; + } + return res; +} + +SmallVector +AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { + unsigned mDim = getMDim(); + unsigned nDim = getNDim(); + assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + constexpr int waveSize = 64; // MFMA is used on wave64 architectures only + int kGroups = -1; + if (mDim == nDim) + kGroups = waveSize / mDim; + if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) + kGroups = 1; + int64_t kDim = kWidth * kGroups; + if (opIdx == 0) + return {mDim, kDim}; + else + assert(opIdx == 1); + return {kDim, nDim}; +} + +SmallVector +AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, + int kWidth, int opIdx) const { + auto operandTileShape = getMFMAInstrShapeForOperands(kWidth, opIdx); + auto rank = operandShape.size(); + auto warpsPerCTA = getWarpsPerCTA(); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto rep = getMFMARepForOperands(shape, kWidth, opIdx); + return product(rep) * kWidth; +} + +SmallVector +AMDMfmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + if (opIdx == 0) { + return {4, 1}; + } else if (opIdx == 1) { + return {1, 4}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; + } +} + +SmallVector +AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, + int opIdx) const { + assert(getMDim() == 32 || getMDim() == 16); + auto parentShapePerCTATile = getShapePerCTATile(shape); + auto rank = parentShapePerCTATile.size(); + if (opIdx == 0) { + if (rank == 2) + return {parentShapePerCTATile[rank - 2], 32}; + else + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32}; + } else if (opIdx == 1) { + if (rank == 2) + return {32, parentShapePerCTATile[rank - 1]}; + else + return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } + llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1"); +} + +SmallVector +AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + + auto mnkDim = getMNKDimPerWMMAInstr(); + shapePerCTATile[rank - 2] *= mnkDim[0]; + shapePerCTATile[rank - 1] *= mnkDim[1]; + return shapePerCTATile; +} +SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDWmmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDWmmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector AMDWmmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector AMDWmmaEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector AMDWmmaEncodingAttr::getThreadsPerWarp() const { + auto rank = getWarpsPerCTA().size(); + SmallVector threads(rank, 1); + auto mnkInstr = getMNKDimPerWMMAInstr(); + threads[rank - 2] = mnkInstr[0] / getSizePerThread()[rank - 2]; + threads[rank - 1] = mnkInstr[1] / getSizePerThread()[rank - 1]; + return threads; +} + +SmallVector AMDWmmaEncodingAttr::getSizePerThread() const { + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); + sizePerThread[rank - 2] = 8; + sizePerThread[rank - 1] = 1; + return sizePerThread; +} +SmallVector +AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); + if (opIdx == 0) { + sizePerThread[rank - 2] = 1; + sizePerThread[rank - 1] = 16; + } else if (opIdx == 1) { + sizePerThread[rank - 2] = 16; + sizePerThread[rank - 1] = 1; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } + return sizePerThread; +} + +SmallVector +AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, + int opIdx) const { + auto parentShapePerCTA = getShapePerCTATile(shape); + auto rank = shape.size(); + assert(rank = 2); + if (opIdx == 0) { + return {parentShapePerCTA[0], static_cast(shape[1])}; + } else if (opIdx == 1) { + return {static_cast(shape[0]), parentShapePerCTA[1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } +} + +unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx); + return product(rep) * kWidth; +} + +SmallVector +AMDWmmaEncodingAttr::getWMMAElemsPerInstrForOperands() const { + return {16, 16}; +} + +SmallVector +AMDWmmaEncodingAttr::getWMMARepForOperands(ArrayRef operandShape, + Type elemType, int kWidth, + int opIdx) const { + auto operandTileShape = getWMMAElemsPerInstrForOperands(); + assert(operandTileShape.size() == 2); + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = operandShape.size(); + assert(rank == 2 || rank == 3); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +SmallVector AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr() { + // TODO: move magic numbers out of the code + return {16, 16, 16}; +} + +//===----------------------------------------------------------------------===// +// Mma encoding +//===----------------------------------------------------------------------===// + +bool NvidiaMmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; } + +bool NvidiaMmaEncodingAttr::isTuring() const { + return getVersionMajor() == 2 && getVersionMinor() == 1; +} + +bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } + +bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } + +SmallVector NvidiaMmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector NvidiaMmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector NvidiaMmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector NvidiaMmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector NvidiaMmaEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { + auto rank = getWarpsPerCTA().size(); + SmallVector res(rank, 1); + if (isVolta()) { + res[rank - 2] = 4; + res[rank - 1] = 8; + return res; + } + if (isAmpere()) { + res[rank - 2] = 8; + res[rank - 1] = 4; + return res; + } + if (isHopper()) { + res[rank - 2] = 8; + res[rank - 1] = 4; + return res; + } + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for unknown Mma version "); +} +SmallVector NvidiaMmaEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { + auto rank = ::getOrder(*this).size(); + SmallVector res(rank, 1); + if (isAmpere()) { + res[rank - 2] = 2; + res[rank - 1] = 2; + return res; + } + if (isVolta()) { + res[rank - 2] = 1; + res[rank - 1] = 2; + return res; + } + if (isHopper()) { + auto instrShape = getInstrShape(); + // WGMMA instructions have an order of [0, 1] with 4 warps, each with 8 + // unique thread ids (32 in a warp group) per column. It is 1 warp wide with + // 4 unique thread ids in the row. So the size per thread is the instruction + // size divided by the number of unique thread ids. + return SmallVector{instrShape[0] * 4 / 32, instrShape[1] / 4}; + } + llvm_unreachable("Unexpected mma version"); +} + +SmallVector +NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + if (isAmpere()) { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), + warpsPerCTA.end()); + shapePerCTATile[rank - 1] *= 8; + shapePerCTATile[rank - 2] *= 16; + return shapePerCTATile; + } + if (isVolta()) { + assert(!tensorShape.empty() && "Volta needs the tensorShape"); + if (tensorShape.size() == 1) // must be SliceEncoding + return {static_cast(tensorShape[0]), + static_cast(tensorShape[0])}; + return {static_cast(tensorShape[0]), + static_cast(tensorShape[1])}; + } + if (isHopper()) { + auto instrShape = getInstrShape(); + return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]}; + } + llvm::report_fatal_error("Unexpected MMA layout version found"); +} + +// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor +std::tuple +NvidiaMmaEncodingAttr::decodeVoltaLayoutStates() const { + unsigned versionMinor = getVersionMinor(); + bool isARow = versionMinor & (1 << 0); + bool isBRow = versionMinor & (1 << 1); + bool isAVec4 = versionMinor & (1 << 2); + bool isBVec4 = versionMinor & (1 << 3); + + int id = 0; + for (int i = numBitsToHoldMmaV1ID - 1; i >= 0; --i) + id = (id << 1) + static_cast(versionMinor & (1 << (4 + i))); + + return std::make_tuple(isARow, isBRow, isAVec4, isBVec4, id); +} + +bool NvidiaMmaEncodingAttr::getMMAv1IsRow(int opIdx) const { + auto [isARow, isBRow, _0, _1, _2] = decodeVoltaLayoutStates(); + return opIdx == 0 ? isARow : isBRow; +} +bool NvidiaMmaEncodingAttr::getMMAv1IsVec4(int opIdx) const { + auto [_0, _1, isAVec4, isBVec4, _2] = decodeVoltaLayoutStates(); + return opIdx == 0 ? isAVec4 : isBVec4; +} +int NvidiaMmaEncodingAttr::getMMAv1NumOuter(ArrayRef shape, + int opIdx) const { + auto spw = getMMAv1ShapePerWarp(opIdx); + auto rep = getMMAv1Rep(opIdx); + auto warpsPerCTA = getWarpsPerCTA(); + if (opIdx == 0) { + return rep[0] * shape[0] / (spw[0] * warpsPerCTA[0]); + } else { + return rep[1] * shape[1] / (spw[1] * warpsPerCTA[1]); + } +} +SmallVector NvidiaMmaEncodingAttr::getMMAv1Rep(int opIdx) const { + auto [isARow, isBRow, isAVec4, isBVec4, _] = decodeVoltaLayoutStates(); + // A + if (opIdx == 0) { + int packSize = (isARow || isAVec4) ? 1 : 2; + return {2 * packSize, 0, 1}; + } + // B + else { + int packSize = (isBRow && !isBVec4) ? 2 : 1; + return {0, 2 * packSize, 1}; + } +} +SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { + auto rep = getMMAv1Rep(opIdx); + if (opIdx == 0) { + return {8 * rep[0], 0, 1}; + } else { + return {0, 8 * rep[1], 1}; + } +} +int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { + return 2 * getMMAv1Rep(opIdx)[opIdx]; +} +SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, + int bitwidth, + int opIdx) const { + auto rank = shape.size(); + auto warpsPerCTA = getWarpsPerCTA(); + SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; + int numRepBatch = + rank == 3 + ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) + : 1; + assert(isAmpere()); + + if (opIdx == 0) + return {numRepBatch, + std::max(1, shape[rank - 2] / + (shapePerWarp[1] * warpsPerCTA[rank - 2])), + std::max(1, shape[rank - 1] / shapePerWarp[3])}; + else { + assert(opIdx == 1); + return {numRepBatch, + std::max(1, shape[rank - 2] / shapePerWarp[3]), + std::max(1, shape[rank - 1] / (shapePerWarp[2] * + warpsPerCTA[rank - 1]))}; + } +} +unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto shapePerCTA = getShapePerCTA(*this, shape); + int warpsPerCTAM = getWarpsPerCTA()[0]; + int warpsPerCTAN = getWarpsPerCTA()[1]; + // H100 + if (isHopper()) { + return getTotalElemsPerThread(shape, eltTy); + } + // A100 + if (isAmpere()) { + auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); + if (opIdx == 0) + return 4 * rep[0] * rep[1] * rep[2]; + if (opIdx == 1) + return 4 * rep[0] * rep[1] * std::max(rep[2] / 2, 1); + } + // V100 + if (isVolta()) { + bool isRow = getMMAv1IsRow(opIdx); + bool isVec4 = getMMAv1IsVec4(opIdx); + if (opIdx == 0) { + int packSizeM = (isRow || isVec4) ? 1 : 2; + int repM = 2 * packSizeM; + int spwM = 2 * 4 * repM; + int numM = getMMAv1NumOuter(shape, opIdx); + int NK = shape[1]; + int vec = 2 * repM; + // Here we mimic the logic in loadA, the result cannot be calculated + // directly. + llvm::DenseSet> visited; + auto ld = [&](int m, int k) { + visited.insert({m, k}); + if (vec > 4) { + if (isRow) + visited.insert({m, k + 4}); + else + visited.insert({m + 1, k}); + } + }; + for (unsigned k = 0; k < NK; k += 4) + for (unsigned m = 0; m < numM / 2; ++m) + if (!visited.count({m, k})) + ld(m, k); + return visited.size() * 2; + } + if (opIdx == 1) { + int packSizeN = (isRow && !isVec4) ? 2 : 1; + int repN = 2 * packSizeN; + int spwN = 2 * 4 * repN; + int numN = getMMAv1NumOuter(shape, opIdx); + int vec = 2 * repN; + + int NK = shape[0]; + // Here we mimic the logic in loadA, the result cannot be calculated + // directly. + llvm::DenseSet> visited; + int elemsPerLd = vec > 4 ? 4 : 2; + auto ld = [&](int n, int k) { + visited.insert({n, k}); + if (vec > 4) { + if (isRow) + visited.insert({n + 1, k}); + else + visited.insert({n, k + 4}); + } + }; + + for (unsigned k = 0; k < NK; k += 4) + for (unsigned n = 0; n < numN / 2; ++n) { + if (!visited.count({n, k})) + ld(n, k); + } + + return visited.size() * 2; + } + } + llvm_unreachable("unknown mma layout"); +} +SmallVector +NvidiaMmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, + int opIdx) const { + assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); + auto parentShapePerCTATile = getShapePerCTATile(shape); + auto rank = parentShapePerCTATile.size(); + if (opIdx == 0) { + if (rank == 2) + return {parentShapePerCTATile[rank - 2], 16}; + else + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 16}; + } else if (opIdx == 1) { + if (rank == 2) + return {16, parentShapePerCTATile[rank - 1]}; + else + return {parentShapePerCTATile[0], 16, parentShapePerCTATile[rank - 1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } +} +SmallVector +NvidiaMmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); + if (opIdx == 0) { + return {2, 4}; + } else if (opIdx == 1) { + return {4, 1}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; + } +} + +//===----------------------------------------------------------------------===// +// DotOperand Encoding +//===----------------------------------------------------------------------===// +SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for DotOperandEncodingAttr"); +} +SmallVector DotOperandEncodingAttr::getSizePerThread() const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { + return parentMmaLayout.getSizePerThreadForOperands(getOpIdx()); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " + "supported yet"); + return {}; + } +} + +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// + +class TritonGPUOpAsmInterface : public OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (auto mmaAttr = mlir::dyn_cast(attr)) { + os << "mma"; + return AliasResult::FinalAlias; + } else if (auto sharedAttr = mlir::dyn_cast(attr)) { + os << "shared"; + return AliasResult::FinalAlias; + } else if (auto blockedAttr = mlir::dyn_cast(attr)) { + os << "blocked"; + return AliasResult::FinalAlias; + } /* else if (auto sliceAttr = dyn_cast(attr)) { + os << "slice"; + return AliasResult::FinalAlias; + } */ + return OpAsmDialectInterface::getAlias(attr, os); + } +}; + +struct TritonGPUInferLayoutInterface + : public triton::DialectInferLayoutInterface { + using DialectInferLayoutInterface::DialectInferLayoutInterface; + + LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const override { + resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis, + operandEncoding); + return success(); + } + + // Infer the encoding of a tt.trans(x) given the encoding of x. + // + // Our goal is to choose an encoding so that the trans is a "nop". For + // example, in a blocked encoding, the same GPU threads hold the same + // elements, they're just "renamed" -- what was element [i,j] of the tensor is + // now element [j,i], but that element is held by the same GPU thread. + // + // For most properties of the encoding, we let + // outputEnc.prop = inputEnc.prop * trans.order, + // where `x * y` means we apply permutation y to x. + // + // This works because prop[i] tells you something about the i'th dimension of + // the tensor. (For example, sizePerThread[2] == 4 means that one GPU thread + // contains 4 elements along dim 2 of the tensor.) The transpose reorders the + // dimensions according to the perm trans.order, so we achieve our goal of + // having a "nop" transpose by reordering the values in the prop the same way. + // + // The big exception to this is the encoding's `order`. + // + // An encoding's order is a list of dimensions, from fastest moving (most + // minor) to slowest moving. Thus enc.order[i] does not tell you something + // about the i'th dimension of the tensor, and it would be disasterously + // incorrect to do enc.order * trans.order. + // + // But! If we invert enc.order, it *does* meet this criterion. For example, + // if enc.order = [2,0,1], inverse(enc.order) = [1,2,0]. If you stare at it, + // you'll see that inverse(enc.order)[i] == j means that dimension i is the + // j'th most minor. Therefore we can safely permute *this* by trans.order. + // + // Thus we have + // + // outputEnc.order = inverse(inverse(inputEnc.order) * trans.order) + // = inverse(trans.order) * inputEnc.order. + // + LogicalResult inferTransOpEncoding(Attribute operandEncoding, + ArrayRef order, // trans order + Attribute &resultEncoding) const override { + // Note: inferFooOpEncoding should not crash if given invalid inputs, which + // happens when someone creates invalid IR. If we return failure() on + // error, then MLIR will generate a helpful error message. + + auto invOrder = inversePermutation(order); + SmallVector invOrderUnsigned(invOrder.begin(), invOrder.end()); + + auto permuteCTALayout = + [&](const CTALayoutAttr &layout) -> FailureOr { + auto n = order.size(); + if (layout.getCTAsPerCGA().size() != n || + layout.getCTASplitNum().size() != n || + layout.getCTAOrder().size() != n) { + return failure(); + } + + return CTALayoutAttr::get( + getDialect()->getContext(), + applyPermutation(layout.getCTAsPerCGA(), order), + applyPermutation(layout.getCTASplitNum(), order), + applyPermutation(invOrderUnsigned, layout.getCTAOrder())); + }; + + if (auto enc = mlir::dyn_cast(operandEncoding)) { + if (enc.getOrder().size() != order.size()) { + return failure(); + } + FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + if (failed(ctaLayout)) { + return failure(); + } + resultEncoding = SharedEncodingAttr::get( + getDialect()->getContext(), enc.getVec(), enc.getPerPhase(), + enc.getMaxPhase(), applyPermutation(invOrderUnsigned, enc.getOrder()), + *ctaLayout, enc.getHasLeadingOffset()); + return success(); + } + + if (auto enc = mlir::dyn_cast(operandEncoding)) { + auto n = order.size(); + if (enc.getSizePerThread().size() != n || + enc.getThreadsPerWarp().size() != n || + enc.getWarpsPerCTA().size() != n || enc.getOrder().size() != n) { + return failure(); + } + FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + if (failed(ctaLayout)) { + return failure(); + } + resultEncoding = BlockedEncodingAttr::get( + getDialect()->getContext(), + applyPermutation(enc.getSizePerThread(), order), + applyPermutation(enc.getThreadsPerWarp(), order), + applyPermutation(enc.getWarpsPerCTA(), order), + applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout); + return success(); + } + + return failure(); // unhandled encoding + } + + LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const override { + auto sliceEncoding = mlir::dyn_cast(operandEncoding); + if (!sliceEncoding) + return emitOptionalError( + location, "ExpandDimsOp operand encoding must be SliceEncodingAttr"); + if (sliceEncoding.getDim() != axis) + return emitOptionalError( + location, "Incompatible slice dimension for ExpandDimsOp operand"); + resultEncoding = sliceEncoding.getParent(); + return success(); + } + + LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const override { + auto mmaRetEncoding = mlir::dyn_cast(retEncoding); + if (mmaRetEncoding && mmaRetEncoding.isHopper()) { + auto dotOpEnc = mlir::dyn_cast(operandEncoding); + if (!mlir::isa(operandEncoding) && + !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && + mlir::isa(dotOpEnc.getParent()))) { + return emitOptionalError( + location, "unexpected operand layout for NvidiaMmaEncodingAttr v3"); + } + } else if (auto dotOpEnc = + mlir::dyn_cast(operandEncoding)) { + if (opIdx != dotOpEnc.getOpIdx()) + return emitOptionalError(location, "Wrong opIdx"); + if (retEncoding != dotOpEnc.getParent()) + return emitOptionalError(location, "Incompatible parent encoding"); + } else + return emitOptionalError( + location, "Dot's a/b's encoding should be of DotOperandEncodingAttr"); + return success(); + } + + LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const override { + auto aEncoding = + mlir::dyn_cast(operandEncodingA); + auto bEncoding = + mlir::dyn_cast(operandEncodingB); + if (!aEncoding && !bEncoding) + return mlir::success(); + auto mmaAEncoding = + mlir::dyn_cast_or_null(aEncoding.getParent()); + if (mmaAEncoding && mmaAEncoding.isHopper()) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return op->emitError("mismatching encoding between A and B operands"); + if (aEncoding.getKWidth() != bEncoding.getKWidth()) + return op->emitError("mismatching kWidth between A and B operands"); + return success(); + } + + // Given a src shape + encoding and a dst shape, our goal is to compute a dst + // encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z] + // contains elements [a,b,c,d] before the reshape, it contains those same + // elements after the reshape, they're just "renamed". + // + // A dst encoding that satisfies this property does not exist for all inputs. + // Here are some positive and negative examples. + // + // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so + // dim 1 is the fastest-changing in the dst, but the src has the opposite + // order. + // - OK: 2x2x32 order=[1,0,2] -> 4x32. We choose dst order [0,1]. + // What's important is that the 2x2 dimensions appear in major-to-minor + // order. + // - NOT OK: 32x32 sizePerThread=[2,2] -> 1024. Thread 0 in the src + // contains elements [(0,0), (0,1), (1,0), and (1,1)]. We cannot express + // this with an encoding based on the dst shape. + // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will + // contain the same elements as before. + // + // Users of this function require that it is symmetrical: if + // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => + // srcEnc. + LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + auto src = mlir::dyn_cast(srcEnc); + if (!src) { + return emitOptionalError( + loc, "Non-reordering reshape only supports BlockedEncoding"); + } + + // Nop reshape; we can always infer an encoding. + if (srcShape == dstShape) { + dstEnc = srcEnc; + return success(); + } + + // default -> default encoding is always a nop. + auto context = srcEnc.getContext(); + int32_t numWarps = product(src.getWarpsPerCTA()); + int32_t threadsPerWarp = product(src.getThreadsPerWarp()); + int32_t numCTAs = product(src.getCTALayout().getCTAsPerCGA()); + if (srcEnc == getDefaultBlockedEncoding(context, srcShape, numWarps, + threadsPerWarp, numCTAs)) { + dstEnc = getDefaultBlockedEncoding(context, dstShape, numWarps, + threadsPerWarp, numCTAs); + return success(); + } + + // Feature flag to disable this routine while it's relatively new. + // TODO(jlebar): Remove this once we're confident in the code. + if (triton::tools::getBoolEnv( + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE")) { + return failure(); + } + + // Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA + // should be like the other fields in blocked encoding, but I'm not sure how + // to handle CTASplitNum. + if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || + !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { + return emitOptionalError( + loc, "Non-reordering reshape does not currently support multi-CTA " + "layouts other than the default layout."); + } + + // Cowardly refuse to handle encodings where shape[dim] is not divisible by + // sizePerThread[dim], threadsPerWarp[dim], and warpsPerCTA[dim]. (We make + // an exception if the block is larger than the shape.) + auto checkDivisibility = [&](StringRef name, ArrayRef subblock) { + for (int dim = 0; dim < srcShape.size(); dim++) { + if (srcShape[dim] >= subblock[dim] && + srcShape[dim] % subblock[dim] != 0) { + return emitOptionalError(loc, + "Can't do a non-reordering reshape because " + "the size of dimension ", + dim, " (", srcShape[dim], ")", + " is not divisible by ", name, "[", dim, "]", + " = ", subblock[dim]); + } + } + return success(); + }; + if (!succeeded( + checkDivisibility("sizePerThread", src.getSizePerThread())) || + !succeeded( + checkDivisibility("threadsPerWarp", src.getThreadsPerWarp())) || + !succeeded(checkDivisibility("warpsPerCTA", src.getWarpsPerCTA()))) { + return failure(); + } + + SmallVector, SmallVector>> decomp = + getReshapeDecomposition(srcShape, dstShape); + + // enc.order[i] == j means that dimension j is the enc.order[i]'th most + // minor. But what we usually want is the inverse: inverse(enc.order)[i] = j + // means that dimension i is the j'th most minor (larger means more major). + auto srcInvOrder = inversePermutation(src.getOrder()); + + // If src dims [a,b,c] are to be merged, then they must be consecutive in + // physical order, with `a` being the most major. + for (const auto &[srcDims, dstDims] : decomp) { + if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { + return emitOptionalError(loc, + "Cannot do a non-reordering reshape given " + "this src encoding order. Dimensions [", + join(srcDims), + "] must be physically consecutive."); + } + } + + // If src dims [a,b,c] are to be merged, then `c` must fill up sizePerThread + // / threadsPerWarp / blocksPerCTA before `b` can have any non-1 values. + // Examples: + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,2,2]. + // The total sizePerThread for dim 2 is 2, which is less than dim 2's + // size of 4. Therefore dim 1 cannot have non-1 sizePerThread. + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4]. + // Dim 2's sizePerThread covers its whole size, so dim 1 is allowed to + // have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[2,1,4]. + // Dim 1's sizePerThread does not cover its whole size, so dim 0 is not + // allowed to have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,1,2], + // threadsPerWarp=[1,2,1]. + // Dim 2 has 2 elems per thread and 1 thread per warp. 2*1 is less than + // dim 2's size. Therefore dim 1 must have threadsPerWarp=1. + // + // In addition, the encoding's block can be larger than the shape, but only + // in the most-major dimension of each decomposed chunk, and only after + // we've "used up" the more minor dims. Examples: + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4], threadsPerWarp=[16,2,1], + // warpsPerCTA=[4,1,1]. + // The whole size of dims 0 and 1 are covered by sizePerThread * + // threadsPerWarp. Therefore dim 2 is allowed to have threadsPerWarp and + // warpsPerCTA larger than its size. + for (const auto &[srcDims, dstDims] : decomp) { + auto shapeRemaining = gather(srcShape, srcDims); + auto checkSubblock = [&, srcDims = srcDims](ArrayRef subblock) { + // Iterate minor-to-major (i==0 is most major). + for (int i = srcDims.size() - 1; i >= 0; i--) { + int dim = srcDims[i]; + if (subblock[dim] == 1) { + continue; + } + + // Check that more-minor dims all have 1 in shapeRemaining. + for (int j = i + 1; j < srcDims.size(); j++) { + if (shapeRemaining[j] != 1) { + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Must use " + "up sizePerThread / threadsPerWarp / warpsPerCTA for " + "more-minor dimensions before more major-dims can use them."); + } + } + + if (shapeRemaining[i] >= subblock[dim]) { + assert(shapeRemaining[i] % subblock[dim] == 0); // checked earlier + shapeRemaining[i] /= subblock[dim]; + } else { + shapeRemaining[i] = 0; + } + + // Is the block larger than the shape in this dimension? This is OK + // only if we're the most-major dimension of the chunk and in all + // future chunks, only this most-major dim has a non-1 size. + if (shapeRemaining[i] == 0 && i != 0) { + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Block " + "size in dimension ", + dim, + " is larger than the shape that dimension, but this is only " + "allowed for the most-major dimension of a reshape chunk"); + } + } + return success(); + }; + if (!succeeded(checkSubblock(src.getSizePerThread())) || + !succeeded(checkSubblock(src.getThreadsPerWarp())) || + !succeeded(checkSubblock(src.getWarpsPerCTA()))) { + return failure(); + } + } + + // Given e.g. src.getSizePerThread(), computeSubblockSize computes e.g. + // dst.getSizePerThread(). This should be called for each of sizePerThread, + // threadsPerWarp, and warpsPerCTA, in that order. + SmallVector dstShapeRemaining(dstShape); + auto computeSubblockSize = [&](ArrayRef srcSubblock, + SmallVector &dstSubblock, + StringRef fieldName) -> LogicalResult { + // The dst subblock is "filled up" greedily starting with the most minor + // dim. When we're done, we are left with a smaller shape, of size + // dstShape / dstSubblock, which we store in dstShapeRemaining and use for + // the next call to computeSubblockSize. + dstSubblock.resize(dstShape.size()); + for (const auto &[srcDims, dstDims] : decomp) { + int64_t subblockRemaining = product(gather(srcSubblock, srcDims)); + for (int i = dstDims.size() - 1; i >= 0; i--) { + auto &val = dstSubblock[dstDims[i]]; + auto &shapeRemaining = dstShapeRemaining[dstDims[i]]; + val = std::min(subblockRemaining, shapeRemaining); + + assert(shapeRemaining % val == 0); // Checked earlier. + subblockRemaining /= val; + shapeRemaining /= val; + } + + // If there are any elems remaining in the subblock, it must be because + // the block is larger than the shape. This excess goes into the + // most-major dim of the subblock. + dstSubblock[dstDims[0]] *= subblockRemaining; + } + return success(); + }; + + SmallVector dstSizePerThread; + SmallVector dstThreadsPerWarp; + SmallVector dstWarpsPerCTA; + if (!succeeded(computeSubblockSize(src.getSizePerThread(), dstSizePerThread, + "sizePerThread")) || + !succeeded(computeSubblockSize(src.getThreadsPerWarp(), + dstThreadsPerWarp, "threadsPerWarp")) || + !succeeded(computeSubblockSize(src.getWarpsPerCTA(), dstWarpsPerCTA, + "warpsPerCTA"))) { + return failure(); + } + + // Since we know that each set of srcDims is consecutive, we can + // meaningfully sort decomp by the physical order of the src dimensions, + // major-to-minor. This will also be the order of the dst dimensions. + llvm::sort(decomp, [&](const auto &a, const auto &b) { + const auto &[srcDimsA, dstDimsA] = a; + const auto &[srcDimsB, dstDimsB] = b; + return srcInvOrder[srcDimsA.front()] < srcInvOrder[srcDimsB.front()]; + }); + + // Compute the dst order. Make the dimensions appear in the same order as + // their corresponding src dimensions. + SmallVector dstInvOrder(dstShape.size()); + int i = 0; + for (const auto &[srcDims, dstDims] : decomp) { + for (auto dim : reverse(dstDims)) { + dstInvOrder[dim] = i++; + } + } + auto dstOrder = inversePermutation(dstInvOrder); + + // CTALayout can be all 1's because we bailed on multi-CTA layouts above. + auto CTALayout = CTALayoutAttr::get( + src.getContext(), + /*CTAsPerCGA=*/SmallVector(dstShape.size(), 1), + /*CTASplitNum=*/SmallVector(dstShape.size(), 1), + /*CTAOrder=*/llvm::to_vector(llvm::seq(dstShape.size()))); + + dstEnc = BlockedEncodingAttr::get(src.getContext(), dstSizePerThread, + dstThreadsPerWarp, dstWarpsPerCTA, + dstOrder, CTALayout); + + return success(); + } + + LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const override { + auto enc = mlir::dyn_cast(srcEnc); + if (!enc) { + return emitOptionalError(loc, + "JoinOp can only operate on BlockedEncoding"); + } + + // JoinOp takes two tensors of shape AxBxC and generates a tensor of shape + // AxBxCx2. The encoding is the same as the input, but with 2 elems per + // thread in the new dimension. The new dimension is most-minor. + auto append = [](ArrayRef vals, int val) { + SmallVector ret(vals); + ret.push_back(val); + return ret; + }; + auto appendMinorDim = [](ArrayRef order) { + SmallVector ret(order); + ret.insert(ret.begin(), ret.size()); + return ret; + }; + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + append(enc.getSizePerThread(), 2), // + append(enc.getThreadsPerWarp(), 1), // + append(enc.getWarpsPerCTA(), 1), // + appendMinorDim(enc.getOrder()), // + CTALayoutAttr::get(enc.getContext(), // + append(enc.getCTAsPerCGA(), 1), + append(enc.getCTASplitNum(), 1), + appendMinorDim(enc.getCTAOrder()))); + return success(); + } + + LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const override { + auto enc = mlir::dyn_cast(srcEnc); + if (!enc) { + return emitOptionalError(loc, + "SplitOp can only operate on BlockedEncoding"); + } + + // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of + // shape AxBxC. The input must have 2 elements per thread in the last + // dimension, which must be most-minor. The result encoding is the same as + // the input, but with the last dimension removed. + if (enc.getSizePerThread().back() != 2) { + return emitOptionalError(loc, + "SplitOp requires 2 elements per thread in the " + "last dimension of the input"); + } + if (enc.getThreadsPerWarp().back() != 1 || + enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, "SplitOp requires threadsPerWarp, warpsPerCTA, " + "and CTAsPerCGA = 1 for the last dimension of the input"); + } + if (enc.getOrder().front() != enc.getOrder().size() - 1) { + return emitOptionalError( + loc, "SplitOp requires the last dimension to be most-minor in order"); + } + if (enc.getCTALayout().getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, + "SplitOp requires the last dimension to be most-minor in CTAOrder"); + } + + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + ArrayRef(enc.getSizePerThread()).drop_back(1), + ArrayRef(enc.getThreadsPerWarp()).drop_back(1), + ArrayRef(enc.getWarpsPerCTA()).drop_back(1), + ArrayRef(enc.getOrder()).drop_front(1), + CTALayoutAttr::get(enc.getContext(), // + ArrayRef(enc.getCTAsPerCGA()).drop_back(1), + ArrayRef(enc.getCTASplitNum()).drop_back(1), + ArrayRef(enc.getCTAOrder()).drop_front(1))); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Canonicalizer +//===----------------------------------------------------------------------===// + +// reshape(cvt) -> reshape +struct CanonicalizeConvertFromReshape + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + if (isExpensiveView(convert.getSrc().getType(), op.getType())) + return failure(); + if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getAllowReorder()); + return mlir::success(); + } +}; + +// histogram(cvt) -> histogram +struct CanonicalizeConvertFromHistogram + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::HistogramOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getSrc()); + return mlir::success(); + } +}; + +// alloc(cvt) -> alloc +struct CanonicalizeConvertFromAlloc + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, + PatternRewriter &rewriter) const override { + if (!op.getSrc()) + return failure(); + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getSrc()); + return mlir::success(); + } +}; + +// local_store(cvt) -> local_store +struct CanonicalizeConvertFromLocalStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp(op, convert.getSrc(), + op.getDst()); + return mlir::success(); + } +}; + +struct CanonicalizeConvertFromConvert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ConvertLayoutOp op, + PatternRewriter &rewriter) const override { + // Convert to the same layout is redundant. + if (op->getResultTypes() == op->getOperandTypes()) { + rewriter.replaceOp(op, op->getOperands()); + return success(); + } + + // We don't handle conversions to DotOperandEncodingAttr. This is a + // heuristic to accommodate fused attention. + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); + if (mlir::isa(dstType.getEncoding()) && + mlir::isa(srcType.getEncoding())) + return failure(); + + // for hopper MMAv3 + if (mlir::isa(dstType.getEncoding()) && + mlir::isa(srcType.getEncoding()) && + llvm::any_of(op.getResult().getUsers(), + [](Operation *dot) { return isa(dot); })) { + return failure(); + } + + Operation *arg = op.getSrc().getDefiningOp(); + if (!arg) + return failure(); + + // cvt(reshape) -> reshape + if (auto reshape = dyn_cast(arg)) { + if (!reshape.getAllowReorder() || + reshape.getEfficientLayout().has_value() || + isExpensiveView(reshape.getSrc().getType(), op.getType())) + return failure(); + + // In TritonGPUToLLVM phase, ViewOp is converted to unpacking and packing + // operations, which requires the element type to match between unpacking + // and packing. However, part of values with dot operand encoding will be + // packed/unpacked as i32 elements instead of the underlying element type. + // To avoid errors, skip this folding when either the operand or result + // of view has a dot operand encoding. + if (hasDotOperandEncoding(op->getOperand(0)) || + hasDotOperandEncoding(op->getResult(0))) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + reshape.getResult(), + reshape.getAllowReorder()); + return success(); + } + + // cvt(histogram) -> histogram + if (auto histogram = dyn_cast(arg)) { + // For histogram ops the input and output layouts are independent, so we + // can always fold convert into the histogram op. + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + histogram.getSrc()); + return success(); + } + + // cvt(local_load) -> local_load. + if (auto sharedLoad = dyn_cast(arg)) { + // Shared_load can load to any layout so we can always fold convert into + // it. + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + sharedLoad.getSrc()); + return success(); + } + + // cvt(cat) -> cat + if (auto cat = dyn_cast(arg)) { + if (isExpensiveCat(cat, op.getType().getEncoding())) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + cat.getOperands()); + return success(); + } + + // cvt(cvt(x, type1), type2) -> cvt(x, type2) + if (auto cvt = dyn_cast(arg)) { + auto srcType = op.getSrc().getType(); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), cvt.getSrc()); + return success(); + } + + // cvt(type1, splat(type2, x)) -> splat(type1, x) + if (auto splat = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + splat.getSrc()); + return success(); + } + + // cvt(type1, make_range(type2, x)) -> make_range(type1, x) + if (auto range = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), range.getStart(), range.getEnd()); + return success(); + } + + // cvt(type, constant) -> constant + if (auto cst = llvm::dyn_cast(arg)) + if (auto ret = dyn_cast(cst.getValue())) { + auto ty = cast(op->getResultTypes().front()); + auto newRet = + SplatElementsAttr::get(ty, ret.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newRet); + return success(); + } + return failure(); + } +}; + +void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); +} + +// LocalAllocOp +void LocalAllocOp::getEffects( + SmallVectorImpl> + &effects) { + Operation *op = getOperation(); + // If allocation is immutable, mark it as no side effect allow things like + // CSE, DCE to work in early compiler passes. + // After the memory offset is computed, we attach the true side effect to the + // op. + if (!getType().getMutableMemory() && !op->hasAttr("allocation.offset")) + return; + effects.emplace_back(MemoryEffects::Allocate::get(), + mlir::triton::gpu::SharedMemory::get()); + if (getSrc()) + effects.emplace_back(MemoryEffects::Write::get(), getResult(), + mlir::triton::gpu::SharedMemory::get()); +} + +// LocalLoadOp +void LocalLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + mlir::triton::gpu::SharedMemory::get()); +} + +// LocalStoreOp +void LocalStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), getDst(), + mlir::triton::gpu::SharedMemory::get()); +} + +// AsyncCopyGlobalToLocalOp +void AsyncCopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), getResult(), + mlir::triton::gpu::SharedMemory::get()); +} + +LogicalResult MemDescSubviewOp::verify() { + auto srcTy = getSrc().getType(); + auto dstTy = getType(); + + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match desc element type"); + } + if (getOffsets().size() != srcTy.getRank()) { + return emitError("offsets must have the same rank as input"); + } + if (srcTy.getRank() < dstTy.getRank()) { + return emitError("result rank must be less than or equal to input rank"); + } + auto rankDiff = srcTy.getRank() - dstTy.getRank(); + for (int i = 0; i < dstTy.getRank(); i++) { + if (dstTy.getDimSize(i) > srcTy.getDimSize(i + rankDiff)) { + return emitError( + "result shape cannot be larger than input shape at dimension ") + << i; + } + } + + auto srcEnc = srcTy.getEncoding(); + auto dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("src and result must both have or not have an encoding"); + } + + if (!isa(srcEnc)) { + return emitError("src encoding must be SharedEncodingAttr"); + } + if (!isa(dstEnc)) { + return emitError("result encoding must be SharedEncodingAttr"); + } + + // TODO(jlebar): Currently we generate illegal encodings, so we can't add a + // verifier for them. In particular, we use the same encoding for the src and + // dst of a subview op, when the subview removes a dimension. That generates + // an illegal shared encoding (because the size of `order` doesn't match the + // rank of the tensor), but it's not checked anywhere, and we believe the + // resulting code ultimately works. + + return success(); +} + +void TritonGPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc" + >(); + addInterfaces(); + addInterfaces(); +} + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" + +// verify TritonGPU ops +LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp new file mode 100644 index 000000000..ae34598ae --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -0,0 +1,489 @@ +#include + +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton::gpu { +namespace { + +// We use the following nomenclature in this file. +// +// - ctaLayout: A layout for one block, i.e. input dims (register, lane, warp). +// - cgaLayout: Arrangement of multiple blocks, i.e. input dims (block). +// +// Note that this is inconsistent with the type name CTALayoutAttr. That type +// is equivalent to our cgaLayout. +// +// IMO the type name is wrong. If we tried to be consistent anyway, then we'd +// have to rename ctaLayout to "warpLayout". I think that's more confusing than +// being inconsistent about "cgaLayout", especially when we have to consider the +// size of the warpLayout (surely that's not the "warpSize"). + +#define S(v) StringAttr::get(ctx, (v)) + +// Returns ["out0", "out1", ..., "out"]. +SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { + SmallVector ret; + for (int i = 0; i < rank; i++) { + ret.push_back(S("dim" + llvm::Twine(i))); + } + return ret; +} + +// Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of +// size product(shape) and then reshaping to permute(shape, order). +LinearLayout identityND(StringAttr inDimName, ArrayRef shape, + ArrayRef order, + ArrayRef outDimNames) { + assert(shape.size() == order.size()); + + MLIRContext *ctx = inDimName.getContext(); + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < shape.size(); i++) { + // Start with the most-minor dimension, which is order[0]. + int dim = order[i]; + ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]); + } + return ret; +} + +// Make a LinearLayout that maps a block-id to an N-dimensional index. +// +// The tensor is split up into CTAsPerCGA pieces, which are distributed among +// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups). +// +// See the nomenclature note at the top of the file for an explanation of why +// this is called makeCgaLayout when it accepts a CTALayoutAttr. +LinearLayout makeCgaLayout(CTALayoutAttr layout) { + MLIRContext *ctx = layout.getContext(); + StringAttr kBlock = S("block"); + + int rank = layout.getCTAOrder().size(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < rank; i++) { + // Start with the most minor dimension, which is order[0]. + int dim = layout.getCTAOrder()[i]; + int split = layout.getCTASplitNum()[dim]; + int ctas = layout.getCTAsPerCGA()[dim]; + assert(ctas % split == 0); + ret *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * + LinearLayout::zeros1D(ctas / split, kBlock, outDimNames[dim]); + } + + // Transpose to standard order (dim0, dim1, ...). + return ret.transposeOuts(outDimNames); +} + +// Shrinks the output set of a layout function while leaving the input set +// unchanged, by making high-order inputs in inDimName map to the same output. +// Attempts to shrink down to desiredSize, but this is not always possible just +// by modifying one the specified input dimension. +// +// We do this by making the most-major inputs to the layout map to 0. This +// effectively duplicates data along that input dimension. For example, this +// layout has out-dim size 32: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 16. +// +// If we shrink it to size 16 along the `lane` dimension, we set L(lane=2) to 0: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0. +// +// This means that lane=2 has the same data as lane=0. +// +// If we shrink to size 8 along the lane dimension, we set L(lane=1) = 0 as +// well. But when we do this, we have to remove bit 1 (the value of L(lane=1)) +// from all other bases: +// +// L(register=1) = 4 +// L(register=2) = 2 +// L(register=1) = 1 +// L(lane=1) = 0 +// L(lane=2) = 0. +// +// Note this only works because the bases are powers of two. I don't quite know +// what to do when they're not. +LinearLayout shrinkCodomain(const LinearLayout &layout, StringAttr inDimName, + StringAttr outDimName, int desiredSize) { + assert(llvm::isPowerOf2_32(desiredSize)); + int outDimIdx = layout.getOutDimIndex(outDimName); + int desiredZeros = + llvm::Log2_32(layout.getOutDimSize(outDimName) / desiredSize); + if (desiredZeros == 0) { + return layout; + } + + // Find the desiredZeros most-major basis vectors that are not already zero. + // These are the ones we will set to zero. + SmallVector basesToZero; + for (int i = layout.getInDimSizeLog2(inDimName) - 1; + i >= 0 && basesToZero.size() < desiredZeros; i--) { + int basis = layout.getBasis(inDimName, i, outDimName); + if (basis != 0) { + basesToZero.push_back(basis); + } + } + + // Bail if all the bases are already zero; nothing more we can do. + if (basesToZero.empty()) { + return layout; + } + + // The algorithm below only works because the bases are powers of two. I'm + // not sure what to do otherwise. + assert(llvm::all_of(basesToZero, + [&](int basis) { return llvm::isPowerOf2_32(basis); })); + + // We want to zero out the bases in `basesToZero`, and also "shift out" the + // corresponding bits from all other bases. For example if we remove the + // basis with value 8 = 0b100, then if another basis has value 26 = 0b11010, + // the 1 in its 3rd position gets removed and it becomes 10 = 0b1010. + // + // We could manually alter the bases in `layout` to achieve this, but it's + // perhaps simpler to use the linearity of LLs to our advantage. + // + // Consider the function O which is the identity map from out-dims to + // out-dims. We can easily calculate what happens when we remove the relevant + // bases from O. Call this new function O'. + // + // Because of linearity, removing the bases from L is equivalent to composing + // L with O'. So that's what we do below. + + // Construct the out-dims -> out-dims identity layout O. + LinearLayout outputIdentity = LinearLayout::empty(); + for (StringAttr dim : layout.getOutDimNames()) { + outputIdentity *= + LinearLayout::identity1D(layout.getOutDimSize(dim), dim, dim); + } + + // Modify O to remove the relevant bases. + // + // TODO(jlebar): I don't like manually modifying bases here. Perhaps this + // should be a function on LinearLayout. + LinearLayout::BasesT newBases = outputIdentity.getBases(); + llvm::sort(basesToZero); + for (int basis : basesToZero) { + int idx = llvm::Log2_32(basis); + for (int i = newBases[outDimName].size() - 1; i > idx; i--) { + newBases[outDimName][i][outDimIdx] = + newBases[outDimName][i - 1][outDimIdx]; + } + newBases[outDimName][idx][outDimIdx] = 0; + } + + // Construct O'. + LinearLayout transform(std::move(newBases), layout.getOutDimNames()); + + // Compose O' with L. + return layout.compose(transform); +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// larger than shape[d]. Do this without changing the size of the layout's +// inputs (i.e. leave its domain unchanged). +// +// This function is invariant to the order of the layout's input and output +// dimensions. +LinearLayout ensureLayoutNotLargerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + MLIRContext *ctx = shape.begin()->first.getContext(); + + // For the purposes of this function, "block" is the "most-minor" dimension. + // This is just a consequence of how legacy layouts work: We only put the same + // tensor element into two different blocks as a last resort, only after all + // the registers in all the lanes in all the warps in a block already have the + // same tensor element. + SmallVector inDimNames = { + S("block"), + S("register"), + S("lane"), + S("warp"), + }; + + LinearLayout ret = layout; + for (auto outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + if (actualSize <= desiredSize) { + continue; + } + assert(actualSize % desiredSize == 0); + // TODO: We claim this is invariant to the order of dims, so can we get rid + // of llvm::reverse? + for (StringAttr inDimName : llvm::reverse(inDimNames)) { + if (ret.hasInDim(inDimName)) { + ret = shrinkCodomain(ret, inDimName, outDimName, desiredSize); + } + } + assert(ret.getOutDimSize(outDimName) == desiredSize); + } + return ret; +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along the "register" dimension. +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + + MLIRContext *ctx = shape.begin()->first.getContext(); + StringAttr kRegister = S("register"); + + LinearLayout ret = layout; + for (StringAttr outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + assert(actualSize > desiredSize || desiredSize % actualSize == 0); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kRegister, + outDimName); + assert(ret.getOutDimSize(outDimName) >= desiredSize); + } + return ret; +} + +// Combines the layout of a CTA (input dims [register, lane, warp]) with the +// layout of a CGA (i.e. a block), and ensures that the resulting layout has the +// given shape. +// +// See the nomenclature note at the top of the file for why the variable with +// type CTALayoutAttr is called cgaLayoutAttr. +LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, + CTALayoutAttr cgaLayoutAttr, + ArrayRef shape) { + int rank = shape.size(); + assert(ctaLayout.getNumOutDims() == rank); + assert(cgaLayoutAttr.getCTAOrder().size() == rank); + MLIRContext *ctx = cgaLayoutAttr.getContext(); + + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + llvm::SmallDenseMap labeledShape; + for (auto [dim, size] : llvm::zip(outDimNames, shape)) { + labeledShape[dim] = size; + } + + LinearLayout cgaLayout = + ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape) + .transposeOuts(ctaLayout.getOutDimNames()); + + // Calculate the shape of the ctaLayout, which is `shape` divided by the + // cgaLayout's size. + llvm::SmallDenseMap ctaShape; + assert(ctaLayout.getOutDimNames() == cgaLayout.getOutDimNames()); + for (auto dim : ctaLayout.getOutDimNames()) { + ctaShape[dim] = + std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim)); + } + + ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape); + ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape); + + LinearLayout ret = (ctaLayout * cgaLayout).transposeOuts(outDimNames); + for (auto dim : ret.getOutDimNames()) { + assert(ret.getOutDimSize(dim) == labeledShape[dim]); + } + return ret; +} + +LinearLayout blockedToLinearLayout(ArrayRef shape, + BlockedEncodingAttr blocked) { + assert(shape.size() == blocked.getOrder().size()); + + int rank = shape.size(); + MLIRContext *ctx = blocked.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + const auto &order = blocked.getOrder(); + LinearLayout ctaLayout = + identityND(S("register"), blocked.getSizePerThread(), order, + outDimNames) * + identityND(S("lane"), blocked.getThreadsPerWarp(), order, outDimNames) * + identityND(S("warp"), blocked.getWarpsPerCTA(), order, outDimNames); + + return combineCtaCgaWithShape(ctaLayout, blocked.getCTALayout(), shape); +} + +LinearLayout ampereMmaToLinearLayout(ArrayRef shape, + NvidiaMmaEncodingAttr mma) { + int rank = shape.size(); + + assert(mma.isAmpere()); + assert(rank == 2 || rank == 3); + assert(mma.getInstrShape().size() == rank); + assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || + (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); + + MLIRContext *ctx = mma.getContext(); + SmallVector dimNames = standardOutDimNames(ctx, rank); + + LinearLayout ctaLayout( + {{S("register"), {{1, 0}, {0, 8}}}, + {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, + llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); + + ctaLayout *= identityND( + S("warp"), mma.getWarpsPerCTA(), + llvm::to_vector(llvm::reverse(llvm::seq(rank))), dimNames); + + return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); +} + +LinearLayout hopperMmaToLinearLayout(ArrayRef shape, + NvidiaMmaEncodingAttr mma) { + int rank = shape.size(); + assert(mma.isHopper()); + assert(rank == 2); + + // wgmma operates on groups of 4 warps. + assert(product(mma.getWarpsPerCTA()) % 4 == 0); + + // Check that it's a known MMA layout. + assert(mma.getInstrShape().size() == 3); + int m = mma.getInstrShape()[0]; + int n = mma.getInstrShape()[1]; + int k = mma.getInstrShape()[2]; + assert(m == 16); + assert(n == 16 || n == 32 || n == 64 || n == 128 || n == 256); + assert(k == 8 || k == 16 || k == 32); + + MLIRContext *ctx = mma.getContext(); + LinearLayout ctaLayout( + {{S("register"), {{1, 0}, {0, 8}}}, + {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, + {S("dim1"), S("dim0")}); + + // Expand the `register` dimension so the size of dim1 matches `n`. + ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")), + S("register"), S("dim1")); + + // Expand the `warp` dimension according to warpsPerCTA. + // + // It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but + // this really does seem to be correct. + ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}, + {S("dim0"), S("dim1")}) + .transposeOuts(ctaLayout.getOutDimNames()); + + return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); +} + +std::optional toLinearLayout(ArrayRef shape, + SliceEncodingAttr slice) { + MLIRContext *ctx = slice.getContext(); + + // First compute the linear layout for this layout's parent. + SmallVector parentShape(shape); + parentShape.insert(parentShape.begin() + slice.getDim(), 1); + std::optional parentLL = + triton::gpu::toLinearLayout(parentShape, slice.getParent()); + if (!parentLL) { + return std::nullopt; + } + + // Remove dimension slice.getDim() from the parent layout. + // + // 1. Construct a layout `transform` from parent-out-dims to slice-out-dims + // that removes the relevant out-dim. + // 2. Compute linearSlice = parent.compose(transform). Now linearSlice maps + // from parent in-dims to slice out-dims. + // 3. Fix up duplicate registers introduced by slicing. + auto outDimNames = standardOutDimNames(ctx, shape.size() + 1); + LinearLayout transform = LinearLayout::empty(); + for (auto [idx, outDim] : llvm::enumerate(parentLL->getOutDimNames())) { + if (idx == slice.getDim()) { + // Because we're multiplying by all zeros, we could replace outDimNames[0] + // with any other valid out-dim; the layout will be the same. + transform *= LinearLayout::zeros1D(parentLL->getOutDimSize(outDim), + outDim, outDimNames[0]); + } else { + transform *= LinearLayout::identity1D( + parentLL->getOutDimSize(outDim), outDim, + outDimNames[idx - (idx < slice.getDim() ? 0 : 1)]); + } + } + LinearLayout sliceLL = parentLL->compose(transform); + + // Step 3: Along the "register" dim, remove any all-zero bases. + auto bases = sliceLL.getBases(); + std::vector> newRegBases; + for (const auto &basis : bases[S("register")]) { + if (llvm::any_of(basis, [](int b) { return b != 0; })) { + newRegBases.push_back(basis); + } + } + bases[S("register")] = newRegBases; + + LinearLayout ret = LinearLayout(std::move(bases), sliceLL.getOutDimNames()); + + // Match a hack in the legacy code that ensures that the number of registers + // matches getTotalElemsPerThread. Yup: We just removed all the zeros, now + // we're (maybe) adding some back. :) + // + // TODO(jlebar): Once getTotalElemsPerThread uses LLs instead of the existing + // legacy code, I think we can remove this. + int expectedNumRegisters = getTotalElemsPerThread(RankedTensorType::get( + shape, IntegerType::get(ctx, 32) /*dummy type*/, slice)); + if (ret.getInDimSize(S("register")) != expectedNumRegisters) { + int extraZeros = expectedNumRegisters / ret.getInDimSize(S("register")); + // Our use of "dim0" here is arbitrary; because we're adding zeros, any + // output dimension would work. + ret *= LinearLayout::zeros1D(extraZeros, S("register"), S("dim0")); + } + return ret; +} + +} // anonymous namespace + +std::optional toLinearLayout(ArrayRef shape, + Attribute layout) { + if (auto blocked = dyn_cast(layout)) { + return blockedToLinearLayout(shape, blocked); + } + if (auto mma = dyn_cast(layout)) { + if (mma.isAmpere()) { + return ampereMmaToLinearLayout(shape, mma); + } + if (mma.isHopper()) { + return hopperMmaToLinearLayout(shape, mma); + } + } + if (auto slice = dyn_cast(layout)) { + return toLinearLayout(shape, slice); + } + + // TODO(jlebar): Other layouts + return std::nullopt; +} + +} // namespace mlir::triton::gpu diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/Types.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Types.cpp new file mode 100644 index 000000000..77f673cc2 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Types.cpp @@ -0,0 +1,38 @@ +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::gpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + +Type TokenType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + int type = 1; + if (parser.parseInteger(type)) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return TokenType::get(parser.getContext(), type); +} + +void TokenType::print(AsmPrinter &printer) const { + printer << "<" << getType() << ">"; +} + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::gpu::TritonGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + >(); +} diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..7fd678ef1 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,20 @@ +add_triton_library(TritonGPUTransforms + Coalesce.cpp + CombineTensorSelectAndIf.cpp + ReduceDataDuplication.cpp + OptimizeThreadLocality.cpp + RemoveLayoutConversions.cpp + ReorderInstructions.cpp + Utility.cpp + + DEPENDS + TritonGPUTransformsIncGen + + LINK_LIBS PUBLIC + MLIRTransforms + MLIRTransformUtils + TritonAnalysis + TritonIR + TritonGPUIR + MLIRTransformUtils +) diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp new file mode 100644 index 000000000..06a7d963d --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -0,0 +1,198 @@ +#include +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-coalesce" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct CoalescePass : public impl::TritonGPUCoalesceBase { + void + setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, + int numWarps, int threadsPerWarp, + llvm::MapVector &layoutMap) { + Value ptr = getMemAccessPtr(op); + auto refTensorType = cast(ptr.getType()); + + LDBG("Considering op: " << *op); + LLVM_DEBUG({ + DBGS() << "axis info of pointer: "; + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); + SmallVector order = argSort(contiguity); + LDBG("order=[" << triton::join(order, ", ") << "]"); + + auto matchesShape = [&refTensorType](const Value &val) { + auto rttType = dyn_cast(val.getType()); + return rttType && rttType.getShape() == refTensorType.getShape(); + }; + + // The desired divisibility is the maximum divisibility among all dependent + // pointers which have the same shape and order as `ptr`. + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); + if (ptr.getDefiningOp()) { + for (Operation *use : mlir::multiRootGetSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use)) + continue; + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); + memAccessesSameOrder.insert(use); + } + } + } + + auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType); + LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); + + int numElems = product(shapePerCTA); + int numThreads = numWarps * threadsPerWarp; + + unsigned perThread = getNumElementsPerThread(op, order, axisInfoAnalysis); + LDBG("perThread for op: " << perThread); + + for (Operation *opSameOrder : memAccessesSameOrder) { + if (opSameOrder == op) + continue; + unsigned currPerThread = + getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis); + LDBG("perThread for opSameOrder: " << currPerThread); + perThread = std::max(perThread, currPerThread); + } + + perThread = std::min(perThread, std::max(numElems / numThreads, 1)); + LDBG("perThread: " << perThread); + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + unsigned elemNumBits = getElementBitWidth(refTensorType); + perThread = std::min( + perThread, getNumElementsPerThread(op, order, axisInfoAnalysis)); + } + SmallVector sizePerThread(refTensorType.getRank(), 1); + sizePerThread[order[0]] = perThread; + + auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); + layoutMap[op] = triton::gpu::BlockedEncodingAttr::get( + &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, + threadsPerWarp, CTALayout); + } + + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + } + + void coalesceOp(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + } + + void runOnOperation() override { + // Run axis info analysis + ModuleOp moduleOp = getOperation(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // For each i/o operation, we determine what layout + // the pointers should have for best memory coalescing + llvm::MapVector layoutMap; + moduleOp.walk([&](Operation *curr) { + Value ptr = getMemAccessPtr(curr); + if (!ptr) + return; + // We only convert `tensor>` load/store + bool isPtrTensor = false; + if (auto tensorType = dyn_cast(ptr.getType())) + isPtrTensor = isa(tensorType.getElementType()); + if (!isPtrTensor) + return; + auto mod = curr->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, + layoutMap); + }); + + // For each memory op that has a layout L1: + // 1. Create a coalesced memory layout L2 of the pointer operands + // 2. Convert all operands from layout L1 to layout L2 + // 3. Create a new memory op that consumes these operands and + // produces a tensor with layout L2 + // 4. Convert the output of this new memory op back to L1 + // 5. Replace all the uses of the original memory op by the new one + for (auto &kv : layoutMap) { + coalesceOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp new file mode 100644 index 000000000..16183b1af --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp @@ -0,0 +1,124 @@ +#include "mlir/IR/Dominance.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Return true if the select could be merged into the If without breaking SSA +// rules. +static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp, + DominanceInfo &dom) { + // If needs to be dominated by the select. + if (!dom.dominates(selectOp.getOperation(), ifOp.getOperation())) { + return false; + } + // If needs to dominate all the select's users. + for (auto user : selectOp.getResult().getUsers()) { + if (!dom.dominates(ifOp, user)) { + return false; + } + } + return true; +} + +class CombineTensorSelectAndIfPass + : public impl::TritonGPUCombineTensorSelectAndIfBase< + CombineTensorSelectAndIfPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + DominanceInfo dom(m); + + // Go over the arith.select ops, look if there is an if + // with the same condition. + llvm::MapVector> selectToIf; + m.walk([&](arith::SelectOp selectOp) { + // Look if there is an if in the same block, with the same condition. + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + SetVector conditionUsers(condition.getUsers().begin(), + condition.getUsers().end()); + // sort the users in topological order. + conditionUsers = multiRootTopologicalSort(conditionUsers); + // Get condition's users + for (Operation *user : conditionUsers) { + auto ifOp = dyn_cast(user); + if (!ifOp || ifOp->getBlock() != parentBlock) + continue; + if (canMergeIntoIf(selectOp, ifOp, dom)) { + selectToIf[ifOp].push_back(selectOp); + break; + } + } + }); + + for (auto [ifOp, selectOps] : selectToIf) { + // Add new return value to the if (and create else block if necessary), + // then yield the select value in the then block and the else block. + OpBuilder builder(ifOp); + auto loc = ifOp.getLoc(); + // Create an scf::IfOp with extra return value. + SmallVector newResultTypes = {ifOp.getResultTypes().begin(), + ifOp.getResultTypes().end()}; + for (arith::SelectOp selectOp : selectOps) { + newResultTypes.push_back(selectOp.getResult().getType()); + } + auto newIfOp = builder.create( + loc, newResultTypes, ifOp.getCondition(), /*hasElse*/ true); + // Move the existing blocks to the new if. + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + + if (ifOp.elseBlock()) { + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + } else { + // Create an empty yield + auto yieldOp = newIfOp.getElseBodyBuilder().create(loc); + } + + SmallVector ifYieldOperands = newIfOp.thenYield().getOperands(); + SmallVector elseYieldOperands = newIfOp.elseYield().getOperands(); + for (arith::SelectOp selectOp : selectOps) { + Value thenValue = selectOp.getTrueValue(); + Value elseValue = selectOp.getFalseValue(); + ifYieldOperands.push_back(thenValue); + elseYieldOperands.push_back(elseValue); + } + // Update yields + auto updateYield = [&](scf::YieldOp yield, SmallVector &operands) { + builder.setInsertionPoint(yield); + builder.create(loc, operands); + yield.erase(); + }; + updateYield(newIfOp.thenYield(), ifYieldOperands); + updateYield(newIfOp.elseYield(), elseYieldOperands); + + int resultIdx = 0; + // Replace old if with the new one. + for (auto result : ifOp.getResults()) { + result.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + } + // Replace the select with the new return value. + for (arith::SelectOp selectOp : selectOps) { + selectOp.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + selectOp.erase(); + } + + ifOp.erase(); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp new file mode 100644 index 000000000..30211da08 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -0,0 +1,436 @@ +#include +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZETHREADLOCALITY +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +// Change the destination layout of reshape ops allowing reorder when used by a +// reduction in order to minimize the amount of cross thread communication for +// the reduction. +struct OptimizeReshapeLayoutPattern + : public mlir::OpRewritePattern { + OptimizeReshapeLayoutPattern(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp viewOp, + mlir::PatternRewriter &rewriter) const override { + if (!viewOp.getAllowReorder()) + return failure(); + std::optional reductionAxis; + for (Operation *user : viewOp.getResult().getUsers()) { + if (auto reduceOp = dyn_cast(user)) { + if (reductionAxis) { + if (reductionAxis != reduceOp.getAxis()) + return failure(); + } else { + reductionAxis = reduceOp.getAxis(); + } + } + } + if (!reductionAxis) + return failure(); + RankedTensorType tensorType = viewOp.getType(); + if (auto blocked = mlir::dyn_cast( + tensorType.getEncoding())) { + // If the layout already has all the elements along the reduction + // dimension in the same thread we can skip. + if (blocked.getThreadsPerWarp()[*reductionAxis] == 1 && + blocked.getWarpsPerCTA()[*reductionAxis] == 1 && + blocked.getCTAsPerCGA()[*reductionAxis] == 1) + return failure(); + } + ArrayRef shape = tensorType.getShape(); + llvm::SmallVector order; + for (int i : triton::gpu::getOrder(tensorType.getEncoding())) { + if (i != *reductionAxis) + order.push_back(i); + } + // Make the reduction axis last so that elements won't be distributed + // amongst threads along this dimension. + order.push_back(*reductionAxis); + llvm::SmallVector sizePerThread(shape.size(), 1); + auto mod = viewOp->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + triton::gpu::BlockedEncodingAttr encoding = + triton::gpu::BlockedEncodingAttr::get(viewOp.getContext(), shape, + sizePerThread, order, numWarps, + threadsPerWarp, numCTAs); + if (encoding == tensorType.getEncoding()) + return failure(); + RankedTensorType newType = + RankedTensorType::get(shape, tensorType.getElementType(), encoding); + if (triton::gpu::isExpensiveView(viewOp.getSrc().getType(), newType)) + return failure(); + rewriter.setInsertionPointAfter(viewOp); + rewriter.modifyOpInPlace(viewOp, [&]() { + viewOp.getResult().setType(newType); + viewOp.setEfficientLayout(true); + }); + auto cvt = rewriter.create( + viewOp.getLoc(), tensorType, viewOp.getResult()); + rewriter.replaceAllUsesExcept(viewOp.getResult(), cvt.getResult(), cvt); + return mlir::success(); + } +}; + +} // namespace + +class TritonGPUOptimizeThreadLocalityPass + : public impl::TritonGPUOptimizeThreadLocalityBase< + TritonGPUOptimizeThreadLocalityPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // First try to optimize the layout of existing views. + mlir::RewritePatternSet viewLayoutPatterns(&getContext()); + viewLayoutPatterns.add(&getContext()); + if (mlir::applyPatternsAndFoldGreedily(mod, std::move(viewLayoutPatterns)) + .failed()) { + signalPassFailure(); + } + + DenseSet reduceOps; + mod.walk([&](triton::ReduceOp reduce) -> void { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto reductionOp = getReductionOp(reduce); + if (!reductionOp || + !isa( + reductionOp.value())) + return; + // TODO: relax this restriction + if (!(isa(srcEncoding) && rank > 1)) + return; + for (auto operand : reduce->getOperands()) { + auto def = operand.getDefiningOp(); + if (!isa(def)) + return; + } + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + // Not worth applying this optimization if there is only one element per + // thread on the reduction axis + if (elemsPerThread == 1) + return; + if (!reduce->hasOneUse()) + return; + Operation *user = *(reduce->getUsers().begin()); + if (!user->hasOneUse()) + return; + OpOperand &yieldOpOperand = *(user->getUses().begin()); + auto yieldOp = dyn_cast(yieldOpOperand.getOwner()); + if (!yieldOp) + return; + auto operandNumber = yieldOpOperand.getOperandNumber(); + Block *block = reduce->getBlock(); + Operation *parentOp = block->getParentOp(); + auto forOp = dyn_cast(parentOp); + if (!forOp) + return; + auto argNum = yieldOpOperand.getOperandNumber(); + auto oldAccum = forOp.getInitArgs()[argNum]; + auto cstOp = dyn_cast(oldAccum.getDefiningOp()); + if (!cstOp) + return; + reduceOps.insert(reduce); + }); + + IRRewriter builder(&getContext()); + for (auto reduce : reduceOps) { + builder.setInsertionPoint(reduce); + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + assert(isa(srcEncoding) && + "Thread locality optimization only supports blocked encoding"); + auto blocked = dyn_cast(srcEncoding); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto rank = srcShape.size(); + // create new layouts + auto blocked3d = getThreadLocalityOptimizedEncoding(reduce); + auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce); + auto viewOpTensorType = RankedTensorType::get( + viewOpTensorShape, srcType.getElementType(), blocked3d); + auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank, + blocked3d); + // Get forOp + assert(reduce->hasOneUse()); + OpOperand &use = *(reduce->getUses().begin()); + auto operandNumber = use.getOperandNumber(); + auto oldUpdate = use.getOwner(); + assert(oldUpdate->getNumOperands() == 2); + auto accumOperandNumber = (operandNumber == 0) ? 1 : 0; + auto accumOperand = oldUpdate->getOperand(accumOperandNumber); + assert(isa(accumOperand)); + auto blockArg = dyn_cast(accumOperand); + auto blockArgNum = blockArg.getArgNumber(); + auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); + // get oldAccum + auto oldAccum = + forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()]; + // get old loop user + Value loopResult = + forOp.getResult(blockArgNum - forOp.getNumInductionVars()); + assert(loopResult.hasOneUse()); + OpOperand &loopUse = *(loopResult.getUses().begin()); + Operation *loopUser = loopUse.getOwner(); + // get old loop yield + auto oldYield = cast(forOp.getBody()->getTerminator()); + // create newAccum initialization + auto newAccum = + createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d); + // create new loop by copying the old for op signature and appending + // newAccum to the block arguments + auto newLoop = replaceForOpWithNewSignature( + builder, forOp, ValueRange{newAccum->getResult(0)}); + // create thread local reduction (also adds viewOps) + auto newReduce = createReduce(builder, reduce, viewOpTensorType); + + // create new accum update + auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate); + // create new yield + auto newYield = createYield(builder, newLoop, oldYield, + newUpdate->getResult(0), blockArgNum); + // create post loop reduction on the original reduce axis + auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce); + // add convert_layout to get back to original layout, the result layout + // should now match the layout of the old accumulator (%cst) + Type destType = loopResult.getType(); + auto cvtLayout = createConvertLayout(builder, destType, newReduce2); + // incorporate the original accumulator value into the final result + auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate, + cvtLayout, oldAccum); + // Replace the old loop user with the final result + loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0)); + + // cleanup + oldYield.erase(); + forOp.erase(); + } + }; + +private: + std::optional getReductionOp(triton::ReduceOp reduce) const { + auto numRegions = reduce->getNumRegions(); + if (numRegions != 1) + return std::nullopt; + Region ®ion = reduce->getRegion(0); + auto numBlocks = region.getBlocks().size(); + if (numBlocks != 1) + return std::nullopt; + Block &block = region.front(); + auto blockWithoutTerminator = block.without_terminator(); + auto blockSizeWithoutTerminator = std::distance( + blockWithoutTerminator.begin(), blockWithoutTerminator.end()); + if (blockSizeWithoutTerminator != 1) + return std::nullopt; + Operation *op = &block.front(); + return std::optional(op); + } + Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder, + Operation *oldUpdate, + Operation *cvtLayout, + Value oldAccum) const { + builder.setInsertionPointAfter(cvtLayout); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), oldAccum); + mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0)); + auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping); + return finalOp; + } + Operation *createConvertLayout(OpBuilder &builder, Type destType, + Operation *newReduce) const { + builder.setInsertionPointAfter(newReduce); + auto newCvt = builder.create( + newReduce->getLoc(), destType, newReduce->getResult(0)); + return newCvt; + } + + Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop, + triton::ReduceOp &reduce) const { + auto resultIndex = + loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars(); + auto newLoopResult = loop.getResult(resultIndex); + builder.setInsertionPointAfter(loop); + IRMapping mapping; + mapping.map(*(reduce.getOperands().begin()), newLoopResult); + auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping); + return newReduce2; + } + + Operation *createYield(OpBuilder &builder, scf::ForOp &loop, + scf::YieldOp &oldYield, Value newUpdate, + int oldAccumBlockArgNum) const { + builder.setInsertionPoint(oldYield); + SmallVector yieldValues = llvm::to_vector(oldYield.getOperands()); + yieldValues[oldAccumBlockArgNum - 1] = + loop.getBody()->getArgument(oldAccumBlockArgNum); + yieldValues.push_back(newUpdate); + auto newYield = + builder.create(oldYield.getLoc(), yieldValues); + return newYield; + } + + Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop, + Operation *newReduce, Operation *oldUpdate) const { + auto blockArgNum = loop.getBody()->getNumArguments() - 1; + auto newArg = loop.getBody()->getArgument(blockArgNum); + builder.setInsertionPointAfter(newReduce); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), newArg); + mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0)); + auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping); + return newUpdate; + } + + Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce, + Type viewOpTensorType) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + builder.setInsertionPointAfter(reduce); + IRMapping mapping; + for (auto operand : reduce.getOperands()) { + auto viewOp = builder.create( + reduce.getLoc(), viewOpTensorType, operand, /*allowReorder=*/true); + viewOp.setEfficientLayout(true); + mapping.map(operand, viewOp); + } + + auto newReduce = cloneWithInferType(builder, &(*reduce), mapping); + newReduce->setAttr("axis", builder.getI32IntegerAttr(rank)); + auto typeInfer = dyn_cast(newReduce); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newReduce->getContext(), newReduce->getLoc(), + newReduce->getOperands(), newReduce->getAttrDictionary(), + newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newReduce->getResult(i).setType(newTypes[i]); + } + } + return newReduce; + } + + // Work around the lack of support for MaxNumFOp and MinNumFOp in + // arith::getNeutralElement. + std::optional getNeutralElement(Operation *op) const { + if (isa(op)) { + OpBuilder builder(op->getContext()); + + Type resultType = op->getResult(0).getType(); + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/true)); + } + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/false)); + } + } else { + return mlir::arith::getNeutralElement(op); + } + llvm_unreachable("Unhandled reduction op"); + return std::nullopt; + } + + Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce, + Value &oldAccum, SmallVector &shape, + Attribute &slice2d) const { + // Drop the last dimension (thread locality dimension) + SmallVector accumShape(shape.begin(), shape.end() - 1); + auto elemType = cast(oldAccum.getType()).getElementType(); + // Create tensor type for the new accumulator + auto accumType = RankedTensorType::get(accumShape, elemType, slice2d); + // Create new accumulator + builder.setInsertionPointAfter(oldAccum.getDefiningOp()); + auto reductionOp = getReductionOp(reduce); + assert(reductionOp && "Processing a reduce that is not supported!"); + auto neutralVal = getNeutralElement(reductionOp.value()); + assert(neutralVal && "Could not find neutral value for reduction op!"); + auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value()); + auto newAccum = builder.create(oldAccum.getLoc(), + accumType, denseAttr); + return newAccum; + } + + SmallVector + getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto rank = srcShape.size(); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto viewOpTensorShape = insertValue(srcShape, rank, 1); + viewOpTensorShape[reduce.getAxis()] /= elemsPerThread; + viewOpTensorShape[rank] = elemsPerThread; + return viewOpTensorShape; + } + + Attribute getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto blocked = dyn_cast(srcEncoding); + auto sizePerThread3d = + insertValue(blocked.getSizePerThread(), rank, + blocked.getSizePerThread()[reduce.getAxis()]); + sizePerThread3d[reduce.getAxis()] = 1; + auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1); + auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1); + auto order3d = insertValue(blocked.getOrder(), 0, rank); + auto ctasPerCGA3d = + insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1); + auto ctasSplitNum3d = + insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1); + auto ctaOrder3d = + insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank); + auto ctaLayout3d = triton::gpu::CTALayoutAttr::get( + reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d); + auto blocked3d = triton::gpu::BlockedEncodingAttr::get( + reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d, + order3d, ctaLayout3d); + return blocked3d; + } + + template + SmallVector insertValue(ArrayRef vec, unsigned index, int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } + template + SmallVector insertValue(const SmallVector &vec, unsigned index, + int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp new file mode 100644 index 000000000..c0b586d60 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -0,0 +1,91 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREDUCEDATADUPLICATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUReduceDataDuplicationPass + : public impl::TritonGPUReduceDataDuplicationBase< + TritonGPUReduceDataDuplicationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcEncoding = srcType.getEncoding(); + if (isa(srcEncoding)) + return; + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (!dstDotOp) + return; + if (auto srcMmaEncoding = + dyn_cast(srcEncoding)) { + + if (srcMmaEncoding.getVersionMajor() != 2 || + (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && + dstDotOp.getParent() == srcMmaEncoding)) + return; + } + if (auto srcMfmaEncoding = + dyn_cast(srcEncoding)) { + + if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 && + srcMfmaEncoding.getIsTransposed() && + dstDotOp.getParent() == srcMfmaEncoding) + return; + } + auto srcOrder = triton::gpu::getOrder(srcEncoding); + auto rank = srcOrder.size(); + SmallVector sharedOrder; + if (rank == 3) { + // add all elements except the element that is zero + for (unsigned i = 0; i < rank; ++i) + if (srcOrder[i] != 0) + sharedOrder.emplace_back(srcOrder[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = srcOrder; + } + auto tmpType = triton::MemDescType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get( + mod.getContext(), dstDotOp, srcType.getShape(), sharedOrder, + triton::gpu::getCTALayout(srcEncoding), + srcType.getElementType())); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + auto newConvert = builder.create(cvtOp.getLoc(), + dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp new file mode 100644 index 000000000..99865df5d --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -0,0 +1,1321 @@ +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREMOVELAYOUTCONVERSIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-remove-layout-conversions" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0)) +class ConvertDotConvert : public RewritePattern { +public: + ConvertDotConvert(MLIRContext *context) + : RewritePattern(ConvertLayoutOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto dstOp = cast(op); + auto dotOp = dstOp.getSrc().getDefiningOp(); + if (!dotOp) + return failure(); + if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || + std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) + return failure(); + auto cvtOp = dotOp.getOperand(2).getDefiningOp(); + if (!cvtOp) + return failure(); + if (!cvtOp.getSrc().getDefiningOp()) + return failure(); + RankedTensorType dstTy = dstOp.getType(); + RankedTensorType srcTy = cvtOp.getSrc().getType(); + if (dstTy != srcTy) + return failure(); + + auto _0f = rewriter.create( + op->getLoc(), dstTy.getElementType(), + rewriter.getZeroAttr(dstTy.getElementType())); + auto _0 = rewriter.create(op->getLoc(), dotOp.getType(), _0f); + auto newDot = rewriter.create( + op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1), + _0, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); + auto newCvt = rewriter.create(op->getLoc(), dstTy, + newDot.getResult()); + rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getSrc()); + return success(); + } +}; + +// The current algorithm works by analyzing the IR and doing a one-shot rewrite +// based on the analysis. The algorithm is as follows. +// +// 1. Find all the anchor ops. These are ops that have a layout we want to +// preserve. +// +// 2. For each anchor, propagate its layout to all its descendants. +// An op can have multiple ancestors that are anchors, so at this stage an op +// may have multiple layouts associated with it. +// +// 3. Resolve conflicts by deciding which of the multiple layouts the op should +// keep, inserting convert-layout ops to resolve conflicts. After this +// stage, each value has only one layout associated with it. +// +// 4. Rewrite the IR by walking the function in dominance order. Since we +// assume the IR is structured we just need to process the regions in the +// correct order. For each op, rewrite it using the layout decided by the +// analysis phase. +class LayoutPropagation { +public: + // Structure to keep track of the layout associated to a value. + struct LayoutInfo { + LayoutInfo(Attribute encoding) { encodings.insert(encoding); } + LayoutInfo() {} + llvm::SmallSetVector encodings; + }; + LayoutPropagation(FuncOp F) : funcOp(F) {} + // Find the anchor ops and set their layout in the data structure. + void initAnchorLayout(); + // Recursively Propagate the layout to all the users of the anchor ops until + // we reach a fix point. + void propagateLayout(); + // Add layouts given in `Info` to the uses of `value`. + SmallVector propagateToUsers(Value value, LayoutInfo &info); + // Set the encoding to all the values and fill out the values with new layout + // in `changed`. + void setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, Operation *op); + // Resolve cases where a value has multiple layouts associated to it. + void resolveConflicts(); + // Rewrite the IR for the full module. + void rewrite(); + // Rewrite the IR for a region. + void rewriteRegion(Region &R); + // Rewrite an op based on the layout picked by the analysis. + Operation *rewriteOp(Operation *op); + // Rewrite a for op based on the layout picked by the analysis. + Operation *rewriteForOp(scf::ForOp forOp); + Operation *rewriteWhileOp(scf::WhileOp whileOp); + Operation *rewriteIfOp(scf::IfOp ifOp); + void rewriteYieldOp(scf::YieldOp yieldOp); + void rewriteConditionOp(scf::ConditionOp conditionOp); + void rewriteReduceToScalar(Operation *reduceOp); + void rewriteAssertOp(AssertOp assertOp); + Operation *cloneElementwise(OpBuilder &rewriter, Operation *op, + Attribute encoding); + // Map the original value to the rewritten one. + void map(Value old, Value newV); + // Return the mapped value in the given encoding. This will insert a convert + // if the encoding is different than the encoding decided at resolve time. + Value getValueAs(Value value, Attribute encoding); + // Dump the current stage of layout information. + void dump(); + +private: + // map from value to layout information. + llvm::MapVector layouts; + // map of the values rewrite based on their encoding. + DenseMap, Value> rewriteMapping; + SetVector opToDelete; + FuncOp funcOp; +}; + +class LayoutRematerialization { +public: + LayoutRematerialization(FuncOp F) : funcOp(F) {} + // Map the original value to the remat'ed one. + void addRematValue(Value old, Attribute encoding, Value newV); + bool hasRematValue(Value value, Attribute encoding) { + return rematMapping.contains({value, encoding}); + } + // Return the remat'ed value in the given encoding. + Value getRematValue(Value value, Attribute encoding) { + auto it = rematMapping.find({value, encoding}); + assert(it != rematMapping.end()); + return it->second; + } + void cleanup(); + void backwardRematerialization(); + void backwardRematerialization(ConvertLayoutOp convertOp); + void hoistConvertOnTopOfExtOrBroadcast(); + void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp, IRMapping &mapping); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp); + +private: + void updateRematMapping(SmallVector> &values); + // Existing tuples of (value, layout) that needs to be updated when recreating + // scf ops. This prevents keeping track of Values that have been delete when + // rewriting slices. + DenseMap mappedValues; + // map of the values remat based on encoding. + DenseMap, Value> rematMapping; + // DenseMap, Operation*> + SetVector opToDelete; + FuncOp funcOp; +}; + +void LayoutRematerialization::addRematValue(Value old, Attribute encoding, + Value newV) { + LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); + rematMapping[{old, encoding}] = newV; + mappedValues[old] = encoding; +} + +// Remove unneeded values now that we are done with the rematMapping. +void LayoutRematerialization::cleanup() { + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +// Look ahead to at the transitive uses and see if there is a convert to mma +// operations. +bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { + SmallVector queue = {op->getResult(0)}; + SetVector forwardSlice; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + Value currentValue = queue.back(); + queue.pop_back(); + getForwardSlice(currentValue, &forwardSlice); + for (Operation *op : forwardSlice) { + // HACK: Stop propagation if the ReduceOp is using mma layout but is + // producing tensor smaller than the layout we would like to propagate. + // This is to avoid stepping into the known bug. + if (isa(op)) { + auto tensorType = + dyn_cast(op->getOperand(0).getType()); + if (tensorType && + isa(tensorType.getEncoding())) { + auto mmaInstrShape = + cast(encoding).getInstrShape(); + if (tensorType.getShape()[tensorType.getRank() - 2] < + mmaInstrShape[0] || + tensorType.getShape()[tensorType.getRank() - 1] < + mmaInstrShape[1]) { + return false; + } + } + } + + if (auto convertOp = dyn_cast(op)) { + Attribute dstEncoding = convertOp.getType().getEncoding(); + if (auto mmaLayout = dyn_cast(dstEncoding)) + return (mmaLayout.getVersionMajor() > 1) ? true + : mmaLayout == encoding; + if (isa(dstEncoding)) + return true; + if (isa(dstEncoding)) { + if (auto mmaLayout = dyn_cast(encoding)) { + return mmaLayout.getVersionMajor() > 1; + } else { + assert((mlir::isa(encoding))); + return true; + } + } + } + bool isMMAV3 = + isa(encoding) && + cast(encoding).getVersionMajor() == 3; + if (isMMAV3 && (isa(op) || isa(op))) + return true; + auto yield = dyn_cast(op); + if (!yield) + continue; + if (auto ifOp = dyn_cast(yield->getParentOp())) { + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && + (forwardSlice.count(def) || operand.get() == currentValue) && + (seen.insert(operand.get()).second == true)) + queue.push_back(ifOp.getResult(operand.getOperandNumber())); + } + } + auto forOp = dyn_cast(yield.getOperation()->getParentOp()); + if (!forOp) + continue; + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && (forwardSlice.count(def) || operand.get() == currentValue) && + (seen.insert(operand.get()).second == true)) + queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); + } + } + } + return false; +} + +// Return true if the op is an op with a layout we don't want to change. We will +// propagate the layout starting from anchor ops. +bool isLayoutAnchor(Operation *op) { + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return true; + + // Heuristic: Mark permuting reshape as a layout anchor. Its dst can be + // anything, so it stops forward-propagation of layouts. We rely on the + // backwards pass to fix it up if necessary. (If we didn't do this, then + // anything following the reshape won't be covered by the forward pass at + // all.) + if (auto reshape = dyn_cast(op)) + return reshape.getAllowReorder(); + + return false; +} + +void LayoutPropagation::initAnchorLayout() { + auto maybeAddAnchor = [&](Value v) { + if (auto tensorType = dyn_cast(v.getType())) { + // Workaround, don't popagate MMA layout unless there is a convert + // back to mma further down to avoid generating reduction with MMA + // layout that may have lower performance. + // This can be improved with more aggressive backward propagation. + if (isa(tensorType.getEncoding()) && + v.getDefiningOp() && + !hasConvertToMMATransisitiveUse(v.getDefiningOp(), + tensorType.getEncoding())) { + return; + } + layouts.insert({v, LayoutInfo(tensorType.getEncoding())}); + } + }; + + // Consider function args as anchors. This makes it easier to write tests -- + // you can pass a tensor with an encoding as an arg, instead of explicitly + // calling tt.load. + for (auto arg : funcOp.getArguments()) { + maybeAddAnchor(arg); + } + + funcOp.walk([&](Operation *op) { + if (isLayoutAnchor(op)) { + for (auto result : op->getResults()) { + maybeAddAnchor(result); + } + } + }); +} + +void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, + Operation *op) { + for (Value value : values) { + if (!isa(value.getType())) + continue; + bool hasChanged = false; + for (auto encoding : info.encodings) { + std::optional dstEncoding; + if (isa(op)) { + // Try to remove the convert by making the dst encoding match the source + // encoding. + dstEncoding = encoding; + } else { + dstEncoding = inferDstEncoding(op, encoding); + } + if (dstEncoding) + hasChanged |= layouts[value].encodings.insert(*dstEncoding); + } + if (hasChanged) + changed.push_back(value); + } +} + +SmallVector LayoutPropagation::propagateToUsers(Value value, + LayoutInfo &info) { + SmallVector changed; + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (auto forOp = dyn_cast(user)) { + Value arg = forOp.getTiedLoopRegionIterArg(&use); + Value result = forOp.getTiedLoopResult(&use); + setEncoding({arg, result}, info, changed, user); + continue; + } + if (auto whileOp = dyn_cast(user)) { + Value arg = whileOp.getBeforeArguments()[use.getOperandNumber()]; + setEncoding({arg}, info, changed, user); + continue; + } + if (auto yieldOp = dyn_cast(user)) { + auto parent = yieldOp->getParentOp(); + SmallVector valuesToPropagate; + if (isa(parent)) + valuesToPropagate.push_back(parent->getResult(use.getOperandNumber())); + if (auto forOp = dyn_cast(parent)) + valuesToPropagate.push_back( + forOp.getRegionIterArg(use.getOperandNumber())); + if (auto whileOp = dyn_cast(parent)) { + valuesToPropagate.push_back( + whileOp.getBeforeArguments()[use.getOperandNumber()]); + valuesToPropagate.push_back( + whileOp->getOperand(use.getOperandNumber())); + } + if (isa(parent)) + setEncoding(valuesToPropagate, info, changed, user); + continue; + } + if (auto conditionOp = dyn_cast(user)) { + auto whileOp = cast(conditionOp->getParentOp()); + // Skip arg 0 as it is the condition. + unsigned argIndex = use.getOperandNumber() - 1; + Value afterArg = whileOp.getAfterArguments()[argIndex]; + Value result = whileOp->getResult(argIndex); + setEncoding({afterArg, result}, info, changed, user); + continue; + } + if (user->hasTrait() || + user->hasTrait() || + isa(user)) { + setEncoding(user->getResults(), info, changed, user); + continue; + } + } + return changed; +} + +void LayoutPropagation::propagateLayout() { + SmallVector queue; + for (auto it : layouts) { + queue.push_back(it.first); + } + while (!queue.empty()) { + Value currentValue = queue.back(); + LayoutInfo info = layouts[currentValue]; + queue.pop_back(); + SmallVector changed = propagateToUsers(currentValue, info); + + LLVM_DEBUG({ + DBGS() << "propagateLayout considering " << currentValue << ", which has " + << info.encodings.size() << " candidate encoding(s):\n"; + for (Attribute encoding : info.encodings) + DBGS() << " " << encoding << "\n"; + }); + + queue.insert(queue.end(), changed.begin(), changed.end()); + } +} + +void LayoutPropagation::resolveConflicts() { + for (auto &it : layouts) { + Operation *op = it.first.getDefiningOp(); + LayoutInfo &info = it.second; + if (info.encodings.size() <= 1) + continue; + // Hacky resolve, prefer block encoding. + // TODO: add a proper heuristic. + Attribute encoding = *info.encodings.begin(); + bool isLoadOrStore = + op && isa(op); + for (Attribute e : info.encodings) { + if ((isLoadOrStore && isa(e)) || + (!isLoadOrStore && isa(e))) { + encoding = e; + break; + } + } + info.encodings.clear(); + info.encodings.insert(encoding); + } +} + +void LayoutPropagation::dump() { + for (auto it : layouts) { + llvm::errs() << "Value: "; + OpPrintingFlags flags; + flags.skipRegions(); + it.first.print(llvm::errs(), flags); + llvm::errs() << " \n encoding:\n"; + for (auto encoding : it.second.encodings) { + encoding.print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "--\n"; + } +} + +void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } + +bool reduceToScalar(Operation *op) { + // For reductions returning a scalar we can change the src encoding without + // affecting the output. + return isa(op) && !isa(op->getResultTypes()[0]); +} + +void LayoutPropagation::rewriteRegion(Region ®ion) { + SmallVector queue = {®ion}; + while (!queue.empty()) { + Region *currentRegion = queue.back(); + queue.pop_back(); + for (Operation &op : currentRegion->getOps()) { + bool needRewrite = false; + SmallVector results = op.getResults(); + for (Value result : results) { + auto it = layouts.find(result); + // If we haven't mapped this value skip. + if (it == layouts.end()) + continue; + LayoutInfo &info = it->second; + assert(info.encodings.size() == 1 && + "we should have resolved to a single encoding"); + auto encoding = cast(result.getType()).getEncoding(); + // If the encoding is already what we want skip. + if (encoding == *info.encodings.begin()) + continue; + needRewrite = true; + } + if (needRewrite) { + Operation *newOp = rewriteOp(&op); + for (Region &R : newOp->getRegions()) + queue.push_back(&R); + } else if (auto yieldOp = dyn_cast(&op)) { + rewriteYieldOp(yieldOp); + } else if (auto conditionOp = dyn_cast(&op)) { + rewriteConditionOp(conditionOp); + } else if (reduceToScalar(&op)) { + rewriteReduceToScalar(&op); + } else if (auto assertOp = dyn_cast(&op)) { + rewriteAssertOp(assertOp); + } else { + // If we don't need to rewrite the op we still need to remap the + // operands. + for (OpOperand &operand : op.getOpOperands()) { + auto it = layouts.find(operand.get()); + if (it == layouts.end()) + continue; + Attribute encoding = + cast(operand.get().getType()).getEncoding(); + Value newOperand = getValueAs(operand.get(), encoding); + op.setOperand(operand.getOperandNumber(), newOperand); + } + for (Region &R : op.getRegions()) + queue.push_back(&R); + } + } + } + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +void LayoutPropagation::map(Value old, Value newV) { + rewriteMapping[{old, cast(newV.getType()).getEncoding()}] = + newV; +} + +Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { + if (auto tensorType = dyn_cast(value.getType())) { + Value rewrittenValue; + auto layoutIt = layouts.find(value); + if (layoutIt == layouts.end()) { + rewrittenValue = value; + } else { + assert(layoutIt->second.encodings.size() == 1 && + "we should have resolved to a single encoding"); + Attribute encodingPicked = *(layoutIt->second.encodings.begin()); + if (encodingPicked == tensorType.getEncoding()) + rewrittenValue = value; + else + rewrittenValue = rewriteMapping[{value, encodingPicked}]; + } + assert(rewrittenValue); + if (cast(rewrittenValue.getType()).getEncoding() == + encoding) + return rewrittenValue; + OpBuilder rewriter(value.getContext()); + rewriter.setInsertionPointAfterValue(rewrittenValue); + auto tmpType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + Value converted = rewriter.create(value.getLoc(), tmpType, + rewrittenValue); + // TODO: we could cache the conversion. + return converted; + } + return value; +} + +Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, + Operation *op, + Attribute encoding) { + Operation *newOp = rewriter.clone(*op); + + std::optional operandEnc; + if (op->getNumOperands() > 0) { + operandEnc = inferSrcEncoding(op, encoding); + assert(operandEnc.has_value()); + } + + for (OpOperand &operand : op->getOpOperands()) { + newOp->setOperand(operand.getOperandNumber(), + getValueAs(operand.get(), *operandEnc)); + } + + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { + auto origType = dyn_cast(op->getResult(i).getType()); + if (!origType) + continue; + auto newType = RankedTensorType::get(origType.getShape(), + origType.getElementType(), encoding); + newOp->getResult(i).setType(newType); + } + return newOp; +} + +Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) { + SmallVector operands; + OpBuilder rewriter(forOp); + for (auto [operand, result] : + llvm::zip(forOp.getInitArgs(), forOp.getResults())) { + Value convertedOperand = operand; + if (layouts.count(result)) + convertedOperand = + getValueAs(operand, *layouts[result].encodings.begin()); + operands.push_back(convertedOperand); + } + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), operands); + newForOp->setAttrs(forOp->getAttrs()); + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + forOp.getBody()->getOperations()); + + for (auto [oldResult, newResult] : + llvm::zip(forOp.getResults(), newForOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + + for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(), + newForOp.getBody()->getArguments())) { + if (oldArg.getType() == newArg.getType()) { + oldArg.replaceAllUsesWith(newArg); + continue; + } + map(oldArg, newArg); + } + return newForOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) { + SmallVector operands; + SmallVector returnTypes; + OpBuilder rewriter(whileOp); + for (auto [operand, arg] : + llvm::zip(whileOp->getOperands(), whileOp.getBeforeArguments())) { + Value convertedOperand = operand; + if (layouts.count(arg)) + convertedOperand = getValueAs(operand, *layouts[arg].encodings.begin()); + operands.push_back(convertedOperand); + } + for (Value ret : whileOp.getResults()) { + auto it = layouts.find(ret); + if (it == layouts.end()) { + returnTypes.push_back(ret.getType()); + continue; + } + auto origType = dyn_cast(ret.getType()); + auto newType = + RankedTensorType::get(origType.getShape(), origType.getElementType(), + it->second.encodings[0]); + returnTypes.push_back(newType); + } + + auto newWhileOp = + rewriter.create(whileOp.getLoc(), returnTypes, operands); + SmallVector argsTypesBefore; + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + SmallVector bbArgLocsBefore(argsTypesBefore.size(), + whileOp.getLoc()); + SmallVector bbArgLocsAfter(returnTypes.size(), whileOp.getLoc()); + rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter); + + for (int i = 0; i < whileOp.getNumRegions(); ++i) { + newWhileOp->getRegion(i).front().getOperations().splice( + newWhileOp->getRegion(i).front().getOperations().begin(), + whileOp->getRegion(i).front().getOperations()); + } + + auto remapArg = [&](Value oldVal, Value newVal) { + if (oldVal.getType() == newVal.getType()) + oldVal.replaceAllUsesWith(newVal); + else + map(oldVal, newVal); + }; + for (auto [oldResult, newResult] : + llvm::zip(whileOp.getResults(), newWhileOp.getResults())) + remapArg(oldResult, newResult); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getBeforeArguments(), newWhileOp.getBeforeArguments())) + remapArg(oldArg, newArg); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getAfterArguments(), newWhileOp.getAfterArguments())) + remapArg(oldArg, newArg); + return newWhileOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { + SmallVector operands; + OpBuilder rewriter(ifOp); + SmallVector newResultTypes(ifOp->getResultTypes()); + for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) { + auto it = layouts.find(ifOp->getResult(i)); + if (it == layouts.end()) + continue; + auto origType = cast(ifOp->getResult(i).getType()); + Attribute encoding = *(it->second.encodings.begin()); + newResultTypes[i] = RankedTensorType::get( + origType.getShape(), origType.getElementType(), encoding); + } + auto newIfOp = rewriter.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), true, true); + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + for (auto [oldResult, newResult] : + llvm::zip(ifOp.getResults(), newIfOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newIfOp.getOperation(); +} + +void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { + Operation *parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + Type yieldType = operand.get().getType(); + if (isa(parentOp)) + yieldType = parentOp->getResult(operand.getOperandNumber()).getType(); + if (auto whileOp = dyn_cast(parentOp)) + yieldType = + whileOp.getBeforeArguments()[operand.getOperandNumber()].getType(); + auto tensorType = dyn_cast(yieldType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + yieldOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) { + scf::WhileOp whileOp = cast(conditionOp->getParentOp()); + for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) { + OpOperand &operand = conditionOp->getOpOperand(i); + Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType(); + auto tensorType = dyn_cast(argType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + conditionOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) { + OpBuilder rewriter(reduceOp); + Attribute srcEncoding; + // Since all the operands need to have the same encoding pick the first one + // and use it for all the operands. + for (Value operand : reduceOp->getOperands()) { + auto it = layouts.find(operand); + if (it != layouts.end()) { + srcEncoding = it->second.encodings[0]; + break; + } + } + if (!srcEncoding) + return; + for (OpOperand &operand : reduceOp->getOpOperands()) { + Value newOperand = getValueAs(operand.get(), srcEncoding); + reduceOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { + Attribute srcEncoding; + // Only need to deal with the first operand which is the condition tensor. + Value operand = assertOp->getOperand(0); + auto it = layouts.find(operand); + if (it == layouts.end()) + return; + srcEncoding = it->second.encodings[0]; + Value newOperand = getValueAs(operand, srcEncoding); + assertOp->setOperand(0, newOperand); +} + +Operation *LayoutPropagation::rewriteOp(Operation *op) { + opToDelete.insert(op); + if (auto forOp = dyn_cast(op)) + return rewriteForOp(forOp); + if (auto whileOp = dyn_cast(op)) + return rewriteWhileOp(whileOp); + if (auto ifOp = dyn_cast(op)) + return rewriteIfOp(ifOp); + OpBuilder rewriter(op); + Attribute encoding = *layouts[op->getResult(0)].encodings.begin(); + if (auto convertOp = dyn_cast(op)) { + Attribute srcEncoding = convertOp.getSrc().getType().getEncoding(); + auto it = layouts.find(convertOp.getSrc()); + if (it != layouts.end()) + srcEncoding = *(it->second.encodings.begin()); + Value src = getValueAs(convertOp.getSrc(), srcEncoding); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, src); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (canFoldIntoConversion(op, encoding)) { + Operation *newOp = rewriter.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, + newOp->getResult(0)); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (op->hasTrait() || + op->hasTrait() || + isa(op)) { + Operation *newOp = cloneElementwise(rewriter, op, encoding); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newOp; + } + llvm::report_fatal_error("unexpected op in rewrite"); + return nullptr; +} + +bool canBeRemat(Operation *op) { + if (isa(op)) + return !isExpensiveLoadOrStore(op); + if (isa(op)) + return false; + if (isa(op)) + return false; + + return true; +} + +void LayoutRematerialization::updateRematMapping( + SmallVector> &values) { + for (auto [old, newV] : values) { + auto it = mappedValues.find(old); + if (it != mappedValues.end()) { + Attribute encoding = it->second; + auto rematIt = rematMapping.find({old, it->second}); + assert(rematIt != rematMapping.end()); + Value replacedValue = rematIt->second; + rematMapping.erase(rematIt); + mappedValues.erase(it); + // Loop through the replacement value to find the new version of remat + // value. This should be okay as the number of values should be small. + for (auto [before, after] : values) { + if (before == replacedValue) { + replacedValue = after; + break; + } + } + rematMapping[{newV, encoding}] = replacedValue; + mappedValues[newV] = encoding; + } + } +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp, + IRMapping &mapping) { + SetVector opsToRewrite; + // Keep track of yield operands that need to be duplicated. + DenseMap> yieldOperandsMap; + // Keep these around to remove them from the slice after our collection pass + // This ensures we don't duplicate them during an for rewrite or causing the + // for/yield to fall out of sync + SetVector valuesWithExistingRemat; + for (Value v : slice) { + auto layoutIt = layout.find(v); + assert(layoutIt != layout.end()); + // If we already have a remat value for this value, use it. + if (hasRematValue(v, layoutIt->second)) { + mapping.map(v, getRematValue(v, layoutIt->second)); + valuesWithExistingRemat.insert(v); + continue; + } + if (v.getDefiningOp()) { + opsToRewrite.insert(v.getDefiningOp()); + if (auto ifOp = v.getDefiningOp()) { + unsigned operandIdx = cast(v).getResultNumber(); + opsToRewrite.insert(ifOp.thenYield().getOperation()); + yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx); + opsToRewrite.insert(ifOp.elseYield().getOperation()); + yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx); + } + } else { + BlockArgument blockArg = cast(v); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (auto loopOp = cast(parentOp)) { + opsToRewrite.insert(loopOp.getOperation()); + OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg); + auto yieldOp = blockArg.getOwner()->getTerminator(); + yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber()); + opsToRewrite.insert(yieldOp); + } + } + } + slice.set_subtract(valuesWithExistingRemat); + opsToRewrite = multiRootTopologicalSort(opsToRewrite); + + // replaceAllUsesWith calls delayed until after initial rewrite. + // This is required for slice.count(value) to work mid rewrite. + SmallVector> replacements; + + SmallVector deadOps; + IRRewriter builder(slice.begin()->getContext()); + for (Operation *op : opsToRewrite) { + if (auto forOp = dyn_cast(op)) { + // Keep a mapping of the operands index to the new operands index. + SmallVector> argMapping; + SmallVector newOperands; + for (auto arg : forOp.getRegionIterArgs()) { + if (slice.count(arg)) { + OpOperand &initVal = *forOp.getTiedLoopInit(arg); + argMapping.push_back(std::make_pair( + forOp.getTiedLoopResult(&initVal).getResultNumber(), + forOp.getInitArgs().size() + newOperands.size())); + newOperands.push_back(mapping.lookup(initVal.get())); + } + } + // Create a new for loop with the new operands. + scf::ForOp newForOp = replaceForOpWithNewSignature( + builder, forOp, newOperands, replacements); + deadOps.push_back(forOp.getOperation()); + Block &loopBody = *newForOp.getBody(); + for (auto m : argMapping) { + mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second)); + int numIndVars = newForOp.getNumInductionVars(); + mapping.map(loopBody.getArgument(m.first + numIndVars), + loopBody.getArgument(m.second + numIndVars)); + LLVM_DEBUG({ + DBGS() << "mapping forOp " + << loopBody.getArgument(m.first + numIndVars) << " to " + << loopBody.getArgument(m.second + numIndVars) << '\n'; + }); + // The result is not in the layout/slice, the argument is. + Value oldArg = loopBody.getArgument(m.first + numIndVars); + addRematValue(newForOp.getResult(m.first), layout[oldArg], + newForOp.getResult(m.second)); + addRematValue(oldArg, layout[oldArg], + loopBody.getArgument(m.second + numIndVars)); + } + continue; + } + if (auto ifOp = dyn_cast(op)) { + SmallVector newTypes; + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + auto it = layout.find(res); + assert(it != layout.end()); + + auto oldType = cast(res.getType()); + auto newType = RankedTensorType::get( + oldType.getShape(), oldType.getElementType(), it->second); + newTypes.push_back(newType); + } + } + scf::IfOp newIfOp = + replaceIfOpWithNewSignature(builder, ifOp, newTypes, replacements); + unsigned oldIdx = 0; + unsigned newIdx = ifOp.getNumResults(); + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + // Why can't we use res instead of ifOp.getResult(oldIdx)? + mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx)); + addRematValue(ifOp.getResult(oldIdx), layout[res], + newIfOp.getResult(newIdx)); + ++newIdx; + } + ++oldIdx; + } + deadOps.push_back(ifOp.getOperation()); + continue; + } + builder.setInsertionPoint(op); + if (auto yieldOp = dyn_cast(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + SmallVector operandsToRewrite = yieldOperandsMap[op]; + // Sort so that operands are added in the same order as the new scf + // results/arguments. + std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); + for (int operandIdx : operandsToRewrite) { + yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); + } + builder.create(op->getLoc(), yieldOperands); + op->erase(); + continue; + } + if (isa(op)) { + Operation *newOp = builder.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), + layout[op->getResult(0)]); + auto cvt = builder.create(op->getLoc(), newType, + newOp->getResult(0)); + mapping.map(op->getResult(0), cvt.getResult()); + addRematValue(op->getResult(0), layout[op->getResult(0)], + cvt.getResult()); + continue; + } + Operation *newOp = builder.clone(*op, mapping); + for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { + auto it = layout.find(old); + if (it == layout.end()) + continue; + auto newType = RankedTensorType::get( + cast(old.getType()).getShape(), + cast(old.getType()).getElementType(), it->second); + newV.setType(newType); + addRematValue(old, it->second, newV); + } + } + // Check mapping and see if there are existing convertOps on the old Argument + convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc())); + opToDelete.insert(convertOp); + + updateRematMapping(replacements); + for (auto &kv : replacements) { + builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + + for (Operation *op : deadOps) + opToDelete.insert(op); +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp) { + IRMapping mapping; + rewriteSlice(slice, layout, convertOp, mapping); +} + +LogicalResult getRematerializableSlice( + Value root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation = nullptr) { + LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding, + layout, stopPropagation); + if (result.failed() || slice.empty()) + return failure(); + + // Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return failure(); + } + } + return success(); +} + +void LayoutRematerialization::backwardRematerialization() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + backwardRematerialization(convertOp); + } +} + +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertOnTopOfExtOrBroadcast(convertOp); + } +} + +void LayoutRematerialization::backwardRematerialization( + ConvertLayoutOp convertOp) { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristic to accommodate fused attention + RankedTensorType targetType = convertOp.getType(); + if (isa(targetType.getEncoding())) + return; + Value oldV = convertOp->getOperand(0); + LDBG("check backward remat with source " << oldV << " encoding " + << targetType.getEncoding()); + // Check to see if there are existing remat'ed values for the pair of oldValue + // and encoding. + if (hasRematValue(oldV, targetType.getEncoding())) { + // Replace it with the remat'ed value. + Value newV = getRematValue(oldV, targetType.getEncoding()); + convertOp.replaceAllUsesWith(newV); + opToDelete.insert(convertOp); + LDBG("found remat'ed value" << newV); + return; + } + + // 1. Take a backward slice of all the tensor dependencies that can be + // rematerialized. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getSrc(), targetType.getEncoding(), slice, layout); + if (result.failed()) { + LDBG(" getRematerializableSlice failed"); + return; + } + + LLVM_DEBUG({ + DBGS() << " remat convert op " << convertOp << '\n'; + for (Value v : slice) + DBGS() << " " << v << '\n'; + }); + // 2. Rewrite the slice. + rewriteSlice(slice, layout, convertOp); +} + +// For convert left we try to hoist them above type extension to reduce the cost +// of the convert. +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( + ConvertLayoutOp convertOp) { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + RankedTensorType targetType = convertOp.getType(); + if (mlir::isa(targetType.getEncoding())) + return; + + auto isExtOrBroadcastOp = [](Operation *op) { + if (isa(op)) { + return true; + } + if (auto fpToFpOp = dyn_cast(op)) { + auto srcType = cast(fpToFpOp.getOperand().getType()); + return getElementBitWidth(srcType) < + getElementBitWidth(fpToFpOp.getType()); + } + return false; + }; + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + LogicalResult result = + getRematerializableSlice(convertOp.getSrc(), targetType.getEncoding(), + slice, layout, isExtOrBroadcastOp); + if (result.failed()) + return; + + Operation *extOrBroadcatOp = nullptr; + unsigned sliceSize = slice.size(); + for (unsigned i = 0; i < sliceSize; i++) { + Value v = slice[i]; + Operation *op = v.getDefiningOp(); + if (!op) + continue; + if (isExtOrBroadcastOp(op)) { + SetVector tempSlice; + DenseMap tempLayout; + std::optional srcEncoding = inferSrcEncoding(op, layout[v]); + if (!srcEncoding) + return; + LogicalResult result = getRematerializableSlice( + op->getOperand(0), *srcEncoding, tempSlice, tempLayout); + // If we can rematerialize the rest of the ext slice we can ignore this + // ext as it won't need a convert. + if (result.succeeded()) { + slice.insert(tempSlice.begin(), tempSlice.end()); + layout.insert(tempLayout.begin(), tempLayout.end()); + continue; + } + // Only apply it if there is a single ext op otherwise we would have to + // duplicate the convert. + if (extOrBroadcatOp != nullptr) + return; + extOrBroadcatOp = op; + } + } + + if (extOrBroadcatOp == nullptr) + return; + Attribute dstEncoding = layout[extOrBroadcatOp->getResult(0)]; + std::optional srcEncoding = + inferSrcEncoding(extOrBroadcatOp, dstEncoding); + if (!srcEncoding) + return; + // Move the convert before the ext op and rewrite the slice. + OpBuilder builder(extOrBroadcatOp); + auto tensorType = + cast(extOrBroadcatOp->getOperand(0).getType()); + auto newType = RankedTensorType::get( + tensorType.getShape(), tensorType.getElementType(), *srcEncoding); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); + Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp); + newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); + auto oldExtOrBroadcastType = + cast(extOrBroadcatOp->getResult(0).getType()); + Type newExtOrBroadcasrType = RankedTensorType::get( + oldExtOrBroadcastType.getShape(), oldExtOrBroadcastType.getElementType(), + dstEncoding); + newExtOrBroadcast->getResult(0).setType(newExtOrBroadcasrType); + IRMapping mapping; + mapping.map(extOrBroadcatOp->getResult(0), newExtOrBroadcast->getResult(0)); + slice.remove(extOrBroadcatOp->getResult(0)); + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp, mapping); +} + +void backwardRematerialization(ModuleOp module) { + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.backwardRematerialization(); + layoutRemat.cleanup(); + }); +} + +void hoistConvert(ModuleOp module) { + SmallVector convertOps; + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.hoistConvertOnTopOfExtOrBroadcast(); + layoutRemat.cleanup(); + }); +} +} // namespace + +class TritonGPURemoveLayoutConversionsPass + : public impl::TritonGPURemoveLayoutConversionsBase< + TritonGPURemoveLayoutConversionsPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + // 1. Propagate layout forward starting from "anchor" ops. + m.walk([](FuncOp funcOp) { + LayoutPropagation layoutPropagation(funcOp); + layoutPropagation.initAnchorLayout(); + layoutPropagation.propagateLayout(); + layoutPropagation.resolveConflicts(); + layoutPropagation.rewrite(); + }); + + LLVM_DEBUG({ + DBGS() << "Module after propagating layouts forward:\n"; + m.dump(); + }); + + RewritePatternSet cleanUpPatterns(context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context); + if (applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns)).failed()) { + signalPassFailure(); + } + + LLVM_DEBUG({ + DBGS() << "Module after canonicalizing:\n"; + m.dump(); + }); + + // 2. For remaining convert ops, try to rematerialize the slice of producer + // operation to avoid having to convert. + backwardRematerialization(m); + LLVM_DEBUG({ + DBGS() << "Module after backward remat:\n"; + m.dump(); + }); + + // 3. For remaining converts, try to hoist them above cast generating larger + // size types in order to reduce the cost of the convert op. + hoistConvert(m); + LLVM_DEBUG({ + DBGS() << "Module after hoisting converts:\n"; + m.dump(); + }); + + RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); + if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) + .failed()) { + signalPassFailure(); + } + LLVM_DEBUG({ + DBGS() << "Module after decomposing dot-converts:\n"; + m.dump(); + }); + + // 4. Apply clean up patterns to remove remove dead convert and dead code + // generated by the previous transformations. + RewritePatternSet cleanUpPatterns2(context); + populateForOpDeadArgumentElimination(cleanUpPatterns2); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + if (applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns2)).failed()) { + signalPassFailure(); + } + LLVM_DEBUG({ + DBGS() << "Module after final cleanups:\n"; + m.dump(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp new file mode 100644 index 000000000..bff277c59 --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp @@ -0,0 +1,140 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREORDERINSTRUCTIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static bool willIncreaseRegisterPressure(Operation *op) { + if (isa(op)) + return true; + auto cvt = dyn_cast(op); + if (!cvt) + return false; + if (mlir::isa( + cvt.getType().getEncoding())) + return true; + return false; +} + +class TritonGPUReorderInstructionsPass + : public impl::TritonGPUReorderInstructionsBase< + TritonGPUReorderInstructionsPass> { +public: + TritonGPUReorderInstructionsPass() = default; + + Operation *getFirstUse(Operation *op) { + std::vector users; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + users.push_back(ancestor); + } + auto minOpIt = std::min_element(users.begin(), users.end(), + [](mlir::Operation *a, mlir::Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != users.end() ? *minOpIt : nullptr; + } + + void runOnOperation() override { + ModuleOp m = getOperation(); + mlir::DominanceInfo dom(m); + // sink conversion after the last dealloc + // before the first use ancestor in its block + m.walk([&](triton::gpu::ConvertLayoutOp op) { + auto curr = mlir::Block::iterator(op); + for (; &*curr != getFirstUse(op); curr++) + if (isa(&*curr)) + op->moveAfter(&*curr); + }); + // Sink conversions into loops when they will increase + // register pressure + DenseMap opToMove; + auto moveAfter = [](Operation *lhs, Operation *rhs) { + lhs->moveAfter(rhs); + }; + m.walk([&](Operation *op) { + if (!willIncreaseRegisterPressure(op)) + return; + auto user_begin = op->user_begin(); + auto user_end = op->user_end(); + if (std::distance(user_begin, user_end) != 1) + return; + if (user_begin->getParentOfType() == + op->getParentOfType()) + return; + opToMove.insert({op, *user_begin}); + }); + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); + // Move alloc(load) immediately after dependent load + m.walk([&](triton::gpu::LocalAllocOp op) { + if (!op.getSrc()) + return; + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + moveAfter(op, argOp); + }); + // Move transpositions just after their definition + opToMove.clear(); + m.walk([&](triton::TransOp op) { + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + moveAfter(op, argOp); + }); + // Move `dot` operand so that conversions to opIdx=1 happens after + // conversions to opIdx=0 + m.walk([&](triton::gpu::LocalLoadOp op) { + auto dstEncoding = mlir::dyn_cast( + op.getType().getEncoding()); + if (!dstEncoding) + return; + int opIdx = dstEncoding.getOpIdx(); + if (opIdx != 1) + return; + if (!op->hasOneUse()) + return; + auto dotUser = dyn_cast(*op->user_begin()); + if (!dotUser) + return; + auto AOp = + dotUser.getOperand(0).getDefiningOp(); + if (!AOp) + return; + // Check that the conversion to OpIdx=1 happens before and can be moved + // after the conversion to OpIdx=0. + if (!dom.dominates(op.getOperation(), AOp.getOperation())) + return; + moveAfter(op, AOp); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Utility.cpp new file mode 100644 index 000000000..286b60d4d --- /dev/null +++ b/third_party/mthreads/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -0,0 +1,976 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "ttg-utility" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { + +using namespace triton; + +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + TensorOrMemDesc type, + int numWarps) { + if (version == 1) + return {16, 16}; + else if (version == 2) { + auto rank = shape.size(); + SmallVector ret(rank, 1); + ret[rank - 1] = 8; + ret[rank - 2] = 16; + return ret; + } else if (version == 3) { + unsigned k = 256 / type.getElementTypeBitWidth(); + if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { + assert(false && "type not supported"); + return {0, 0, 0}; + } + auto eltType = type.getElementType(); + SmallVector validN; + + // MMAv3 with larger instruction shape is preferred. + if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, + 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, + 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); + } + + if (eltType.isInteger(8)) { + validN.assign({224, 208, 192, 176, 160, 144, 128, 112, 96, 80, 64, 48, 32, + 24, 16, 8}); + } + + unsigned m = 16; + unsigned mWarps = std::max(shape[0] / m, 1); + unsigned nWarps = std::max(numWarps / mWarps, 1); + unsigned maxN = std::max(shape[1] / nWarps, 8); + for (auto n : validN) { + if (shape[1] % n == 0 && n <= maxN) { + return {m, n, k}; + } + } + + assert(false && "type not supported"); + return {0, 0, 0}; + } else { + assert(false && "version not supported"); + return {0, 0}; + } +} + +bool isLoadFromTensorPtr(triton::LoadOp op) { + return mlir::triton::isTensorPointerType(op.getPtr().getType()); +} + +SmallVector argSort(const SmallVector &arr) { + SmallVector ret(arr.size()); + std::iota(ret.begin(), ret.end(), 0); + std::stable_sort(ret.begin(), ret.end(), + [&](unsigned x, unsigned y) { return arr[x] > arr[y]; }); + return ret; +} + +Value getMemAccessPtr(Operation *op) { + if (auto ld = dyn_cast(op)) + return ld.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto copy = dyn_cast(op)) + return copy.getSrc(); + if (auto store = dyn_cast(op)) + return store.getPtr(); + return nullptr; +} + +unsigned getElementBitWidth(RankedTensorType type) { + auto typeForMem = + isa(type.getElementType()) + ? cast(type.getElementType()).getPointeeType() + : type.getElementType(); + return typeForMem.getIntOrFloatBitWidth(); +} + +unsigned getNumElementsPerThread(Operation *op, SmallVector order, + ModuleAxisInfoAnalysis &axisInfoAnalysis) { + Value val = getMemAccessPtr(op); + auto ty = cast(val.getType()); + auto shapePerCTA = triton::gpu::getShapePerCTA(ty); + AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + unsigned elemNumBits = getElementBitWidth(ty); + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); + unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); + unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); + unsigned maxContig = + std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); + unsigned alignment = std::min(maxMultiple, maxContig); + unsigned currPerThread = std::min(alignment, 128 / elemNumBits); + LDBG("elemNumBytes: " << elemNumBytes + << ", divisibility: " << maxMultipleBytes + << ", contig: " << valInfo.getContiguity(order[0]) + << ", alignment: " << alignment); + return currPerThread; +} + +//===----------------------------------------------------------------------===// +// GraphDumper +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphDumper::onValue(Value value) const { + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +GraphDumper::NodeInfo GraphDumper::onOperation(Operation *op) const { + return {{"shape", "ellipse"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +std::string GraphDumper::dump(triton::FuncOp func) const { + llvm::SetVector values; + llvm::SetVector operations; + + func.walk([&](Operation *op) { + operations.insert(op); + for (Value operand : op->getOperands()) + values.insert(operand); + for (Value result : op->getResults()) + values.insert(result); + }); + + std::ostringstream oss; + oss << "// Generated by Triton GraphDumper\n" + << "\n" + << "digraph {\n"; + + oss << " // Value Nodes\n"; + for (Value value : values) + oss << " " << emitValueNode(value) << "\n"; + oss << "\n"; + + oss << " // Operation Nodes\n"; + for (Operation *op : operations) + oss << " " << emitOperationNode(op) << "\n"; + oss << "\n"; + + oss << " // Edges\n"; + for (Operation *op : operations) { + for (Value operand : op->getOperands()) + oss << " " << emitEdge(getUniqueId(operand), getUniqueId(op)) << "\n"; + for (Value result : op->getResults()) + oss << " " << emitEdge(getUniqueId(op), getUniqueId(result)) << "\n"; + } + + oss << "}\n"; + return oss.str(); +} + +void GraphDumper::dumpToFile(triton::FuncOp func, + const std::string &filename) const { + std::ofstream ofs(filename); + ofs << dump(func); +} + +std::string GraphDumper::getShapeStr(const Type &type) const { + std::ostringstream oss; + oss << "["; + if (auto tensorTy = dyn_cast(type)) { + auto shape = tensorTy.getShape(); + for (unsigned i = 0; i < shape.size(); ++i) { + if (i > 0) + oss << ", "; + oss << shape[i]; + } + } + oss << "]"; + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Value value) const { + std::ostringstream oss; + oss << value.getImpl(); + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Operation *op) const { + std::ostringstream oss; + oss << op; + return oss.str(); +} + +std::string GraphDumper::emitNode(const std::string &id, + const GraphDumper::NodeInfo info) const { + std::ostringstream oss; + oss << "\"" << id << "\" ["; + for (auto it = info.begin(); it != info.end(); ++it) { + if (it != info.begin()) + oss << ", "; + oss << it->first << " = \"" << it->second << "\""; + } + oss << "];"; + return oss.str(); +} + +std::string GraphDumper::emitEdge(const std::string &srcId, + const std::string &destId) const { + std::ostringstream oss; + oss << "\"" << srcId << "\" -> \"" << destId << "\";"; + return oss.str(); +} + +std::string GraphDumper::emitValueNode(Value value) const { + NodeInfo info = onValue(value); + if (info.find("label") == info.end()) { + std::string shapeStr = getShapeStr(value.getType()); + if (auto arg = mlir::dyn_cast(value)) + info["label"] = + "BlockArg" + std::to_string(arg.getArgNumber()) + " " + shapeStr; + else + info["label"] = shapeStr; + } + return emitNode(getUniqueId(value), info); +} + +std::string GraphDumper::emitOperationNode(Operation *op) const { + NodeInfo info = onOperation(op); + if (info.find("label") == info.end()) + info["label"] = op->getName().getStringRef().str(); + return emitNode(getUniqueId(op), info); +} + +//===----------------------------------------------------------------------===// +// GraphLayoutMarker +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphLayoutMarker::onValue(Value value) const { + std::string color = getColor(value.getType()); + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", color}}; +} + +std::string GraphLayoutMarker::getColor(const Type &type) const { + if (auto tensorTy = dyn_cast(type)) { + auto layout = tensorTy.getEncoding(); + if (isa(layout)) + return "green"; + else if (isa(layout)) + return "yellow"; + else if (isa(layout)) + return "lightslateblue"; + else if (isa(layout)) + return "orange"; + else if (isa(layout)) + return "orangered"; + else { + llvm::report_fatal_error("Unrecognized layout"); + return "unknown"; + } + } else { + return "white"; + } +} +// -------------------------------------------------------------------------- // + +static std::optional inferDstEncoding(triton::ReduceOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding); +} + +static std::optional inferDstEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return std::nullopt; + if (op.getAxis() != sliceEncoding.getDim()) + return std::nullopt; + return sliceEncoding.getParent(); +} + +static std::optional inferDstEncoding(JoinOp op, Attribute srcEnc) { + Attribute dstEnc; + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferJoinOpEncoding(srcEnc, dstEnc, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return std::nullopt; +} + +static std::optional inferDstEncoding(SplitOp op, Attribute srcEnc) { + Attribute dstEnc; + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(srcEnc, dstEnc, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return std::nullopt; +} + +static std::optional inferSrcEncoding(triton::ReduceOp op, + Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return std::nullopt; + if (op.getAxis() != sliceEncoding.getDim()) + return std::nullopt; + return sliceEncoding.getParent(); +} + +static std::optional inferSrcEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding); +} + +static std::optional inferSrcEncoding(JoinOp op, Attribute dstEnc) { + // Split is the inverse of join. + Attribute srcEnc; + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return std::nullopt; +} + +static std::optional inferSrcEncoding(SplitOp op, Attribute dstEnc) { + // Join is the inverse of split. + Attribute srcEnc; + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferJoinOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return std::nullopt; +} + +static std::optional +inferTransOpDstEncoding(Attribute srcEnc, ArrayRef order) { + // Simply forward to the existing inferTransOpEncoding function. + Attribute retEncoding; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferTransOpEncoding(srcEnc, order, retEncoding))) { + return retEncoding; + } + return std::nullopt; +} + +static std::optional inferDstEncoding(triton::TransOp op, + Attribute encoding) { + return inferTransOpDstEncoding(encoding, op.getOrder()); +} + +static std::optional inferSrcEncoding(triton::TransOp op, + Attribute encoding) { + // We want to solve for srcEnc in + // transpose(srcEnc, order) -> dstEnc. + // Given the identity + // transpose(transpose(x, order), inverse(order)) == x, + // we can see this is equivalent to + // transpose(dstEnc, inverse(order)) -> srcEnc. + return inferTransOpDstEncoding(encoding, + triton::inversePermutation(op.getOrder())); +} + +static std::optional +inferReshapeOpDstEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, bool allowReorder) { + // We don't do anything smart to allow-reorder reshapes here. They are + // handled in OptimizeThreadLocality. + if (allowReorder) + return std::nullopt; + + Attribute dstEnc; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpNoReorderEncoding( + srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) { + return dstEnc; + } + return std::nullopt; +} + +static std::optional inferDstEncoding(triton::ReshapeOp op, + Attribute encoding) { + return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding, + op.getType().getShape(), + op.getAllowReorder()); +} + +static std::optional inferSrcEncoding(triton::ReshapeOp op, + Attribute encoding) { + // The encoding of x given the encoding of y in `reshape(x) -> y` is the same + // as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an + // invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this + // way. + return inferReshapeOpDstEncoding(op.getType().getShape(), encoding, + op.getSrc().getType().getShape(), + op.getAllowReorder()); +} + +std::optional inferSrcEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + // Scan only supports blocked encoding at the moment. + if (!isa(encoding)) + return std::nullopt; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op)) { + return encoding; + } + + if (auto reduceOp = dyn_cast(op)) + return inferSrcEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferSrcEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferSrcEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferSrcEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferSrcEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferSrcEncoding(reshape, encoding); + + return std::nullopt; +} + +std::optional inferDstEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + if (!isa(encoding)) + return std::nullopt; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op)) + return encoding; + if (auto reduceOp = dyn_cast(op)) + return inferDstEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferDstEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferDstEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferDstEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferDstEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferDstEncoding(reshape, encoding); + + return std::nullopt; +} + +bool isSingleValue(Value value) { + // Don't consider load as expensive if it is loading a scalar. + if (auto tensorTy = dyn_cast(value.getType())) + return tensorTy.getNumElements() == 1; + // TODO: Handle other cases. + // For example, when ptr is a tensor of single value. + // It means that ptr is a resultant of broadcast or generated through + // a chain of broadcast and other operations. + // Rematerialize it without considering contiguous memory access pattern is + // fine. + return true; +} + +bool isExpensiveLoadOrStore(Operation *op) { + // Case 1: Pointer of tensor is always expensive + auto operandType = op->getOperand(0).getType(); + if (triton::isTensorPointerType(operandType)) + return true; + // Case 2a: A size 1 tensor is not expensive since all threads will load the + // same + if (isSingleValue(op->getOperand(0))) + return false; + // Case 2b: Tensor of pointers has more threads than elements + // we can presume a high hit-rate that makes it cheap to load + auto ptrType = cast(op->getOperand(0).getType()); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (ptrType.getNumElements() < numWarps * threadsPerWarp) + return false; + return true; +} + +bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { + if (!op) + return true; + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return triton::gpu::isExpensiveCat(cast(op), targetEncoding); + if (isa(op)) + return true; + if (isa( + op)) + return true; + return false; +} + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { + if (isa(op)) + return !triton::gpu::isExpensiveCat(cast(op), + targetEncoding); + if (auto convert = dyn_cast(op)) { + if (mlir::isa(targetEncoding)) { + auto srcEncoding = convert.getSrc().getType().getEncoding(); + if (targetEncoding != srcEncoding) + return false; + } + return true; + } + + if (auto reshape = dyn_cast(op)) { + auto reshapeDstType = reshape.getType(); + RankedTensorType newDstType = + RankedTensorType::get(reshapeDstType.getShape(), + reshapeDstType.getElementType(), targetEncoding); + return reshape.getAllowReorder() && + !reshape.getEfficientLayout().has_value() && + !triton::gpu::isExpensiveView(reshape.getSrc().getType(), + newDstType); + } + return isa(op); +} + +scf::ForOp replaceForOpWithNewSignature( + RewriterBase &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInitArgs()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands); + newLoop->setAttrs(loop->getAttrs()); + newLoop.getBody()->erase(); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); + for (Value operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + return newLoop; +} + +scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, + ValueRange newIterOperands) { + SmallVector> replacements; + auto newForOp = replaceForOpWithNewSignature(rewriter, loop, newIterOperands, + replacements); + for (auto &kv : replacements) { + rewriter.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + return newForOp; +} + +scf::IfOp replaceIfOpWithNewSignature( + RewriterBase &rewriter, scf::IfOp ifOp, TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(ifOp); + + // Create a new loop before the existing one, with the extra operands. + auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes()); + resultTypes.append(newResultTypes.begin(), newResultTypes.end()); + scf::IfOp newIf = rewriter.create( + ifOp.getLoc(), resultTypes, ifOp.getCondition(), /*withElse=*/true); + newIf->setAttrs(ifOp->getAttrs()); + + rewriter.inlineBlockBefore(ifOp.thenBlock(), newIf.thenBlock(), + newIf.thenBlock()->begin()); + rewriter.inlineBlockBefore(ifOp.elseBlock(), newIf.elseBlock(), + newIf.elseBlock()->begin()); + + for (auto it : llvm::zip(ifOp.getResults(), + newIf.getResults().take_front(ifOp.getNumResults()))) + replacements.push_back(it); + return newIf; +} + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping) { + Operation *newOp = rewriter.clone(*op, mapping); + // if input types haven't changed, we're done + bool preserveTypes = + std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) { + return !mapping.contains(v) || + v.getType() == mapping.lookup(v).getType(); + }); + if (preserveTypes) + return newOp; + + if (newOp->getNumResults() == 0) + return newOp; + auto origType = dyn_cast(op->getResult(0).getType()); + auto argType = dyn_cast(newOp->getOperand(0).getType()); + if (!origType || !argType) + return newOp; + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), argType.getEncoding()); + newOp->getResult(0).setType(newType); + auto typeInfer = dyn_cast(newOp); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newOp->getContext(), newOp->getLoc(), newOp->getOperands(), + newOp->getAttrDictionary(), newOp->getPropertiesStorage(), + newOp->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newOp->getResult(i).setType(newTypes[i]); + } + } + return newOp; +} + +// Check if the convert will be a no-op in codegen. +static bool isFreeConvert(Operation *op) { + auto convertOp = dyn_cast(op); + if (!convertOp) + return false; + return isMmaToMmaShortcut(convertOp.getSrc().getType(), convertOp.getType()); +} + +LogicalResult +getConvertBackwardSlice(Value root, SetVector &slice, + Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation) { + DenseSet> seen; + SmallVector> queue; + + auto enqueue = [&](Value operand, Attribute encoding) { + auto x = std::make_pair(operand, encoding); + if (!seen.insert(x).second) { + return; // Already enqueued, skip + } + queue.push_back(x); + }; + enqueue(root, rootEncoding); + + while (!queue.empty()) { + auto [currentValue, encoding] = queue.back(); + queue.pop_back(); + if (!isa(currentValue.getType())) + continue; + // Skip propagating through for op results for now. + // TODO: enable this based on needs. + if (currentValue.getDefiningOp()) + return failure(); + slice.insert(currentValue); + if (layout.find(currentValue) != layout.end()) { + if (layout[currentValue] != encoding) + return failure(); + } + layout[currentValue] = encoding; + + if (auto ifOp = currentValue.getDefiningOp()) { + auto results = ifOp.getResults(); + unsigned argIdx = mlir::cast(currentValue).getResultNumber(); + + auto thenValue = ifOp.thenYield().getOperand(argIdx); + auto elseValue = ifOp.elseYield().getOperand(argIdx); + + enqueue(thenValue, encoding); + enqueue(elseValue, encoding); + + continue; + } + if (auto *definingOp = currentValue.getDefiningOp()) { + // If the op has multiple results we need to update all results layout. + for (Value result : definingOp->getResults()) { + if (result == currentValue || !isa(result.getType())) + continue; + enqueue(result, encoding); + } + if (!isFreeConvert(definingOp) && + canFoldIntoConversion(definingOp, encoding)) + continue; + if (stopPropagation && stopPropagation(definingOp)) + continue; + if (isa(definingOp)) + return failure(); + for (Value operand : definingOp->getOperands()) { + auto srcEncoding = inferSrcEncoding(definingOp, encoding); + if (!srcEncoding) + return failure(); + enqueue(operand, *srcEncoding); + } + continue; + } + auto blockArg = cast(currentValue); + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); + Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( + blockArg.getArgNumber() - forOp.getNumInductionVars()); + enqueue(initOperand->get(), encoding); + enqueue(yieldOperand, encoding); + continue; + } + // TODO: add support for WhileOp and other region types. + return failure(); + } + return success(); +} + +// TODO(thomas): this is duplicated with what is in GPUToLLVM +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = triton::applyPermutation(shape, order); + auto reorderedMultiDim = delinearize(b, loc, linear, reordered); + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + if (rank == 1) { + multiDim[0] = linear; + } else { + Value remained = linear; + for (auto &&en : llvm::enumerate(shape.drop_back())) { + auto dimSize = b.create(loc, en.value(), 32); + multiDim[en.index()] = b.create(loc, remained, dimSize); + remained = b.create(loc, remained, dimSize); + } + multiDim[rank - 1] = remained; + } + return multiDim; +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(b, loc, triton::applyPermutation(multiDim, order), + triton::applyPermutation(shape, order)); +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = b.create(loc, 0, 32); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = b.create(loc, dimShape, 32); + linear = b.create( + loc, b.create(loc, linear, dimSize), dim); + } + } + return linear; +} + +bool isPureUnaryInlineAsm(Operation *op) { + auto inlineAsmOp = dyn_cast(op); + if (!inlineAsmOp) + return false; + return op->getNumOperands() == 1 && op->getNumResults() == 1 && + inlineAsmOp.getPure(); +} + +int getNVIDIAComputeCapability(Operation *module) { + assert(module->hasAttr(triton::AttrTargetName) && + "Expected a target attribute on the module operation"); + + StringAttr targetAttr = + cast(module->getAttr(triton::AttrTargetName)); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("cuda:") && + "expected target attribute to be prefixed with \"cuda:\""); + + StringRef capabilityStr = ref.drop_front(5); // drop the "cuda:" + int computeCapability; + bool parseError = capabilityStr.getAsInteger(10, computeCapability); + assert(!parseError && + "invalid compute capability string in target attribute"); + + return computeCapability; +} + +namespace { + +/// Detect dead arguments in scf.for op by assuming all the values are dead and +/// propagate liveness property. +struct ForOpDeadArgElimination : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + Block &block = *forOp.getBody(); + auto yieldOp = cast(block.getTerminator()); + // Assume that nothing is live at the beginning and mark values as live + // based on uses. + DenseSet aliveValues; + SmallVector queue; + // Helper to mark values as live and add them to the queue of value to + // propagate if it is the first time we detect the value as live. + auto markLive = [&](Value val) { + if (!forOp->isAncestor(val.getParentRegion()->getParentOp())) + return; + if (aliveValues.insert(val).second) + queue.push_back(val); + }; + // Mark all yield operands as live if the associated forOp result has any + // use. + for (auto result : llvm::enumerate(forOp.getResults())) { + if (!result.value().use_empty()) + markLive(yieldOp.getOperand(result.index())); + } + if (aliveValues.size() == forOp.getNumResults()) + return failure(); + // Operations with side-effects are always live. Mark all theirs operands as + // live. + block.walk([&](Operation *op) { + if (!isa(op) && !wouldOpBeTriviallyDead(op)) { + for (Value operand : op->getOperands()) + markLive(operand); + } + }); + // Propagate live property until reaching a fixed point. + while (!queue.empty()) { + Value value = queue.pop_back_val(); + if (auto nestedFor = value.getDefiningOp()) { + auto result = mlir::cast(value); + OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); + markLive(forOperand.get()); + auto nestedYieldOp = + cast(nestedFor.getBody()->getTerminator()); + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + continue; + } + if (auto nestedIf = value.getDefiningOp()) { + auto result = mlir::cast(value); + for (scf::YieldOp nestedYieldOp : + {nestedIf.thenYield(), nestedIf.elseYield()}) { + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + } + continue; + } + if (Operation *def = value.getDefiningOp()) { + // TODO: support while ops. + if (isa(def)) + return failure(); + for (Value operand : def->getOperands()) + markLive(operand); + continue; + } + // If an argument block is live then the associated yield operand and + // forOp operand are live. + auto arg = mlir::cast(value); + if (auto forOwner = dyn_cast(arg.getOwner()->getParentOp())) { + if (arg.getArgNumber() < forOwner.getNumInductionVars()) + continue; + unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars(); + Value yieldOperand = + forOwner.getBody()->getTerminator()->getOperand(iterIdx); + markLive(yieldOperand); + markLive(forOwner.getInitArgs()[iterIdx]); + } + } + SmallVector deadArg; + for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) { + if (aliveValues.contains(yieldOperand.value())) + continue; + if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1)) + continue; + + // The yield operand might live outside the loop, e.g. + // %init = ... + // %x = ... + // %y = for iter_args(%unused = %init) { + // yield %x + // } + // + // In this case, the loop returns %x if it runs 1 or more times, and + // otherwise it returns %init. We cowardly refuse to remove this operand + // from the yield. (We could, but we'd need to prove that the loop runs 0 + // or >=1 times.) + // + // As a special case, if it doesn't matter whether the loop runs 0 or >=1 + // times (because the loop returns the same value in both cases) then we + // can still mark the operand as dead. This occurs in the above example + // when %init is the same as %x. + if (!forOp->isAncestor( + yieldOperand.value().getParentRegion()->getParentOp()) && + yieldOperand.value() != forOp.getInitArgs()[yieldOperand.index()]) + continue; + + deadArg.push_back(yieldOperand.index()); + } + if (deadArg.empty()) + return failure(); + rewriter.modifyOpInPlace(forOp, [&]() { + // For simplicity we just change the dead yield operand to use the + // associated argument and leave the operations and argument removal to + // dead code elimination. + for (unsigned deadArgIdx : deadArg) { + BlockArgument arg = block.getArgument(deadArgIdx + 1); + yieldOp.setOperand(deadArgIdx, arg); + } + }); + return success(); + } +}; + +} // namespace + +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace mlir diff --git a/third_party/mthreads/lib/Target/CMakeLists.txt b/third_party/mthreads/lib/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/mthreads/lib/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/mthreads/lib/Target/LLVMIR/CMakeLists.txt b/third_party/mthreads/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..f2f9adf8f --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,28 @@ +add_triton_library(TritonLLVMIR + LLVMDIScope.cpp + LLVMIRBreakPhiStruct.cpp + + DEPENDS + LLVMIRIncGen + + LINK_LIBS + ${CMAKE_DL_LIBS} + PUBLIC + MLIRArithToLLVM + MLIRBuiltinToLLVMIRTranslation + MLIRIndexToLLVM + MLIRIR + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRNVVMToLLVMIRTranslation + MLIRROCDLToLLVMIRTranslation + MLIRSCFToControlFlow + MLIRSupport + MLIRTargetLLVMIRExport + TritonGPUToLLVM + ) + +set_source_files_properties( + LLVMIRTranslation.cpp + PROPERTIES + COMPILE_FLAGS "-D__BUILD_DIR__=\\\"${CMAKE_BINARY_DIR}\\\"") diff --git a/third_party/mthreads/lib/Target/LLVMIR/LLVMDIScope.cpp b/third_party/mthreads/lib/Target/LLVMIR/LLVMDIScope.cpp new file mode 100644 index 000000000..af7079060 --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/LLVMDIScope.cpp @@ -0,0 +1,161 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Target/LLVMIR/Passes.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Path.h" + +//===----------------------------------------------------------------------===// +// This file implements a pass to add debug info scope to LLVM operations, and +// is inspired by the DIScopeForLLVMFuncOpPass in LLVM/MLIR. Different from the +// DIScopeForLLVMFuncOpPass, this pass also handles inlined functions. +//===----------------------------------------------------------------------===// + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Target/LLVMIR/Passes.h.inc" + +namespace { + +/// Attempt to extract a filename for the given loc. +FileLineColLoc extractFileLoc(Location loc) { + if (auto fileLoc = dyn_cast(loc)) + return fileLoc; + if (auto nameLoc = dyn_cast(loc)) + return extractFileLoc(nameLoc.getChildLoc()); + if (auto opaqueLoc = dyn_cast(loc)) + return extractFileLoc(opaqueLoc.getFallbackLocation()); + if (auto fusedLoc = dyn_cast(loc)) + return extractFileLoc(fusedLoc.getLocations().front()); + if (auto callerLoc = dyn_cast(loc)) + return extractFileLoc(callerLoc.getCaller()); + StringAttr unknownFile = mlir::StringAttr::get(loc.getContext(), ""); + return mlir::FileLineColLoc::get(unknownFile, 0, 0); +} + +/// Add a debug info scope to LLVMFuncOp that are missing it. +struct LLVMDIScopePass : public LLVMDIScopeBase { + LLVMDIScopePass() = default; + + void setSubprogramAttr(LLVM::LLVMFuncOp funcOp) { + Location loc = funcOp.getLoc(); + if (loc->findInstanceOf>()) + return; + + MLIRContext *context = &getContext(); + + // To find a DICompileUnitAttr attached to a parent (the module for + // example), otherwise create a default one. + LLVM::DICompileUnitAttr compileUnitAttr; + if (ModuleOp module = funcOp->getParentOfType()) { + auto fusedCompileUnitAttr = + module->getLoc() + ->findInstanceOf>(); + if (fusedCompileUnitAttr) + compileUnitAttr = fusedCompileUnitAttr.getMetadata(); + } + + // Filename, line and colmun to associate to the function. + LLVM::DIFileAttr fileAttr; + int64_t line = 1, col = 1; + FileLineColLoc fileLoc = extractFileLoc(loc); + if (!fileLoc && compileUnitAttr) { + fileAttr = compileUnitAttr.getFile(); + } else if (!fileLoc) { + fileAttr = LLVM::DIFileAttr::get(context, "", ""); + } else { + line = fileLoc.getLine(); + col = fileLoc.getColumn(); + StringRef inputFilePath = fileLoc.getFilename().getValue(); + fileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(inputFilePath), + llvm::sys::path::parent_path(inputFilePath)); + } + auto subroutineTypeAttr = + LLVM::DISubroutineTypeAttr::get(context, llvm::dwarf::DW_CC_normal, {}); + + // Figure out debug information (`subprogramFlags` and `compileUnitAttr`) to + // attach to the function definition / declaration. External functions are + // declarations only, and are defined in a different compile unit, so mark + // them appropriately in `subprogramFlags`, and set an empty + // `compileUnitAttr`. + DistinctAttr distinctId; + auto subprogramFlags = LLVM::DISubprogramFlags::Optimized; + if (!funcOp.isExternal()) { + distinctId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + if (!compileUnitAttr) { + compileUnitAttr = LLVM::DICompileUnitAttr::get( + distinctId, llvm::dwarf::DW_LANG_C, fileAttr, + StringAttr::get(context, "triton"), + /*isOptimized=*/true, LLVM::DIEmissionKind::LineTablesOnly); + } + subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition; + } else { + compileUnitAttr = {}; + } + + StringAttr funcNameAttr = funcOp.getNameAttr(); + // Note that scopeline is set differently from LLVM's + // DIScopeForLLVMFuncOpPass. I don't find reasons why scopeline should be + // the column offset + auto subprogramAttr = LLVM::DISubprogramAttr::get( + context, distinctId, compileUnitAttr, fileAttr, funcNameAttr, + funcNameAttr, fileAttr, + /*line=*/line, + /*scopeline=*/line, subprogramFlags, subroutineTypeAttr); + funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); + } + + // Get a nested loc for inlined functions + Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr, + Location calleeLoc) { + auto calleeFileName = extractFileLoc(calleeLoc).getFilename(); + auto context = op->getContext(); + LLVM::DIFileAttr calleeFileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(calleeFileName), + llvm::sys::path::parent_path(calleeFileName)); + auto lexicalBlockFileAttr = LLVM::DILexicalBlockFileAttr::get( + context, scopeAttr, calleeFileAttr, /*discriminator=*/0); + Location loc = calleeLoc; + if (mlir::isa(calleeLoc)) { + auto nestedLoc = mlir::cast(calleeLoc).getCallee(); + loc = getNestedLoc(op, lexicalBlockFileAttr, nestedLoc); + } + return FusedLoc::get(context, {loc}, lexicalBlockFileAttr); + } + + void setLexicalBlockFileAttr(Operation *op) { + auto opLoc = op->getLoc(); + if (auto callSiteLoc = dyn_cast(opLoc)) { + auto callerLoc = callSiteLoc.getCaller(); + auto calleeLoc = callSiteLoc.getCallee(); + LLVM::DIScopeAttr scopeAttr; + // We assemble the full inline stack so the parent of this loc must be a + // function + auto funcOp = op->getParentOfType(); + auto funcOpLoc = mlir::cast(funcOp.getLoc()); + scopeAttr = mlir::cast(funcOpLoc.getMetadata()); + auto loc = + CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc); + op->setLoc(loc); + } + } + + void runOnOperation() override { + getOperation()->walk([&](Operation *op) -> void { + if (isa(op)) + setSubprogramAttr(cast(op)); + else + setLexicalBlockFileAttr(op); + }); + } +}; + +} // end anonymous namespace + +std::unique_ptr mlir::createLLVMDIScopePass() { + return std::make_unique(); +} diff --git a/third_party/mthreads/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp b/third_party/mthreads/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp new file mode 100644 index 000000000..44afcfd21 --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +/// Implements a trivial pass breaking up 1 level deep structure in phi nodes. +/// This handles the common case generated by Triton and allow better +/// optimizations down the compiler pipeline. +//===----------------------------------------------------------------------===// +#include "LLVMPasses.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiStruct(PHINode *phiNode) { + StructType *STy = dyn_cast(phiNode->getType()); + if (!STy) + return false; + IRBuilder<> builder(phiNode); + unsigned numOperands = phiNode->getNumIncomingValues(); + unsigned numScalarEl = STy->getNumElements(); + Value *newStruct = UndefValue::get(STy); + builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHI()); + llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP(); + for (unsigned i = 0; i < numScalarEl; i++) { + builder.SetInsertPoint(phiNode); + PHINode *newPhiNode = + builder.CreatePHI(STy->getElementType(i), numOperands); + for (unsigned j = 0; j < numOperands; ++j) { + Value *operand = phiNode->getIncomingValue(j); + builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator()); + newPhiNode->addIncoming(builder.CreateExtractValue(operand, i), + phiNode->getIncomingBlock(j)); + } + builder.restoreIP(insertInsertPt); + newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i); + insertInsertPt = builder.saveIP(); + } + phiNode->replaceAllUsesWith(newStruct); + return true; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + SmallVector PhiNodes; + for (BasicBlock &BB : F) { + for (Instruction &inst : BB) { + if (PHINode *phiNode = dyn_cast(&inst)) { + Changed |= processPhiStruct(phiNode); + continue; + } + break; + } + } + return Changed; +} + +PreservedAnalyses BreakStructPhiNodesPass::run(Function &F, + FunctionAnalysisManager &AM) { + + bool b = runOnFunction(F); + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/third_party/mthreads/lib/Target/LLVMIR/LLVMPasses.h b/third_party/mthreads/lib/Target/LLVMIR/LLVMPasses.h new file mode 100644 index 000000000..1dcdb2992 --- /dev/null +++ b/third_party/mthreads/lib/Target/LLVMIR/LLVMPasses.h @@ -0,0 +1,16 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CodeGen.h" + +namespace llvm { + +// Pass to pre-process LLVM IR before optimization and break up phi of struct. +// Breaking up those phis into elementary types allows better optimizations +// downstream. +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; + +} // namespace llvm diff --git a/third_party/mthreads/lib/Tools/CMakeLists.txt b/third_party/mthreads/lib/Tools/CMakeLists.txt new file mode 100644 index 000000000..4b021da33 --- /dev/null +++ b/third_party/mthreads/lib/Tools/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(TritonTools + LinearLayout.cpp + + DEPENDS + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + f2reduce +) diff --git a/third_party/mthreads/lib/Tools/LinearLayout.cpp b/third_party/mthreads/lib/Tools/LinearLayout.cpp new file mode 100644 index 000000000..75e530db5 --- /dev/null +++ b/third_party/mthreads/lib/Tools/LinearLayout.cpp @@ -0,0 +1,427 @@ +#include "triton/Tools/LinearLayout.h" + +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "third_party/f2reduce/f2reduce.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton { + +namespace { +using BasesT = LinearLayout::BasesT; +using llvm::Twine; + +BasesT makeBasesMap( + ArrayRef>>> bases) { + BasesT ret; + for (const auto &[inDim, inDimBases] : bases) { + ret[inDim] = inDimBases; + } + return ret; +} + +std::string stringifyBases(const BasesT &bases, + ArrayRef outDimNames) { + std::string ret; + + if (bases.empty()) + return "(empty layout)\n"; + + // TODO: Add spaces for alignment. + for (const auto &[inDim, inDimBases] : bases) { + if (inDimBases.empty()) { + ret += " - " + inDim.str() + " is a size 1 dimension\n"; + continue; + } + + ret += " - " + + join(llvm::seq(inDimBases.size()), "\n ", + [&, &inDim = inDim, &inDimBases = inDimBases](int i) { + return inDim.str() + "=" + std::to_string(1 << i) + " -> (" + + join(inDimBases[i], ", ") + ")"; + }) + + "\n"; + } + ret += "where out dims are: [" + + join(outDimNames, ", ", [](StringAttr s) { return s.str(); }) + "]\n"; + return ret; +} + +BasesT validateBases(BasesT bases, ArrayRef outDimNames) { + if (bases.empty()) + return bases; + + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t b) { return b < 0; })) { + llvm::report_fatal_error( + "Invalid bases passed to LinearLayout. Expected all basis " + "values to be non-negative, but found a negative value for " + "in dimension '" + + Twine(inDim) + "'. Full list of bases:\n" + + stringifyBases(bases, outDimNames)); + } + } + } + + // Check that the bases all have length equal to outDimNames.size(). + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (basis.size() != outDimNames.size()) { + llvm::report_fatal_error( + "Invalid bases passed to LinearLayout. Expect all bases to have " + "the same size, equal to outDimNames.size() (" + + Twine(outDimNames.size()) + + "). But this failed for in dimension '" + Twine(inDim) + + "'. Full list of bases:\n" + stringifyBases(bases, outDimNames)); + } + } + } + + return bases; +} + +// Compute the rank of the matrix formed by taking the bases for the given +// outDim as columns. In other words, finds the number of linearly-independent +// bases for this output dimension. +int getMatrixRank(const LinearLayout &layout, StringAttr outDim) { + // Suppose we have a layout specified by the following key values. + // + // L(0,1) = 0b01 + // L(0,2) = 0b10 + // L(1,0) = 0b10 + // L(2,0) = 0b11 + // + // We will create one column per key value. The max bit width of these values + // is 2, so our matrix will have 2 rows. The final matrix will be + // + // | ↑ ↑ ↑ ↑ | | 0b0111 | + // | L(0,1) L(0,2) L(1,0) L(2,0) | = | 0b1001 | + // | ↓ ↓ ↓ ↓ | + int numRows = layout.getOutDimSizeLog2(outDim); + + int numCols = 0; + for (StringAttr inDim : layout.getInDimNames()) { + numCols += layout.getInDimSizeLog2(inDim); + } + + if (numCols == 0 || numRows == 0) + return 0; + + // Don't handle giant LLs. This makes some things easier; for example, each + // row can be a single uint64_t. + assert(numCols <= 64 && "LinearLayout too large"); + assert(numRows <= 64 && "LinearLayout too large"); + + // Note that `new int[n]()` is zero-initialized, whereas `new int[n]` is not. + std::unique_ptr m(new uint64_t[numRows]()); + + // Fill in the matrix. + int c = 0; + for (StringAttr inDim : layout.getInDimNames()) { + for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) { + uint64_t basis = layout.getBasis(inDim, i, outDim); + for (int j = 0; j < numRows; j++) { + m[j] |= ((basis >> j) & 1) << c; + } + c++; + } + } + + // stride is specified in number of 64-bit words per row. + f2reduce::inplace_rref_strided(m.get(), numRows, numCols, /*stride=*/1); + + // The rank of the reduced matrix is simply the number of nonzero rows. + int rank = 0; + for (int i = 0; i < numRows; i++) { + if (m[i] != 0) + rank++; + } + return rank; +} + +// Check that the given layout is surjective, i.e. that every `out` coordinate +// can be reached by some `in` coordinate. +// +// It's sufficient to check each output dimension indepedently. Still, +// it's prohibitively slow to calculate this naively. +// +// Thankfully, this is equivalent to checking that the number of +// linearly-independent bases for outDim d is equal to getOutDimSizeLog2(d). +// This can be computed by finding the rank of the matrix whose columns are +// those bases. We can compute the rank of our matrix using Gaussian +// elimination, which runs in O(n^3) for an n x n matrix. Our matrix size is +// log(product(inDimSize)) x log(outDimSize), and we do this numOutDims times, +// so this should be plenty fast overall. +void validateSurjectivity(const LinearLayout &layout) { + for (const auto &outDim : layout.getOutDimNames()) { + unsigned rank = getMatrixRank(layout, outDim); + unsigned expectedRank = layout.getOutDimSizeLog2(outDim); + if (rank != expectedRank) { + llvm::report_fatal_error( + "Invalid bases passed to LinearLayout. Expected bases to be " + "surjective, i.e. all possible output coordinates can be reached " + "by some input coordinates. But this failed for output dimension " + + Twine(outDim) + ", where we got rank " + Twine(rank) + + " instead of expected rank " + Twine(expectedRank) + + ". Full list of bases:\n" + + Twine(stringifyBases(layout.getBases(), layout.getOutDimNames()))); + } + } +} + +template +void assertDimsEqualIgnoringOrder(T &&a, U &&b) { + llvm::DenseSet as(a.begin(), a.end()); + llvm::DenseSet bs(b.begin(), b.end()); + if (as != bs) { + llvm::report_fatal_error("Dimensions must match, ignoring order, but they " + "don't. Got dims: [" + + Twine(triton::join(a, ", ")) + "] and [" + + triton::join(b, ", ") + "]"); + } +} + +} // anonymous namespace + +LinearLayout::LinearLayout(BasesT bases, ArrayRef outDimNames) + : bases(validateBases(std::move(bases), outDimNames)), + outDimNames(outDimNames.begin(), outDimNames.end()) { + validateSurjectivity(*this); +} + +LinearLayout::LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames) + : LinearLayout(makeBasesMap(bases), outDimNames) {} + +/*static*/ LinearLayout LinearLayout::identity1D(int32_t size, + StringAttr inDimName, + StringAttr outDimName) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> powersOf2; + for (int32_t i = 1; i < size; i *= 2) { + powersOf2.emplace_back().push_back(i); + } + return LinearLayout({{inDimName, std::move(powersOf2)}}, {outDimName}); +} + +/*static*/ LinearLayout LinearLayout::zeros1D(int32_t size, + StringAttr inDimName, + StringAttr outDimName) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> zeros; + for (int i = 0; i < llvm::Log2_32(size); i++) { + zeros.emplace_back().push_back(0); + } + return LinearLayout({{inDimName, zeros}}, {outDimName}); +} + +int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const { + // Sadly SetVector doesn't provide an O(1) way to do this. + for (int i = 0; i < outDimNames.size(); ++i) { + if (outDimNames[i] == outDim) { + return i; + } + } + llvm::report_fatal_error("outDim " + Twine(outDim) + " is not in layout\n" + + toString()); +} + +int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + return it->second.size(); +} + +int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const { + // TODO(jlebar): Cache this? + int32_t outDimIdx = getOutDimIndex(outDim); + int32_t max = 0; + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + max = std::max(max, basis[outDimIdx]); + } + } + return max == 0 ? 0 : llvm::Log2_32(max) + 1; +} + +LinearLayout LinearLayout::transposeIns(ArrayRef newInDims) const { + assertDimsEqualIgnoringOrder(newInDims, getInDimNames()); + + BasesT newBases; + for (const auto &inDim : newInDims) { + newBases[inDim] = bases.find(inDim)->second; + } + return LinearLayout(std::move(newBases), outDimNames.getArrayRef()); +} + +LinearLayout +LinearLayout::transposeOuts(ArrayRef newOutDims) const { + assertDimsEqualIgnoringOrder(newOutDims, getOutDimNames()); + + std::vector permutation; + for (const auto &outDim : newOutDims) { + permutation.push_back(getOutDimIndex(outDim)); + } + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + std::vector newBasis; + for (int32_t i : permutation) { + newBasis.push_back(basis[i]); + } + newInDimBases.push_back(std::move(newBasis)); + } + } + return LinearLayout(std::move(newBases), newOutDims); +} + +LinearLayout operator*(LinearLayout inner, LinearLayout outer) { + // Check that elements common to both outerDimsRange and innerDimsRange appear + // in the same relative order. + auto checkCommonDims = [&](auto outerDimsRange, auto innerDimsRange) { + llvm::DenseSet outerDims(outerDimsRange.begin(), + outerDimsRange.end()); + llvm::DenseSet innerDims(innerDimsRange.begin(), + innerDimsRange.end()); + + std::vector outerCommonDims; + for (StringAttr dim : outerDimsRange) { + if (innerDims.contains(dim)) { + outerCommonDims.push_back(dim); + } + } + + std::vector innerCommonDims; + for (StringAttr dim : innerDimsRange) { + if (outerDims.contains(dim)) { + innerCommonDims.push_back(dim); + } + } + + if (outerCommonDims != innerCommonDims) { + llvm::report_fatal_error( + "Cannot multiply layouts. All in/out dimensions common to both " + "layouts must appear in the same relative order, but they " + "don't.\nOuter:\n" + + Twine(outer.toString()) + "\nInner:\n" + inner.toString()); + } + }; + + // Check that dims common to outer and inner have the same relative order. + checkCommonDims(outer.getInDimNames(), inner.getInDimNames()); + checkCommonDims(outer.getOutDimNames(), inner.getOutDimNames()); + + // Get the sizeLog2 of all input and output dimensions we're going to + // consider, in order. `inner` is more minor, so its dimensions come first. + llvm::MapVector inDimSizes; + llvm::SetVector outDimNames; + for (const auto &layout : {inner, outer}) { + for (StringAttr inDim : layout.getInDimNames()) { + inDimSizes[inDim] += layout.getInDimSizeLog2(inDim); + } + for (StringAttr outDim : layout.getOutDimNames()) { + outDimNames.insert(outDim); + } + } + BasesT allBases; + for (auto [inDimName, inDimSize] : inDimSizes) { + std::vector> &inDimBases = allBases[inDimName]; + + // Fill with zeros. + inDimBases = std::vector>( + inDimSize, std::vector(outDimNames.size(), 0)); + + for (auto [outDimIdx, outDimName] : llvm::enumerate(outDimNames)) { + if (inner.hasInDim(inDimName) && inner.hasOutDim(outDimName)) { + for (int i = 0; i < inner.getInDimSizeLog2(inDimName); i++) { + inDimBases[i][outDimIdx] = inner.getBasis(inDimName, i, outDimName); + } + } + if (outer.hasInDim(inDimName) && outer.hasOutDim(outDimName)) { + int offset = + inner.hasInDim(inDimName) ? inner.getInDimSizeLog2(inDimName) : 0; + int shift = inner.hasOutDim(outDimName) + ? inner.getOutDimSizeLog2(outDimName) + : 0; + for (int i = 0; i < outer.getInDimSizeLog2(inDimName); i++) { + inDimBases[offset + i][outDimIdx] = + outer.getBasis(inDimName, i, outDimName) << shift; + } + } + } + } + + return LinearLayout(std::move(allBases), outDimNames.getArrayRef()); +} + +SmallVector> +LinearLayout::apply(ArrayRef> ins) const { + assertDimsEqualIgnoringOrder(llvm::make_first_range(ins), getInDimNames()); + + SmallVector> ret; + for (StringAttr outDim : getOutDimNames()) { + int32_t outVal = 0; + for (auto &[inDim, val] : ins) { + for (int i = 0; i < getInDimSizeLog2(inDim); i++) { + if (val & (1 << i)) + outVal ^= getBasis(inDim, i, outDim); + } + } + ret.push_back({outDim, outVal}); + } + return ret; +} + +LinearLayout LinearLayout::compose(const LinearLayout &outer) const { + assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getInDimNames()); + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + SmallVector> bases; + for (auto [outDim, b] : llvm::zip(getOutDimNames(), basis)) { + bases.push_back({outDim, b}); + } + auto newBases = outer.apply(bases); + auto newBasesRange = llvm::make_second_range(newBases); + newInDimBases.push_back( + std::vector(newBasesRange.begin(), newBasesRange.end())); + } + } + return LinearLayout(std::move(newBases), outer.getOutDimNames()); +} + +bool operator==(LinearLayout lhs, LinearLayout rhs) { + // llvm::MapVector doesn't have an operator== :(. + if (lhs.getOutDimNames() != rhs.getOutDimNames()) + return false; + if (lhs.bases.size() != rhs.bases.size()) + return false; + for (auto it1 = lhs.bases.begin(), it2 = rhs.bases.begin(); + it1 != lhs.bases.end(); ++it1, ++it2) { + if (*it1 != *it2) + return false; + } + return true; +} + +std::string LinearLayout::toString() const { + return stringifyBases(bases, getOutDimNames()); +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/plugin/CMakeLists.txt b/third_party/mthreads/plugin/CMakeLists.txt new file mode 100644 index 000000000..0d670a5f2 --- /dev/null +++ b/third_party/mthreads/plugin/CMakeLists.txt @@ -0,0 +1,22 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) + +if(TRITON_BUILD_PYTHON_MODULE) + find_package(Python3 REQUIRED COMPONENTS Development Interpreter) + add_library(mthreadsTritonPlugin SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/triton_mthreads.cc + ) + set_target_properties(mthreadsTritonPlugin PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + POSITION_INDEPENDENT_CODE ON + ) + target_link_libraries(mthreadsTritonPlugin PRIVATE + TritonMTGPUToLLVM + # Py + ${Python3_LIBRARIES} + ${PYTHON_LDFLAGS} + ) +endif() diff --git a/third_party/mthreads/plugin/include/CMakeLists.txt b/third_party/mthreads/plugin/include/CMakeLists.txt new file mode 100644 index 000000000..d91fde97e --- /dev/null +++ b/third_party/mthreads/plugin/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonMTGPUToLLVM) diff --git a/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/CMakeLists.txt b/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..f813becb3 --- /dev/null +++ b/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonMTGPUToLLVM) +add_public_tablegen_target(TritonMTGPUConversionPassIncGen) diff --git a/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/MUSATranslation.h b/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/MUSATranslation.h new file mode 100644 index 000000000..f42a57860 --- /dev/null +++ b/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/MUSATranslation.h @@ -0,0 +1,21 @@ +#ifndef TRITON_TARGET_MUSATRANSLATION_H +#define TRITON_TARGET_MUSATRANSLATION_H + +#include +#include +#include + +namespace llvm { +class Module; +} // namespace llvm + +namespace mlir::triton { + +// Translate TritonGPU IR to MUSA binary code. +std::tuple +translateLLVMIRToMUBIN(llvm::Module &module, const std::string &opt_option, + int capability, int version); + +} // namespace mlir::triton + +#endif diff --git a/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/Passes.h b/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/Passes.h new file mode 100644 index 000000000..41b904e72 --- /dev/null +++ b/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/Passes.h @@ -0,0 +1,35 @@ +#ifndef TRITONGPU_CONVERSION_TRITONMTGPUTOLLVM_PASSES_H +#define TRITONGPU_CONVERSION_TRITONMTGPUTOLLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +#define GEN_PASS_DECL +#include "mthreads/plugin/include/TritonMTGPUToLLVM/Passes.h.inc" + +namespace MUSA {} // namespace MUSA + +std::unique_ptr> createConvertTritonMTGPUToLLVMPass(); +std::unique_ptr> +createConvertTritonMTGPUToLLVMPass(int32_t computeCapability); +std::unique_ptr> +createConvertMTGPUBuiltinFuncToLLVMPass(); + +#define GEN_PASS_REGISTRATION +#include "mthreads/plugin/include/TritonMTGPUToLLVM/Passes.h.inc" + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/Passes.td b/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/Passes.td new file mode 100644 index 000000000..b941fa82b --- /dev/null +++ b/third_party/mthreads/plugin/include/TritonMTGPUToLLVM/Passes.td @@ -0,0 +1,37 @@ +#ifndef TRITONMTGPU_CONVERSION_PASSES +#define TRITONMTGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonMTGPUToLLVM : Pass<"convert-triton-mtgpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert TritonGPU to LLVM"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonMTGPUToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::gpu::GPUDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::tensor::TensorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::MTGPU::MTGPUDialect"]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"80", + "device compute capability">, + ]; +} + +def ConvertMTGPUBuiltinFuncToLLVM : Pass<"convert-mtgpu-builtin-func-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert MTGPU Builtin Func to LLVM"; + let constructor = "mlir::triton::createConvertMTGPUBuiltinFuncToLLVMPass()"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; +} + +#endif // TRITONMTGPU_CONVERSION_PASSES diff --git a/third_party/mthreads/plugin/lib/CMakeLists.txt b/third_party/mthreads/plugin/lib/CMakeLists.txt new file mode 100644 index 000000000..d91fde97e --- /dev/null +++ b/third_party/mthreads/plugin/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TritonMTGPUToLLVM) diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/BuiltinFuncToLLVM.cpp new file mode 100644 index 000000000..adc424b9e --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -0,0 +1,137 @@ +#include "TritonMTGPUToLLVM/Passes.h" + +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTMTGPUBUILTINFUNCTOLLVM +#include "TritonMTGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; + +namespace { + +class CallOpConversion : public mlir::RewritePattern { +public: + CallOpConversion(mlir::MLIRContext *context) + : mlir::RewritePattern(LLVM::CallOp::getOperationName(), 1, context) {} + + LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto callOp = cast(op); + if (isPredicatedLoad(callOp)) { + return convertPredicatedLoad(callOp, rewriter); + } else if (isPredicatedStore(callOp)) { + return convertPredicatedStore(callOp, rewriter); + } else { + return failure(); + } + } + +private: + bool isPredicatedLoad(LLVM::CallOp callOp) const { + return callOp.getCallee().value().find(mlir::LLVM::MUSA::Predicated_Load) != + llvm::StringRef::npos; + } + + bool isPredicatedStore(LLVM::CallOp callOp) const { + return callOp.getCallee().value().find( + mlir::LLVM::MUSA::Predicated_Store) != llvm::StringRef::npos; + } + + LogicalResult convertPredicatedStore(LLVM::CallOp callOp, + mlir::PatternRewriter &rewriter) const { + auto operands = callOp.getOperands(); + + auto loc = callOp.getLoc(); + auto ptr = operands[0]; + auto val = operands[1]; + auto pred = operands[2]; + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterStore = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *trueBlock = rewriter.createBlock(afterStore); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, pred, trueBlock, afterStore); + rewriter.setInsertionPointToStart(trueBlock); + auto storeOp = rewriter.create(loc, val, ptr); + rewriter.create(loc, afterStore); + rewriter.setInsertionPointToStart(afterStore); + rewriter.eraseOp(callOp); + return mlir::success(); + } + + LogicalResult convertPredicatedLoad(LLVM::CallOp callOp, + mlir::PatternRewriter &rewriter) const { + auto operands = callOp.getOperands(); + auto result = callOp.getResult(); + + auto loc = callOp.getLoc(); + auto elemTy = result.getType(); + auto ptr = operands[0]; + auto pred = operands[1]; + auto falseVal = operands[2]; + + Block *currentBlock = rewriter.getInsertionBlock(); + Block *afterLoad = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + afterLoad->addArgument({elemTy}, {loc}); + Block *trueBlock = rewriter.createBlock(afterLoad); + Block *falseBlock = + rewriter.splitBlock(trueBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, pred, trueBlock, falseBlock); + rewriter.setInsertionPointToStart(trueBlock); + auto loadOp = rewriter.create(loc, elemTy, ptr); + rewriter.create(loc, loadOp->getResult(0), afterLoad); + rewriter.setInsertionPointToStart(falseBlock); + rewriter.create(loc, falseVal, afterLoad); + rewriter.setInsertionPointToStart(afterLoad); + Value loadVal = afterLoad->getArgument(0); + rewriter.replaceOp(callOp, loadVal); + return mlir::success(); + } +}; + +struct ConvertBuiltinFuncToLLVM + : public triton::impl::ConvertMTGPUBuiltinFuncToLLVMBase< + ConvertBuiltinFuncToLLVM> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + // Disable block merging because of: + // https://github.com/llvm/llvm-project/issues/63230 + // TODO(giuseros): enable block merging once the above ticket is completed + GreedyRewriteConfig config; + config.enableRegionSimplification = false; + + RewritePatternSet patterns(context); + patterns.add(context); + + if (mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns), config) + .failed()) { + signalPassFailure(); + } + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { + +std::unique_ptr> +createConvertMTGPUBuiltinFuncToLLVMPass() { + return std::make_unique(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/CMakeLists.txt b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..f5b478ddb --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(TritonMTGPUToLLVM + ConvertLayoutOpToLLVM.cpp + DotOpToLLVM.cpp + ElementwiseOpToLLVM.cpp + LoadStoreOpToLLVM.cpp + TritonGPUToLLVM.cpp + SPMDOpToLLVM.cpp + Utility.cpp + TargetInfo.cpp + MUSATranslation.cpp + FuncOpToLLVM.cpp + BuiltinFuncToLLVM.cpp + + DEPENDS + TritonMTGPUConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRMTGPUDialect + MLIRMTGPUToLLVMIRTranslation + MLIRGPUToMTGPUTransforms + TritonGPUToLLVM +) diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 000000000..b11028a2a --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,32 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using mlir::isLayoutMmaV1; +using ::mlir::LLVM::getMultiDimOffset; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getWrappedMultiDimOffset; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::isaDistributedLayout; +using ::mlir::triton::gpu::SharedEncodingAttr; + +void mlir::triton::MUSA::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, + patterns, benefit); +} diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/DotOpToLLVM.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/DotOpToLLVM.cpp new file mode 100644 index 000000000..480d8221a --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/DotOpToLLVM.cpp @@ -0,0 +1,45 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" + +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getShapePerCTA; + +namespace { +struct DotOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + // D = A * B + C + Value A = op.getA(); + Value D = op.getResult(); + + // Here we assume the DotOp's operands always comes from shared memory. + auto AShapePerCTA = getShapePerCTA(A.getType()); + size_t reduceAxis = 1; + unsigned K = AShapePerCTA[reduceAxis]; + bool isOuter = K == 1; + + if (isa( + cast(D.getType()).getEncoding())) + return convertFMADot(op, adaptor, getTypeConverter(), rewriter); + + llvm::report_fatal_error( + "Unsupported DotOp found when converting TritonGPU to LLVM."); + } +}; +} // namespace + +void mlir::triton::MUSA::populateDotOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 000000000..bc1e8ca10 --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,618 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +using namespace mlir::triton::gpu; +using namespace mlir::triton::MUSA; + +namespace mlir::triton { + +namespace gpu { +namespace { + +struct Fp8ConversionDesc { + std::string funcName; + size_t numElements; +}; + +static SmallVector convertFp8(const LLVMTypeConverter *typeConverter, + Location loc, + ConversionPatternRewriter &rewriter, + const SmallVector &v, + Type &srcElementType, Type &dstElementType, + const std::string funcName) { + size_t numElements = v.size(); + Type inpType; + Type outType; + Value inVals; + + if (numElements == 1) { + inpType = typeConverter->convertType(srcElementType); + outType = typeConverter->convertType(dstElementType); + inVals = v[0]; + } else { + inpType = vec_ty(typeConverter->convertType(srcElementType), numElements); + outType = vec_ty(typeConverter->convertType(dstElementType), numElements); + inVals = undef(inpType); + for (size_t i = 0; i < numElements; i++) { + inVals = insert_element(inpType, inVals, v[i], i32_val(i)); + } + } + + Type funcType = LLVM::LLVMFunctionType::get(outType, inpType); + + std::string libName = ""; + std::string libPath = ""; + + // Call libdevice + LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( + rewriter, v[0].getDefiningOp(), funcName, funcType, libName, libPath); + auto outVals = rewriter.create(loc, funcOp, inVals).getResult(); + + SmallVector ret; + for (size_t i = 0; i < numElements; i++) { + ret.push_back(numElements == 1 ? outVals + : extract_element(typeConverter->convertType( + dstElementType), + outVals, i32_val(i))); + } + return ret; +} + +// Attempts to use vectorized conversions via inline PTX when possible. +struct FpToFpOpConversion + : public ElementwiseOpConversionBase { + using ElementwiseOpConversionBase< + FpToFpOp, FpToFpOpConversion>::ElementwiseOpConversionBase; + + explicit FpToFpOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + int computeCapability, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} + + static Value convertBf16ToFp32(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v) { + auto as_int16 = bitcast(v, i16_ty); + auto as_int32 = zext(i32_ty, as_int16); + auto shifted = shl(i32_ty, as_int32, i32_val(16)); + return bitcast(shifted, f32_ty); + } + + static Value convertFp32ToBf16(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v, const RoundingMode rounding) { + if (rounding == RoundingMode::RTZ) { + auto as_int32 = bitcast(v, i32_ty); + auto shifted = lshr(i32_ty, as_int32, i32_val(16)); + auto truncated = trunc(i16_ty, shifted); + return bitcast(truncated, bf16_ty); + } + // Otherwise it is (rounding == RoundingMode::RTNE) + auto as_uint32 = bitcast(v, i32_ty); + auto check_exponent = + and_(i32_ty, xor_(i32_ty, as_uint32, i32_val(0xffffffff)), + i32_val(0x7f800000)); + auto exponent_not_all1s = icmp_ne(check_exponent, i32_val(0)); + auto exponent_all1s = icmp_eq(check_exponent, i32_val(0)); + auto rounded = + add(i32_ty, i32_val(0x7fff), + and_(i32_ty, lshr(i32_ty, as_uint32, i32_val(16)), i32_val(1))); + rounded = add(i32_ty, rounded, as_uint32); + auto res = select(exponent_not_all1s, rounded, as_uint32); + + auto preserve_nan = + and_(i1_ty, exponent_all1s, + icmp_ne(and_(i32_ty, as_uint32, i32_val(0xffff)), i32_val(0))); + auto nan = or_(i32_ty, as_uint32, i32_val(0x10000)); + res = select(preserve_nan, nan, res); + + auto shifted = lshr(i32_ty, res, i32_val(16)); + auto truncated = trunc(i16_ty, shifted); + return bitcast(truncated, bf16_ty); + } + + std::pair + getConversionFunc(Type srcTy, Type dstTy, + std::optional roundingMode, + bool enableFp8Burst2) const { + auto F8E4M3TyID = TypeID::get(); + auto F8E5M2TyID = TypeID::get(); + auto F16TyID = TypeID::get(); + auto BF16TyID = TypeID::get(); + auto F32TyID = TypeID::get(); + auto F64TyID = TypeID::get(); + + auto undefRounding = static_cast(-1); + + static DenseMap, + SmallVector> + srcMap = { + // F8 -> F32 + {{F8E4M3TyID, F32TyID, undefRounding}, + {{"__mt_tt_v2e4m3_to_v2f32", 2}, {"__mt_tt_e4m3_to_f32", 1}}}, + {{F8E5M2TyID, F32TyID, undefRounding}, + {{"__mt_tt_v2e5m2_to_v2f32", 2}, {"__mt_tt_e5m2_to_f32", 1}}}, + // F8 -> F16 + {{F8E4M3TyID, F16TyID, undefRounding}, + {{"__mt_tt_v2e4m3_to_v2f16", 2}, {"__mt_tt_e4m3_to_f16", 1}}}, + {{F8E5M2TyID, F16TyID, undefRounding}, + {{"__mt_tt_v2e5m2_to_v2f16", 2}, {"__mt_tt_e5m2_to_f16", 1}}}, + // F8 -> BF16 + {{F8E4M3TyID, BF16TyID, undefRounding}, + {{"__mt_tt_v2e4m3_to_v2bf16", 2}, {"__mt_tt_e4m3_to_bf16", 1}}}, + {{F8E5M2TyID, BF16TyID, undefRounding}, + {{"__mt_tt_v2e5m2_to_v2bf16", 2}, {"__mt_tt_e5m2_to_bf16", 1}}}, + // F32 -> F8 + {{F32TyID, F8E4M3TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2f32_to_v2e4m3", 2}, {"__mt_tt_f32_to_e4m3", 1}}}, + {{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2f32_to_v2e5m2", 2}, {"__mt_tt_f32_to_e5m2", 1}}}, + // F16 -> F8 + {{F16TyID, F8E4M3TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2f16_to_v2e4m3", 2}, {"__mt_tt_f16_to_e4m3", 1}}}, + {{F16TyID, F8E5M2TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2f16_to_v2e5m2", 2}, {"__mt_tt_f16_to_e5m2", 1}}}, + // BF16 -> F8 + {{BF16TyID, F8E4M3TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2bf16_to_v2e4m3", 2}, {"__mt_tt_bf16_to_e4m3", 1}}}, + {{BF16TyID, F8E5M2TyID, RoundingMode::RTNE}, + {{"__mt_tt_v2bf16_to_v2e5m2", 2}, {"__mt_tt_bf16_to_e5m2", 1}}}, + }; + std::tuple key = { + srcTy.getTypeID(), dstTy.getTypeID(), + roundingMode.value_or(undefRounding)}; + if (srcMap.count(key) == 0) { + llvm::errs() << "Unsupported conversion from " << srcTy << " to " + << dstTy; + if (roundingMode.has_value()) + llvm::errs() << " with rounding mode " + << stringifyRoundingMode(roundingMode.value()); + llvm::errs() << "\n"; + llvm::report_fatal_error("Unsupported rounding mode for conversion."); + } + auto convDesc = + enableFp8Burst2 ? srcMap.lookup(key)[0] : srcMap.lookup(key)[1]; + + return {convDesc.funcName, convDesc.numElements}; + } + + SmallVector createDestOps(FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto srcElementType = getElementType(op.getSrc()); + auto dstElementType = getElementType(op.getResult()); + auto roundingMode = op.getRounding(); + + bool isFp8Converion = + srcElementType.isFloat8E4M3FNUZ() || srcElementType.isFloat8E5M2() || + dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2(); + assert(isFp8Converion && + "For now only Fp8 conversions are supported for the op FpToFp."); + + if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) { + assert(roundingMode.has_value() && + "Rounding mode must be specified for convertsions to fp8"); + + // For now only RTNE is supported for all conversions + if (roundingMode.value() != RoundingMode::RTNE) { + llvm::errs() << "Unsupported rounding mode for conversion to fp8: " + << stringifyRoundingMode(roundingMode.value()) << "\n"; + llvm_unreachable(""); + } + } + + // Default disable fp8 burst2 + bool enableFp8Burst2 = false; + std::string envValue = + mlir::triton::tools::getStrEnv("MUSA_ENABLE_FP8_BURST2"); + if (!envValue.empty() && + (envValue == "true" || envValue == "TRUE" || envValue == "1")) { + enableFp8Burst2 = true; + } + + auto [funcName, numElements] = getConversionFunc( + srcElementType, dstElementType, roundingMode, enableFp8Burst2); + + // FP8 conversions + if (isFp8Converion) { + SmallVector inVals; + for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { + inVals.push_back(operands[i][0]); + } + inVals.resize(numElements, + undef(typeConverter->convertType(srcElementType))); + SmallVector outVals = + convertFp8(getTypeConverter(), loc, rewriter, inVals, srcElementType, + dstElementType, funcName); + return outVals; + } + llvm_unreachable("Unsupported conversion"); + return {}; + } + +private: + int computeCapability; +}; + +template +Value EmitDualBF16ElementwiseOp(Location loc, + ConversionPatternRewriter &rewriter, + MultipleOperandsRange operands) { + auto v0 = + FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]); + auto v1 = + FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][1]); + auto result = rewriter.create(loc, f32_ty, v0, v1); + return FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, result, + RoundingMode::RTNE); +} + +struct FDivOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::DivFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } +}; + +struct FMulOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::MulFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {EmitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + } +}; + +struct FAddOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::AddFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {EmitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + } +}; + +struct FSubOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SubFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto lhsElemTy = getElementType(op.getLhs()); + auto rhsElemTy = getElementType(op.getRhs()); + if (lhsElemTy.isBF16() && rhsElemTy.isBF16()) { + return {EmitDualBF16ElementwiseOp(loc, rewriter, operands)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + } +}; + +// Uses inline ptx to convert s8/u8 to bf16, since the +struct SIToFPOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + Type inElemTy = getElementType(op.getIn()); + Type outElemTy = getElementType(op.getOut()); + if (outElemTy.isBF16() && inElemTy.isInteger(8)) { + // TODO(lingfeng.qiu): use inline asm to vectorize 4*s8. + // s8 -> s32 -> fp32 -> bf16 + Value i32Val = sext(i32_ty, operands[0][0]); + Value f32Val = inttofloat(f32_ty, i32Val); + f32Val = bitcast(f32Val, i32_ty); + auto shifted = lshr(i32_ty, f32Val, i32_val(16)); + auto truncated = trunc(i16_ty, shifted); + auto outVal = bitcast(truncated, bf16_ty); + return {outVal}; + } else if (outElemTy.isBF16()) { + auto value = rewriter.create(loc, f32_ty, operands[0][0]); + return {FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, value, + RoundingMode::RTNE)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0])}; + } + } +}; + +struct FPToSIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::FPToSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto inElemTy = getElementType(op.getIn()); + return {rewriter.create(loc, elemTy, operands[0][0])}; + } +}; + +struct ExtFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::ExtFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto inElemTy = getElementType(op.getIn()); + if (inElemTy.isBF16()) { + auto outElemTy = getElementType(op.getOut()); + assert(outElemTy.isF32() && "unsupported conversion"); + return { + FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0])}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0])}; + } + } +}; + +struct TruncFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::TruncFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto outElemTy = getElementType(op.getOut()); + if (outElemTy.isBF16()) { + auto inElemTy = getElementType(op.getIn()); + assert(inElemTy.isF32() && "unsupported conversion"); + return {// Trunc uses the default rounding mode: RTNE + FpToFpOpConversion::convertFp32ToBf16( + loc, rewriter, operands[0][0], RoundingMode::RTNE)}; + } else { + return {rewriter.create(loc, elemTy, operands[0][0])}; + } + } +}; + +struct ExpOpConversionApprox + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::ExpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // For non-FP32 input, call __nv_expf for higher-precision calculation + if (elemTy.getIntOrFloatBitWidth() != 32) + return {}; + + const double log2e = 1.4426950408889634; + Value prod = fmul(f32_ty, operands[0][0], f32_val(log2e)); + + return {rewriter.create(loc, f32_ty, prod, + adaptor.getAttributes().getValue())}; + } +}; + +struct ClampFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit ClampFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + int computeCapability, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} + + bool isClipPattern(ClampFOp op) const { + bool xorsignAbsAvailable = (computeCapability >= 90); + // Pattern matching the sequence of clamp(x, -limit, limit) to generate + // more efficient PTX code. NOTE: This pattern matching is not general + // enough, but it is sufficient. We detect only two cases here: + // 1. where the "-limit" is computed as 0 - limit: + // %cst = arith.constant dense<0.000000e+00> + // %8 = tt.load %7, %2 + // %11 = arith.subf %cst, %8 + // %12 = tt.clamp %5, %11, %8 + // 2. where "-limit" and "limit" are constants. + // %cst_6 = arith.constant dense<-6.0000e+00> + // %cst_7 = arith.constant dense<6.0000e+00> + // %160 = tt.clamp %158, %cst_6, %cst_7 + bool patternFound = false; + + auto getSplatInitializer = [](Value v) -> std::optional { + if (auto constOp = v.getDefiningOp()) { + if (auto attr = mlir::dyn_cast( + constOp.getValueAttr())) { + if (attr.isSplat()) { + return attr.getSplatValue().convertToDouble(); + } + } + } + return std::nullopt; + }; + + if (xorsignAbsAvailable) { + if (auto subOp = op.getOperand(1).getDefiningOp()) { + if (subOp.getOperand(1) == op.getOperand(2)) { + auto initializer = getSplatInitializer(subOp.getOperand(0)); + if (initializer.has_value() && initializer.value() == 0.0) { + patternFound = true; + } + } + } else { + auto initializer1 = getSplatInitializer(op.getOperand(1)); + auto initializer2 = getSplatInitializer(op.getOperand(2)); + if (initializer1.has_value() && initializer2.has_value() && + initializer1.value() == -initializer2.value()) { + patternFound = true; + } + } + } + return patternFound; + } + + SmallVector emitOptimization(ClampFOp op, + ConversionPatternRewriter &rewriter, + Type elemTy, + MultipleOperandsRange operands, + Location loc) const { + llvm_unreachable("This function is not implemented yet."); + return {}; + } + + SmallVector createDestOps(ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (isClipPattern(op)) { + return emitOptimization(op, rewriter, elemTy, operands, loc); + } + return {}; + } + +private: + int computeCapability; +}; + +template +struct OpToExternCallConversion + : public ElementwiseOpConversionBase> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + StringRef externFuncName, + PatternBenefit benefit) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + funcName(externFuncName) {} + + SmallVector createDestOps(TritonOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } + +private: + StringRef funcName; +}; +} // namespace +} // namespace gpu + +} // namespace mlir::triton + +void mlir::triton::MUSA::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, + const TargetInfo &targetInfo, PatternBenefit benefit) { + using namespace mlir::triton::gpu; + + patterns.add>( + typeConverter, axisInfoAnalysis, "__nv_fsqrt_rn", benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, "__nv_fdiv_rn", benefit); + + mlir::triton::populateElementwiseOpToLLVMPatterns( + typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); + + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); + + // ExpOpConversionApprox will try using ex2.approx if the input type is + // FP32. For other input types, ExpOpConversionApprox will return failure and + // ElementwiseOpConversion defined below will call + // __nv_expf for higher-precision calculation + patterns.add(typeConverter, axisInfoAnalysis, benefit); + bool hwNanPropagationSupported = targetInfo.supportMaximumMinimum(); + mlir::triton::populateMinMaxFOpToLLVMPattern( + typeConverter, patterns, axisInfoAnalysis, hwNanPropagationSupported, + benefit); + mlir::triton::populateClampFOpToLLVMPattern( + typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); +} + +void mlir::triton::MUSA::populateClampFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, + PatternBenefit benefit) { + using namespace mlir::triton::gpu; + + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); +} diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/FuncOpToLLVM.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000..d107f26d2 --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,119 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, int numWarps, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + // 1. Modify the function type to add the new argument. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(ptrTy); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add the new argument. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + // 3. Add a new argument to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(ptrTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = funcOp; + if (!LLVM::isKernel(funcOp)) + amendedFuncOp = amendFuncOp(funcOp, rewriter); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + amendedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) { + return failure(); + } + + auto ctx = funcOp->getContext(); + + if (LLVM::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr("mtgpu.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + newFuncOp.setLinkage(LLVM::Linkage::External); + } else { + // The noinline attribute will be used by the LLVM codegen to prevent + // inlining. + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + newFuncOp.setPassthroughAttr( + ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + rewriter.eraseOp(amendedFuncOp); + newFuncOp.setLinkage(LLVM::Linkage::Internal); + } + // Set an attribute for reqntidx, it could be used in latter LLVM codegen + // for `nvvm.annotation` metadata. + newFuncOp->setAttr("mtgpu.maxntid", + rewriter.getI32ArrayAttr(128 * numWarps)); + rewriter.eraseOp(funcOp); + return success(); + } + +private: + int numWarps{0}; +}; + +} // namespace + +void mlir::triton::MUSA::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit) { + patterns.add(typeConverter, numWarps, benefit); +} diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/LoadStoreOpToLLVM.cpp new file mode 100644 index 000000000..e74767771 --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -0,0 +1,540 @@ +#include "TargetInfo.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" + +#include "PatternTritonGPUOpToLLVM.h" + +#include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::MUSA; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::linearize; +using ::mlir::LLVM::MUSA::llLoad; +using ::mlir::LLVM::MUSA::llStore; +using ::mlir::triton::gpu::getCTALayout; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::SharedEncodingAttr; + +namespace { + +// Return the mask for the unique data accessed by given tensor type. +// Used to mask out the redundant data accessed by threads. +Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, + Location loc, const MUSA::TargetInfo &targetInfo) { + auto tensorTy = dyn_cast(valueTy); + Value mask = int_val(1, 1); + auto tid = tid_val(); + auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc); + if (tensorTy) { + auto layout = tensorTy.getEncoding(); + auto shape = tensorTy.getShape(); + unsigned rank = shape.size(); + auto sizePerThread = triton::gpu::getSizePerThread(layout); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); + auto order = triton::gpu::getOrder(layout); + auto warpOrder = triton::gpu::getWarpOrder(layout); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); + auto wrapSizeInt = product(threadsPerWarp); + Value warpSize = i32_val(wrapSizeInt); + Value laneId = urem(tid, warpSize); + Value warpId = udiv(tid, warpSize); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); + SmallVector multiDimThreadId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + for (unsigned dim = 0; dim < rank; ++dim) { + // if there is no data replication across threads on this dimension + if (shape[dim] >= shapePerCTATile[dim]) + continue; + // Otherwise, we need to mask threads that will replicate data on this + // dimension. Calculate the thread index on this dimension for the CTA + Value threadDim = + add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])), + multiDimThreadId[dim]); + mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])), + i32_val(shape[dim]))); + } + // Do not write duplicated data when multicast is enabled + if (triton::gpu::getNumCTAs(layout) > 1) { + auto _0 = i32_val(0); + auto CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); + auto CTASplitNum = triton::gpu::getCTASplitNum(layout); + auto CTAOrder = triton::gpu::getCTAOrder(layout); + + LLVM_DEBUG(DBGS() << "[pattern storeOpConversion] " + << " numCTAS = " << triton::gpu::getNumCTAs(layout)); + auto multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + for (unsigned dim = 0; dim < rank; ++dim) { + // Skip when multicast is not enabled in this dimension + if (CTAsPerCGA[dim] == CTASplitNum[dim]) + continue; + // This wrapping rule must be consistent with emitCTAOffsetForLayout + unsigned splitNum = std::min(shape[dim], CTASplitNum[dim]); + Value repId = udiv(multiDimClusterCTAId[dim], i32_val(splitNum)); + // Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]: + // CTA0 and CTA2 holds data of block0, + // CTA1 and CTA3 holds data of block1. + // Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should + // be masked. We add the following mask: + // multiDimClusterCTAId[dim] / splitNum == 0 + // Actually in all existing cases of multicast, splitNum is always 1. + // The mask is equivalent to: + // multiDimClusterCTAId[dim] == 0 + mask = and_(mask, icmp_eq(repId, _0)); + } + } + } else { + // If the tensor is not ranked, then it is a scalar and only thread 0 of + // CTA0 can write + mask = and_(mask, icmp_eq(clusterCTAId, i32_val(0))); + mask = and_(mask, icmp_eq(tid, i32_val(0))); + } + return mask; +} + +// Contains some helper functions for both Load and Store conversions. +struct LoadStoreConversionBase { + explicit LoadStoreConversionBase(const MUSA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass) + : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {} + + unsigned getContiguity(Value ptr) const { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + return axisAnalysisPass.getPtrContiguity(ptr); + } + + unsigned getVectorSize(Value ptr) const { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto contiguity = getContiguity(ptr); + auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); + LDBG("getVectorSize contiguity = " << contiguity << " pointeeBitWidth = " + << pointeeBitWidth); + // The maximum vector size is 128 bits on MTGPU GPUs. + return std::min(128 / pointeeBitWidth, contiguity); + } + + unsigned getMaskAlignment(Value mask) const { + return axisAnalysisPass.getMaskAlignment(mask); + } + +protected: + const MUSA::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + +struct LoadOpConversion : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + LoadOpConversion(LLVMTypeConverter &converter, + const MUSA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // original values + Value ptr = op.getPtr(); + Value mask = op.getMask(); + Value other = op.getOther(); + + // adaptor values + assert(!isTensorPointerType(ptr.getType()) && + "Cannot convert load with a tensor pointer into LLVM; " + "this case should be transformed to normal load before lowering"); + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + + // Determine the vectorization size + Type valueTy = op.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + unsigned vec = getVectorSize(ptr); + unsigned numElems = getTotalElemsPerThread(ptr.getType()); + if (llMask) + vec = std::min(vec, getMaskAlignment(mask)); + + // Get the LLVM values for pointers + auto ptrElems = unpackLLElements(loc, llPtr, rewriter); + assert(ptrElems.size() == numElems); + + // Get the LLVM values for mask + SmallVector maskElems; + if (llMask) { + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(maskElems.size() == numElems); + } + + // Get the LLVM values for `other` + // TODO: (goostavz) handle when other is const but not splat, which + // should be rarely seen + bool otherIsSplatConstInt = false; + DenseElementsAttr constAttr; + int64_t splatVal = 0; + if (other && isa(valueElemTy) && + matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && + isa(constAttr.getElementType())) { + otherIsSplatConstInt = true; + splatVal = constAttr.getSplatValue().getSExtValue(); + } + SmallVector otherElems; + if (other) { + otherElems = unpackLLElements(loc, llOther, rewriter); + } + + SmallVector loadedVals; + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + // TODO: optimization when ptr is GEP with constant offset + + Value pred = mask ? maskElems[vecStart] : int_val(1, 1); + auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec); + Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]); + + mlir::Attribute zeroAttr = rewriter.getZeroAttr(valueElemTy); + auto denseValue = + DenseElementsAttr::get(cast(vecTy), zeroAttr); + Value zeroVal = rewriter.create(loc, vecTy, denseValue); + + Value falseVal = zeroVal; + // If we need to mask the loaded value with other elements + if (otherElems.size() != 0) { + Value v = undef(vecTy); + for (size_t s = 0; s < vec; ++s) { + Value otherElem = otherElems[vecStart + s]; + Value indexVal = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), s); + v = insert_element(vecTy, v, otherElem, indexVal); + } + falseVal = v; + } + + auto loadVal = llLoad(rewriter, loc, ptr, vecTy, pred, falseVal); + for (size_t ii = 0; ii < vec; ++ii) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec); + Value loaded = extract_element(valueElemTy, loadVal, vecIdx); + loadedVals.push_back(loaded); + } + } // end vec + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), loadedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct StoreOpConversion : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + StoreOpConversion(LLVMTypeConverter &converter, + const MUSA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value ptr = op.getPtr(); + Value value = op.getValue(); + + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llValue = adaptor.getValue(); + + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto valueTy = value.getType(); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + + unsigned vec = getVectorSize(ptr); + unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); + + auto ptrElems = unpackLLElements(loc, llPtr, rewriter); + auto valueElems = unpackLLElements(loc, llValue, rewriter); + assert(ptrElems.size() == valueElems.size()); + + // Determine the vectorization size + SmallVector maskElems; + if (llMask) { + Value mask = op.getMask(); + maskElems = unpackLLElements(loc, llMask, rewriter); + assert(valueElems.size() == maskElems.size()); + + unsigned maskAlign = getMaskAlignment(mask); + vec = std::min(vec, maskAlign); + } + + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); + + auto vecTy = vec_ty(valueElemTy, vec); + for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) { + // TODO: optimization when ptr is AddPtr with constant offset + // TODO(Superjomn) Add cache policy fields to StoreOp. + // TODO(Superjomn) Deal with cache policy here. + + Value storeVal = undef(vecTy); + for (size_t elemIdx = 0; elemIdx < vec; ++elemIdx) { + Value elem = valueElems[vecStart + elemIdx]; + if (elem.getType().isInteger(1)) + elem = sext(i8_ty, elem); + elem = bitcast(elem, valueElemTy); + storeVal = insert_element(vecTy, storeVal, elem, i32_val(elemIdx)); + } + Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask; + auto address = ptrElems[vecStart]; + llStore(rewriter, loc, address, storeVal, maskVal); + } + rewriter.eraseOp(op); + return success(); + } +}; + +static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) { + switch (memOrdering) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + return LLVM::AtomicOrdering::acq_rel; + } +} + +struct AtomicRMWOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + AtomicRMWOpConversion(LLVMTypeConverter &converter, + const MUSA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + /// Try to match the mlir::triton::RMWOp to LLVM::AtomicBinOp. + static std::optional matchAtomicOp(RMWOp atomicOp) { + switch (atomicOp) { + case RMWOp::AND: + return LLVM::AtomicBinOp::_and; + case RMWOp::OR: + return LLVM::AtomicBinOp::_or; + case RMWOp::XOR: + return LLVM::AtomicBinOp::_xor; + case RMWOp::ADD: + return LLVM::AtomicBinOp::add; + case RMWOp::FADD: + return LLVM::AtomicBinOp::fadd; + case RMWOp::MAX: + return LLVM::AtomicBinOp::max; + case RMWOp::MIN: + return LLVM::AtomicBinOp::min; + case RMWOp::UMAX: + return LLVM::AtomicBinOp::umax; + case RMWOp::UMIN: + return LLVM::AtomicBinOp::umin; + case RMWOp::XCHG: + return LLVM::AtomicBinOp::xchg; + default: + return std::nullopt; + } + llvm_unreachable("Invalid RMWOp"); + } + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto funcOp = op->getParentOfType(); + auto atomicRmwAttr = op.getAtomicRmwOp(); + Value ptr = op.getPtr(); + Value val = op.getVal(); + + Value llPtr = adaptor.getPtr(); + Value llVal = adaptor.getVal(); + Value llMask = adaptor.getMask(); + + auto valElements = unpackLLElements(loc, llVal, rewriter); + auto ptrElements = unpackLLElements(loc, llPtr, rewriter); + SmallVector maskElements; + if (llMask) + maskElements = unpackLLElements(loc, llMask, rewriter); + + auto valueTy = op.getType(); + auto tensorTy = dyn_cast(valueTy); + Type valueElemTy = + tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) + : valueTy; + const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth(); + auto elemsPerThread = getTotalElemsPerThread(val.getType()); + // vec = 1, numElements = 1 for scalar + auto vec = getVectorSize(ptr); + int numElems = 1; + // tensor + if (tensorTy) { + auto valTy = cast(val.getType()); + // NV for the f16v2 case generates one packed instruction. + if (funcOp->hasAttr("nvvm.kernel")) + vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + else + vec = std::min(vec, 1); + // mask + numElems = tensorTy.getNumElements(); + } + Value mask = int_val(1, 1); + auto tid = tid_val(); + mask = and_(mask, + icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); + + auto memOrdering = op.getSem(); + auto atomicMemOrdering = getMemoryOrdering(memOrdering); + + auto vecTy = vec_ty(valueElemTy, vec); + auto retType = vec == 1 ? valueElemTy : vecTy; + SmallVector resultVals(elemsPerThread); + const bool f16v2 = vec == 2 && valueElemTy.isF16(); + for (size_t i = 0; i < elemsPerThread; i += vec) { + Value rmwPtr = ptrElements[i]; + // TODO: in case llMask is zero we can create only one branch for all + // elemsPerThread. + Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; + + Value undefVal = undef(retType); + // Build blocks to bypass the atomic instruction for ~rmwMask. + auto *curBlock = rewriter.getInsertionBlock(); + auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); + auto *atomicBlock = rewriter.createBlock( + curBlock->getParent(), std::next(Region::iterator(curBlock))); + endBlock->addArgument({retType}, {loc}); + + rewriter.setInsertionPointToEnd(curBlock); + rewriter.create(loc, rmwMask, atomicBlock, endBlock, + undefVal); + + rewriter.setInsertionPointToEnd(atomicBlock); + auto maybeKind = matchAtomicOp(atomicRmwAttr); + // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient + // atomics for MI-* series of AMD GPU. + Value atom; + if (*maybeKind == LLVM::AtomicBinOp::fadd) { + StringRef funcName; + Type fpType; + if (valueElemTy.isF16()) { + funcName = "__mt_atomicAdd_f16"; + fpType = rewriter.getF16Type(); + } else if (valueElemTy.isF32()) { + funcName = "__mt_atomicAdd_f32"; + fpType = rewriter.getF32Type(); + } else if (valueElemTy.isF64()) { + funcName = "__mt_atomicAdd_f64"; + fpType = rewriter.getF64Type(); + } else { + llvm_unreachable("Invalid value element type."); + return failure(); + } + auto moduleOp = op->getParentOfType(); + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + LLVM::LLVMFuncOp calleeFuncOp = + moduleOp.lookupSymbol(funcName); + if (!calleeFuncOp) { + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto type = LLVM::LLVMFunctionType::get(fpType, {ptrTy, fpType}); + calleeFuncOp = rewriter.create( + loc, funcName, type, LLVM::Linkage::External); + } + rewriter.setInsertionPointToEnd(atomicBlock); + Value addressCast = + rewriter.create(loc, ptrTy, rmwPtr); + atom = + rewriter + .create(loc, calleeFuncOp, + ValueRange{addressCast, valElements[i]}) + .getResult(); + } else { + atom = rewriter + .create(loc, *maybeKind, rmwPtr, + valElements[i], atomicMemOrdering) + .getResult(); + } + + // NV for the f16v2 case generates one packed instruction. We have to + // create two separate instructions since LLVM::AtomicRMWOp doesn't + // support this. Can be optimized out with rocdl.raw.buffer.atomic. + if (f16v2 && funcOp->hasAttr("nvvm.kernel")) { + Value atom2 = + rewriter + .create( + loc, *maybeKind, ptrElements[i + 1], valElements[i + 1], + LLVM::AtomicOrdering::monotonic, StringRef("agent")) + .getResult(); + auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); + atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); + } + rewriter.create(loc, atom, endBlock); + + rewriter.setInsertionPointToStart(endBlock); + Value retVal = endBlock->getArgument(0); + if (tensorTy) { + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 ? retVal + : extract_element(valueElemTy, retVal, i32_val(ii)); + } + } else { + Value atomPtr = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); + store(retVal, atomPtr); + Value ret = load(valueElemTy, atomPtr); + rewriter.replaceOp(op, {ret}); + } + } + if (tensorTy) { + Type structTy = getTypeConverter()->convertType(tensorTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, structTy); + rewriter.replaceOp(op, {resultStruct}); + } + return success(); + } +}; + +} // namespace + +void mlir::triton::MUSA::populateLoadStoreOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); +} diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/MUSATranslation.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/MUSATranslation.cpp new file mode 100644 index 000000000..b29151356 --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/MUSATranslation.cpp @@ -0,0 +1,295 @@ +#include "TritonMTGPUToLLVM/MUSATranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ExecutionEngine/ExecutionEngine.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include +#include +#include +#include +#include +#include +#include + +namespace { + +std::string readStringFromEnv(const std::string &env_name, + const std::string &default_value) { + std::string env_path = mlir::triton::tools::getStrEnv(env_name); + return (!env_path.empty()) ? env_path : default_value; +} + +void execute_llc(const std::string &mtcc_path, + llvm::ArrayRef args) { + auto llc_program = llvm::sys::findProgramByName("llc", {mtcc_path}); + if (!llc_program) { + llvm::errs() << "llc program not found in path: " << mtcc_path << "\n"; + assert("llc program not found in path!"); + } + std::string err_msg; + int ret = llvm::sys::ExecuteAndWait(*llc_program, args, std::nullopt, {}, 0, + 0, &err_msg); + if (ret) { + llvm::errs() << "llc execute fail: " << err_msg << "\n"; + assert("using llc to generate asm or obj failed!"); + } +} + +// convert latest llvm ir to mtcc compatible llvm ir. +// see llvm/docs/ReleaseNotes.rst +void convertLLVMIR(const std::string &filename) { + // LLVM compatible. mtcc dependencies on llvm-14, convert llvm ir to mtcc + // compatible format. + auto make_llvm_compatible = [](std::string &ll_str) { + // clang-format off + std::vector old_format = { + "readnone", + "readonly", + "writeonly", + "argmemonly", + "argmemonly readonly", + "argmemonly writeonly", + "inaccessiblememonly", + "inaccessiblememonly readonly", + "inaccessiblememonly writeonly", + "inaccessiblemem_or_argmemonly", + "inaccessiblemem_or_argmemonly readonly", + "inaccessiblemem_or_argmemonly writeonly" + }; + std::vector new_format = { + "memory\\(none\\)", + "memory\\(read\\)", + "memory\\(write\\)", + "memory\\(argmem: readwrite\\)", + "memory\\(argmem: read\\)", + "memory\\(argmem: write\\)", + "memory\\(inaccessiblemem: readwrite\\)", + "memory\\(inaccessiblemem: read\\)", + "memory\\(inaccessiblemem: write\\)", + "memory\\(argmem: readwrite, inaccessiblemem: readwrite\\)", + "memory\\(argmem: read, inaccessiblemem: read\\)", + "memory\\(argmem: write, inaccessiblemem: write\\)" + }; + // clang-format on + for (int i = 0; i < old_format.size(); ++i) { + ll_str = + std::regex_replace(ll_str, std::regex(new_format[i]), old_format[i]); + } + }; + + // convert latest llvm ir to mtcc compatible llvm ir. + std::ifstream is(filename); + std::string ll_str((std::istreambuf_iterator(is)), + std::istreambuf_iterator()); + is.close(); + make_llvm_compatible(ll_str); + + // save the mtcc compatible llvm ir to ll file. + std::ofstream os(filename); + os << ll_str; + os.close(); + + if (mlir::triton::tools::getBoolEnv("MUSA_LLVMIR_ENABLE_DUMP")) { + std::cout << "// -----// MUSA LLVMIR Dump //----- //\n" + << ll_str << std::endl; + } +} + +std::string generate_muasm(const llvm::Module &llvmModule, + const std::string &opt_option, const int capability, + const int version, std::string &ll_file_name) { + std::string function_name; + std::string ll_file; + std::string asm_file; + + llvm::SmallString<128> kernel; + llvm::sys::fs::createTemporaryFile("mt_triton_kernel", /*suffix*/ "ll", + kernel); + ll_file = llvm::StringRef(kernel).str(); + ll_file_name = ll_file; + llvm::sys::path::replace_extension(kernel, "s"); + asm_file = llvm::StringRef(kernel).str(); + + std::error_code ec; + llvm::raw_fd_ostream os(ll_file, ec, llvm::sys::fs::OF_None); + llvmModule.print(os, nullptr); + os.close(); + + // get the name of mtgpu kernel. + for (auto &F : llvmModule.getFunctionList()) { + if (!F.isDeclaration() && + F.getCallingConv() == llvm::CallingConv::MTGPU_KERNEL) { + function_name = F.getName().str(); + break; + } + } + + // convert latest llvm ir to mtcc compatible llvm ir. + convertLLVMIR(ll_file); + + // because mtcc's building script has an option --disable_asm (default: + // False), which can control mtcc's llc whether can support -filetype=asm or + // not. so here we use an ENV: MTCC_ENABLE_ASM_BIN_PATH to indicate that this + // path's llc can support -filetype=asm. + // + // by default, we use /usr/local/musa/bin/llc, which can't support + // -filetype=asm, so we return the name of mtgpu kernel. otherwise, if we set + // the ENV: MTCC_ENABLE_ASM_BIN_PATH, we will return the generated asm code. + std::string mtcc_enable_asm_bin_path = + readStringFromEnv("MTCC_ENABLE_ASM_BIN_PATH", ""); + + if (!mtcc_enable_asm_bin_path.empty()) { + // set ENV: MTCC_ENABLE_ASM_BIN_PATH, so return the generated asm code. + // llc out.ll -march=mtgpu -O2 -filetype=asm -o out.asm + std::string assign_subtarget = "-mcpu=mp_" + std::to_string(capability); + llvm::SmallVector args{ + llvm::StringRef("llc"), + llvm::StringRef(ll_file), + llvm::StringRef("-march=mtgpu"), + llvm::StringRef(assign_subtarget), + llvm::StringRef("--opaque-pointers"), + llvm::StringRef("-filetype=asm"), + llvm::StringRef("-o"), + llvm::StringRef(asm_file), + llvm::StringRef("-O2"), + llvm::StringRef(opt_option)}; + + // use the mtcc_enable_asm_bin_path's llc to generate asm code. + execute_llc(mtcc_enable_asm_bin_path, args); + + // get the muasm code. + std::ifstream is(asm_file); + std::string muasm((std::istreambuf_iterator(is)), + std::istreambuf_iterator()); + is.close(); + + if (mlir::triton::tools::getBoolEnv("MUASM_ENABLE_DUMP")) { + std::cout << "// -----// MUASM Dump //----- //\n" << muasm << std::endl; + } + + return muasm; + } else { + // by default, /usr/local/musa/bin/llc can't support -filetype=asm, + // so return the name of mtgpu kernel. + return ".globl\t" + function_name; + } +} + +std::string generate_mubin(const std::string &ll_file_name, + const std::string &opt_option, const int capability, + const int version) { + int pos = ll_file_name.find_last_of('.'); + std::string obj_file = ll_file_name.substr(0, pos + 1) + "o"; + std::string lld_obj_file = ll_file_name.substr(0, pos + 1) + "lld.o.mubin"; + + // llc out.ll -march=mtgpu -O2 -filetype=obj -o out.o + std::string assign_subtarget = "-mcpu=mp_" + std::to_string(capability); + llvm::SmallVector args{llvm::StringRef("llc"), + llvm::StringRef(ll_file_name), + llvm::StringRef("-march=mtgpu"), + llvm::StringRef(assign_subtarget), + llvm::StringRef("--opaque-pointers"), + llvm::StringRef("-filetype=obj"), + llvm::StringRef("-o"), + llvm::StringRef(obj_file), + llvm::StringRef("-O2"), + llvm::StringRef(opt_option)}; + + // by default, we use the /usr/local/musa/bin/llc. + // if we set the ENV: MTCC_ENABLE_ASM_BIN_PATH, + // we should keep using the same llc tool with function: generate_muasm + std::string mtcc_path = + readStringFromEnv("MTCC_BIN_PATH", "/usr/local/musa/bin"); + std::string mtcc_enable_asm_bin_path = + readStringFromEnv("MTCC_ENABLE_ASM_BIN_PATH", ""); + + if (!mtcc_enable_asm_bin_path.empty()) { + execute_llc(mtcc_enable_asm_bin_path, args); + } else { + // TODO: pre-install MTCC in docker or build bin in third_party + execute_llc(mtcc_path, args); + } + + // lld -flavor gnu -shared %bin -o %obj + // clang-format off + llvm::SmallVector lld_args{ + llvm::StringRef("ld.lld"), + llvm::StringRef("-flavor"), + llvm::StringRef("gnu"), + llvm::StringRef("-shared"), + llvm::StringRef(obj_file), + llvm::StringRef("-o"), + llvm::StringRef(lld_obj_file) + }; + // clang-format on + auto lld_program = llvm::sys::findProgramByName("ld.lld", {mtcc_path}); + if (!lld_program) { + llvm::errs() << "lld program not found in path: " << mtcc_path << "\n"; + assert("using llc to generate obj failed!"); + } + + std::string err_msg; + int lld_ret = llvm::sys::ExecuteAndWait(*lld_program, lld_args, std::nullopt, + {}, 0, 0, &err_msg); + if (lld_ret) { + llvm::errs() << "lld execute fail: " << err_msg << "\n"; + assert("using llc to generate obj failed!"); + } + + return lld_obj_file; +} + +std::tuple +llir_to_muasm_and_mubin(llvm::Module *module, const std::string &opt_option, + int capability, int version) { + std::string ll_file_name; + auto muasm = + generate_muasm(*module, opt_option, capability, version, ll_file_name); + auto mubin_path = + generate_mubin(ll_file_name, opt_option, capability, version); + + return std::make_tuple(muasm, mubin_path); +} + +} // namespace + +namespace mlir::triton { + +std::tuple +translateLLVMIRToMUBIN(llvm::Module &module, const std::string &opt_option, + int capability, int version) { + auto muCode = + llir_to_muasm_and_mubin(&module, opt_option, capability, version); + return muCode; +} + +} // namespace mlir::triton diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/PatternTritonGPUOpToLLVM.h new file mode 100644 index 000000000..516ff0c25 --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -0,0 +1,62 @@ +#ifndef TRITON_CONVERSION_TRITONMTGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONMTGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H + +#include "TargetInfo.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" + +namespace mlir { +namespace triton { + +namespace MUSA { + +void populateBarrierOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateClusterOpsToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, + const TargetInfo &targetInfo, PatternBenefit benefit); + +void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit); + +void populateTensorPtrOpsToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateClampFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + int computeCapability, + PatternBenefit benefit); + +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit); +} // namespace MUSA +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000..d9fc19809 --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,35 @@ +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetNumProgramsOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::GetNumProgramsOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; + Location loc = op->getLoc(); + assert(op.getAxisAsInt() < 3); + Value blockId = + rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxisAsInt()]); + rewriter.replaceOpWithNewOp(op, i32_ty, blockId); + return success(); + } +}; + +} // namespace + +void mlir::triton::MUSA::populateSPMDOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/TargetInfo.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/TargetInfo.cpp new file mode 100644 index 000000000..9570661f7 --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/TargetInfo.cpp @@ -0,0 +1,109 @@ +#include "TargetInfo.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; + +using mlir::LLVM::getWrappedMultiDimOffset; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; + +namespace mlir::triton::MUSA { + +bool TargetInfo::supportMaximumMinimum() const { + // TODO(lingfeng.qiu): Mtcc currently does not support llvm.minimum and + // llvm.maximum. + return false; +} + +Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { + // TODO(lingfeng.qiu): Figure out whether MTGPU support CTA clusters. + // On AMD hardware we don't have CTA clusters like NVIDIA. So this will always + // be zero. Whoever calling into this should make sure the whole program does + // not try to utilize CTA clusters. + return rewriter.create(loc, 0, 32); +} + +Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, + Type type, Value cmp) const { + auto int32Ty = rewriter.getI32Type(); + return rewriter.create(loc, int32Ty, + rewriter.getI32IntegerAttr(0)); +} + +void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, + Value ptr, Value val, Value pred) const { + mlir::LLVM::MUSA::llStore(rewriter, loc, ptr, val, pred); +} + +Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *converter, Value ptr, + Type elemTy, Value pred) const { + Value falseVal = rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)); + return mlir::LLVM::MUSA::llLoad(rewriter, loc, ptr, elemTy, pred, falseVal); +} + +Value TargetInfo::shuffleXor(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const { + return mlir::LLVM::MUSA::MTGPU_shuffleXor(loc, rewriter, val, i, 128); +} + +Value TargetInfo::shuffleUp(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const { + return mlir::LLVM::MUSA::MTGPU_shuffleUp(loc, rewriter, val, i, 128); +} + +Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const { + return mlir::LLVM::MUSA::MTGPU_shuffleIdx(loc, rewriter, val, i, 128); +} + +Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, Value i) const { + return mlir::LLVM::MUSA::MTGPU_shuffleIdx(loc, rewriter, val, i, 128); +} + +Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, + ModuleOp moduleOp, int axis) const { + return LLVM::MUSA::llGetPid(loc, rewriter, moduleOp, axis); +} + +bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce) const { + return false; +} + +bool TargetInfo::processReplicaUsingStMatrix( + ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + SmallVector &vals, RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, ArrayRef origRepShape, + ArrayRef outOrd, unsigned accumNumReplicates, + int swizzleByteWidth) const { + return false; +} + +std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { + std::string funcName = + resultElementTy.isInteger(32) ? "__mt_umulhi" : "__mt_umul64hi"; + return funcName; +} + +void TargetInfo::printf(ConversionPatternRewriter &rewriter, + Value formatStrStart, int /*formatStrByteCount*/, + ValueRange args) const { + return; +} + +void TargetInfo::assertFail(ConversionPatternRewriter &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const { + return; +} + +} // namespace mlir::triton::MUSA diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/TargetInfo.h b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/TargetInfo.h new file mode 100644 index 000000000..8020a17b9 --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/TargetInfo.h @@ -0,0 +1,63 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOMUSA_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOMUSA_H + +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" + +namespace mlir::triton::MUSA { + +class TargetInfo : public mlir::triton::TargetInfoBase { +public: + TargetInfo(int computeCapability) : computeCapability(computeCapability) {} + + bool supportMaximumMinimum() const override; + + Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; + + Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, + Value cmp) const override; + + void storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value val, Value pred) const override; + Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *converter, Value ptr, Type elemTy, + Value pred) const override; + + Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, Value val, + int i) const override; + Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value i) const override; + + Value programId(ConversionPatternRewriter &rewriter, Location loc, + ModuleOp moduleOp, int axis) const override; + + bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce) const override; + + bool processReplicaUsingStMatrix( + ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + SmallVector &vals, RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, ArrayRef origRepShape, + ArrayRef outOrd, unsigned accumNumReplicates, + int swizzleByteWidth) const override; + + std::string getMulhiFuncName(Type resultElementTy) const override; + + void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const override; + + void assertFail(ConversionPatternRewriter &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const override; + +private: + int computeCapability; +}; + +} // namespace mlir::triton::MUSA + +#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOMUSA_H diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/TritonGPUToLLVM.cpp new file mode 100644 index 000000000..1df2320aa --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/TritonGPUToLLVM.cpp @@ -0,0 +1,216 @@ +#include "TritonMTGPUToLLVM/Passes.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/GPUToMTGPU/GPUToMTGPUPass.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/MTGPUDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "PatternTritonGPUOpToLLVM.h" +#include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTTRITONMTGPUTOLLVM +#include "TritonMTGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton::MUSA; + +namespace { + +// pass ws related named attrs. +static void addAttrs(Operation *op, ArrayRef attrs) { + for (const NamedAttribute attr : attrs) + op->setAttr(attr.getName(), attr.getValue()); +} + +class TritonLLVMFunctionConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + } +}; + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); + } +}; + +struct ConvertTritonMTGPUToLLVM + : public triton::impl::ConvertTritonMTGPUToLLVMBase< + ConvertTritonMTGPUToLLVM> { + using ConvertTritonMTGPUToLLVMBase::ConvertTritonMTGPUToLLVMBase; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + ConvertTritonMTGPUToLLVM(int32_t computeCapability) + : ConvertTritonMTGPUToLLVMBase({computeCapability}) {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + option.overrideIndexBitwidth(32); + TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + + // Allocate shared memory and set barrier + ModuleAllocation allocation(mod); + ModuleMembarAnalysis membarPass(&allocation); + membarPass.run(); + + // Lower functions + { + mlir::LowerToLLVMOptions option(context); + TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + mlir::triton::MUSA::populateFuncOpConversionPattern( + typeConverter, funcPatterns, numWarps, patternBenefitDefault); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); + } + + // initSharedMemory is run before the conversion of call and ret ops, + // because the call op has to know the shared memory base address of each + // function + initSharedMemory(typeConverter); + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + OpBuilder::InsertPoint indexInsertPoint; + + RewritePatternSet patterns(context); + TargetInfo targetInfo(computeCapability); + int benefit = patternBenefitPrioritizeOverLLVMConversions; + mlir::triton::MUSA::populateConvertLayoutOpToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit); + populateDotOpToLLVMPatterns(typeConverter, patterns, benefit); + populateElementwiseOpToLLVMPatterns(typeConverter, patterns, + axisInfoAnalysis, computeCapability, + targetInfo, benefit); + populateClampFOpToLLVMPattern(typeConverter, patterns, axisInfoAnalysis, + computeCapability, + patternBenefitClampOptimizedPattern); + populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, + axisInfoAnalysis, benefit); + mlir::triton::populateReduceOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateScanOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateHistogramOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::MUSA::populateSPMDOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + // TODO(thomas): this should probably be done in a separate step to not + // interfere with our own lowering of arith ops. Add arith/math's patterns + // to help convert scalar expression to LLVM. + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); + + // Native lowering patterns. + mlir::populateGpuToMTGPUConversionPatterns(typeConverter, patterns); + + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); + mlir::triton::populateViewOpToLLVMPatterns(typeConverter, patterns, + benefit); + mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, + patterns, benefit); + mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, + patterns, benefit); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + +private: + void initSharedMemory(LLVMTypeConverter &typeConverter) { + ModuleOp mod = getOperation(); + OpBuilder b(mod.getBodyRegion()); + auto ctx = mod.getContext(); + auto loc = mod.getLoc(); + auto elemTy = typeConverter.convertType(b.getIntegerType(8)); + // Set array size 0 and external linkage indicates that we use dynamic + // shared allocation to allow a larger shared memory size for each kernel. + // + // Ask for 16B alignment on global_smem because that's the largest we should + // ever need (4xi32). + auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); + auto global = b.create( + loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, + "global_smem", /*value=*/Attribute(), /*alignment=*/16, + static_cast( + mlir::MTGPU::MTGPUMemorySpace::kSharedMemorySpace)); + } + + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, + Type promotedType) { + Type tensorPromotedType = cast(operand.getType()) + .cloneWith(std::nullopt, promotedType); + return builder.create(loc, tensorPromotedType, operand); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { + +std::unique_ptr> createConvertTritonMTGPUToLLVMPass() { + return std::make_unique(); +} +std::unique_ptr> +createConvertTritonMTGPUToLLVMPass(int32_t computeCapability) { + return std::make_unique(computeCapability); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/Utility.cpp b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/Utility.cpp new file mode 100644 index 000000000..ac33935b2 --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/Utility.cpp @@ -0,0 +1,162 @@ +#include "Utility.h" +#include "PatternTritonGPUOpToLLVM.h" +#include "mlir/Dialect/LLVMIR/MTGPUDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +using mlir::triton::gpu::appendOrGetExternFuncOp; +using mlir::triton::gpu::getFunctionType; + +namespace { +std::string getTypeString(Type ty) { + std::string str; + llvm::raw_string_ostream rso(str); + ty.print(rso); + rso.flush(); + return str; +} + +std::string mangleFunc(std::string name, Type type) { + auto funcType = dyn_cast(type); + assert(funcType && "Expecting an LLVMFunctionType"); + std::string mangled = name + "_"; + auto retTy = funcType.getReturnType(); + mangled += getTypeString(retTy) + "_"; + auto params = funcType.getParams(); + for (auto paramType : params) { + mangled += getTypeString(paramType) + "_"; + } + return mangled; +} +} // anonymous namespace + +namespace mlir { +namespace LLVM { +namespace MUSA { + +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value value, + Value i, const MTGPU::ShflKind &mode, int widthInt) { + auto valueTy = value.getType(); + unsigned bits = valueTy.getIntOrFloatBitWidth(); + + auto int8Ty = rewriter.getI8Type(); + auto int32Ty = rewriter.getI32Type(); + auto nullPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 5); + Value zero = rewriter.create(loc, int32Ty, + rewriter.getI32IntegerAttr(0)); + Value one = rewriter.create(loc, int32Ty, + rewriter.getI32IntegerAttr(1)); + Value seven = rewriter.create( + loc, int32Ty, rewriter.getI32IntegerAttr(7)); + Value num_128 = rewriter.create( + loc, int32Ty, rewriter.getI32IntegerAttr(128)); + Value width = rewriter.create( + loc, int32Ty, rewriter.getI32IntegerAttr(widthInt)); + Value nullPtr = rewriter.create(loc, nullPtrTy); + Value offset = i; + + Value maskAndClamp; + + if (bits == 64) { + Type vecTy = vec_ty(f32_ty, 2); + Value vec = bitcast(value, vecTy); + Value val0 = extract_element(f32_ty, vec, i32_val(0)); + Value val1 = extract_element(f32_ty, vec, i32_val(1)); + val0 = shuffleCommon(loc, rewriter, val0, i, mode, widthInt); + val1 = shuffleCommon(loc, rewriter, val1, i, mode, widthInt); + vec = undef(vecTy); + vec = insert_element(vecTy, vec, val0, i32_val(0)); + vec = insert_element(vecTy, vec, val1, i32_val(1)); + return bitcast(vec, value.getType()); + } + if (valueTy != i32_ty) { + value = bitcast(value, int_ty(bits)); + if (bits < 32) + value = zext(i32_ty, value); + } + + // maskAndClamp is set to 0 when in 'up' mode. + if (mode == MTGPU::ShflKind::up) { + maskAndClamp = zero; + } else { + Value Clamp = rewriter.create(loc, int32Ty, width, one); + Value SegMask = rewriter.create(loc, int32Ty, num_128, width); + SegMask = rewriter.create(loc, int32Ty, SegMask, seven); + maskAndClamp = rewriter.create(loc, int32Ty, SegMask, Clamp); + } + + // shuffle argument pred is default nullptr if not given. + Value result = rewriter.create(loc, int32Ty, value, offset, + maskAndClamp, mode, nullPtr); + + if (valueTy != i32_ty) { + if (bits < 32) + result = trunc(int_ty(bits), result); + result = bitcast(result, valueTy); + } + + return result; +} + +Value MTGPU_shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width) { + return shuffleCommon(loc, rewriter, val, i32_val(i), MTGPU::ShflKind::bfly, + width); +} + +Value MTGPU_shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width) { + return shuffleCommon(loc, rewriter, val, i32_val(i), MTGPU::ShflKind::up, + width); +} + +Value MTGPU_shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width) { + return shuffleCommon(loc, rewriter, val, i32_val(i), MTGPU::ShflKind::idx, + width); +} + +Value MTGPU_shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + unsigned width) { + return shuffleCommon(loc, rewriter, val, i, MTGPU::ShflKind::idx, width); +} + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis) { + assert(axis >= 0); + assert(axis < 3); + assert(moduleOp); + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; + Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]); + return rewriter.create(loc, i32_ty, blockId); +} + +Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Type elemTy, Value pred, Value falseVal) { + Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal})); + auto parent = ptr.getParentRegion()->getParentOfType(); + auto funcName = mangleFunc(mlir::LLVM::MUSA::Predicated_Load, funcType); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); + auto loadVal = + rewriter + .create(loc, funcOp, ValueRange({ptr, pred, falseVal})) + .getResult(); + return loadVal; +} + +void llStore(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value val, Value pred) { + auto ctx = ptr.getContext(); + Type funcType = getFunctionType(void_ty(ctx), ValueRange({ptr, val, pred})); + auto parent = ptr.getParentRegion()->getParentOfType(); + auto funcName = mangleFunc(mlir::LLVM::MUSA::Predicated_Store, funcType); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, parent, funcName, funcType); + rewriter.create(loc, funcOp, ValueRange({ptr, val, pred})); +} + +} // namespace MUSA +} // namespace LLVM +} // namespace mlir diff --git a/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/Utility.h b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/Utility.h new file mode 100644 index 000000000..7e989f5ae --- /dev/null +++ b/third_party/mthreads/plugin/lib/TritonMTGPUToLLVM/Utility.h @@ -0,0 +1,49 @@ +#ifndef TRITON_CONVERSION_TRITONMTGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONMTGPU_TO_LLVM_UTILITY_H + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" + +#define DEBUG_TYPE "ttgpu_to_llvm" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace LLVM { + +namespace MUSA { +const char Predicated_Load[] = "__predicated_load"; +const char Predicated_Store[] = "__predicated_store"; + +// Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr); +Value MTGPU_shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width); +Value MTGPU_shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width); +Value MTGPU_shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i, + unsigned width); +Value MTGPU_shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i, + unsigned width); + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis); + +// Loads from shared or global memory with predication. +// `otherElems` is used to mask out the elements that are not loaded +Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Type elemTy, Value pred, Value falseVal); + +// Stores to shared or global memory with predication. +void llStore(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value val, Value pred); +} // namespace MUSA +} // namespace LLVM + +} // namespace mlir + +#endif diff --git a/third_party/mthreads/plugin/triton_mthreads.cc b/third_party/mthreads/plugin/triton_mthreads.cc new file mode 100644 index 000000000..3dbc7971f --- /dev/null +++ b/third_party/mthreads/plugin/triton_mthreads.cc @@ -0,0 +1,93 @@ +#include "TritonMTGPUToLLVM/MUSATranslation.h" +#include "TritonMTGPUToLLVM/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/MTGPU/MTGPUToLLVMIRTranslation.h" +#include "passes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" +#include +#include +#include + +#ifdef _WIN32 +#define PLUGIN_EXPORT __declspec(dllexport) +#else +#define PLUGIN_EXPORT __attribute__((visibility("default"))) +#endif + +namespace py = pybind11; + +using namespace mlir; + +PLUGIN_EXPORT void init_triton_mthreads_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton; + + // ttgir -> llvm dialect + m.def("add_to_llvmir", [](mlir::PassManager &pm, int32_t capability) { + pm.addPass(mlir::triton::createConvertTritonMTGPUToLLVMPass(capability)); + }); + m.def("add_mtgpu_builtin_func_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createConvertMTGPUBuiltinFuncToLLVMPass()); + }); +} + +PLUGIN_EXPORT void init_triton_mthreads(py::module &&m) { + using ret = py::return_value_policy; + + auto passes = m.def_submodule("passes"); + init_triton_mthreads_passes_ttgpuir(passes.def_submodule("ttgpuir")); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + mlir::registerMTGPUDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + m.def( + "translate_llvmir_to_mubin", + [](const std::string llvmIR, const std::string opt_option, int capability, + int version) -> std::tuple { + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + // translate module to mubin + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + auto mubinCode = triton::translateLLVMIRToMUBIN(*module, opt_option, + capability, version); + return mubinCode; + }, + ret::take_ownership); + m.def("attach_datalayout", [](llvm::Module &module) { + const std::string dataLayout = "e-p:64:64:64:64-" + "p1:64:64:64:64-" + "p2:64:64:64:64-" + "p3:32:32-" + "p4:32:32-" + "p5:64:64-" + "i64:64-" + "v16:16-" + "v24:32-" + "v32:32-" + "v48:64-" + "v96:128"; + module.setDataLayout(dataLayout); + }); +} + +extern "C" { +PLUGIN_EXPORT void registerMthreadsPasses() { + mlir::triton::registerConvertTritonMTGPUToLLVM(); + mlir::triton::registerConvertMTGPUBuiltinFuncToLLVM(); +} +} diff --git a/third_party/mthreads/python/src/interpreter.cc b/third_party/mthreads/python/src/interpreter.cc new file mode 100644 index 000000000..6ab7c6c75 --- /dev/null +++ b/third_party/mthreads/python/src/interpreter.cc @@ -0,0 +1,435 @@ +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace { + +enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; + +enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; + +std::map mem_semantic_map = { + {MemSemantic::ACQUIRE_RELEASE, __ATOMIC_ACQ_REL}, + {MemSemantic::ACQUIRE, __ATOMIC_ACQUIRE}, + {MemSemantic::RELEASE, __ATOMIC_RELEASE}, + {MemSemantic::RELAXED, __ATOMIC_RELAXED}, +}; + +// Use compiler builtin atomics instead of std::atomic which requires +// each variable to be declared as atomic. +// Currently work for clang and gcc. +template T atomic_cmp(T *ptr, T val, int order) { + auto cmp = [](T old, T val) { + if constexpr (is_min) { + return old > val; + } else { + return old < val; + } + }; + // First load + T old_val = __atomic_load_n(ptr, order); + while (cmp(old_val, val)) { + if (__atomic_compare_exchange(ptr, &old_val, &val, false, order, order)) { + break; + } + } + return old_val; +} + +template T atomic_fadd(T *ptr, T val, int order) { + T old_val; + T new_val; + // First load + // Load ptr as if uint32_t or uint64_t and then memcpy to T + if constexpr (sizeof(T) == 4) { + uint32_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); + std::memcpy(&old_val, &tmp, sizeof(T)); + } else if constexpr (sizeof(T) == 8) { + uint64_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); + std::memcpy(&old_val, &tmp, sizeof(T)); + } else { + throw std::invalid_argument("Unsupported data type"); + } + while (true) { + new_val = old_val + val; + if (__atomic_compare_exchange(ptr, &old_val, &new_val, false, order, + order)) { + break; + } + } + return old_val; +} + +class AtomicOp { +public: + AtomicOp(const uint64_t *ptr, size_t numel, int order) + : ptr(ptr), numel(numel), order(order) {} + + void apply() { + for (size_t i = 0; i < numel; ++i) { + applyAt(reinterpret_cast(ptr[i]), i); + } + } + + virtual ~AtomicOp() = default; + +protected: + virtual void applyAt(void *, size_t i) = 0; + + const uint64_t *ptr; + size_t numel; + int order; +}; + +template class AtomicRMWOpBase : public AtomicOp { +public: + AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret, + const bool *mask, size_t numel, int order) + : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {} + +protected: + void applyAt(void *loc, size_t i) override final { + if (mask[i]) { + *(static_cast(ret) + i) = + applyAtMasked(static_cast(loc), + *(static_cast(val) + i), order); + } + } + + virtual DType applyAtMasked(DType *loc, const DType value, int order) = 0; + + const void *val; + void *ret; + const bool *mask; +}; + +template +class AtomicRMWOp : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_add(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_fadd(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_and(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_or(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_xor(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_exchange_n(loc, value, order); + } +}; + +class AtomicCASOp : public AtomicOp { +public: + AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, + size_t itemsize, size_t numel, int order) + : AtomicOp(ptr, numel, order), expected(expected), desired(desired), + itemsize(itemsize) {} + +protected: + void applyAt(void *loc, size_t i) override { + // Atomic operations perform bitwise comparison, so it's safe to + // use number of bytes (itemsize) to determine the type of pointers + if (itemsize == 1) { + uint8_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 2) { + uint16_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 4) { + uint32_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 8) { + uint64_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else { + // The ‘__atomic’ builtins can be used with any integral scalar or pointer + // type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are + // also allowed if ‘__int128’ (see 128-bit Integers) is supported by the + // architecture. + // https://gcc.gnu.org/onlinedocs/gcc/_005f_005fatomic-Builtins.html + throw std::invalid_argument("Invalid byte size"); + } + } + +private: + void *expected; + const void *desired; + size_t itemsize; +}; + +// This is a workaround because explicit template parameter list for lambdas is +// a C++20 extension: +// auto try_make_op = [&]() { +// if (dtype.is(pybind11::dtype::of())) { +// atomic_op = std::make_unique>(ptr, val, ret, mask, +// numel, order); +// } +// }; +template struct OpCreator { + pybind11::dtype dtype; + const uint64_t *ptr; + const void *val; + void *ret; + const bool *mask; + size_t numel; + int order; + std::unique_ptr &atomic_op; + + template void create() { + if (!atomic_op && dtype.is(pybind11::dtype::of())) { + atomic_op = std::make_unique>(ptr, val, ret, mask, + numel, order); + } + } +}; + +template +std::unique_ptr +makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, + void *ret, const bool *mask, size_t numel, int order) { + // Iterate over all supported data types, make one that matches, and return + std::unique_ptr atomic_op; + OpCreator try_make_op{dtype, ptr, val, ret, + mask, numel, order, atomic_op}; + + (try_make_op.template create(), ...); + if (!atomic_op) { + throw std::invalid_argument("Unsupported data type"); + } + // Make it a unique_ptr + return atomic_op; +} + +} // namespace + +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "RMW_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX) + .export_values(); + + m.def("load", + [](py::array_t ptr, py::array_t mask, py::array other, + py::dtype ret_dtype) -> py::array { + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptr.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } + return ret.reshape(shape); + }); + + m.def("store", + [](py::array_t ptr, py::array value, py::array_t mask) { + int numel = ptr.size(); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_value = value.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) { + memcpy(reinterpret_cast(reshaped_ptr.mutable_at(i)), + reshaped_value.data(i), value.dtype().itemsize()); + } + } + }); + + m.def("atomic_rmw", + [](RMWOp rmw_op, py::array_t ptr, py::array val, + py::array_t mask, MemSemantic sem) -> py::array { + int order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = val.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto *ptr_data = reshaped_ptr.data(); + auto *mask_data = reshaped_mask.data(); + auto *val_data = static_cast(reshaped_val.data()); + auto *ret_data = static_cast(ret.mutable_data()); + + std::unique_ptr atomic_op; + +#define MAKE_ATOMIC_RMW_OP(OP_NAME, ...) \ + case OP_NAME: \ + atomic_op = makeAtomicRMWOp( \ + ret_dtype, ptr_data, val_data, ret_data, mask_data, numel, order); \ + break; + + switch (rmw_op) { + MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, float, double) + MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MAX, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMAX, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MIN, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMIN, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XCHG, int32_t, uint32_t, int64_t, + uint64_t) + default: + throw std::invalid_argument("Unsupported RMW operation"); + } + +#undef MAKE_ATOMIC_RMW_OP + + atomic_op->apply(); + return ret.reshape(shape); + }); + + m.def("atomic_cas", + [](py::array_t ptr, py::array &cmp, py::array &val, + MemSemantic sem) -> py::array { + int order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = cmp.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array reshaped_cmp = cmp.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto itemsize = cmp.itemsize(); + memcpy(static_cast(ret.mutable_data()), + static_cast(reshaped_cmp.data()), + itemsize * numel); + AtomicCASOp(reshaped_ptr.data(), ret.mutable_data(), + static_cast(reshaped_val.data()), itemsize, + numel, order) + .apply(); + return ret.reshape(shape); + }); +} diff --git a/third_party/mthreads/python/src/ir.cc b/third_party/mthreads/python/src/ir.cc new file mode 100644 index 000000000..0befdc491 --- /dev/null +++ b/third_party/mthreads/python/src/ir.cc @@ -0,0 +1,1646 @@ +#include +#include +#include + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/LocationSnapshot.h" +#include "mlir/Transforms/Passes.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace { + +namespace py = pybind11; +using namespace mlir; +using namespace triton; + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + OpBuilder &getBuilder() { return *builder; } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(FileLineColLoc::get(context, fileName, line, column)); + } + + Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +}; + +std::string locationToString(Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +void outputWarning(Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), + /*stack_level=*/2); +} + +} // anonymous namespace + +/*****************************************************************************/ +/* Python bindings for ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::enum_(m, "PADDING_OPTION", py::module_local()) + .value("PAD_ZERO", PaddingOption::PAD_ZERO) + .value("PAD_NAN", PaddingOption::PAD_NAN) + .export_values(); + + py::enum_(m, "CACHE_MODIFIER", py::module_local()) + .value("NONE", CacheModifier::NONE) + .value("CA", CacheModifier::CA) + .value("CG", CacheModifier::CG) + .value("WB", CacheModifier::WB) + .value("CS", CacheModifier::CS) + .value("WT", CacheModifier::WT) + .export_values(); + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", MemSyncScope::GPU) + .value("CTA", MemSyncScope::CTA) + .value("SYSTEM", MemSyncScope::SYSTEM) + .export_values(); + + py::enum_(m, "EVICTION_POLICY", py::module_local()) + .value("NORMAL", EvictionPolicy::NORMAL) + .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) + .export_values(); + + py::enum_(m, "ATOMIC_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX); + + py::enum_(m, "ROUNDING_MODE", py::module_local()) + .value("RTZ", RoundingMode::RTZ) + .value("RTNE", RoundingMode::RTNE); + + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", PropagateNan::NONE) + .value("ALL", PropagateNan::ALL); + + py::enum_(m, "INPUT_PRECISION", py::module_local()) + .value("TF32", InputPrecision::TF32) + .value("TF32x3", InputPrecision::TF32x3) + .value("IEEE", InputPrecision::IEEE) + .export_values(); + + py::class_(m, "context", py::module_local()).def(py::init<>()); + + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "type", py::module_local()) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) + .def("is_fp16", &Type::isF16) + .def("__str__", [](Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type", py::module_local()) + .def("param_types", [](FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); + + py::class_(m, "location", py::module_local()) + .def("__str__", [](Location &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "value", py::module_local()) + .def("set_attr", + [](Value &self, std::string &name, Attribute &attr) -> void { + if (Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + auto arg = mlir::cast(self); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } + }) + .def("get_context", &Value::getContext) + .def("replace_all_uses_with", + [](Value &self, Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + .def("get_type", &Value::getType) + .def("id", [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }); + + py::class_(m, "op_result", py::module_local()); + + py::class_(m, "block_argument", py::module_local()); + + py::class_(m, "region", py::module_local()) + .def("get_parent_region", &Region::getParentRegion, ret::reference) + .def("size", [](Region &self) { return self.getBlocks().size(); }) + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }); + + py::class_(m, "block", py::module_local()) + .def("arg", + [](Block &self, int index) -> BlockArgument { + if (index >= self.getNumArguments()) + throw pybind11::index_error("Block argument index out of range"); + return self.getArgument(index); + }) + .def("add_argument", + [](Block &self, Type ty) { + auto loc = UnknownLoc::get(ty.getContext()); + self.addArgument(ty, loc); + }) + .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) + .def("dump", &Block::dump) + .def("move_before", + [](Block &self, Block &dst) { self.moveBefore(&dst); }) + .def("insert_before", &Block::insertBefore) + .def("get_parent", &Block::getParent, ret::reference) + .def("merge_block_before", + [](Block &self, Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error( + "This block has arguments, don't merge"); + dst.getOperations().splice(dst.begin(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", + [](Block &self, Value &v, Value &newVal) { + v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = + currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; + }); + }) + .def("__str__", + [](Block &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return str; + }) + .def("has_terminator", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("has_return", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); + + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", py::module_local()); + py::class_(m, "bool_attr", py::module_local()); + + // Ops + py::class_(m, "OpState", py::module_local()) + .def("set_attr", + [](OpState &self, std::string &name, Attribute &attr) -> void { + self->setAttr(name, attr); + }) + .def("get_num_results", + [](OpState &self) -> unsigned { return self->getNumResults(); }) + .def("get_result", + [](OpState &self, unsigned idx) -> Value { + if (idx >= self->getNumResults()) + throw pybind11::index_error("Op result index out of range"); + return self->getResult(idx); + }) + .def( + "get_region", + [](OpState &self, unsigned idx) -> Region & { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self->getRegion(idx); + }, + ret::reference) + .def( + "get_body", + [](scf::ForOp &self, unsigned idx) -> Block * { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self.getBody(idx); + }, + ret::reference) + .def("dump", [](OpState &self) { self->dump(); }) + .def("__str__", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self->print(os, printingFlags); + return str; + }) + .def("append_operand", + [](OpState &self, Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) + .def("verify", [](OpState &self) -> bool { + return succeeded(verify(self.getOperation())); + }); + // scf Ops + py::class_(m, "ForOp", py::module_local()) + .def("get_induction_var", &scf::ForOp::getInductionVar); + + py::class_(m, "IfOp", py::module_local()) + .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &scf::IfOp::thenYield) + .def("get_else_yield", &scf::IfOp::elseYield); + py::class_(m, "YieldOp", py::module_local()); + py::class_(m, "WhileOp", py::module_local()) + .def("get_before", &scf::WhileOp::getBefore, ret::reference) + .def("get_after", &scf::WhileOp::getAfter, ret::reference); + py::class_(m, "ConditionOp", py::module_local()); + + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + + // dynamic_attr is used to transfer ownership of the MLIR context to the + // module + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) + .def("dump", &ModuleOp::dump) + .def("str", + [](ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + self.print(os, printingFlags); + return str; + }) + .def("push_back", + [](ModuleOp &self, FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("has_function", + [](ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", + [](ModuleOp &self, std::string &funcName) -> FuncOp { + return self.lookupSymbol(funcName); + }) + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("create_location_snapshot", + [](ModuleOp &self, const std::string &fileName) -> void { + generateLocationsFromIR(/*raw_ostream=*/llvm::nulls(), + /*fileName=*/fileName, + /*op=*/self, /*flags=*/{}); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); + + m.def("make_attr", [](const std::vector &values, MLIRContext &context) { + return mlir::cast(DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + IntegerType::get(&context, 32)), + values)); + }); + + m.def( + "parse_mlir_module", + [](const std::string &inputFilename, MLIRContext &context) { + // parse module + OwningOpRef module = + parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); + return module->clone(); + }, + ret::take_ownership); + + py::class_(m, "function", py::module_local()) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", + [](FuncOp &self, unsigned idx) -> BlockArgument { + if (idx >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + return self.getArgument(idx); + }) + .def( + "add_entry_block", + [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, + ret::reference) + .def( + "set_arg_attr", + [](FuncOp &self, int arg_no, const std::string &name, int val) { + // set arg attributes "name" to value "val" + auto attrTy = IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); + }, + ret::reference) + // .def("has_attr", &::FuncOp::hasAttr) + .def("finalize", + [](FuncOp &self) -> void { + // Remove dead code + // 1. Unreachable code after return + self.walk([&](Block *block) { + Operation *retOp = nullptr; + // It's better to not use walk here because we only want to + // check operations in the current block + for (auto &op : block->getOperations()) { + if (isa(op)) + if (retOp == nullptr) { + retOp = &op; + break; + } + } + if (retOp && retOp != &block->back()) { + auto pos = retOp->getIterator(); + pos++; + auto *newBlock = block->splitBlock(pos); + newBlock->erase(); + } + }); + // 2. Check if the result of tl.advance is used + self.walk([&](Operation *op) { + if (isa(op) && op->getResult(0).use_empty()) + outputWarning(op->getLoc(), "The result of tl.advance is not " + "being used. Note that tl.advance " + "does not have any side effects. " + "To move the block pointer, you " + "need to assign the result of " + "tl.advance to a variable."); + }); + }) + .def_property_readonly("type", &FuncOp::getFunctionType) + .def("reset_type", &FuncOp::setType); + + py::class_(m, "InsertPoint", py::module_local()); + + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + // getters + .def("create_module", + [](TritonOpBuilder &self) -> ModuleOp { + return self.create(); + }) + // insertion block/point + .def("set_insertion_point_to_start", + [](TritonOpBuilder &self, Block &block) -> void { + self.setInsertionPointToStart(block); + }) + .def("set_insertion_point_to_end", + [](TritonOpBuilder &self, Block &block) { + self.setInsertionPointToEnd(block); + }) + .def("set_insertion_point_after", + [](TritonOpBuilder &self, Operation &op) { + self.setInsertionPointAfter(op); + }) + .def( + "get_insertion_block", + [](TritonOpBuilder &self) -> Block * { + return self.getBuilder().getInsertionBlock(); + }, + ret::reference) + .def("get_insertion_point", + [](TritonOpBuilder &self) { + return self.getBuilder().saveInsertionPoint(); + }) + .def("restore_insertion_point", + [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { + self.restoreInsertionPoint(pt); + }) + // Attr + .def("get_bool_attr", + [](TritonOpBuilder &self, bool value) { + return self.getBuilder().getBoolAttr(value); + }) + .def("get_int32_attr", + [](TritonOpBuilder &self, int32_t value) { + return self.getBuilder().getI32IntegerAttr(value); + }) + // Use arith.ConstantOp to create constants + // Constants + .def("get_int1", + [](TritonOpBuilder &self, bool v) -> Value { + return Value(self.create( + v, self.getBuilder().getI1Type())); + }) + .def("get_int8", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_int16", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_int32", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_int64", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_uint8", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_uint16", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_uint32", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_uint64", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_bf16", + [](TritonOpBuilder &self, float v) -> Value { + auto type = self.getBuilder().getBF16Type(); + return self.create( + APFloat(type.getFloatSemantics(), std::to_string(v)), type); + }) + .def("get_fp16", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF16FloatAttr(v)); + }) + .def("get_fp32", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF32FloatAttr(v)); + }) + .def("get_fp64", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + self.getBuilder().getF64FloatAttr(v)); + }) + .def("get_null_value", + [](TritonOpBuilder &self, Type type) -> Value { + if (auto floatTy = dyn_cast(type)) + return self.create( + APFloat(floatTy.getFloatSemantics(), 0), floatTy); + else if (auto intTy = dyn_cast(type)) + return self.create(0, intTy); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](TritonOpBuilder &self, Type type) -> Value { + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = dyn_cast(type)) + return self.create(val, intTy); + else + throw std::runtime_error("Not implemented"); + }) + + // Types + .def("get_void_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getNoneType(); + }) + .def("get_int1_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI1Type(); + }) // or ret::copy? + .def("get_int8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_int16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(16); + }) + .def("get_int32_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI32Type(); + }) + .def("get_int64_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI64Type(); + }) + .def("get_fp8e4nv_ty", + // TODO: fp8e4nv is using Float8E4M3FNUZType, which + // does not seem right. It should use FloatE4M3FNType + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b15_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_fp8e5_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_half_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF16Type(); + }) + .def("get_bf16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getBF16Type(); + }) + .def("get_float_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF32Type(); + }) + .def("get_double_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF64Type(); + }) + .def("get_ptr_ty", + [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { + return PointerType::get(type, addrSpace); + }) + .def("get_block_ty", + [](TritonOpBuilder &self, Type &elementType, + std::vector &shape) -> Type { + return RankedTensorType::get(shape, elementType); + }) + .def("get_function_ty", + [](TritonOpBuilder &self, std::vector inTypes, + std::vector outTypes) -> Type { + return self.getBuilder().getFunctionType(inTypes, outTypes); + }) + // locs + .def("set_loc", + [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) + .def("set_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) { self.setLastLoc(fileName, line, column); }) + .def("get_loc", + [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) + + // Ops + .def("get_or_insert_function", + [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, + Type &funcType, std::string &visibility, + bool noinline) -> FuncOp { + if (Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + if (auto funcTy = dyn_cast(funcType)) { + llvm::SmallVector attrs = { + NamedAttribute( + self.getBuilder().getStringAttr("sym_visibility"), + self.getBuilder().getStringAttr(visibility)), + NamedAttribute(self.getBuilder().getStringAttr("noinline"), + self.getBuilder().getBoolAttr(noinline))}; + return self.create(funcName, funcTy, attrs); + } + throw std::invalid_argument("invalid function type"); + }) + .def( + "create_block", + [](TritonOpBuilder &self) -> Block * { + Region *parent = self.getBuilder().getBlock()->getParent(); + return self.getBuilder().createBlock(parent); + }, + ret::reference) + .def( + "create_block_with_parent", + [](TritonOpBuilder &self, Region &parent, + std::vector &argTypes) -> Block * { + // TODO: update arg loc + auto loc = self.getBuilder().getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), loc); + return self.getBuilder().createBlock(&parent, {}, argTypes, + argLocs); + }, + ret::reference) + .def( + "new_block", + [](TritonOpBuilder &self) -> Block * { return new Block(); }, + ret::reference) + // Function + .def("ret", + [](TritonOpBuilder &self, std::vector &vals) -> OpState { + return self.create(vals); + }) + .def("call", + [](TritonOpBuilder &self, FuncOp &func, std::vector &args) + -> OpState { return self.create(func, args); }) + // Unstructured control flow + .def("create_cond_branch", + [](TritonOpBuilder &self, Value condition, Block *trueDest, + Block *falseDest) -> OpState { + return self.create(condition, trueDest, + falseDest); + }) + .def("create_branch", + [](TritonOpBuilder &self, Block *dest, std::vector &args) + -> OpState { return self.create(dest, args); }) + // Structured control flow + .def("create_for_op", + [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, + std::vector &initArgs) -> scf::ForOp { + return self.create(lb, ub, step, initArgs); + }) + .def("create_if_op", + [](TritonOpBuilder &self, std::vector &retTypes, + Value &condition, bool withElse) -> scf::IfOp { + return self.create(retTypes, condition, withElse); + }) + .def("create_yield_op", + [](TritonOpBuilder &self, std::vector &yields) + -> scf::YieldOp { return self.create(yields); }) + .def("create_while_op", + [](TritonOpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> scf::WhileOp { + return self.create(retTypes, initArgs); + }) + .def("create_condition_op", + [](TritonOpBuilder &self, Value &cond, + std::vector &args) -> scf::ConditionOp { + return self.create(cond, args); + }) + + // miscellaneous + .def("create_make_range", + [](TritonOpBuilder &self, int start, int end) -> Value { + auto retType = RankedTensorType::get( + {end - start}, self.getBuilder().getI32Type()); + return self.create(retType, start, end); + }) + + // Cast instructions + // Conversions for custom FP types (FP8 and non-standard rounding modes) + .def("create_fp_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType, + std::optional roundingMode) -> Value { + if (roundingMode.has_value()) + return self.create( + dstType, src, + RoundingModeAttr::get(self.getBuilder().getContext(), + roundingMode.value())); + else + return self.create(dstType, src); + }) + // Conversions for standard LLVM builtin types + .def("create_bitcast", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_si_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_ui_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_si", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_ui", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_ext", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_trunc", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_int_cast", + [](TritonOpBuilder &self, Value &src, Type &dstType, + bool isSigned) -> Value { + // get element type if necessary + Type srcType = src.getType(); + auto srcTensorType = dyn_cast(srcType); + auto dstTensorType = dyn_cast(dstType); + Type srcEltType = srcType; + Type dstEltType = dstType; + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); + } + unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); + unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); + if (srcWidth == dstWidth) + return self.create(dstType, src); + else if (srcWidth > dstWidth) + return self.create(dstType, src); + else if (isSigned) + return self.create(dstType, src); + else + return self.create(dstType, src); + }) + .def("create_to_index", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getIndexType(), input); + }) + .def("create_index_to_si", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getI64Type(), input); + }) + .def("create_fmul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_frem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fadd", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fsub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_umulhi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_udiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_srem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_urem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_fma", + [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { + return Value(self.create(a, b, c)); + }) + .def("create_shl", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_lshr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_ashr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minimumf follows the torch.minimum convention and returns NaN if either + // operand is NaN + .def("create_minimumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minnumf follows the torch.fmin convention and returns the non-NaN + // operand + .def("create_minnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maximumf follows the torch.maximum convention and returns NaN if either + // operand is NaN + .def("create_maximumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maxnumf follows the torch.fmax convention and returns the non-NaN + // operand + .def("create_maxnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_clampf", + [](TritonOpBuilder &self, Value &input, Value &min, Value &max, + PropagateNan propagateNan) -> Value { + return Value(self.create(input, min, max, propagateNan)); + }) + .def("create_precise_sqrt", + [](TritonOpBuilder &self, Value &input) -> Value { + return Value(self.create(input)); + }) + .def("create_precise_divf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // AddPtr (similar to GEP) + .def("create_addptr", + [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { + return self.create(ptr.getType(), ptr, offset); + }) + // Comparison (int) + .def("create_icmpSLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sle, lhs, + rhs); + }) + .def("create_icmpSLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::slt, lhs, + rhs); + }) + .def("create_icmpSGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sge, lhs, + rhs); + }) + .def("create_icmpSGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sgt, lhs, + rhs); + }) + .def("create_icmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ule, lhs, + rhs); + }) + .def("create_icmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ult, lhs, + rhs); + }) + .def("create_icmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::uge, lhs, + rhs); + }) + .def("create_icmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ugt, lhs, + rhs); + }) + .def("create_icmpEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::eq, lhs, + rhs); + }) + .def("create_icmpNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ne, lhs, + rhs); + }) + // Comparison (float) + .def("create_fcmpOLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLT, lhs, + rhs); + }) + .def("create_fcmpOGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGT, lhs, + rhs); + }) + .def("create_fcmpOLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLE, lhs, + rhs); + }) + .def("create_fcmpOGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGE, lhs, + rhs); + }) + .def("create_fcmpOEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OEQ, lhs, + rhs); + }) + .def("create_fcmpONE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ONE, lhs, + rhs); + }) + .def("create_fcmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULT, lhs, + rhs); + }) + .def("create_fcmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGT, lhs, + rhs); + }) + .def("create_fcmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULE, lhs, + rhs); + }) + .def("create_fcmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGE, lhs, + rhs); + }) + .def("create_fcmpUEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UEQ, lhs, + rhs); + }) + .def("create_fcmpUNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UNE, lhs, + rhs); + }) + // // Logical + .def("create_and", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_xor", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_or", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + // Input/Output + .def("create_load", + [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_store", + [](TritonOpBuilder &self, Value &ptrs, Value &value, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, value, cacheModifier, evictionPolicy); + }) + .def("create_tensor_pointer_load", + [](TritonOpBuilder &self, Value &ptr, + std::vector &boundaryCheck, + std::optional paddingOption, + CacheModifier cacheModifier, EvictionPolicy evictionPolicy, + bool isVolatile) -> Value { + return self.create(ptr, boundaryCheck, paddingOption, + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_tensor_pointer_store", + [](TritonOpBuilder &self, Value &ptr, Value &val, + std::vector &boundaryCheck, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptr, val, boundaryCheck, cacheModifier, + evictionPolicy); + }) + .def("create_masked_load", + [](TritonOpBuilder &self, Value &ptrs, Value &mask, + std::optional &other, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_masked_store", + [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, val, mask, cacheModifier, + evictionPolicy); + }) + .def("create_descriptor_load", + [](TritonOpBuilder &self, Value &desc_ptr, + std::vector &indices, Type type, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> Value { + return self.create( + type, desc_ptr, indices, cacheModifier, evictionPolicy); + }) + .def("create_descriptor_store", + [](TritonOpBuilder &self, Value &desc_ptr, Value value, + std::vector &indices) -> void { + self.create(desc_ptr, value, + indices); + }) + .def("create_reshape", + [](TritonOpBuilder &self, Value &arg, std::vector &shape, + bool allowReorder) -> Value { + auto argType = + cast(arg.getType()).getElementType(); + return self.create( + RankedTensorType::get(shape, argType), arg, allowReorder); + }) + .def("create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + std::vector retShape = argType.getShape(); + retShape.insert(retShape.begin() + axis, 1); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, axis); + }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create( + RankedTensorType::get(shape, lhsType.getElementType()), lhs, + rhs); + }) + .def("create_join", + [](TritonOpBuilder &self, Value &a, Value &b) -> Value { + return self.create(a, b); + }) + .def("create_split", + [](TritonOpBuilder &self, Value &a) -> std::vector { + auto op = self.create(a); + return std::vector(op->result_begin(), op->result_end()); + }) + // Implements tl.trans and tl.permute. + .def("create_trans", + [](TritonOpBuilder &self, Value &arg, + std::vector &order) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + auto retShape = applyPermutation(argType.getShape(), order); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, order); + }) + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + if (auto argType = dyn_cast(arg.getType())) + return self.createOrFold( + RankedTensorType::get(shape, argType.getElementType()), arg); + throw std::invalid_argument( + "arg is not of RankedTensorType, use create_splat"); + }) + .def("create_splat", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + auto argType = arg.getType(); + auto ret = self.createOrFold( + RankedTensorType::get(shape, argType), arg); + return ret; + }) + // // atomic + .def("create_atomic_cas", + [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, + MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, ptr, cmp, val, sem, + scope); + }) + .def("create_atomic_rmw", + [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, + Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, rmwOp, ptr, val, mask, + sem, scope); + }) + // External + .def("create_extern_elementwise", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, Type retType, bool isPure) -> Value { + return self.create(retType, argList, libName, + libPath, symbol, isPure); + }) + // Built-in instruction + .def("create_get_program_id", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create( + self.getBuilder().getI32Type(), + ProgramIDDimAttr::get(self.getBuilder().getContext(), + ProgramIDDim(axis))); + }) + .def("create_get_num_programs", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create( + self.getBuilder().getI32Type(), + ProgramIDDimAttr::get(self.getBuilder().getContext(), + ProgramIDDim(axis))); + }) + .def("create_dot", + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, InputPrecision inputPrecision, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, inputPrecision, + maxNumImpreciseAcc); + }) + .def("create_floor", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_ceil", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_cos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_erf", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_rsqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_fabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_iabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_reduce", + [](TritonOpBuilder &self, std::vector operands, int axis) + -> OpState { return self.create(operands, axis); }) + .def("create_reduce_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_scan", + [](TritonOpBuilder &self, std::vector operands, int axis, + bool reverse) -> OpState { + return self.create(operands, axis, reverse); + }) + .def("create_scan_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_ptr_to_int", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_int_to_ptr", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_select", + [](TritonOpBuilder &self, Value &condition, Value &trueValue, + Value &falseValue) -> Value { + return self.create(condition, trueValue, + falseValue); + }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, const std::vector &values, + const std::vector &types, bool isPure, + int pack) -> OpState { + return self.create( + types, inlineAsm, constraints, isPure, pack, values); + }) + .def("create_print", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values) -> void { + self.create( + StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)), + hex, values); + }) + .def("create_assert", + [](TritonOpBuilder &self, Value &condition, + const std::string &message, const std::string &fileName, + const std::string &funcName, unsigned lineNo) -> void { + auto messageAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(message)); + auto fileNameAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(fileName)); + auto funcNameAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(funcName)); + auto lineNoAttr = self.getBuilder().getI32IntegerAttr(lineNo); + self.create(condition, messageAttr, fileNameAttr, + funcNameAttr, lineNoAttr); + }) + // Undef + .def("create_undef", + [](TritonOpBuilder &self, Type &type) -> Value { + return self.create(type); + }) + .def("create_histogram", + [](TritonOpBuilder &self, Value operand, int numBins) -> Value { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand); + }) + // Force GPU barrier + .def("create_barrier", + [](TritonOpBuilder &self) { self.create(); }) + // Make a block pointer (tensor pointer in Triton IR) + .def("create_make_block_ptr", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &offsets, + std::vector &tensorShape, + std::vector &order) -> Value { + return self.create(base, shape, strides, offsets, + tensorShape, order); + }) + // Advance a block pointer + .def("create_advance", + [](TritonOpBuilder &self, Value &ptr, + std::vector &offsets) -> Value { + return self.create(ptr.getType(), ptr, offsets); + }); + + py::class_(m, "pass_manager", py::module_local()) + .def(py::init()) + .def("enable_debug", + [](PassManager &self) { + auto *context = self.getContext(); + bool haveDiagnostics = + ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); + bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + if (haveDiagnostics || haveDump) { + context->disableMultithreading(); + } + if (haveDiagnostics) { + context->printOpOnDiagnostic(true); + context->printStackTraceOnDiagnostic(true); + context->getDiagEngine().registerHandler([](Diagnostic &diag) { + llvm::outs() << diag << "\n"; + return success(); + }); + } + if (haveDump) { + auto printingFlags = OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + printingFlags.enableDebugInfo(); + auto printAlways = [](Pass *, Operation *) { return true; }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/printAlways, + /*shouldPrintAfterPass=*/printAlways, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure*/ true, llvm::dbgs(), + printingFlags); + } + }) + .def("run", [](PassManager &self, ModuleOp &mod) { + // TODO: maybe dump module to file and print error for better + // diagnostics + + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + makeReproducer(anchorName, passes, op, reproducerPath); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + llvm::SmallVector split; + llvm::SmallVector storage; + llvm::SmallVector debugTypes; + + StringRef(debugOnly.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(debugTypes), + [&storage](StringRef str) { + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + + ::llvm::DebugFlag = true; + ::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }); +} + +void init_triton_env_vars(py::module &m) { + m.def("get_cache_invalidating_env_vars", + []() -> std::map { + std::map ret; + for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { + auto strVal = triton::tools::getStrEnv(envVar); + if (strVal.empty()) + continue; + auto boolV = triton::tools::isEnvValueBool(strVal); + if (boolV.has_value()) + ret[envVar] = boolV.value() ? "true" : "false"; + else + ret[envVar] = strVal; + } + return ret; + }); +} diff --git a/third_party/mthreads/python/src/llvm.cc b/third_party/mthreads/python/src/llvm.cc new file mode 100644 index 000000000..a66ffa1cb --- /dev/null +++ b/third_party/mthreads/python/src/llvm.cc @@ -0,0 +1,412 @@ +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Pass.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include +#include +#include + +namespace py = pybind11; + +namespace llvm { +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; +} // namespace llvm + +using namespace llvm; + +std::string translateLLVMIRToASM(llvm::Module &module, + const std::string &triple, + const std::string &proc, + const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, bool isObject) { + using namespace mlir; + // options + auto options = llvm::cl::getRegisteredOptions(); + for (std::string flag : flags) { + auto *shortPtr = static_cast *>(options[flag]); + assert(shortPtr); + shortPtr->setValue(true); + } + if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optIt = options.find("print-after-all"); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + bool disableLLVMOpt = triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (!disableLLVMOpt) { + // Check to see if we are passing a list of flags to disable optimizations. + auto flagList = triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + } + + // inline everything + for (llvm::Function &f : module.functions()) + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); + pm.add(llvm::createVerifierPass()); + + const bool enabledTiming = triton::tools::getBoolEnv("LLVM_ENABLE_TIMING"); + if (enabledTiming) { + llvm::TimePassesIsEnabled = true; + llvm::TimePassesPerRun = true; + } + + pm.run(module); + + SmallString<0> timePassesStr; + raw_svector_ostream reportStream(timePassesStr); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + // module->print(llvm::outs(), nullptr); + + // create machine + module.setTargetTriple(triple); + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error); + llvm::TargetOptions opt; + if (enable_fp_fusion) + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + std::unique_ptr machine{target->createTargetMachine( + module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + std::nullopt, + disableLLVMOpt ? llvm::CodeGenOptLevel::None + : llvm::CodeGenOptLevel::Aggressive)}; + // set data layout + module.setDataLayout(machine->createDataLayout()); + // emit machine code + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + for (llvm::Function &f : module.functions()) + f.addFnAttr(llvm::Attribute::AlwaysInline); + llvm::legacy::PassManager pass; + // emit + auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile + : llvm::CodeGenFileType::AssemblyFile; + machine->addPassesToEmitFile(pass, pstream, nullptr, fileType); + pass.run(module); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + } + return result; +} + +using ret = py::return_value_policy; + +void init_triton_llvm(py::module &&m) { + + py::class_(m, "context", py::module_local()) + .def(py::init<>()); + + py::class_(m, "function_list") + .def( + "__iter__", + [](llvm::Module::FunctionListType &s) { + return py::make_iterator(s.begin(), s.end()); + }, + py::keep_alive<0, 1>()); + + // Module Flag behavior. See + // https://llvm.org/doxygen/classllvm_1_1Module.html#a0a5c55e12c97b80021330fe82b642293 + // for details. + py::class_(m, "module_flag_behavior", + py::module_local()); + m.attr("MODULE_FLAG_BEHAVIOR_ERROR") = llvm::Module::Error; + m.attr("MODULE_FLAG_BEHAVIOR_WARNING") = llvm::Module::Warning; + m.attr("MODULE_FLAG_BEHAVIOR_REQUIRE") = llvm::Module::Require; + m.attr("MODULE_FLAG_BEHAVIOR_OVERRIDE") = llvm::Module::Override; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND") = llvm::Module::Append; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND_UNIQUE") = llvm::Module::AppendUnique; + m.attr("MODULE_FLAG_BEHAVIOR_MAX") = llvm::Module::Max; + m.attr("MODULE_FLAG_BEHAVIOR_MIN") = llvm::Module::Min; + + py::class_(m, "module", py::module_local()) + .def( + "__str__", + [](llvm::Module *self) { + std::string str; + llvm::raw_string_ostream os(str); + os << *self; + return os.str(); + }, + ret::take_ownership) + .def( + "get_functions", + [](llvm::Module *mod) -> llvm::Module::FunctionListType & { + // Note: Backends assume that we are compiling exactly one kernel + // (i.e. one function that's that's called by the CPU) and that it's + // the first function in this list. + return mod->getFunctionList(); + }, + ret::reference_internal) + .def("add_flag", + [](llvm::Module *mod, llvm::Module::ModFlagBehavior behavior, + std::string &key, uint32_t value) { + return mod->addModuleFlag(behavior, key, value); + }); + + py::class_(m, "function", py::module_local()) + .def_property_readonly( + "name", [](llvm::Function *fn) { return fn->getName().str(); }) + .def("set_calling_conv", &llvm::Function::setCallingConv) + .def("add_fn_attr", [](llvm::Function *fn, std::string &name, + std::string &val) { fn->addFnAttr(name, val); }) + + // Sets the nvvm.maxreg property on the given function. + .def("set_nvvm_maxnreg", + [](llvm::Function *fn, int maxnreg) { + auto op = MDNode::get( + fn->getContext(), + { + ValueAsMetadata::get(fn), + MDString::get(fn->getContext(), "maxnreg"), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(fn->getContext()), maxnreg)), + }); + fn->getParent() + ->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(op); + }) + // External functions that are definitions (i.e. not declarations) are + // kernel functions. + .def("is_declaration", &llvm::Function::isDeclaration) + .def("is_external_linkage", [](llvm::Function *fn) { + return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage; + }); + + // optimization levels + py::class_(m, "optimization_level", + py::module_local()); + m.attr("OPTIMIZE_O0") = llvm::OptimizationLevel::O0; + m.attr("OPTIMIZE_O1") = llvm::OptimizationLevel::O1; + m.attr("OPTIMIZE_O2") = llvm::OptimizationLevel::O2; + m.attr("OPTIMIZE_O3") = llvm::OptimizationLevel::O3; + m.attr("OPTIMIZE_Os") = llvm::OptimizationLevel::Os; + m.attr("OPTIMIZE_Oz") = llvm::OptimizationLevel::Oz; + + m.def( + "to_module", + [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { + return mlir::translateModuleToLLVMIR(mod, ctx); + }, + py::keep_alive<0, 2>()); + + m.def( + "optimize_module", + [](llvm::Module *mod, const llvm::OptimizationLevel &opt, + const std::string triple) { + if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT")) + return; + // Check to see if we are passing a list of flags to disable + // optimizations. + auto flagList = mlir::triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + auto options = llvm::cl::getRegisteredOptions(); + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + using namespace llvm; + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + PassInstrumentationCallbacks *instrCbPtr = nullptr; + PassInstrumentationCallbacks passInstrCb; + StandardInstrumentations standardInstr(mod->getContext(), + /*DebugLogging*/ true); + if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optMap = llvm::cl::getRegisteredOptions(); + auto optIt = optMap.find("print-after-all"); + if (optIt != optMap.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + standardInstr.registerCallbacks(passInstrCb, &mam); + instrCbPtr = &passInstrCb; + } + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = true; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. + // This cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also + // applies some scheduling that helps performance in some cases. We + // should work on using NVPTX target instead and address the performance + // regressions with some scheduling solution. + // FIXME: SLPVectorization generates some large vectors, + // such as <64 * float>, which mtcc cannot currently handle.(SW-47321) + tuningOptions.SLPVectorization = false; + + if (!triple.empty()) + mod->setTargetTriple(triple.c_str()); + + PassBuilder pb(nullptr /*targetMachine*/, tuningOptions, std::nullopt, + instrCbPtr); + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + pb.registerVectorizerStartEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make + // sure all the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + fpm.addPass(InstCombinePass()); + }); + mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); + mpm.run(*mod, mam); + }, + py::arg("mod"), py::arg("opt"), py::arg("triple") = ""); + + m.def( + "translate_to_asm", + [](std::string llvmIR, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, bool isObject) -> py::object { + std::string obj; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + obj = translateLLVMIRToASM(*module, triple, proc, features, flags, + enable_fp_fusion, isObject); + } + if (isObject) + return py::bytes(obj); + else + return py::str(obj); + }, + ret::take_ownership); + + m.def("init_targets", []() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + }); + }); + + m.def("link_extern_libs", [](llvm::Module *dstMod, + const std::vector &paths) { + if (paths.empty()) + return; + + LLVMContext &ctx = dstMod->getContext(); + llvm::Linker linker(*dstMod); + for (const std::string &path : paths) { + llvm::SMDiagnostic err; + std::unique_ptr libMod = llvm::parseIRFile(path, err, ctx); + if (!libMod) { + std::string message = "Failed to parse library at " + path; + throw std::invalid_argument(message); + } + libMod->setTargetTriple(dstMod->getTargetTriple()); + libMod->setDataLayout(dstMod->getDataLayout()); + + std::unordered_set externalFns; + for (llvm::Function &fn : libMod->functions()) { + if (!fn.isDeclaration()) + externalFns.insert(fn.getName().str()); + if (fn.hasFnAttribute(llvm::Attribute::NoInline)) { + fn.removeFnAttr(llvm::Attribute::NoInline); + fn.removeFnAttr(llvm::Attribute::OptimizeNone); + fn.addFnAttr(llvm::Attribute::AlwaysInline); + } + } + + if (linker.linkInModule(std::move(libMod), + llvm::Linker::Flags::LinkOnlyNeeded)) { + std::string message = "Failed to link library at " + path; + throw std::invalid_argument(message); + } + + // Mark linked-in functions as internal because backends use external + // linkage as a signifier of kernel functions. + for (llvm::Function &fn : dstMod->functions()) { + if (externalFns.count(fn.getName().str())) { + fn.setLinkage(llvm::GlobalValue::InternalLinkage); + } + } + } + }); +} diff --git a/third_party/mthreads/python/src/main.cc b/third_party/mthreads/python/src/main.cc new file mode 100644 index 000000000..5ad4be7d5 --- /dev/null +++ b/third_party/mthreads/python/src/main.cc @@ -0,0 +1,50 @@ +#include +namespace py = pybind11; + +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N +#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +#define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); + +void init_triton_env_vars(pybind11::module &m); +void init_triton_ir(pybind11::module &&m); +void init_triton_llvm(pybind11::module &&m); +void init_triton_interpreter(pybind11::module &&m); +void init_triton_passes(pybind11::module &&m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) + +PYBIND11_MODULE(libtriton, m) { + m.doc() = "Python bindings to the C++ Triton API"; + init_triton_env_vars(m); + init_triton_ir(m.def_submodule("ir")); + init_triton_passes(m.def_submodule("passes")); + init_triton_interpreter(m.def_submodule("interpreter")); + init_triton_llvm(m.def_submodule("llvm")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) +} diff --git a/third_party/mthreads/python/src/passes.cc b/third_party/mthreads/python/src/passes.cc new file mode 100644 index 000000000..557d146e7 --- /dev/null +++ b/third_party/mthreads/python/src/passes.cc @@ -0,0 +1,84 @@ +#include "mlir/Transforms/Passes.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" +#include +#include + +namespace py = pybind11; + +void init_triton_analysis(py::module &&m) { + py::class_(m, "allocation", py::module_local()) + .def(py::init()); + py::class_(m, "membar", py::module_local()) + .def(py::init()) + .def("run", &mlir::ModuleMembarAnalysis::run); +} + +void init_triton_passes_common(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass); + ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass); + ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass); + ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); + ADD_PASS_WRAPPER_0("add_cse", createCSEPass); + ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); +} + +void init_triton_passes_ttir(py::module &&m) { + using namespace mlir::triton; + ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass); + ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass); + ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", + createRewriteTensorPointerPass); + ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", + createConvertTritonToTritonGPUPass, const std::string &, + int, int, int); +} + +void init_triton_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton::gpu; + ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce); + ADD_PASS_WRAPPER_0("add_optimize_thread_locality", + createTritonGPUOptimizeThreadLocality); + ADD_PASS_WRAPPER_0("add_reorder_instructions", + createTritonGPUReorderInstructions); + ADD_PASS_WRAPPER_0("add_remove_layout_conversions", + createTritonGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_0("add_reduce_data_duplication", + createTritonGPUReduceDataDuplication); + ADD_PASS_WRAPPER_0("add_allocate_shared_memory", + createAllocateSharedMemoryPass); + ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", + createTritonGPUCombineTensorSelectAndIf); +} + +void init_triton_passes_convert(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_scf_to_cf", createConvertSCFToCFPass); + ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); + ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); + ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); +} + +void init_triton_passes_llvmir(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_di_scope", createLLVMDIScopePass); +} + +void init_triton_passes(py::module &&m) { + init_triton_analysis(m.def_submodule("analysis")); + init_triton_passes_common(m.def_submodule("common")); + init_triton_passes_convert(m.def_submodule("convert")); + init_triton_passes_ttir(m.def_submodule("ttir")); + init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); + init_triton_passes_llvmir(m.def_submodule("llvmir")); +} diff --git a/third_party/mthreads/python/src/passes.h b/third_party/mthreads/python/src/passes.h new file mode 100644 index 000000000..46801d802 --- /dev/null +++ b/third_party/mthreads/python/src/passes.h @@ -0,0 +1,40 @@ +#define ADD_PASS_WRAPPER_0(name, builder) \ + m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); }) + +#define ADD_PASS_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); }) + +#define ADD_PASS_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder(val0, val1)); \ + }) + +#define ADD_PASS_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder(val0, val1, val2)); \ + }) + +#define ADD_PASS_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); }) + +#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) + +#define ADD_PASS_OPTION_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder({val0, val1})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder({val0, val1, val2})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3) { \ + pm.addPass(builder({val0, val1, val2, val3})); \ + }) diff --git a/third_party/mthreads/python/test/unit/conftest.py b/third_party/mthreads/python/test/unit/conftest.py new file mode 100644 index 000000000..ee4276761 --- /dev/null +++ b/third_party/mthreads/python/test/unit/conftest.py @@ -0,0 +1,12 @@ +# content of conftest.py + +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default='musa') + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/third_party/mthreads/python/test/unit/language/assert_helper.py b/third_party/mthreads/python/test/unit/language/assert_helper.py new file mode 100644 index 000000000..832a26d85 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/assert_helper.py @@ -0,0 +1,154 @@ +import sys + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +def get_current_target_warp_size(): + return triton.runtime.driver.active.get_current_target().warp_size + + +@triton.jit +def kernel_device_assert(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_assert(x == 0, "x != 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_assert_passes(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Trivial assert, should not be an error. + tl.device_assert(0 == 0, "x != 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit(debug=False) +def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_assert(x == 0, "x != 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_assert(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + assert x == 0, "x != 0" + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_static_assert(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_assert(BLOCK == 128, "BLOCK != 128") + tl.store(Y + tl.arange(0, BLOCK), x) + + +def test_assert(func: str, device: str): + N = 128 # This value should match with test_print in test_subprocess.py. + num_warps = N // get_current_target_warp_size() + + x = torch.arange(0, N, dtype=torch.int32, device=device) + y = torch.zeros((N, ), dtype=x.dtype, device=device) + if func == "device_assert": + kernel_device_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) + if func == "device_assert_passes": + # Assert passes; no error. + kernel_assert_passes[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "no_debug": + # TRITON_DEBUG=1 can override the debug flag + kernel_device_assert_no_debug[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "assert": + kernel_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "static_assert": + kernel_static_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "double_assert": + # Launching a different kernel after the first one asserted used to + # segfault. What seems to have happened is: + # - The first kernel is enqueued but doesn't run yet. + # - We go to launch the second kernel. Because this is the first time + # we're running it, we have to load the kernel into the GPU. + # - Loading the kernel takes some time, during which the first launch + # completes. + # - Now the GPU is in an error state. We need to detect this inside + # the kernel-launch/loading code and bail out properly. If we don't, + # we segfault. + kernel_device_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) + kernel_assert_passes[(1, )](x, y, num_warps=num_warps, BLOCK=N) + assert_close(y, x) + + +@triton.jit +def jit_device_assert_none(x): + tl.device_assert(x == 0, "x != 0") + + +@triton.jit(debug=True) +def jit_device_assert_true(x): + tl.device_assert(x == 0, "x != 0") + + +@triton.jit(debug=False) +def jit_device_assert_false(x): + tl.device_assert(x == 0, "x != 0") + + +@triton.jit +def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + if jit_debug == "true": + jit_device_assert_true(x) + elif jit_debug == "false": + jit_device_assert_false(x) + else: + jit_device_assert_none(x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit(debug=True) +def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + if jit_debug == "true": + jit_device_assert_true(x) + elif jit_debug == "false": + jit_device_assert_false(x) + else: + jit_device_assert_none(x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit(debug=False) +def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + if jit_debug == "true": + jit_device_assert_true(x) + elif jit_debug == "false": + jit_device_assert_false(x) + else: + jit_device_assert_none(x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +def test_assert_nested(caller: str, callee: str, device: str): + N = 128 # This value should match with test_print in test_subprocess.py. + num_warps = N // get_current_target_warp_size() + + x = torch.arange(0, N, dtype=torch.int32, device=device) + y = torch.zeros((N, ), dtype=x.dtype, device=device) + if caller == "none": + kernel_device_assert_nested[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) + elif caller == "true": + kernel_device_assert_nested_true[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) + elif caller == "false": + kernel_device_assert_nested_false[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) + assert_close(y, x) + + +if __name__ == "__main__": + if len(sys.argv) == 4: + test_assert_nested(sys.argv[1], sys.argv[2], sys.argv[3]) + else: + test_assert(sys.argv[1], sys.argv[2]) diff --git a/third_party/mthreads/python/test/unit/language/conftest.py b/third_party/mthreads/python/test/unit/language/conftest.py new file mode 100644 index 000000000..091f9ea41 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/conftest.py @@ -0,0 +1,5 @@ +# content of conftest.py + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") diff --git a/third_party/mthreads/python/test/unit/language/print_helper.py b/third_party/mthreads/python/test/unit/language/print_helper.py new file mode 100644 index 000000000..c12b822a5 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/print_helper.py @@ -0,0 +1,125 @@ +import sys +import uuid + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +def get_current_target_warp_size(): + return triton.runtime.driver.active.get_current_target().warp_size + + +@triton.jit +def kernel_device_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_hex(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x, hex=True) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Triton should add a space after this prefix. + print("x:", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_large( + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32) + # Triton should change this prefix to "x: ". + tl.device_print("x ", x) + + +@triton.jit +def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + print("", x, y) + + +@triton.jit +def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + tl.device_print("", x, y) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr): + # This function takes an extra value as a tl.constexpr so this kernel is not + # cached. This way the static print is run every time. + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_print("", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_no_arg_print(): + print("", tl.program_id(0)) + + +@triton.jit +def kernel_print_no_arg(): + print("no arg") + + +@triton.jit +def kernel_print_pointer(X, Y, BLOCK: tl.constexpr): + tl.device_print("ptr ", X + tl.arange(0, BLOCK)) + + +def test_print(func: str, data_type: str, device: str): + N = 128 # This value should match with test_print in test_subprocess.py. + # TODO(antiagainst): Currently the warp count is chosen to make sure wedon't have multiple + # threads printing duplicated messages due to broadcasting. Improve print op lowering logic + # to filter out duplicated data range. + num_warps = N // get_current_target_warp_size() + + x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type)) + y = torch.zeros((N, ), dtype=x.dtype, device=device) + if func == "device_print": + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "print": + kernel_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_large": + kernel_device_print_large[(1, 2)](BLOCK_M=64, num_warps=num_warps, BLOCK_N=N) + elif func == "print_multiple_args": + kernel_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_multiple_args": + kernel_device_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "static_print": + kernel_static_print[(1, )](x, y, num_warps=num_warps, BLOCK=N, PLACEHOLDER=uuid.uuid4()) + elif func == "no_arg_print": + kernel_no_arg_print[(1, )](num_warps=num_warps) + elif func == "print_no_arg": + kernel_print_no_arg[(1, )](num_warps=num_warps) + elif func == "device_print_hex": + kernel_device_print_hex[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_pointer": + kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N) + else: + assert f"Unknown kernel: {func}" + + if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ + func != "print_multiple_args" and func != "device_print_multiple_args" and \ + func != "device_print_pointer": + assert_close(y, x) + + +if __name__ == "__main__": + test_print(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/third_party/mthreads/python/test/unit/language/test_annotations.py b/third_party/mthreads/python/test/unit/language/test_annotations.py new file mode 100644 index 000000000..633e5d1d5 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_annotations.py @@ -0,0 +1,50 @@ +from __future__ import annotations +import torch +import triton +import triton.language as tl +import pytest + + +def annotated_function(return_type=None, **arg_types): + """A decorator to add annotations to a function.""" + + def decorator(func): + func.__annotations__ = {**arg_types, 'return': return_type} + return func + + return decorator + + +# Test integer annotations +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize(("signed", "width"), [ + (signed, width) for signed in [False, True]\ + for width in [8, 16, 32, 64] +] + [(False, 1)] + ) +def test_int_annotation(signed, width, device): + + @triton.jit + @annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}") + def _kernel(X, v): + tl.store(X, v) + + h = _kernel[(1, )](torch.empty(1, device=device), 3) + pfx = 'si' if signed else 'ui' + assert f'%arg1: i{width}' in h.asm["ttir"] + assert f'arith.{pfx}tofp' in h.asm["ttir"] + + +# Test that unknown annotations do not emit an error +def test_unknown_annotation(device): + + @triton.jit + def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): + pass + + x = torch.empty(1, device=device) + _kernel[(1, )](x, x.shape[0], 32) + try: + _kernel[(1, )](x.shape[0], x.shape[0], 32) + except AttributeError: + pass diff --git a/third_party/mthreads/python/test/unit/language/test_block_pointer.py b/third_party/mthreads/python/test/unit/language/test_block_pointer.py new file mode 100644 index 000000000..1a9a2a18b --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_block_pointer.py @@ -0,0 +1,101 @@ +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr): + pid = tl.program_id(0) + # We only copy half of the data to see if the padding works + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) + tl.store(b_block_ptr, a, boundary_check=(0, )) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtypes_str, n, padding_option", [ # + (dtypes_str, n, padding) + for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("float16", "float16"), ("int16", "float16")) + for n in (64, 128, 256, 512, 1024) + for padding in ("zero", "nan") # +]) +def test_block_copy(dtypes_str, n, padding_option, device): + src_dtype_str = dtypes_str[0] + dst_dtype_str = dtypes_str[0] + src_dtype = getattr(torch, src_dtype_str) + dst_dtype = getattr(torch, dst_dtype_str) + if src_dtype_str in ("bool", "int16"): + if padding_option == "nan": + pytest.skip("Padding with NaN is not supported for integer types") + a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype) + else: + a = torch.randn((n, ), device=device, dtype=src_dtype) + b = torch.zeros((n, ), device=device, dtype=dst_dtype) + + grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) + a.to(dst_dtype) + assert torch.all(a[0:n // 2] == b[0:n // 2]) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + else: + assert torch.all(torch.isnan(b[n // 2:n])) + + +@triton.jit +def matmul_no_scf_with_advance_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr # +): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) + # Below two lines are just for testing negative offsets for the `advance` API, which could be removed + a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K)) + a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K)) + a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero") + b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero") + + c = tl.dot(a, b) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, num_warps", [ # + (shape, num_warps) for shape in [ + [64, 64, 16], + [64, 64, 32], + [64, 64, 64], + ] for num_warps in [4, 8] +]) +def test_block_ptr_matmul_no_scf(shape, num_warps, device): + m, n, k = shape + a = torch.randn((m, k), device=device, dtype=torch.float16) + b = torch.randn((k, n), device=device, dtype=torch.float16) + c = torch.empty((m, n), device=device, dtype=torch.float32) + + grid = lambda META: (1, ) + matmul_no_scf_with_advance_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=m, N=n, K=k, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # + num_warps=num_warps) + golden = torch.matmul(a, b) + torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/third_party/mthreads/python/test/unit/language/test_compile_errors.py b/third_party/mthreads/python/test/unit/language/test_compile_errors.py new file mode 100644 index 000000000..503c57a40 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_compile_errors.py @@ -0,0 +1,306 @@ +import pytest + +import triton +import triton.language as tl +from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure +import traceback + +pytestmark = pytest.mark.skip(reason="Skipping entire test file") + + +def test_err_undefined_variable(): + + @triton.jit + def kernel(): + a += 1 # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "is not defined" in str(e.value), "error should mention the undefined variable" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_operator(): + + @triton.jit + def kernel(): + 0 + "a" + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the 0" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_static_assert(): + + @triton.jit + def kernel(): + tl.static_assert(isinstance(0, tl.tensor)) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert isinstance(e.value, CompileTimeAssertionFailure) + assert e.value.__cause__ is None + assert "at 2:4:" in str(e.value), "error should point to the static_assert call" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_unary_op(): + # Currently Triton can't evaluate `not` of a tuple at compile time. That's + # ok, but the error message needs to point to the correct spot. + @triton.jit + def kernel(): + not (0, 0) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert e.value.__cause__ is None + assert "at 2:4:" in str(e.value), "error should point to the `not`" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_op(): + + @triton.jit + def kernel(): + 1.0 << 1 + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the 1.0" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +# This has to be defined as a top-level function; jit'ed functions can't call +# nested functions. +@triton.jit +def nested_call(): + xyz # noqa + + +def test_err_in_nested_call(): + + @triton.jit + def kernel(): + # this is a comment to push nested_call() onto the next line + nested_call() + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + inner = e.value.__cause__ + outer = e.value + assert "at 2:4:" in str(inner), "error should point to xyz" + assert "" not in str(inner) + + assert "at 3:4" in str(outer), "error should point to the nested_call" + assert "" not in str(outer) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_builtin(): + + # The root error here comes from core.py. Make sure the stacktrace reflects + # this. + @triton.jit + def kernel(): + tl.expand_dims(None, -1) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + inner = e.value.__cause__ + outer = e.value + assert "/core.py" in '\n'.join(traceback.format_tb(inner.__traceback__)), "error should point inside core.py" + + assert "at 2:4:" in str(outer), "error should point to expand_dims call" + assert "" not in str(outer) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@triton.jit +def two_returns(): + return tl.arange(0, 4) + return tl.arange(0, 8) + + +def test_two_returns_no_err(): + # This program is valid; `a` has shape (10,). + @triton.jit + def kernel(): + a = two_returns() + a + tl.arange(0, 4) # only works if we took the first return + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +@triton.jit +def returns_branched_on_constexpr(N: tl.constexpr): + if N == 0: + return tl.arange(0, 4) + # Ideally this would work even without the `else`, but we're not that smart + # yet. + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_constexpr(): + + @triton.jit + def kernel1(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 4) + + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0})) + + @triton.jit + def kernel2(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 8) + + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1})) + + +@triton.jit +def returns_branched_on_non_constexpr(N: int): + if N == 0: + return tl.arange(0, 4) + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_non_constexpr(): + + @triton.jit + def kernel(N: int): + returns_branched_on_non_constexpr(N) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the function call" + assert "at 5:8:" in str(e.value.__cause__), "error should point to the second `return`" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_power_of_two_shapes(): + + @triton.jit + def kernel(): + tl.arange(2, 7) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert str(e.value.__cause__) == "arange's range must be a power of 2" + + +def test_power_of_two_shapes_2(): + + @triton.jit + def kernel(): + tl.full((33, ), 0, dtype=tl.int64) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" + + +def test_captured_var_access(): + + CAPTURED = 42 + + @triton.jit + def kernel(): + a = CAPTURED # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert "CAPTURED is not defined" in str(e.value) + + +GLOBAL = 42 + + +def test_global_var_access(): + + @triton.jit + def kernel(): + a = GLOBAL # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert "global variable" in str(e.value) + + +CONSTEXPR_ANNOTATED_GLOBAL: tl.constexpr = 42 + + +def test_constexpr_annotated_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_ANNOTATED_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_constexpr_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +TYPE_ALIAS = tl.pointer_type(tl.int32) + + +def test_global_type_alias_access(): + + @triton.jit + def kernel(): + a = TYPE_ALIAS # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +def test_global_access_in_fn_default_arg(): + + @triton.jit + def kernel(a=GLOBAL): + pass + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={0: "i32"}, constants={})) diff --git a/third_party/mthreads/python/test/unit/language/test_conversions.py b/third_party/mthreads/python/test/unit/language/test_conversions.py new file mode 100644 index 000000000..183d3390c --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_conversions.py @@ -0,0 +1,356 @@ +# fmt: off + + +import os +import numpy as np +import torch +import pytest +import triton +import triton.language as tl + +pytestmark = pytest.mark.skip(reason="Skipping entire test file") + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + +def is_cuda(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" + +def is_hip(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" + +def is_on_mi300(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') + +def matching_int(dtype): + if dtype.primitive_bitwidth == 8: + return torch.int8 + elif dtype.primitive_bitwidth == 16: + return torch.int16 + elif dtype.primitive_bitwidth == 32: + return torch.int32 + elif dtype.primitive_bitwidth == 64: + return torch.int64 + else: + raise ValueError('unsupported number of bits') + +@triton.jit +def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) + tl.store(dst + idxs, y) + + +def launch_type_convert_triton(src, src_dtype, dst_dtype, device, rounding=None, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE) + return dst + + +@triton.jit +def exhaustive_populate(dst, offset, BLOCK_SIZE : tl.constexpr, force_odd : tl.constexpr, output_bits : tl.constexpr, max_repr : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + vals = (idxs + offset).to(tl.uint32) + + # pseudorandom permutation: + multiplier = vals << 1 + multiplier += 3511 + vals *= multiplier + + if force_odd: + vals *= 2 + vals += 1 + + if (output_bits == 8): + vals &= 0xff + avals = vals & 0x7f + elif (output_bits == 16): + vals &= 0xffff + avals = vals & 0x7fff + elif (output_bits == 32): + avals = vals & 0x7fffffff + + vals = tl.where(avals <= max_repr, vals, 0) + + if (output_bits == 8): + vals = vals.to(tl.uint8) + elif (output_bits == 16): + vals = vals.to(tl.uint16) + + vals = vals.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, vals) + + +def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits, max_repr, device, BLOCK_SIZE=4096): + + assert(numel % BLOCK_SIZE == 0) + dst = torch.empty((numel,), dtype=matching_int(dst_dtype), device=device) + exhaustive_populate[(numel // BLOCK_SIZE,)](triton.reinterpret(dst, dst_dtype), offset, BLOCK_SIZE, force_odd, output_bits, max_repr) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. We don't need to have that + # as input to the conversion kernels. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(x.dtype == tl.float32, "input must be float32") + numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_dst == 8) or (numbits_dst == 16), "numbits_dst must be 8 or 16") + + x = x.to(tl.uint32, bitcast=True) + + mantissa = (x & 0x7fffff) + exponent = ((x >> 23) & 0xff).to(tl.int32) + mantissa = tl.where(exponent == 0, mantissa, mantissa + 0x800000).to(tl.int32) + exponent = tl.where(exponent == 0, exponent, exponent - 1) + + sign = (x >> 31) + + exponent = exponent + exponent_bias - 127 + adjustment : tl.constexpr = 0.5 ** (23 - mantissa_bits) + mantissa = mantissa.to(tl.float32) * adjustment + + # make exponent nonnegative: + mantissa = tl.where(exponent > -16, mantissa, 0.0) # destination has fewer than 16 mantissa bits, so safe + exponent = tl.where(exponent > -16, exponent, 0) + mantissa = tl.where(exponent > -8, mantissa, mantissa * 0.00390625) + exponent = tl.where(exponent > -8, exponent, exponent + 8) + mantissa = tl.where(exponent > -4, mantissa, mantissa * 0.0625) + exponent = tl.where(exponent > -4, exponent, exponent + 4) + mantissa = tl.where(exponent > -2, mantissa, mantissa * 0.25) + exponent = tl.where(exponent > -2, exponent, exponent + 2) + mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5) + exponent = tl.where(exponent > -1, exponent, exponent + 1) + + if rounding == 'rtne': + # Bring the value to the range [2 ** 23, 2 ** 24] + # where the representable floats map exactly to integers. + # Addition has RTNE semantics. + mantissa += 0x800000 + # Bring the value back to the original range. + mantissa -= 0x800000 + mantissa = mantissa.to(tl.int32) + elif rounding == 'rtz': + mantissa = mantissa.to(tl.int32) + else: + raise ValueError('unrecognized rounding mode') + + # Reassemble output floating-point representation: + exponent = exponent.to(tl.uint32) + y = (sign << (exponent_bits + mantissa_bits)) + (exponent << mantissa_bits) + mantissa + if numbits_dst == 8: + y = y.to(tl.uint8) + elif numbits_dst == 16: + y = y.to(tl.uint16) + return y + + +@triton.jit +def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias) + y = y.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, y) + + +def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + downcast_emulated[(src.shape[0] // BLOCK_SIZE,)]( + triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. downcast_emulated kernel will + # convert -0. in higher precision to 0x80 and thus need to fix the result to 0. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def upcast_emulated(src, dst, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + exponent_compensator : tl.constexpr = 2.0 ** (127 - exponent_bias) + + numbits_src : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_src == 8) or (numbits_src == 16), "numbits_src must be 8 or 16") + tl.static_assert(dst.dtype.element_ty == tl.float32, "dst dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + + if numbits_src == 8: + x = x.to(tl.uint8, bitcast=True) + elif numbits_src == 16: + x = x.to(tl.uint16, bitcast=True) + + x = x.to(tl.uint32) + + mantissa_mask : tl.constexpr = (1 << mantissa_bits) - 1 + exponent_mask : tl.constexpr = (1 << exponent_bits) - 1 + + mantissa = x & mantissa_mask + exponent = (x >> mantissa_bits) & exponent_mask + sign = (x >> (numbits_src - 1)) + + y = (sign << 31) | (exponent << 23) | (mantissa << (23 - mantissa_bits)) + y = y.to(tl.float32, bitcast=True) + y = y * exponent_compensator + + tl.store(dst + idxs, y) + + +def launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=torch.int32, device=device) + upcast_emulated[(src.shape[0] // BLOCK_SIZE,)](src, triton.reinterpret(dst, tl.float32), BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + return dst + + +def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, max_repr, offset, device): + + src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr, device) + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device, rounding=rounding) + src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device) + + dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device) + + dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias, device=device) + dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias, device=device) + + if not (torch.equal(dst, dst2)): + print('Error!!!') + + dst = dst.cpu().detach().numpy() + dst2 = dst2.cpu().detach().numpy() + src = src.cpu().detach().numpy() + + print(src[dst != dst2][0]) + print(dst[dst != dst2][0]) + print(dst2[dst != dst2][0]) + print(hex(src.view(np.uint32)[dst != dst2][0])) + print(hex(dst.view(np.uint32)[dst != dst2][0])) + print(hex(dst2.view(np.uint32)[dst != dst2][0])) + print('') + raise ValueError('%d elements mismatch' % (dst != dst2).sum()) + + +def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bias, max_repr, device): + + numbits_src = exponent_bits + mantissa_bits + 1 + + src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr, device=device) + + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device) + dst = launch_type_convert_triton(dst, dst_dtype, tl.float32, device=device) + + dst2 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device=device) + + assert(torch.equal(dst, dst2)) + + +@pytest.mark.parametrize("src_dtype, dst_dtype", [ + ('float16', 'float32'), + ('bfloat16', 'float32'), + + ('float8e5', 'float16'), + ('float8e5', 'bfloat16'), + ('float8e5', 'float32'), + + ('float8e4b15', 'float16'), + # ('float8e4b15', 'bfloat16'), # Unsupported conversion from f8E4M3B11FNUZ to bf16 + ('float8e4b15', 'float32'), + + ('float8e4nv', 'float16'), + ('float8e4nv', 'bfloat16'), + ('float8e4nv', 'float32'), + + ('float8e4b8', 'float32'), + ('float8e4b8', 'float16'), + + ('float8e5b16', 'float32'), + ('float8e5b16', 'float16'), +]) +def test_typeconvert_upcast(src_dtype, dst_dtype, device): + + if src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("float8e4nv upcast tests only supported on NVGPU with compute capability 9.0+") + + if src_dtype in ('float8e4nv', 'float8e4b15') and is_hip(): + pytest.skip(f"{src_dtype} upcast tests not supported on ROCm") + + if src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()): + pytest.skip("{src_dtype} upcast tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) + stuff = { + 'float8e4b15': (4, 3, 15, 0x7e), + 'float8e4nv': (4, 3, 7, 0x7e), + 'float8e5': (5, 2, 15, 0x7b), + 'float8e4b8': (4, 3, 8, 0x7f), + 'float8e5b16': (5, 2, 16, 0x7f), + 'float16': (5, 10, 15, 0x7bff), + 'bfloat16': (8, 7, 127, 0x7f7f), + }[src_dtype] + + upcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), *stuff, device=device) + +@pytest.mark.parametrize("src_dtype, dst_dtype, rounding, max_repr", [ + ('float32', 'float16', 'rtne', 0x477fe000), + ('float32', 'float16', 'rtz', 0x477fe000), + ('float32', 'bfloat16', 'rtne', 0x7f7f0000), + ('float32', 'bfloat16', 'rtz', 0x7f7f0000), + ('float32', 'float8e5', 'rtne', 0x47600000), + ('float32', 'float8e5', 'rtz', 0x47600000), + ('float32', 'float8e4nv', 'rtne', 0x43e00000), + ('float32', 'float8e4b8', 'rtne', 0x43700000), + ('float32', 'float8e5b16', 'rtne', 0x47600000), + # ('float32', 'float8e4b15', 'rtne', 0x3fe00000), # Skip, no HW rtne conversion from f32 to f8e4b15 + + ('bfloat16', 'float8e5', 'rtne', 0x4760), + ('bfloat16', 'float8e4nv', 'rtne', 0x43e0), + + ('float16', 'float8e5', 'rtne', 0x7b00), + ('float16', 'float8e4nv', 'rtne', 0x5f00), + + ('bfloat16', 'float8e5b16', 'rtne', 0x4760), + ('bfloat16', 'float8e4b8', 'rtne', 0x4370), + + ('float16', 'float8e5b16', 'rtne', 0x7b00), + ('float16', 'float8e4b8', 'rtne', 0x5b80), +]) +def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): + + if src_dtype != 'float32' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias) + stuff = { + 'float16': (5, 10, 15), + 'bfloat16': (8, 7, 127), + 'float8e5': (5, 2, 15), + 'float8e4b15': (4, 3, 15), + 'float8e4nv': (4, 3, 7), + 'float8e4b8': (4, 3, 8), + 'float8e5b16': (5, 2, 16), + }[dst_dtype] + + for i in range(256): + downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i, device=device) diff --git a/third_party/mthreads/python/test/unit/language/test_core.py b/third_party/mthreads/python/test/unit/language/test_core.py new file mode 100644 index 000000000..ba47f3982 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_core.py @@ -0,0 +1,5483 @@ +# flake8: noqa: F821,F841 +import itertools +import re +from typing import Optional, Union +import math +import textwrap +import tempfile + +import numpy as np +import pytest +import torch +import os +import inspect +from numpy.random import RandomState + +import triton +import triton.language as tl +from triton.runtime.jit import TensorWrapper, reinterpret + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cuda(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_musa(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "musa" + + +int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] +float_dtypes = ['float16', 'float32', 'float64'] +dtypes = int_dtypes + uint_dtypes + float_dtypes +dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] +torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] + +# TODO: enable multiple cta cluster testing. +# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] +num_ctas_list = [1] + +GPU_DIALECT = "triton_gpu" +if is_interpreter(): + THREADS_PER_WARP = 1 +elif is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +def _bitwidth(dtype: str) -> int: + # ex.: "int64" -> 64 + return int(re.search(r'(\d+)$', dtype).group(1)) + + +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): + """ + Override `rs` if you're calling this function twice and don't want the same + result for both calls. + """ + if isinstance(shape, int): + shape = (shape, ) + if rs is None: + rs = RandomState(seed=17) + if dtype_str in int_dtypes + uint_dtypes: + iinfo = np.iinfo(getattr(np, dtype_str)) + low = iinfo.min if low is None else max(low, iinfo.min) + high = iinfo.max if high is None else min(high, iinfo.max) + dtype = getattr(np, dtype_str) + x = rs.randint(low, high, shape, dtype=dtype) + x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out. + return x + elif dtype_str and 'float8' in dtype_str: + x = rs.randint(20, 40, shape, dtype=np.int8) + return x + elif dtype_str in float_dtypes: + return rs.normal(0, 1, shape).astype(dtype_str) + elif dtype_str == 'bfloat16': + return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') + elif dtype_str in ['bool', 'int1', 'bool_']: + return rs.normal(0, 1, shape) > 0.0 + else: + raise RuntimeError(f'Unknown dtype {dtype_str}') + + +def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]: + ''' + Note: We need dst_type because the type of x can be different from dst_type. + For example: x is of type `float32`, dst_type is `bfloat16`. + If dst_type is None, we infer dst_type from x. + ''' + t = x.dtype.name + if t in uint_dtypes: + signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" + x_signed = x.astype(getattr(np, signed_type_name)) + return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) + else: + if dst_type and 'float8' in dst_type: + return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type)) + if t == 'float32' and dst_type == 'bfloat16': + return torch.tensor(x, device=device).bfloat16() + return torch.tensor(x, device=device) + + +def torch_dtype_name(dtype) -> str: + if isinstance(dtype, triton.language.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + # 'torch.int64' -> 'int64' + m = re.match(r'^torch\.(\w+)$', str(dtype)) + return m.group(1) + else: + raise TypeError(f'not a triton or torch dtype: {type(dtype)}') + + +def to_numpy(x): + if isinstance(x, TensorWrapper): + return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) + elif isinstance(x, torch.Tensor): + if x.dtype is torch.bfloat16: + return x.cpu().float().numpy() + return x.cpu().numpy() + else: + raise ValueError(f"Not a triton-compatible tensor: {x}") + + +def patch_kernel(template, to_replace): + if is_interpreter(): + local_namespace = {} + src = textwrap.dedent(inspect.getsource(template.fn)) + for k, v in to_replace.items(): + src = src.replace(k, v) + exec(src, globals(), local_namespace) + return local_namespace[template.fn.__name__] + else: + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel.src = kernel.src.replace(key, value) + return kernel + + +def check_cuda_or_hip(device): + # CUDA and HIP both use pytorch device 'cuda'. Other backends like Intel + # GPU do not. + if device not in ['cuda']: + pytest.skip("Only for cuda") + + +def check_type_supported(dtype, device): + ''' + skip test if dtype is not supported on the current device + ''' + if device in ['cuda']: + cc = torch.cuda.get_device_capability() + if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}: + pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") + if is_interpreter(): + if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]: + pytest.skip("bfloat16 is not supported in the interpreter") + + +class MfmaLayout: + + def __init__(self, version, warps_per_cta, instr_shape, is_transposed): + self.version = version + self.warps_per_cta = warps_per_cta + self.instr_shape = instr_shape + self.is_transposed = is_transposed + + def __str__(self): + return f"#{GPU_DIALECT}.amd_mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, instrShape={self.instr_shape}, isTransposed = {str(self.is_transposed).lower()}}}>" + + +class WmmaLayout: + + def __init__(self, warps_per_cta): + self.warps_per_cta = warps_per_cta + + def __str__(self): + return f"#{GPU_DIALECT}.amd_wmma<{{warpsPerCTA = {self.warps_per_cta}}}>" + + +class MmaLayout: + + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): + self.version = version + self.warps_per_cta = warps_per_cta + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + self.instr_shape = instr_shape + + def __str__(self): + return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class SharedLayout: + + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): + self.vec = vec + self.per_phase = per_phase + self.max_phase = max_phase + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +def is_layout_applicable(layout) -> bool: + common_layouts = [BlockedLayout, SharedLayout] + if layout in common_layouts: + return True + elif is_cuda(): + return isinstance(layout, MmaLayout) + elif is_hip(): + target_arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in target_arch: + # RDNA 3 + return isinstance(layout, WmmaLayout) + elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch): + # CDNA 1, 2, 3 + return isinstance(layout, MfmaLayout) + else: + return False + else: + return True + + +def filter_layouts(layouts): + return [l for l in layouts if is_layout_applicable(l)] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) +def test_empty_kernel(dtype_x, device): + SIZE = 128 + + @triton.jit + def kernel(X, SIZE: tl.constexpr): + pass + + check_type_supported(dtype_x, device) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) + kernel[(1, )](x, SIZE=SIZE, num_warps=4) + + +# generic test functions +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + # inputs + x = numpy_random(SIZE, dtype_str=dtype_x) + if 'log' in expr: + x = np.abs(x) + 0.01 + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_x) + kernel[(1, )](Z=z_tri, X=x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + # compare + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + + +def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: + """ + Given two dtype strings, returns the numpy dtype Triton thinks binary + operations on the two types should return. Returns None if the return value + matches numpy. This is generally needed because Triton and pytorch return + narrower floating point types than numpy in mixed operations, and because + Triton follows C/C++ semantics around mixed signed/unsigned operations, and + numpy/pytorch do not. + """ + overrides = { + ('float16', 'int16'): np.float16, + ('float16', 'int32'): np.float16, + ('float16', 'int64'): np.float16, + ('float16', 'uint16'): np.float16, + ('float16', 'uint32'): np.float16, + ('float16', 'uint64'): np.float16, + ('int8', 'uint8'): np.uint8, + ('int8', 'uint16'): np.uint16, + ('int8', 'uint32'): np.uint32, + ('int8', 'uint64'): np.uint64, + ('int16', 'uint16'): np.uint16, + ('int16', 'uint32'): np.uint32, + ('int16', 'uint64'): np.uint64, + ('int32', 'uint32'): np.uint32, + ('int32', 'uint64'): np.uint64, + ('int64', 'uint64'): np.uint64, + } + key = (a, b) if a < b else (b, a) + return overrides.get(key) + + +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, + y_low=None, y_high=None, test_broadcast=True): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + check_type_supported(dtype_y, device) + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_lhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + replacements = {'GENERATE_TEST_HERE': expr} + kernel = patch_kernel(kernel, replacements) + kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements) + kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements) + + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) + if mode_x == 'nan': + x[:] = float('nan') + if mode_y == 'nan': + y[:] = float('nan') + + def do_test(x, y, kernel_fn): + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) + if dtype_z is not None: + z_ref = z_ref.astype(dtype_z) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) + kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + err_msg = f"{expr}, {kernel_fn.__name__}" + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=1e-3, rtol=0.01) + + do_test(x, y, kernel) + if test_broadcast: + do_test(x[:1].reshape(()), y, kernel_broadcast_lhs) + do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) + + +def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: + # The result of x % y is ill-conditioned if x % y is much smaller than x. + # pytorch/CUDA has slightly different (probably better) rounding on + # remainders than stock LLVM. We currently don't expect to match it + # bit-for-bit. + return (dtype_x, dtype_y) in [ + ('int32', 'bfloat16'), + ('int32', 'float16'), + ('int32', 'float32'), + ('int64', 'bfloat16'), + ('int64', 'float16'), + ('int64', 'float32'), + ('int64', 'float64'), + ('uint16', 'bfloat16'), + ('uint16', 'float16'), + ('uint16', 'float32'), + ('uint32', 'bfloat16'), + ('uint32', 'float16'), + ('uint32', 'float32'), + ('uint64', 'bfloat16'), + ('uint64', 'float16'), + ('uint64', 'float32'), + ('uint64', 'float64'), + ] + + +def test_dtype_codegen(): + for dtype in dtypes_with_bfloat16: + full_name = f"triton.language.{dtype}" + assert repr(eval(full_name)) == full_name + + +# --------------- +# test binary ops +# --------------- + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes_with_bfloat16 + for dtype_y in dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f' x {op} y' + if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = 'np.fmod(x, y)' + elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', + 'bfloat16'): + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' + elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): + with pytest.raises(AssertionError, match="Not equal to tolerance"): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + else: + _test_binary( + dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, + # fails with values where fmod(x, y) is roughly zero, but happens to + # pass with the random values chosen for non-broadcast tests + test_broadcast=(op != "%")) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) +def test_addptr(dtype, order, device): + check_type_supported(dtype, device) + + @triton.jit + def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): + offs = tl.arange(0, SIZE) + if ORDER == 0: + tl.store(y + offs, tl.load(x + offs)) + else: + tl.store(offs + y, tl.load(offs + x)) + + SIZE = 1024 + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + x_tri = to_triton(x, dst_type=dtype, device=device) + y_tri = to_triton(y, dst_type=dtype, device=device) + y = x + kernel[ + 1, + ](x_tri, y_tri, order, SIZE) + np.testing.assert_allclose(y, to_numpy(y_tri)) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y", [ # + (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes +] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_floordiv(dtype_x, dtype_y, num_ctas, device): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +def test_unsigned_name_mangling(device): + # Test that uint32 and int32 are mangled differently by the compiler + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(O1, O2, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + out1 = tl.abs(x) # uint32 -> nop + out2 = tl.abs(-y) # int32 -> should have an effect + tl.store(O1 + off, out1) + tl.store(O2 + off, out2) + + dtype_x = 'uint32' + dtype_y = 'int32' + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + # reference result + expect = (np.abs(x), np.abs(-y)) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect) + kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) + + # Bitwise op, so expect exact equality + assert (expect[0] == to_numpy(actual[0])).all() + assert (expect[1] == to_numpy(actual[1])).all() + + +# test bitwise ops +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if 'float' in dtype_x + dtype_y: + # The CompilationError must have been caused by a C++ exception with this text. + with pytest.raises(triton.TritonError, match='invalid operands of type'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device, num_ctas=num_ctas) + else: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['<<', '>>'] + for dtype_x in int_dtypes + uint_dtypes + for dtype_y in int_dtypes + uint_dtypes +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + if dtype_x.startswith('int'): + dtype_z = f'int{bw}' + else: + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, y_low=0, y_high=bw) + + +# --------------- +# test compare ops +# --------------- +ops = ['==', '!=', '>', '<', '>=', '<='] + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize( + "dtype_x, dtype_y, op, mode_x, mode_y", + # real + [(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes] + # NaNs + + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas) + + +# --------------- +# test broadcast +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) +def test_broadcast(dtype, device): + check_type_supported(dtype, device) + + @triton.jit + def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) + y = tl.load(y_ptr + offset2) + _, y_broadcasted = tl.broadcast(x, y) + tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + + M = 32 + N = 64 + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype, rs=rs) + y = numpy_random(N, dtype_str=dtype, rs=rs) + _, y_broadcasted_np = np.broadcast_arrays(x, y) + + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) + + broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + + +# ---------- +# test slice +# ---------- + + +@pytest.mark.interpreter +def test_slice(device): + + @triton.jit + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) + + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1, )](XBLOCK=32) + + +# ------------------ +# test invalid slice +# ------------------ + + +@pytest.mark.interpreter +def test_invalid_slice(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + dst[10:] + + with pytest.raises(triton.TritonError, match='unsupported tensor index'): + _kernel[(1, )](dst=dst) + + +# ---------------- +# test expand_dims +# ---------------- +@pytest.mark.interpreter +def test_expand_dims(device): + + @triton.jit + def expand_dims_kernel(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 0) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, 1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -2) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, (0, -1)) + tl.static_assert(t.shape == [1, N, 1]) + + t = tl.expand_dims(offset1, (0, 1, 3)) + tl.static_assert(t.shape == [1, 1, N, 1]) + + t = tl.expand_dims(offset1, (-4, 2, -1)) + tl.static_assert(t.shape == [1, N, 1, 1]) + + t = tl.expand_dims(offset1, (3, 1, 2)) + tl.static_assert(t.shape == [N, 1, 1, 1]) + + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + + # N is a scalar that's not even a tl.tensor -- this should work too. + t = tl.expand_dims(N, -1) + tl.static_assert(t.shape == [1]) + + N = 32 + dummy_tensor = torch.empty((), device=device) + expand_dims_kernel[(1, )](dummy_tensor, N) + + +@pytest.mark.interpreter +def test_expand_dims_error_cases(device): + + @triton.jit + def dim_out_of_range1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, -2) + t = tl.expand_dims(offset1, -3) + + @triton.jit + def dim_out_of_range2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 1) + t = tl.expand_dims(offset1, 2) + + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + + @triton.jit + def duplicate_dim1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, 0)) + + @triton.jit + def duplicate_dim2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, -3)) + + N = 32 + dummy_tensor = torch.empty((), device=device) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range1[(1, )](dummy_tensor, N) + assert "invalid axis -3" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range2[(1, )](dummy_tensor, N) + assert "invalid axis 2" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range3[(1, )](dummy_tensor, N) + assert "invalid axis 1" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim1[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim2[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + +# ---------------------------- +# test invalid program id axis +# ---------------------------- +@pytest.mark.interpreter +def test_invalid_pid_axis(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pid = tl.program_id(20) + + with pytest.raises(triton.TritonError) as exc_info: + _kernel[(1, )](dst) + assert re.search(r"program_id axis must be 0, 1, or 2 but got 20", str(exc_info.value.__cause__)) + + +# --------------- +# test where +# --------------- +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where(dtype, num_ctas, device): + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype, device) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_SCALAR_POINTERS: + ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr) + output = tl.load(ptr + offsets, mask=mask) + else: + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) + + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) + + cond_tri = to_triton(cond, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) + + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) + assert (z == to_numpy(z_tri)).all() + if select_ptrs: + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=True) + z = np.where(cond[0], x, y) + assert (z == to_numpy(z_tri)).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where_broadcast(num_ctas, device): + + @triton.jit + def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + + mask = tl.load(cond_ptr + yoffsets) + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + @triton.jit + def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + mask = 0 + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + SIZE = 32 + dtype = 'float32' + rs = RandomState(17) + x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) + mask = numpy_random(SIZE, 'bool', rs=rs) + z = np.where(mask, x, 0) + cond_tri = to_triton(mask, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) + where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE) + assert (z == to_numpy(z_tri)).all() + where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas) + z = np.where(0, x, 0) + assert (z == to_numpy(z_tri)).all() + + +# --------------- +# test unary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr", + [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') + for dtype_x in int_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_unary_op(dtype_x, expr, num_ctas, device): + _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) + + +# ---------------- +# test math ops +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr, x", + [(dtype_x, expr, x) + for dtype_x in ["float32", "float64"] + for expr in ['exp', 'log', 'cos', 'sin', 'exp2', 'log2', 'sqrt', 'floor', 'ceil'] + for x in ['x', '3.0']]) +def test_math_op(dtype_x, expr, x, device): + _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_erf_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + if is_musa(): + # muDNN dose not support double, thus use cpu as ref here. + device = "cpu" + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.math.erf(x) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = torch.erf(x) + z_tri = torch.zeros_like(x) + if is_musa(): + x_musa = x.musa() + z_tri_musa = z_tri.musa() + kernel[(1, )](z_tri_musa, x_musa, SIZE=SIZE, num_warps=4) + z_tri = z_tri.to(device) + else: + kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_fma_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + if is_musa(): + # muDNN dose not support double, thus use cpu as ref here. + device = "cpu" + + @triton.jit + def kernel(Z, X, Y, W, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + w = tl.load(W + off) + z = tl.math.fma(x, y, w) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + y = torch.randn(SIZE, dtype=torch_dtype, device=device) + w = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = x * y + w + z_tri = torch.zeros_like(x) + if is_musa(): + z_tri_musa = z_tri.musa() + x_musa = x.musa() + y_musa = y.musa() + w_musa = w.musa() + kernel[(1, )](z_tri_musa, x_musa, y_musa, w_musa, SIZE=SIZE, num_warps=4) + z_tri = z_tri_musa.to(device) + else: + kernel[(1, )](z_tri, x, y, w, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_math_divide_op(expr, num_ctas, device): + numpy_expr = "x / y" + dtype = "float32" + _test_binary(dtype, dtype, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +# ------------- +# test precise math +# ------------- +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("expr_prec, expr_ref", + [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), + ('tl.math.div_rn(x,y)', '(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)')]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_precise_math(expr_prec, expr_ref, num_ctas, device): + + @triton.jit + def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + prec = PREC_CALC + ref = REF_CALC + tl.store(OUT + tl.arange(0, BLOCK), prec) + tl.store(OUT_REF + tl.arange(0, BLOCK), ref) + + shape = (128, ) + out = torch.zeros(shape, dtype=torch.float32, device=device) + out_ref = torch.zeros(shape, dtype=torch.float32, device=device) + + x = torch.randn(shape, dtype=torch.float32, device=device) + y = torch.randn(shape, dtype=torch.float32, device=device) + + if (expr_prec.count('sqrt') > 0): + x = torch.abs(x) + + if (expr_prec.count('div') > 0): + y += 1e-6 + + kernel = patch_kernel(kernel, {'PREC_CALC': expr_prec, 'REF_CALC': expr_ref}) + + kernel[(1, )](x, y, out, out_ref, BLOCK=shape[0], num_ctas=num_ctas) + assert torch.all(out == out_ref) # bitwise exact + + +# ---------------- +# test abs +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_abs(dtype_x, device): + _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) +def test_abs_fp8(in_dtype, device): + if is_hip(): + pytest.skip('test_abs_fp8 not supported on HIP.') + + @triton.jit + def abs_kernel(X, Z, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.abs(x) + tl.store(Z + off, z) + + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device=device) + # f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan + all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width + f8_tensor[all_exp_ones] = 0 + f8 = triton.reinterpret(f8_tensor, in_dtype) + n_elements = f8_tensor.numel() + out_f8 = torch.empty_like(f8_tensor) + abs_kernel[(1, )](f8, triton.reinterpret(out_f8, in_dtype), n_elements) + + f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) + expect = f32_tensor.abs() + actual_f8 = convert_float_to_float32(out_f8, in_dtype) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) + + +# ---------------- +# test passing shapes as individual params rather than tuples +# ---------------- + + +@pytest.mark.interpreter +def test_shapes_as_params(device): + + @triton.jit + def kernel(): + a = tl.arange(0, 32).expand_dims(-1).broadcast_to(32, 32) + tl.static_assert(a.shape == [tl.constexpr(32), tl.constexpr(32)]) + + a = tl.arange(0, 32).reshape(4, 8).permute(1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).reshape(32) + tl.static_assert(a.shape == [tl.constexpr(32)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.arange(0, 64).view(2, 4, 8) + tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) + + kernel[(1, )]() + + +# ---------------- +# test transpose +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_transpose(dtype_x, device): + check_type_supported(dtype_x, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + off2d = off[None, :] + (tl.arange(0, 2) * SIZE)[:, None] + x = tl.load(X + off2d) + z = x.T + tl.store(Z + off2d.T, z) + + x = numpy_random([SIZE, 2], dtype_str=dtype_x) + z_ref = x.T + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE) + np.testing.assert_allclose(z_ref, to_numpy(z_tri)) + + +# ---------------- +# test indexing +# ---------------- + + +def make_ptr_str(name, shape): + rank = len(shape) + offsets = [] + stride = 1 + for i in reversed(range(rank)): + idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) + offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] + stride *= shape[i] + return f"{name} + {' + '.join(offsets)}" + + +# TODO: handle `%4 = triton_gpu.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16']]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_index1d(expr, dtype_str, num_ctas, device): + rank_x = expr.count(':') + rank_y = expr.count(',') + 1 + shape_x = [32 for _ in range(rank_x)] + shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] + + # Triton kernel + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + m = tl.arange(0, SIZE) + n = tl.arange(0, SIZE) + x = tl.load(X_PTR_EXPR) + z = GENERATE_TEST_HERE + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) + + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) + + # torch result + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) + z_ref = eval(expr) + y + # triton result + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x, device=device) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + # compare + assert (z_ref == to_numpy(z_tri)).all() + + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0], num_ctas=num_ctas) + except triton.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) + + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) + + +# --------------- +# test tuples +# --------------- + + +@triton.jit +def tuples_fn(a, b): + return a + b, \ + a - b, \ + a * b + + +@pytest.mark.interpreter +def test_tuples(device): + + @triton.jit + def with_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = tuples_fn(x, y) + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + @triton.jit + def without_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = x + y, x - y, x * y + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + x = torch.tensor([1.3], device=device, dtype=torch.float32) + y = torch.tensor([1.9], device=device, dtype=torch.float32) + a_tri = torch.tensor([0], device=device, dtype=torch.float32) + b_tri = torch.tensor([0], device=device, dtype=torch.float32) + c_tri = torch.tensor([0], device=device, dtype=torch.float32) + for kernel in [with_fn, without_fn]: + kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) + a_ref, b_ref, c_ref = x + y, x - y, x * y + assert a_tri == a_ref + assert b_tri == b_ref + assert c_tri == c_ref + + +@triton.jit(noinline=True) +def noinline_simple_fn(x, y, Z): + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_graph_fn1(x): + return x + 1 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn2(y): + return y + 2 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn(x, y, Z): + t0 = noinline_call_graph_fn1(x) + t1 = noinline_call_graph_fn2(y) + z = t0 + t1 + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_shared_fn(x, y, Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + x + y + tl.store(Z + offs, z) + + +@triton.jit(noinline=True) +def noinline_dynamic_fn(x, y, Z): + if x >= 1: + x = noinline_call_graph_fn1(x) + else: + x = noinline_call_graph_fn2(x) + if y >= 2: + y = noinline_call_graph_fn2(y) + else: + y = noinline_call_graph_fn1(y) + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_multi_values_fn(x, y): + return x + 1, y + 2 + + +@triton.jit(noinline=True) +def noinline_multi_values_fn(x, y, Z): + x, y = noinline_call_multi_values_fn(x, y) + z = x + y + tl.store(Z, z) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) +def test_noinline(mode, device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + GENERATE_TEST_HERE(x, y, Z) + + func_name = f'noinline_{mode}_fn' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name}) + x = torch.tensor([1.0], device=device, dtype=torch.float32) + y = torch.tensor([2.0], device=device, dtype=torch.float32) + if mode == "shared": + z = torch.ones((16, 16), device=device, dtype=torch.float32) + else: + z = torch.tensor([0.0], device=device, dtype=torch.float32) + kernel[(1, )](x, y, z, num_warps=1) + if mode == "simple": + assert torch.equal(z, x + y) + elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": + assert torch.equal(z, x + 1 + y + 2) + elif mode == "shared": + ref = torch.full((16, 16), 16, device=device, dtype=torch.float32) + assert torch.equal(z, ref + x + y) + + +# --------------- +# test atomics +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_x_str, mode, sem", + itertools.chain.from_iterable([[ + ('add', 'float16', mode, sem), + ('add', 'uint32', mode, sem), + ('add', 'int32', mode, sem), + ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), + ('add', 'int64', mode, sem), + ('add', 'float64', mode, sem), + ('max', 'uint32', mode, sem), + ('max', 'int32', mode, sem), + ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), + ('max', 'int64', mode, sem), + ('max', 'float64', mode, sem), + ('min', 'uint32', mode, sem), + ('min', 'int32', mode, sem), + ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), + ('min', 'int64', mode, sem), + ('min', 'float64', mode, sem), + ] + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] + for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) +def test_atomic_rmw(op, dtype_x_str, mode, sem, device): + if is_interpreter(): + if dtype_x_str == 'float16': + pytest.skip("Only test atomic float16 ops on GPU") + + n_programs = 5 + + # triton kernel + @triton.jit + def kernel(X, Z): + pid = tl.program_id(0) + x = tl.load(X + pid) + old = GENERATE_TEST_HERE + tl.static_assert(old.dtype == x.dtype) + + sem_arg = sem if sem is None else f'"{sem}"' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'}) + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max + neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] + + # triton result + rs = RandomState(17) + x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str)) + if mode == 'all_neg': + x = -np.abs(x) + if mode == 'all_pos': + x = np.abs(x) + if mode == 'min_neg': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 + if mode == 'max_pos': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device) + + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) + h = kernel[(n_programs, )](x_tri, z_tri) + # torch result + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) + # compare + exact = op not in ['add'] + if exact: + assert z_ref.item() == to_numpy(z_tri).item() + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + sem_str = "acq_rel" if sem is None else sem + if not is_cuda(): + return + + assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_rmw_predicate(num_ctas, device): + + @triton.jit + def kernel(X): + val = tl.program_id(0) + if val < 64: + tl.atomic_max(X, val) + + x = torch.zeros((1, ), device=device, dtype=torch.int32) + kernel[(4096, )](x, num_ctas=num_ctas) + assert x.item() == 63 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, device): + shape0, shape1 = shape + # triton kernel + + @triton.jit + def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + z = tl.sum(x, axis=AXIS) + if AXIS == 1: + tl.atomic_add(Z + off0, z) + else: + tl.atomic_add(Z + off1, z) + + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + # reference result + z_ref = np.sum(x, axis=axis, keepdims=False) + # triton result + x_tri = to_triton(x, device=device) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) + kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_rmw_block(num_ctas, device): + shape = (8, 8) + + @triton.jit + def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + offs = off0[:, None] * SHAPE1 + off1[None, :] + val = offs.to(tl.float32) + x = X + offs + tl.atomic_min(x, val) + + x = torch.ones((8, 8), device=device, dtype=torch.float32) + kernel[(2, )](x, shape[0], shape[1], num_ctas=num_ctas) + assert torch.min(x).item() == 0.0 + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_cas(sem, num_ctas, device): + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock): + tl.atomic_cas(Lock, 0, 1) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + change_value[(1, )](Lock) + + assert (Lock[0] == 1) + + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock, SEM: tl.constexpr): + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, 0, 1, SEM) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # release lock + tl.atomic_xchg(Lock, 0) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + data = torch.zeros((128, ), device=device, dtype=torch.float32) + ref = torch.full((128, ), 2000.0) + h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas) + sem_str = "acq_rel" if sem is None else sem + np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) + if not is_cuda(): + return + assert f"atom.global.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_cas(sem, num_ctas, device): + + @triton.jit + def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64) + t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64) + tl.atomic_cas(X + offsets, t1, t2, sem=sem) + + X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64) + Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64) + + change_value[(2, )](X, 4, sem) + assert (torch.equal(X, Y)) + + +# --------------- +# test cast +# --------------- + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", + [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'int1', False, 1024), + ('int8', 'bfloat16', False, 1024), + ] + [(f'uint{x}', f'int{x}', True, 1024) + for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) + for x in [8, 16, 32, 64]] + + (([(dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32]] # + + [(dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32]]) if torch.__version__ >= "2.1" else [])) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): + # CUDA: bfloat16 on cc < 80 will not be tested + # Interpreter: Only bfloat16 <-> float32 is supported + if not is_interpreter() or \ + (is_interpreter() and not ((dtype_z == 'bfloat16' and dtype_x == 'float32') + or (dtype_z == 'float32' and dtype_x == 'bfloat16'))): + check_type_supported(dtype_x, device) + check_type_supported(dtype_z, device) + + if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + + torch.manual_seed(0) + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + if dtype_x.startswith('bfloat'): + x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) + elif dtype_x.startswith('float8'): + x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x)) + else: + x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10 + # Triton clamps negative values to zero, while numpy wraps around + # intmax, so avoid negatives for now. + # TODO: figure out which one should actually be happening, and test it + if dtype_z in uint_dtypes: + x = np.absolute(x) + x_tri = to_triton(x, device=device) + if 'float' in dtype_z and 'float' in dtype_x: + # make sure we use values that can be represented in both types + x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x)) + # triton kernel + + @triton.jit + def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): + x_ptr = X + tl.arange(0, SIZE) + z_ptr = Z + tl.arange(0, SIZE) + x = tl.load(x_ptr) + + # Depending on the value of ARG_HASH (a "random" number determined by + # the test parameters), spell the cast one of three different ways. + if ARG_HASH % 3 == 0: + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 3 == 1: + z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) + else: + z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) + + tl.store(z_ptr, z) + + # "Random" number used inside the kernel to determine how we spell the cast. + # This way we don't have to increase the number of tests. + arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) + + dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' + # triton result + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) + elif dtype_z.startswith('float8'): + z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) + else: + z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) + kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, num_ctas=num_ctas) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( + 'float8') or dtype_x.startswith('float8'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + if dtype_z.startswith('float8') and device not in ['cuda']: + t = z_ref.byte() ^ z_tri.byte() + torch.testing.assert_close(torch.zeros_like(t, dtype=torch.uint8), t) + else: + torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z_np)) + else: + z_ref = x.astype(getattr(np, dtype_z_np)) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, num_warps", + [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) +def test_cat(dtype_str, num_warps, device): + check_type_supported(dtype_str, device) + + if is_musa(): + # muDNN dose not support i16 and i64, thus use cpu as ref here. + device = "cpu" + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.cat(x, y, can_reorder=True) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) + y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) + z_ref = torch.cat([x, y], dim=0).sum() + z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) + if is_musa(): + x_musa = x.musa() + y_musa = y.musa() + z_musa = z.musa() + kernel[(1, )](x_musa, y_musa, z_musa, N=128, num_warps=num_warps) + z = z_musa.to(device) + else: + kernel[(1, )](x, y, z, N=128, num_warps=num_warps) + assert z.sum() == z_ref + # check if there's no duplicate value in z + assert z.unique().size(0) == z.size(0) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", list(torch_dtypes)) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant(dtype_str, num_ctas, device): + check_type_supported(dtype_str, device) + """Tests that boolean True is stored as 1""" + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + output = GENERATE_TEST_HERE + tl.store(output_ptr + offsets, output, mask=mask) + + triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'}) + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + + assert torch.all(output == ref) + + +def test_load_store_same_ptr(device): + + @triton.jit() + def kernel(in_out_ptr): + pid = tl.program_id(axis=0) + x = tl.load(in_out_ptr + pid) + out = x * 2 + tl.store(in_out_ptr + pid, out) + + for _ in range(1000): + x = torch.ones((65536, ), device=device, dtype=torch.float32) + if is_hip(): + kernel[(65536, )](x, num_warps=16) # threads per Warp for ROCM is 64 + else: + kernel[(65536, )](x, num_warps=32) + assert torch.all(x == 2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['int32']) +def test_umulhi(dtype_str, device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.umulhi(x, y) + tl.store(Z + tl.arange(0, N), z) + + def umulhi32(a, b): + # Convert to 64-bit unsigned integers to prevent overflow + a_64 = a.astype(np.int64) + b_64 = b.astype(np.int64) + + # Perform the multiplication in 64-bit + product_64 = a_64 * b_64 + + # Shift right by 32 bits to get the high part of the product + result_high_32 = product_64 >> 32 + return result_high_32 + + rs = RandomState(17) + N = 128 + x = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + x_tri = to_triton(x, device=device) + y = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + y_tri = to_triton(y, device=device) + z_tri = torch.zeros_like(x_tri) + kernel[(1, )](x_tri, y_tri, z_tri, N=N) + + z_ref = umulhi32(x, y) + np.testing.assert_equal(z_ref, to_numpy(z_tri)) + + +@pytest.mark.interpreter +def test_join(device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.join(x, y) + tl.store(Z + tl.arange(0, N)[:, None] * 2 + tl.arange(0, 2)[None, :], z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(-128, 0, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, y, z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + z = tl.join(x, y) + tl.static_assert(z.shape == [2]) + tl.store(Z + tl.arange(0, 2), z) + + x = torch.full([1], 42, device=device).to(torch.int32) + y = torch.full([1], 100, device=device).to(torch.int32) + z = torch.zeros([2], device=device) + kernel[(1, )](x, y, z) + + np.testing.assert_equal([42, 100], to_numpy(z)) + + +@pytest.mark.interpreter +def test_join_with_mma(device): + + @triton.jit + def kernel(X, Z): + x = tl.load(X + 16 * tl.arange(0, 32)[:, None] + tl.arange(0, 16)[None, :]) # (32,16) + x2 = tl.join(x, 2 * x) # (32,16,2) + x3 = tl.reshape(x2, (32, 32)) + z = tl.dot(x3, x3) # (32,32) + tl.store(Z + 32 * tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :], z) + + x = torch.arange(0, 32 * 16, device=device, dtype=torch.float32).reshape((32, 16)) + r = torch.stack([x, 2 * x], dim=-1).reshape((32, 32)) + z_ref = torch.matmul(r, r) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, z) + + torch.testing.assert_close(z, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("debug", [False, True]) +def test_interleave(device, debug): + + @triton.jit(debug=debug) + def kernel(Z, N: tl.constexpr): + z = tl.interleave(tl.arange(0, N), tl.arange(N, 2 * N)) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(128, 256, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1).reshape(256) + z = torch.zeros_like(z_ref) + kernel[(1, )](z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.interpreter +def test_interleave_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + z = tl.interleave(X, Y) + tl.static_assert(z.shape == [tl.constexpr(2)]) + tl.store(Z + tl.arange(0, 2), z) + + z = torch.zeros(2, device=device) + kernel[(1, )](10, 20, z) + + np.testing.assert_equal([10, 20], to_numpy(z)) + + +@pytest.mark.interpreter +def test_split(device): + + @triton.jit + def kernel(X, Z1, Z2, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + x1 = tl.reshape(x, (N // 2, 2)) + z1, z2 = tl.split(x1) + tl.store(Z1 + tl.arange(0, N // 2), z1) + tl.store(Z2 + tl.arange(0, N // 2), z2) + + x = torch.arange(0, 256, device=device).to(torch.int32).reshape((128, 2)) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2, N=256) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +@pytest.mark.interpreter +def test_split_to_scalar(device): + + @triton.jit + def kernel(X, Z1, Z2): + offs = tl.arange(0, 2) + x = tl.load(X + offs) + z1, z2 = tl.split(x) + tl.static_assert(isinstance(z1, tl.tensor)) + tl.static_assert(isinstance(z2, tl.tensor)) + tl.static_assert(z1.shape == []) + tl.static_assert(z2.shape == []) + tl.store(Z1, z1) + tl.store(Z2, z2) + + N = 2 + x = torch.arange(0, N, device=device).reshape(N // 2, 2) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +def convert_float_to_float32(fp: torch.tensor, dtype=None): + if not dtype: + dtype = getattr(tl, torch_dtype_name(fp.dtype)) + + fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}")) + exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1 + exp_bias = dtype.exponent_bias + sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int() + exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int() + frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int() + + output = torch.where( + exp == 0, + # subnormal + ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (frac / (2.0**dtype.fp_mantissa_width)), + # normal + ((-1.0)**sign) * (2.0**(exp - exp_bias)) * (1.0 + frac / (2.0**dtype.fp_mantissa_width))).float() + + extended_exp = ( + (1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width + # special cases, exp is 0b11..1 + if dtype in [tl.float8e4nv, tl.float8e4b15]: + # float8e4m3nv does not have infinities + output[fp == 0b01111111] = torch.nan + output[fp == 0b11111111] = torch.nan + else: + output = torch.where(exp == (1 << exp_width) - 1, + ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp + | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))) # + .view(torch.float32), output) + return output + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) +def test_convert_float16_to_float32(in_dtype, device): + """Tests that check convert_float_to_float32 function""" + check_type_supported(in_dtype, device) + + f16_input = torch.tensor(range(-int(2**(16 - 1)), int(2**(16 - 1))), dtype=torch.int16).view(in_dtype) + f32_output = convert_float_to_float32(f16_input) + + nan = f16_input.isnan() + assert torch.all(f32_output[nan].isnan()) + inf = f16_input.isinf() + assert torch.all(f32_output[inf].isinf()) + other = torch.logical_not(torch.logical_or(nan, inf)) + assert torch.all(f16_input[other] == f32_output[other]) + + +def serialize_fp8(np_data, in_dtype): + return np_data + + +# inverse of `serialize_fp8` + + +def deserialize_fp8(np_data, in_dtype): + return np_data + + +# --------------- +# test reduce +# --------------- + + +@pytest.mark.interpreter +def test_max_returns_zero(device): + # Simple test with a tl.max call that returns 0. The interpreter had a bug + # where it didn't handle this correctly. + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + z = tl.max(x) + tl.store(Z, z) + + BLOCK = 128 + x = torch.zeros((BLOCK, ), device=device) + z = torch.ones((1, ), device=device) + + kernel[(1, )](x, z, BLOCK=BLOCK) + assert z[0] == 0 + + +def get_reduced_dtype(dtype_str, op): + if op in ('argmin', 'argmax'): + return 'int32' + if dtype_str == 'bfloat16': + return 'float32' + return dtype_str + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ + 'min', + 'max', + 'min-with-indices', + 'max-with-indices', + 'argmin-tie-break-left', + 'argmax-tie-break-left', + 'sum', +] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce1d(op, dtype_str, shape, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + # triton kernel + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + GENERATE_TEST_HERE + tl.store(Z, z) + + if 'with-indices' in op: + patch = f'z, _ = tl.{op.split("-")[0]}(x, axis=0, return_indices=True)' + elif 'arg' in op: + tie_break_left = 'tie-break-left' in op + patch = f'z = tl.{op.split("-")[0]}(x, axis=0, tie_break_left={tie_break_left})' + else: + patch = f'z = tl.{op}(x, axis=0)' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs) + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'max-with-indices': np.max, + 'min-with-indices': np.min, + 'argmin-tie-break-fast': np.argmin, + 'argmin-tie-break-left': np.argmin, + 'argmax-tie-break-fast': np.argmax, + 'argmax-tie-break-left': np.argmax, + }[op] + if 'tie-break-left' in op: + x[3:10] = numpy_op(x) + x_tri = to_triton(x, device=device) + # numpy result + z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str + z_tri_dtype_str = z_dtype_str + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # triton result + z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) + z_tri = to_numpy(z_tri) + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# TODO: [Qingyi] Fix argmin / argmax +reduce_configs1 = [(op, dtype, (1, 1024), axis, False) + for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [1]] + +# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory +# exceeds the limit of 99KB +reduce2d_shapes = [(2, 32), (4, 32), (4, 128)] +# TODO: fix and uncomment +# , (32, 64), (64, 128)] +if is_cuda() and 'V100' in torch.cuda.get_device_name(0): + reduce2d_shapes += [(128, 256) and (32, 1024)] + +reduce_configs2 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce2d_shapes + for axis in [0, 1]] + [(op, 'float32', [16, 32], None, False) for op in ['min', 'max', 'sum']] + +reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] +reduce_configs3 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2]] +invalid_config = [('sum', 'float32', (32, 32), axis, False) for axis in [2, 3]] +negative_config = [('sum', 'float32', (32, 32), -1, False)] +keep_dims_2d_configs = [(op, 'float32', (32, 32), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1]] + [(op, 'float32', (32, 32), None, True) for op in ['min', 'max', 'sum']] +keep_dims_3d_configs = [(op, 'float32', (32, 2, 16), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1, 2]] + [(op, 'float32', (32, 2, 16), None, True) + for op in ['min', 'max', 'sum']] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + + negative_config + keep_dims_2d_configs + keep_dims_3d_configs) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, + AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + range_k = tl.arange(0, BLOCK_K) + if IS_3D: + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + + range_k[None, None, :]) + else: + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + z = GENERATE_TEST_HERE + + z_ptr = Z + if KEEP_DIMS and AXIS is None: + if IS_3D: + z_ptr = z_ptr[None, None, None, :] + else: + z_ptr = z_ptr[None, None, :] + if IS_3D: + if AXIS == 0: + z_ptr = Z + range_n[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 1 or AXIS == -2: + z_ptr = Z + range_m[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 2 or AXIS == -1: + z_ptr = Z + range_m[:, None] * BLOCK_N + range_n[None, :] + else: + if AXIS == 0: + z_ptr = Z + range_n + elif AXIS == 1 or AXIS == -1: + z_ptr = Z + range_m + if KEEP_DIMS and AXIS is not None: + z_ptr = tl.expand_dims(z_ptr, axis=AXIS) + tl.store(z_ptr, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_tri = to_triton(x, device=device) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax}[op] + z_dtype_str = get_reduced_dtype(dtype_str, op) + z_tri_dtype_str = z_dtype_str + + # numpy result + # Silence numpy error on axis out of bounds, to give triton a chance to fail + np_axis = axis if axis is not None and axis < len(shape) else None + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + + # triton result + z_shape = z_ref.shape + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + BLOCK_K = 1 if len(shape) == 2 else shape[2] + IS_3D = bool(len(shape) == 3) + if axis is not None and axis >= len(shape): + with pytest.raises(triton.TritonError): + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, num_ctas=num_ctas) + return + else: + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, num_ctas=num_ctas) + + z_tri = to_numpy(z_tri) + + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = z_ref + z_tri_index = z_tri + if not keep_dims: + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)] + +scan_configs = [(op, type, shape, axis, reverse, num_warps) + for num_warps in [4, 16] + for type in ['int32', 'float32', 'bfloat16'] + for axis in [1, 0] + for reverse in [True, False] + for shape in scan2d_shapes + for op in ['cumsum', 'cumprod', 'get_first_element', 'linear_recurrence', 'cummax', 'roll']] +negative_config = [('cumsum', 'float32', (32, 32), -1, False, 4)] + + +@triton.jit +# trivial associative but not commutative function +def get_first_element(a, b): + return a + + +# Compute x_i = a_i * x_{i-1} + b_i +@triton.jit +def linear_recurrence(a1, b1, a2, b2): + return a1 * a2, b1 * a2 + b2 + + +@triton.jit +def cummax(v0, i0, v1, i1): + gt = v0 > v1 + return tl.where(gt, v0, v1), tl.where(gt, i0, i1) + + +@triton.jit +def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): + return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) +def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): + check_type_supported(dtype_str, device) + if dtype_str == 'bfloat16': + if op == 'cummax': + pytest.skip("bfloat16 compare not suppoted before sm90") + if op == 'linear_recurrence': + pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") + numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + + # triton kernel + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + y = tl.load(Y + range_m[:, None] * BLOCK_N + range_n[None, :]) + GENERATE_TEST_HERE + tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) + + if op == 'cumsum' or op == 'cumprod': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'z = tl.{op}(x, axis={axis}, reverse={reverse})'}) + elif op == 'get_first_element': + kernel = patch_kernel( + kernel, + {'GENERATE_TEST_HERE': f'z = tl.associative_scan(x, axis={axis}, combine_fn={op}, reverse={reverse})'}) + elif op == 'cummax': + rg = "range_m[:, None]" if axis == 0 else "range_n[None, :]" + rg = f"tl.broadcast_to({rg}.to(tl.int64), [BLOCK_M, BLOCK_N])" + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, {rg}), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + elif op == 'roll': + assert op == 'roll' + kernel = patch_kernel( + kernel, { + 'GENERATE_TEST_HERE': + f'_, z, _ = tl.associative_scan((1 + 0* x, 0 * x, x), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + else: + assert op == 'linear_recurrence' + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, y), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + # input + rs = RandomState(17) + if op == 'linear_recurrence' and dtype_str in int_dtypes: + # If the numbers are too large the op will overflow + # We sample numbers in -1, 0, 1 + x = rs.randint(-1, 2, shape, dtype=dtype_str) + y = rs.randint(-1, 2, shape, dtype=dtype_str) + else: + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + # y is just used in linear_recurrence + y = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_in = x + if reverse: + x_in = np.flip(x, axis) + z = np.empty_like(x) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + y_tri = to_triton(y, device=device, dst_type=dtype_str) + if op == 'cumsum' or op == 'cumprod': + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] + z_ref = numpy_op(x_in, axis=axis).astype(getattr(np, numpy_dtype_str)) + if reverse: + z_ref = np.flip(z_ref, axis) + + elif op == 'cummax': + # NumPy does not have cummax + z = z.astype(np.int64) + z_ref = torch.cummax(torch.from_numpy(x_in.copy()), axis=axis).indices.numpy() + if reverse: + z_ref = x_in.shape[axis] - np.flip(z_ref, axis) - 1 + elif op == 'roll': + ROLL = 1 + z_ref = np.roll(x_in.copy(), ROLL, axis=axis) + if axis == 0: + z_ref[:ROLL] = 0 + else: + z_ref[:, :ROLL] = 0 + + if reverse: + z_ref = np.flip(z_ref, axis) + elif op == 'linear_recurrence': + # Simplify to the axis=1 case + x_ref = x.T if axis == 0 else x + y_ref = y.T if axis == 0 else y + if reverse: + x_ref = np.flip(x_ref, 1) + y_ref = np.flip(y_ref, 1) + + result = [] + for x_refi, y_refi in zip(x_ref, y_ref): + li = [] + acc = 0 + for xi, yi in zip(x_refi, y_refi): + acc = xi * acc + yi + li.append(acc) + result.append(li) + z_ref = np.array(result) + if reverse: + z_ref = np.flip(z_ref, 1) + + if axis == 0: + z_ref = z_ref.T + else: + assert op == 'get_first_element' + z_ref = x + if axis == 0: + if reverse: + z_ref[:-1] = x[-1] + else: + z_ref[1:] = x[0] + else: + if reverse: + z_ref[:, :-1] = x[:, -1:] + else: + z_ref[:, 1:] = x[:, 0:1] + + # triton result + # we don't cast the `fp32 = bf16 op bf16` result to bfloat16 to alleviate accuracy issues + z_tri = to_triton(z, device=device) + kernel[(1, )](x_tri, y_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) + + z_tri = to_numpy(z_tri) + # compare + if dtype_str not in int_dtypes: + if op == 'cumprod': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3) + else: + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan_layouts = [ + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), +] + +# --------------- +# test histogram +# --------------- + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) +def test_histogram(M, N, device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + offset1) + z = tl.histogram(x, N) + tl.store(z_ptr + offset2, z) + + torch.manual_seed(17) + x = torch.randint(0, N, (M, ), device=device, dtype=torch.int32) + z = torch.empty(N, dtype=torch.int32, device=device) + # torch.histc does not work when the input type is not float and the device is CPU + # https://github.com/pytorch/pytorch/issues/74236 + # This is a workload by converting the input to float + z_torch = torch.histc(x.float(), bins=N, min=0, max=N - 1) + histogram_kernel[(1, )](x, z, M=M, N=N) + assert (z_torch == z).all() + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['sum', 'max', 'min']) +@pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) +@pytest.mark.parametrize("N", [512, 1024, 2048]) +@pytest.mark.parametrize("num_pid_n", [2, 4]) +def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device): + + @triton.jit + def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.constexpr): + start_m = tl.program_id(0) + pid_n = tl.program_id(1) + local = INITIALIZE_PATCH + off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), NUM_PID_N): + off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * N + off_n[None, :] + x = tl.load(Xs) + local = ACCUMULATE_PATCH + tl.store(Y + off_m * NUM_PID_N + pid_n, local) + # the following segfaults AMD backend following #3492 + # really unclear why; the llvm-ir and kernel arguments are + # identical ! + # tl.store(Y + off_m * tl.num_programs(1) + pid_n, local) + + initialize_patch = { + 'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)', + 'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)', + 'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)', + }[op] + reduce_patch = { + 'sum': 'local + tl.sum(x, axis=1)', + 'max': 'tl.maximum(local, tl.max(x, axis=1))', + 'min': 'tl.minimum(local, tl.min(x, axis=1))', + }[op] + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + }[op] + kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch}) + torch.manual_seed(0) + BLOCK_M = 32 + x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device) + y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device) + h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n) + if not is_interpreter(): + assert h.asm['ttgir'].count( + '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" + y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) + y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True) + np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) +@pytest.mark.parametrize("src_layout", scan_layouts) +@pytest.mark.parametrize("axis", [0, 1]) +def test_scan_layouts(M, N, src_layout, axis, device): + + ir = f""" + #blocked = {src_layout} + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %7 = tt.broadcast %4 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #blocked> + %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ + ^bb0(%arg2: i32, %arg3: i32): + %16 = arith.addi %arg2, %arg3 : i32 + tt.scan.return %16 : i32 + }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %14 = tt.broadcast %13 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + tt.store %15, %11 : tensor<{M}x{N}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + rs = RandomState(17) + x = rs.randint(-100, 100, (M, N)).astype('int32') + + z = np.zeros((M, N)).astype('int32') + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + kernel[(1, 1, 1)](x_tri, z_tri) + + z_ref = np.cumsum(x, axis=axis) + + np.testing.assert_equal(z_ref, z_tri.cpu().numpy()) + + +layouts = [ + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [THREADS_PER_WARP // 16, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 16, 16]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 32, 16]), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=True), + WmmaLayout(warps_per_cta=[2, 2]), + WmmaLayout(warps_per_cta=[4, 1]), + WmmaLayout(warps_per_cta=[1, 4]), +] + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [64, 64], [32, 128], [32, 32], [16, 16]]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) +@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): + if isinstance(src_layout, + (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): + pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") + if is_hip() and isinstance(src_layout, MfmaLayout) and ((M, N) == (128, 128)): + pytest.skip("Skipping test because it runs out of shared memory") + if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: + pytest.skip("Skipping sum reduction on float16 due to accuracy issues") + if epilogue_kind == 'expand_reduce2d' and isinstance(src_layout, MmaLayout): + pytest.skip( + "Currently MmaLayout combined with slice encoding and reduce op trigger device illegal memory access") + + if isinstance(src_layout, MmaLayout) and src_layout.version == 3: + src_layout[2] = 16 if dtype_str == "float16" else 8 + + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] + arith_op = { + "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} + }[reduce_op][dtype_str] + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] + rdims_1d = f"{N}" if axis == 0 else f"{M}" + rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" + store_range = "%7" if axis == 0 else "%1" + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) + num_warps = src_layout.warps_per_cta[0] * src_layout.warps_per_cta[1] + if num_warps == 8: + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, 2], [0, 1], [1, 1], [1, 1], [0, 1]) + one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [4], [0], [1], [1], [0]) + + expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1" + other_axis = 1 - axis + epilogue = { + "reduce1d": + f""" + %14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> + %16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> + %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked> + tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + tt.return + }} + }} + """, "reduce2d": + f""" + %14 = "tt.reduce"(%13) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} + tt.store %arg2, %14 : !tt.ptr<{ty}> + tt.return + }} + }} + """, "expand_reduce2d": + f""" + %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> + %15 = "tt.reduce"(%14) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) + %16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> + %17 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.store %17, %16 : tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.return + }} + }} + """ + }[epilogue_kind] + + ir = f""" + #blocked = {blocked} + #src = {src_layout} + #one_d_layout = {one_d_layout} + module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> + %13 = "tt.reduce"(%12) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> + """ + epilogue + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) + reduce2d = 'reduce2d' in epilogue_kind + z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1) + z = np.zeros(z_shape).astype(dtype_str) + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) + z_ref = numpy_op(x) if reduce2d else numpy_op(x, axis=axis, keepdims=True) + + if dtype_str == 'float16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("M", [32, 64, 128, 256]) +@pytest.mark.parametrize("src_layout", layouts) +def test_store_op(M, src_layout, device): + + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 1 : i32}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xf32, #src> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %6 = tt.expand_dims %5 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> + tt.store %8, %4 : tensor<{M}x1x!tt.ptr, #src> + tt.return + }} + }} + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + store_kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, 1)).astype('float32') + y = np.zeros((M, 1), dtype='float32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + + pgm = store_kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +layouts = [ + # TODO (lixun): Add MfmaLayout + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("M", [64, 128, 256]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("src_dim", [0, 1]) +@pytest.mark.parametrize("dst_dim", [0, 1]) +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): + + ir = f""" + #dst = {dst_layout} + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %7 = {GPU_DIALECT}.convert_layout %3 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.store %6, %7 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.return + }} + }} + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, )).astype('int32') + y = np.zeros((M, ), dtype='int32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + pgm = kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@triton.jit +def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = weight_2 / new_weight + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + # [HIP] TO DO: some tests are flaky with the layout, so turn off them for now. + # BlockedLayout([1, 4], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]) +] + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]]) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("op", ["sum", "max"]) +@pytest.mark.parametrize("first_axis", [0, 1]) +def test_chain_reduce(M, N, src_layout, op, device, first_axis): + + op_str = "" + if op == "sum": + op_str = """ + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32""" + elif op == "max": + op_str = """ + %13 = arith.cmpi "sgt", %arg2, %arg3 : i32 + %14 = arith.select %13, %arg2, %arg3 : i32 + tt.reduce.return %14 : i32""" + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> + %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %5 = tt.broadcast %2 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %6 = tt.broadcast %4 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #src> + %11 = "tt.reduce"(%10) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>> + %12 = "tt.reduce"(%11) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32 + tt.store %arg1, %12 : !tt.ptr + tt.return + }} + }} + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, N)).astype('int32') + + z = np.zeros((1, )).astype('int32') + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, z_tri) + if op == "sum": + z_ref = np.sum(x) + elif op == "max": + z_ref = np.max(x) + + np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@pytest.mark.interpreter +def test_generic_reduction(device): + + @triton.jit + def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): + xindex = tl.arange(0, BLOCK) + x = tl.load(X + xindex) + mean = x + m2 = tl.zeros_like(x) + weight = tl.full(x.shape, 1, x.dtype) + (mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine) + tl.store(out_mean, mean) + tl.store(out_var, m2 / weight) + + SIZE = 512 + x = torch.rand(SIZE, device=device) + out_mean = torch.empty((), device=device) + out_var = torch.empty((), device=device) + + var_mean_kernel[(1, )](x, out_mean, out_var, BLOCK=SIZE) + + expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0) + torch.testing.assert_close(out_mean, expect_mean) + torch.testing.assert_close(out_var, expect_var) + + +# --------------- +# test permute +# --------------- + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float8e4b15', 'float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_permute(dtype_str, shape, perm, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if is_hip() and shape == (128, 128) and dtype_str == 'float32': + pytest.skip("TODO Out of LDS for float32 with shape 128x128") + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + + # input + x = numpy_random(shape, dtype_str=dtype_str) + # triton result + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), + x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), + z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + # numpy result + if dtype_str == 'float8e4b15': + ty = tl.float8e4b15 + z_ref = serialize_fp8(deserialize_fp8(x, ty).T.copy(), ty) + z_tri = z_tri.base + z_tri_contiguous = z_tri_contiguous.base + else: + z_ref = x.transpose(*perm) + # compare + np.testing.assert_allclose(to_numpy(z_tri), z_ref) + np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref) + + if not is_cuda(): + return + + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1]))) +def test_trans_2d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: tl.constexpr, + ou_shape2: tl.constexpr, trans1: tl.constexpr, trans2: tl.constexpr): + in_offs = tl.arange(0, in_shape1)[:, None] * in_shape2 + tl.arange(0, in_shape2)[None, :] + ou_offs = tl.arange(0, ou_shape1)[:, None] * ou_shape2 + tl.arange(0, ou_shape2)[None, :] + tl.store(Out + ou_offs, tl.permute(tl.load(In + in_offs), (trans1, trans2))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1, 2, 3]))) +def test_trans_4d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, # + in_shape1: tl.constexpr, in_shape2: tl.constexpr, in_shape3: tl.constexpr, in_shape4: tl.constexpr, + ou_shape1: tl.constexpr, ou_shape2: tl.constexpr, ou_shape3: tl.constexpr, ou_shape4: tl.constexpr, + trans1: tl.constexpr, trans2: tl.constexpr, trans3: tl.constexpr, trans4: tl.constexpr): + in_ptr = tl.make_block_ptr( + base=In, + shape=(in_shape1, in_shape2, in_shape3, in_shape4), + strides=(in_shape4 * in_shape3 * in_shape2, in_shape4 * in_shape3, in_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(in_shape1, in_shape2, in_shape3, in_shape4), + order=(3, 2, 1, 0), + ) + out_ptr = tl.make_block_ptr( + base=Out, + shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + strides=(ou_shape4 * ou_shape3 * ou_shape2, ou_shape4 * ou_shape3, ou_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + order=(3, 2, 1, 0), + ) + tl.store(out_ptr, tl.load(in_ptr).permute((trans1, trans2, trans3, trans4))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm, num_warps=8) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# --------------- +# test dot +# --------------- + + +def convert_fp8_to_fp32(x, device, dtype_str): + if dtype_str == 'float8e4nv': + return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32) + elif dtype_str == 'float8e5': + return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32) + assert "Unsupported float8 dtype" + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize( + "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack", + [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1) + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for input_precision in ['tf32', 'tf32x3', 'ieee'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] + if not (input_precision != 'ieee' and (in_dtype in ['float16']))] + + [(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack) + for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], + [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] + for input_precision in ["ieee" if is_hip() else "tf32"] + for col_a in [True, False] + for col_b in [True, False] + for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', + 'float32')] + for kpack in [1, 2 if is_hip() else 1]] + [(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32', 1) + for col_a in [True, False] + for col_b in [True, False]] + + [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1)] + + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1) + for float8_type in ["float8e5", "float8e4nv"]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, num_ctas, device): + if is_interpreter(): + if in_dtype == 'bfloat16': + pytest.skip("bfloat16 is not supported in the interpreter") + else: + if is_cuda(): + capability = torch.cuda.get_device_capability() + + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if input_precision != "ieee": + pytest.skip("Only test tf32 on devices with sm >= 80") + if capability[0] == 7: + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: + pytest.skip("shared memory out of resource") + if out_dtype == 'float16': + # TODO: support out_dtype=float16 for tl.dot on V100 + pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if capability[0] < 9 and in_dtype == 'float8e4nv': + pytest.skip("float8e4nv not supported on sm <= 80") + if is_hip() and (in_dtype == 'float8e4nv' or in_dtype == 'float8e5'): + pytest.skip("float8e4nv and float8e5 not supported on HIP") + if is_hip() and (input_precision != "ieee"): + pytest.skip(f"{input_precision} not supported on HIP") + if is_hip() and (kpack == 2 and in_dtype == 'int8' and K < 64): + pytest.skip("kpack too large for K") + if not is_hip() and kpack == 2: + pytest.skip("Skip duplicated tests on nv path") + + torch.backends.cuda.matmul.allow_tf32 = input_precision == "tf32" + + if num_ctas > 1 and in_dtype == 'int8': + # FIXME: mma v2 with num_ctas > 1 does not work + pytest.skip() + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, INPUT_PRECISION: tl.constexpr, DO_SOFTMAX: tl.constexpr, + CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + x = tl.load(Xs) + y = tl.load(Ys) + z = tl.dot(x, y, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + if ADD_MATRIX: + z += tl.load(Zs) + if ADD_ROWS: + ZRs = Z + off_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = Z + off_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(Ws) + z = tl.dot(z.to(w.dtype), w, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + tl.store(Zs, z) + + # input + rs = RandomState(17) + if col_a: + x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T + else: + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + if col_b: + y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T + else: + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + w = numpy_random((N, N), dtype_str=in_dtype, rs=rs) + if 'int' not in in_dtype and 'float8' not in in_dtype: + x *= .1 + y *= .1 + if in_dtype == 'float32' and input_precision == "tf32": + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') + x_tri = to_triton(x, device=device, dst_type=in_dtype) + y_tri = to_triton(y, device=device, dst_type=in_dtype) + w_tri = to_triton(w, device=device, dst_type=in_dtype) + # triton result + if out_dtype == 'int8': + z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs) + else: + z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1 + + z_tri = to_triton(z, device=device) + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) + + if out_dtype == 'int8': + out_dtype = tl.int8 + elif out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + kern_kwargs = { + 'COL_A': col_a, 'COL_B': col_b, 'BLOCK_M': M, 'BLOCK_K': K, 'BLOCK_N': N, 'ADD_MATRIX': + epilogue == 'add-matrix', 'ADD_ROWS': epilogue == 'add-rows', 'ADD_COLS': epilogue == 'add-cols', 'DO_SOFTMAX': + epilogue == 'softmax', 'CHAIN_DOT': epilogue == 'chain-dot', 'INPUT_PRECISION': input_precision, 'num_warps': + num_warps, 'num_ctas': num_ctas, 'out_dtype': out_dtype + } + + if is_hip(): + kern_kwargs['kpack'] = kpack + + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, + w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) + + if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"): + if not is_cuda(): + pass + else: + ptx = pgm.asm["ptx"] + start = ptx.find("shfl.sync.bfly") + end = ptx.find("cvt.rn.f16.f32") + red_code = ptx[start:end] + assert len(red_code) > 0 + + # skip this check on hopper because there are some functions whose name contain "shared" in ptx. + # TODO: we should eliminate these unused functions in ptx code. + if not (capability[0] >= 9): + assert "shared" not in red_code + assert "bar.sync" not in red_code + # torch result + if in_dtype == 'int8': + z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) + elif 'float8' in in_dtype: + x = convert_fp8_to_fp32(x, device, in_dtype) + y = convert_fp8_to_fp32(y, device, in_dtype) + z_ref = to_numpy(torch.matmul(x, y)) + else: + z_ref = np.matmul(x, y) + + if epilogue == 'add-matrix': + z_ref += z + if epilogue == 'add-rows': + z_ref += z[:, 0][:, None] + if epilogue == 'add-cols': + z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + if 'float8' in in_dtype: + w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) + z_ref = np.matmul(z_ref, w) + # compare + if in_dtype == 'float32': + # XXX: Somehow there's a larger difference when we use float32 + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + elif out_dtype == tl.float16 or in_dtype == 'bfloat16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + # added atol, to loose precision for float16xfloat16->float32 case + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + if not is_cuda(): + return + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): + # XXX: skip small sizes because they are not vectorized + assert 'ld.global.v4' in ptx + if 'float8' in in_dtype: + assert 'st.global.v2' in ptx + else: + assert 'st.global.v4' in ptx + if in_dtype == 'float32' and input_precision != "ieee": + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float32: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float16: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) + elif in_dtype == 'int8': + if capability[0] == 7 and capability[1] == 5: # Turing + assert 'mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32' in ptx + else: + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + elif in_dtype == "float8e5" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx + elif in_dtype == "float8e4nv" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("B", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("M, N, K", [(64, 64, 64), (32, 32, 32)]) +@pytest.mark.parametrize("in_dtype_str, out_dtype_str", [('int8', 'int8'), ('float16', 'float16'), + ('float16', 'float32'), ('float32', 'float32')]) +def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device): + if is_hip(): + # hip does not support tf32 precision, so use ieee for all tests + input_precision = "ieee" + if "gfx11" in triton.runtime.driver.active.get_current_target().arch: + if in_dtype_str == "float32": + pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") + if out_dtype_str == "float16": + pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + else: + input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" + + @triton.jit + def kernel( + q_ptr, + k_ptr, + o_ptr, + stride_qb, + stride_qm, + stride_qk, + stride_kb, + stride_kk, + stride_kn, + stride_ob, + stride_om, + stride_on, + BLOCK_B: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + INPUT_PRECISION: tl.constexpr, + out_dtype: tl.constexpr = tl.float32, + ): + startm = tl.program_id(0) * BLOCK_M + startn = tl.program_id(1) * BLOCK_N + offs_b = tl.arange(0, BLOCK_B) + offs_m = startm + tl.arange(0, BLOCK_M) + offs_n = startn + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + q_ptrs = q_ptr + offs_b[:, None, None] * stride_qb + offs_m[None, :, None] * stride_qm + offs_k[ + None, None, :] * stride_qk + k_ptrs = k_ptr + offs_b[:, None, None] * stride_kb + offs_k[None, :, None] * stride_kk + offs_n[ + None, None, :] * stride_kn + q = tl.load(q_ptrs) + k = tl.load(k_ptrs) + qk = tl.dot(q, k, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + o_ptrs = o_ptr + offs_b[:, None, None] * stride_ob + offs_m[None, :, None] * stride_om + offs_n[ + None, None, :] * stride_on + tl.store(o_ptrs, qk) + + if out_dtype_str == 'int8': + out_dtype = tl.int8 + elif out_dtype_str == 'float16': + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + rs = RandomState(17) + x = numpy_random((B, M, K), dtype_str=in_dtype_str, rs=rs) + y = numpy_random((B, K, N), dtype_str=in_dtype_str, rs=rs) + if in_dtype_str == 'int8': + out = numpy_random((B, M, N), dtype_str='int32', rs=rs) + else: + out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) + + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + out_tri = to_triton(out, device=device) + + BLOCK_B = B + BLOCK_M, BLOCK_N = 32, 32 + BLOCK_K = K + + grid = ( + triton.cdiv(M, BLOCK_M), + triton.cdiv(N, BLOCK_N), + ) + kernel[grid]( + x_tri, + y_tri, + out_tri, + x_tri.stride(0), + x_tri.stride(1), + x_tri.stride(2), + y_tri.stride(0), + y_tri.stride(1), + y_tri.stride(2), + out_tri.stride(0), + out_tri.stride(1), + out_tri.stride(2), + BLOCK_B=BLOCK_B, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + INPUT_PRECISION=input_precision, + out_dtype=out_dtype, + num_warps=num_warps, + ) + + if in_dtype_str == 'int8': + out_ref = np.matmul(x.astype(np.float32), y.astype(np.float32)).astype(np.int32) + else: + out_ref = np.matmul(x, y) + np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +def test_max_num_imprecise_acc(device): + + if not hasattr(torch, 'float8_e5m2'): + pytest.skip(f"torch {torch.__version__} does not support float8_e5m2") + + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability != (9, 0): + return + + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + MAX_NUM_IMPRECISE_ACC: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + x = tl.load(X + off_m[:, None] * BLOCK_K + off_k[None, :]) + y = tl.load(Y + off_k[:, None] * BLOCK_N + off_n[None, :]) + z = tl.load(Z + off_m[:, None] * BLOCK_N + off_n[None, :]) + z = tl.dot(x, y, acc=z, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC) + tl.store(Z + off_m[:, None] * BLOCK_N + off_n[None, :], z) + + M, N, K, num_warps, MAX_NUM_IMPRECISE_ACC = 128, 128, 128, 4, 64 + x = torch.zeros((M, K), dtype=torch.float8_e5m2, device=device) + y = torch.zeros((K, N), dtype=torch.float8_e5m2, device=device) + z = torch.zeros((M, N), dtype=torch.float32, device=device) + h = kernel[(1, 1)](x, y, z, M, N, K, MAX_NUM_IMPRECISE_ACC, num_warps=num_warps) + if not is_cuda(): + return + assert h.asm["ptx"].count("add.f32") == (M * N) // (32 * num_warps) * (K / MAX_NUM_IMPRECISE_ACC) + + +@pytest.mark.parametrize('in_dtype', ['float32']) +def test_dot_mulbroadcasted(in_dtype, device): + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Requires sm >= 80 to run") + + @triton.jit + def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.constexpr, BN: tl.constexpr, + BK: tl.constexpr): + pidn = tl.program_id(1) + pidm = tl.program_id(0) + offm = tl.arange(0, BM)[:, None] + offn = tl.arange(0, BN)[None, :] + offak = tl.arange(0, BK)[None, :] + offbk = tl.arange(0, BK)[:, None] + acc = tl.full((BM, BN), 0.0, tl.float32) + for ridx5 in range(0, K // BK): + x = tl.load(X + ((pidm * K * BM) + (offm * K) + (ridx5 * BK) + offak)) + y = tl.load(Y + ((pidn * BN) + (offbk * N) + (ridx5 * N * BK) + offn)) + x = tl.expand_dims(x, axis=2) + y = tl.expand_dims(y, axis=0) + t = tl.sum(x * y, axis=1) + acc = t + acc + tl.store(Z + ((pidm * BM * N) + (pidn * BN) + (offm * N) + offn), acc) + + M, N, K = 256, 192, 160 + BM, BN, BK = 128, 32, 32 + rs = RandomState(17) + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + x = x * 0.1 + y = y * 0.1 + z = numpy_random((M, N), dtype_str=in_dtype, rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(z, device=device) + grid = M // BM, N // BN + h = kernel[grid](z_tri, x_tri, y_tri, M, N, K, BM, BN, BK) + z_ref = np.matmul(x, y) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01) + + if not is_cuda(): + return + assert "tt.dot" in h.asm['ttir'] + # When using MMAv3, we will not pipeline the load op for Y, as the loaded + # value is in rowmajor. But MMAv3 requires its second operand is in colmajor + # because transpose is not supported for MMAv3 with float32 input. + if capability[0] >= 9: + assert re.search(r"triton_gpu.async_wait %.* {num = 1 : i32}", h.asm["ttgir"]) is not None + else: + assert re.search(r"triton_gpu.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) +@pytest.mark.parametrize("shape", [(), (1, ), (128, )]) +def test_full(dtype_str, shape, device): + if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): + # PyTorch only has unsigned 8, but not 16, 32, or 64 + dtype = getattr(torch, dtype_str[1:]) # uintx -> intx + else: + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel_static(out): + a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + @triton.jit + def kernel_dynamic(out, val, dtype: tl.constexpr): + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) + out_static = torch.zeros((128), dtype=dtype, device=device) + kernel_static_patched[(1, )](out_static) + assert torch.all(out_static == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device=device) + kernel_dynamic_patched[(1, )](out_dynamic, 2, getattr(triton.language, dtype_str)) + assert torch.all(out_dynamic == 2) + + +@pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), + ('float("-inf")', "f32"), ('float("nan")', "f32"), + ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) +def test_constexpr(literal, dtype_str, device): + + @triton.jit + def kernel(out_ptr): + val = GENERATE_TEST_HERE + tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val) + + kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"}) + out = torch.zeros((1, ), dtype=torch.float32, device=device) + h = kernel_patched[(1, )](out) + assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None + + +@triton.jit +def pass_const(a, b, choose_b): + if choose_b: + return b + else: + return a + + +@pytest.mark.parametrize("choose_const", [True, False]) +@pytest.mark.parametrize("constexpr", [True, False]) +@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) +def test_const(device, choose_const, constexpr, mode): + + @triton.jit(do_not_specialize=["choose_const"]) + def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + @triton.jit + def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32, + BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + if mode == "direct": + if choose_const: + LOSE_TAIL = "final_out = c_out" + else: + LOSE_TAIL = "final_out = out" + elif mode == "call": + LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)" + elif mode == "ternary": + LOSE_TAIL = "final_out = c_out if choose_const else out" + elif mode == "if": + LOSE_TAIL = """ + if choose_const: + final_out = c_out + else: + final_out = out +""" + + SIZE = 128 + input = torch.randn((SIZE, ), dtype=torch.float32, device=device) + output = torch.zeros((SIZE, ), dtype=torch.float32, device=device) + patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''}) + + expect_fail = (not constexpr and mode != "direct") or choose_const + if expect_fail: + with pytest.raises(triton.CompilationError) as exc_info: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + else: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + assert torch.all(input == output) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['float32', 'float16']) +def test_dot_without_load(dtype_str, device): + + @triton.jit + def _kernel(out): + a = GENERATE_TEST_HERE + b = GENERATE_TEST_HERE + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) + a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + out_ref = torch.matmul(a, b) + out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](out) + assert torch.all(out == out_ref) + + +# --------------- +# test arange +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("start", [0, 1, 7, 16]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_arange(start, num_ctas, device): + BLOCK = 128 + z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) + tl.store(z + off, val) + + _kernel[(1, )](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) + np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) + + +# --------------- +# test load +# --------------- + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) + for dtype_str in torch_dtypes + for size in [128, 512] + for size_diff in [0, 1, 2, 3, 4] + for other in [0, 1]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_masked_load(dtype_str, size, size_diff, other, num_ctas, device): + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + input_size = size - size_diff + output_size = size + if dtype_str == 'bool': + input = torch.randint(0, 2, (input_size, ), dtype=dtype, device=device) + elif dtype_str in int_dtypes or dtype_str in uint_dtypes: + input = torch.randint(0, 127, (input_size, ), dtype=dtype, device=device) + else: + input = torch.rand(input_size, dtype=dtype, device=device) + output = torch.zeros((output_size, ), dtype=dtype, device=device) + + @triton.jit + def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): + in_offsets = tl.arange(0, out_size) + # Load inputs. + x = GENERATE_TEST_HERE + # Store output + output_offsets = tl.arange(0, out_size) + tl.store(out_ptr + output_offsets, x) + + mask_str = f"mask=in_offsets < in_size, other={other}" if size_diff > 0 else "None" + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) + kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas) + + reference_out = torch.cat((input, torch.full((size_diff, ), other, dtype=dtype, device=device))) + torch.testing.assert_close(output, reference_out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("mask_val", [True, False]) +@pytest.mark.parametrize("other_val", [0, 1]) +def test_masked_load_scalar(num_ctas, mask_val, other_val, device): + input_val = 4.0 + size = 128 + dtype = torch.float32 + input = torch.full((size, ), input_val, dtype=dtype, device=device) + output = torch.zeros((size, ), dtype=dtype, device=device) + + @triton.jit + def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.constexpr): + offsets = tl.arange(0, size) + x = tl.load(in_ptr + offsets, mask=mask, other=other) + tl.store(out_ptr + offsets, x) + + kernel[(1, )](input, output, size, mask_val, other_val, num_ctas=num_ctas) + + if mask_val: + reference_out = torch.full((size, ), input_val, dtype=dtype, device=device) + else: + reference_out = torch.full((size, ), other_val, dtype=dtype, device=device) + + torch.testing.assert_close(output, reference_out) + + +# Testing masked loads with an intermate copy to shared memory run. +# FIXME: Shape too small for ldmatrix when num_ctas=4 +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_masked_load_shared_memory(dtype, device): + + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + M = 32 + N = 32 + K = 16 + + in1 = torch.rand((M, K), dtype=dtype, device=device) + in2 = torch.rand((K, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=dtype, device=device) + + @triton.jit + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + + M_offsets = tl.arange(0, M) + N_offsets = tl.arange(0, N) + K_offsets = tl.arange(0, K) + + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] + + # Load inputs. + x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K) + w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N) + + # Without a dot product the memory doesn't get promoted to shared. + o = tl.dot(x, w, out_dtype=tl.float32) + + # Store output + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] + tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) + + pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), + out.numel(), M=M, N=N, K=K) + + reference_out = torch.matmul(in1, in2) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) +def test_load_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + if not is_cuda(): + return + + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_vectorization(N, num_ctas, device): + block_size = 1024 * num_ctas + src = torch.empty(block_size, device=device) + dst = torch.empty(block_size, device=device) + + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, N=N, BLOCK_SIZE=block_size) + + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if N % 16 == 0: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.b32" in ptx + # np.testing.assert_allclose(dst, src[:N]) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("has_hints", [False, True]) +def test_vectorization_hints(has_hints, device): + src = torch.empty(1024, device=device) + dst = torch.empty(1024, device=device) + off = torch.zeros(1, device=device, dtype=torch.int32) + + @triton.jit + def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offsets = offsets + tl.load(off) + if HINT: + tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if has_hints: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.v4.b32" not in ptx + + +# --------------- +# test store +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"]) +def test_store_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, cache_modifier=CACHE) + + if not is_cuda(): + return + pgm = _kernel[(1, )](dst, src, CACHE=cache) + ptx = pgm.asm['ptx'] + if cache == '': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.wb': + assert 'st.global.wb' in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cg': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cs': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' in ptx + assert 'st.global.wt' not in ptx + if cache == '.wt': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' in ptx + + +# --------------- +# test default +# --------------- +# TODO: can't be local to test_default + + +@triton.jit +def _impl(value=10): + return value + + +@pytest.mark.interpreter +def test_default(device): + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device=device) + ret1 = torch.zeros(1, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(ret0, ret1, value=3): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) + + _kernel[(1, )](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value + + _kernel[(1, )](ret0, ret1) + assert ret0.item() == 10 + assert ret1.item() == 3 + + +# --------------- +# test noop +# ---------------- + + +@pytest.mark.interpreter +def test_noop(device): + + @triton.jit + def kernel(x): + pass + + x = to_triton(numpy_random((1, ), dtype_str='int32'), device=device) + kernel[(1, )](x) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) +def test_pointer_arguments(device): + + @triton.jit + def kernel(x): + pass + + pin_memory = 'pinned' in device + x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) + if device == "cpu": + with pytest.raises(ValueError): + kernel[(1, )](x) + else: + kernel[(1, )](x) + + +@pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) +def test_value_specialization(value: int, value_type: str, device) -> None: + + def repr(specialization): + spec_type = specialization.signature["VALUE"] + return f"kernel_{spec_type}" + + @triton.jit(repr=repr) + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + h = kernel[(1, )](value, x) + assert value_type in h.name + + +# -------------------- +# value specialization +# -------------------- + + +@pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) +def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + + if overflow: + with pytest.raises(OverflowError): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) + + +# ---------------- +# test constexpr +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): + + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) + + if op in ['<<', '>>', '&', '^', '|']: # int op + x_str = "3" if is_lhs_constexpr else "x" + y_str = "4" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="int32") + + # NOTE: bitshifting beyond bitwidth can lead to undefined behavior + if op in ['<<', '>>']: + y = numpy_random((1, ), dtype_str="int32", low=0, high=_bitwidth("int32")) + else: + y = numpy_random((1, ), dtype_str="int32") + else: + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="float32") + y = numpy_random((1, ), dtype_str="float32") + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(np.empty((1, ), dtype=z.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) + + +@pytest.mark.interpreter +def test_constexpr_shape(device): + + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + + +@pytest.mark.interpreter +def test_constexpr_scalar_shape(device): + + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + + +reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("formats", reshape_list) +def test_reshape(formats, device): + in_format, out_format = formats + + @triton.jit + def kernel(Z, X, out_tuple: tl.constexpr): + x = tl.load(X_PTR_EXPR) + z = tl.reshape(x, out_tuple) + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + } + return patch_kernel(kernel, to_replace) + + x = numpy_random(in_format, dtype_str="int32") + z = x.reshape(out_format) + x_tri = to_triton(x, device=device) + patched_kernel = generate_kernel(in_format, out_format) + z_tri = to_triton(np.empty(out_format, dtype=np.int32), device=device) + patched_kernel[(1, )](z_tri, x_tri, out_format) + np.testing.assert_equal(z, to_numpy(z_tri)) + + +def test_reshape_err(device): + + @triton.jit + def kernel(): + x = tl.arange(0, 8 * 8) + y = tl.reshape(x, (8 * 4, )) + + with pytest.raises(triton.CompilationError) as exc_info: + kernel[(1, )]() + + assert "reshape" in str(exc_info.value) + + +def test_trans_reshape(device): + + @triton.jit + def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.constexpr): + + in_block_ptr = tl.make_block_ptr( + base=in_base_ptr, + shape=(IN_SHAPE0, IN_SHAPE1), + strides=(IN_SHAPE1, 1), + offsets=(0, 0), + block_shape=(IN_SHAPE0, IN_SHAPE1), + order=(1, 0), + ) + x = tl.load(in_block_ptr) + x = tl.reshape(x, (32, 4, 4, 2)) + x = tl.permute(x, (1, 2, 3, 0)) + x = tl.reshape(x, (IN_SHAPE0 * IN_SHAPE1, )) + tl.store(out_base_ptr + tl.arange(0, IN_SHAPE0 * IN_SHAPE1), x) + + shape = (32, 32) + input = torch.arange(math.prod(shape), dtype=torch.int32, device=device).reshape(shape) + expected = torch.permute(input, (1, 0)) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) + + k = kernel[(1, )](input, actual, shape[0], shape[1]) + assert k.asm['ttgir'].count( + 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# ------------- +# test call +# ------------- + + +@triton.jit +def val_multiplier(val, i): + return val * i + + +@triton.jit(noinline=True) +def val_multiplier_noinline(val, i): + return val * i + + +@triton.jit +def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * 128 + tl.arange(0, 128) + mask = offsets < n_elements + vec = tl.load(ptr + offsets, mask=mask) + for i in range(1, rep): + if type == "inline": + vec = val_multiplier(vec, i) + else: + vec = val_multiplier_noinline(vec, i) + tl.store(ptr + offsets, vec, mask=mask) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("type", ["inline", "noinline"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_call(type, num_ctas, device): + + @triton.jit + def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): + vecmul_kernel(ptr, n_elements, num1, type) + vecmul_kernel(ptr, n_elements, num2, type) + + size = 1024 + rand_val = numpy_random((size, ), dtype_str="float32") + rand_val_tri = to_triton(rand_val, device=device) + err_msg = "" + try: + kernel[(size // 128, )](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) + except Exception as e: + err_msg = str(e) + + if type == "noinline" and not is_interpreter(): + assert err_msg != "" + else: + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) + + +# ------------- +# test if +# ------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("if_type", [ + "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", + "if_and_static" +]) +def test_if(if_type, device): + + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr): + pid = tl.program_id(0) + cond = tl.load(Cond) + if IfType == "if": + if pid % 2 == 0: # eq + tl.store(Ret, tl.load(XTrue)) + elif 1 == pid % 2: # req + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_dynamic": + val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_constexpr": + val = 3.14 if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_void": + tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_static": + tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_dynamic": + if BoolVar and (1 != pid % 2 and pid % 2 != 1): # rne and ne + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_static": + if StaticVaue != 0 and StaticVaue != 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + + cond = torch.ones(1, dtype=torch.int32, device=device) + x_true = torch.tensor([3.14], dtype=torch.float32, device=device) + x_false = torch.tensor([1.51], dtype=torch.float32, device=device) + ret = torch.zeros(1, dtype=torch.float32, device=device) + + kernel[(1, )](cond, x_true, x_false, ret, if_type, True, 1) + assert torch.equal(ret, x_true) + + +def test_num_warps_pow2(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pass + + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel[(1, )](dst=dst, num_warps=3) + _kernel[(1, )](dst=dst, num_warps=1) + _kernel[(1, )](dst=dst, num_warps=2) + _kernel[(1, )](dst=dst, num_warps=4) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_str", ['sqrt', 'rsqrt', 'exp', 'exp2', 'log', 'log2', 'sin', 'cos']) +def test_unary_math(func_str, device): + + if is_musa(): + # torch_musa does not support aten::exp2.out + device = "cpu" + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.FUNC_STR(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + kernel = patch_kernel(kernel, {'FUNC_STR': func_str}) + + shape = (128, ) + x = torch.randn(shape, dtype=torch.float32, device=device) + if func_str in ['sqrt', 'rsqrt']: + x = torch.abs(x) + if func_str in ['log', 'log2']: + x = torch.max(x, torch.tensor(1e-6, dtype=torch.float32, device=device)) + y = torch.zeros(shape, dtype=torch.float32, device=device) + + if is_musa(): + x_musa = x.musa() + y_musa = y.musa() + kernel[(1, )](x_musa, y_musa, BLOCK=shape[0]) + else: + kernel[(1, )](x, y, BLOCK=shape[0]) + torch.allclose(getattr(torch, func_str)(x), y, rtol=1e-3) + + +# ----------------------- +# test inline asm +# ----------------------- + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, + is_pure=True, pack=1) + tl.store(Z + tl.arange(0, BLOCK), z) + + shape = (128, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint32', rs=rs) + y = numpy_random(shape, dtype_str='uint32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + n = 17 + z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = (y << n) | (x >> (32 - n)) + # compare + np.testing.assert_equal(y_ref, to_numpy(z_tri)) + + +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm_packed(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # shift 4x8bits values together. + y = tl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", "=r,r", [ + x, + ], dtype=tl.int8, is_pure=True, pack=4) + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +@pytest.mark.parametrize('num_ctas', num_ctas_list) +def test_inline_asm_with_pointers(num_ctas, device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x_ptrs = X + tl.arange(0, BLOCK) + y_ptrs = Y + tl.arange(0, BLOCK) + tl.inline_asm_elementwise( + "ld.global.b8 $0, [$1]; \ + shl.b32 $0, $0, 3; \ + st.global.b8 [$2], $0;", "=r,l,l", [x_ptrs, y_ptrs], dtype=tl.int8, is_pure=False, + pack=1) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +def test_inline_asm_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # C = A - B + # D = B - A + (c, d) = tl.inline_asm_elementwise( + asm=""" + sub.u32 $0, $2, $3; // C = A - B + sub.u32 $1, $3, $2; // D = B - A + """, + constraints=( + # 2 output registers: $0=C and $1=D. + "=r,=r," + # 2 input registers: $2=A and $3=B. + "r,r"), + args=[a, b], + dtype=(tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint32', rs=rs) + B = numpy_random(shape, dtype_str='uint32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A - B + D_ref = B - A + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +def test_inline_asm_packed_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint8', rs=rs) + B = numpy_random(shape, dtype_str='float32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='int32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='float32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A.astype(np.int32) + D_ref = np.maximum(A.astype(np.float32), B) + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +# ----------------------- +# test control flow +# ----------------------- + + +@pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), + (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) +def test_for_iv(lo, hi, iv, device): + + @triton.jit + def kernel(Out, lo, hi, iv: tl.constexpr): + acc = 0 + acc = acc.to(tl.int64) + for i in range(lo, hi, iv): + acc += i + tl.store(Out, acc) + + lo = 2**35 + hi = 2**35 + 20 + out = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + kernel[(1, )](out, lo, hi, iv) + assert out[0] == sum(range(lo, hi, iv)) + + +@pytest.mark.interpreter +def test_if_else(device): + + @triton.jit + def kernel(Cond, TrueVal, FalseVal, Out): + if tl.load(Cond): + val = tl.load(TrueVal) + else: + val = tl.load(FalseVal) + tl.store(Out, val) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + true_val = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + false_val = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + cond = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # True + cond[0] = True + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == true_val[0] + # False + cond[0] = False + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == false_val[0] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["dynamic", "static"]) +def test_if_return(mode, device): + + @triton.jit + def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): + if mode == "dynamic": + if tl.load(ExitEarly): + tl.store(Out, 0) + return + else: + if cond: + tl.store(Out, 0) + return + tl.store(Out, 1) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + exit_early = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # exit early path taken + exit_early[0] = 1 + kernel[(1, )](exit_early, out, True, mode) + assert to_numpy(out)[0] == 0 + # exit early path not taken + exit_early[0] = 0 + kernel[(1, )](exit_early, out, False, mode) + assert to_numpy(out)[0] == 1 + + +@triton.jit +def add_fn(x): + return x + 1 + + +@triton.jit(noinline=True) +def add_fn_noinline(x): + return x + 1 + + +@triton.jit +def add_fn_return(x, pid): + if pid == 0: + return x + 1 + else: + return x + 2 + + +@triton.jit +def add_fn_expr(Out, x): + tl.store(Out, x) + + +@triton.jit +def add_fn_static_cond(x, cond: tl.constexpr): + if cond == "": + return x + else: + return x + 1 + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize( + "call_type", + ["attribute", "attribute_jit", "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) +def test_if_call(call_type, device): + + @triton.jit + def kernel(Out, call_type: tl.constexpr): + pid = tl.program_id(0) + o = tl.load(Out) + if call_type == "attribute": + # call attribute + if pid == 0: + a = o + a = a.to(tl.int32).to(tl.int32) + 1 + o = a + elif call_type == "attribute_jit": + # call attribute and jit function + if pid == 0: + a = o + a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1 + o = a + elif call_type == "jit": + if pid == 0: + # regular function call + a = o + a = add_fn(a) + o = a + elif call_type == "jit_if": + # function without end_if block + if pid == 0: + a = o + a = add_fn_return(a, pid) + o = a + elif call_type == "jit_if_exp": + # ifexp expression + if pid == 0: + a = o + a = add_fn(a) if pid == 0 else add_fn_return(a, pid) + o = a + elif call_type == "jit_expr": + # call without return + if pid == 0: + a = o + 1 + add_fn_expr(Out, a) + o = a + elif call_type == "jit_static_cond": + if pid == 0: + a = o + 1 + add_fn_static_cond(o, call_type) + o = a + elif call_type == "jit_noinline": + if pid == 0: + a = o + 1 + add_fn_noinline(a) + o = a + elif call_type == "jit_extern": + if pid == 0: + a = o + 1 + tl.cdiv(a, a) + o = a + + tl.store(Out, o) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + kernel[(1, )](out, call_type) + assert to_numpy(out)[0] == 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("_cond1", [True, False]) +@pytest.mark.parametrize("_cond2", [True, False]) +@pytest.mark.parametrize("_cond3", [True, False]) +def test_nested_if_else_return(_cond1, _cond2, _cond3, device): + + @triton.jit + def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): + val = 0 + if tl.load(Cond1): + if tl.load(Cond2): + val = tl.load(Val1) + else: + return + else: + if tl.load(Cond3): + val = tl.load(Val2) + else: + val = tl.load(Val3) + tl.store(Out, val) + + out = to_triton(np.full((1, ), -1, dtype=np.int32), device=device) + cond1 = to_triton(np.full((1, ), _cond1, dtype=np.int32), device=device) + cond2 = to_triton(np.full((1, ), _cond2, dtype=np.int32), device=device) + cond3 = to_triton(np.full((1, ), _cond3, dtype=np.int32), device=device) + val1 = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + val2 = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + val3 = to_triton(np.full((1, ), 3, dtype=np.int32), device=device) + kernel[(1, )](cond1, cond2, cond3, val1, val2, val3, out) + targets = { + (True, True, True): val1[0], + (True, True, False): val1[0], + (True, False, True): out[0], + (True, False, False): out[0], + (False, True, True): val2[0], + (False, True, False): val3[0], + (False, False, True): val2[0], + (False, False, False): val3[0], + } + assert out[0] == targets[(_cond1, _cond2, _cond3)] + + +@pytest.mark.interpreter +def test_while(device): + + @triton.jit + def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): + init_i = tl.load(InitI) + curr_i = init_i + j = 0 + # Check that init_i is not updated by the loop + while j < tl.load(Bound): + curr_i = curr_i + (j == tl.load(CutOff)) + j += 1 + tl.store(OutInitI, init_i) + tl.store(OutI, curr_i) + tl.store(OutJ, j) + + out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) + cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) + assert out_init_i[0] == init_i[0] + assert out_i[0] == init_i[0] + 1 + assert out_j[0] == bound[0] + + +@pytest.mark.interpreter +def test_nested_while(device): + + @triton.jit + def nested_while(data, countPtr): + for i in range(10): + count = tl.load(countPtr) + while count > 0: + tl.store(data, tl.load(data) + 1.0) + count = count - 2 + + counter = torch.tensor([8], dtype=torch.int32, device=device) + data = torch.zeros((1, ), device=device, dtype=torch.float32) + nested_while[(1, )](data, counter) + assert data[0] == 40 + + +# ----------------------- +# test extra +# ----------------------- + + +def test_num_threads(device): + if is_hip(): + pytest.skip("test_num_threads is not supported in HIP") + + @triton.jit + def kernel(Out): + num_threads: tl.constexpr = tl.extra.cuda.num_threads() + offs = tl.arange(0, num_threads) + tl.store(Out + offs, 1) + + num_threads = 256 + out = to_triton(np.zeros((num_threads, ), dtype=np.int32), device=device) + kernel[(1, )](out, num_warps=num_threads // 32) + assert torch.sum(out) == 256 + + +def test_globaltimer(device): + if is_hip(): + pytest.skip("test_globaltimer is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out1, Out2): + start = tl.extra.cuda.globaltimer() + off = tl.arange(0, 128) + for i in range(10000): + tl.store(Out1 + off, tl.load(Out1 + off) + 1) + end = tl.extra.cuda.globaltimer() + tl.store(Out2, end - start) + + out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device) + out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + h = kernel[(1, )](out1, out2) + assert out2[0] > 0 + assert h.asm["ptx"].count("%globaltimer") == 2 + + +def test_smid(device): + if is_hip(): + pytest.skip("test_smid is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out): + tl.store(Out + tl.program_id(0), tl.extra.cuda.smid()) + + out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) + h = kernel[(out.shape[0], )](out) + assert out.sort()[0].unique().shape[0] > 0 + assert h.asm["ptx"].count("%smid") == 1 + + +# ----------------------- +# test layout conversions +# ----------------------- +# TODO: backend should be tested separately + +layouts = [ + BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([8, 1], [16, THREADS_PER_WARP // 16], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), +] + +intermediate_layouts = [ + None, + SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +def compute_rep_shape(layout): + if type(layout) is BlockedLayout: + warp_shape = np.multiply(layout.sz_per_thread, layout.threads_per_warp) + rep_shape = np.multiply(warp_shape, layout.warps_per_cta) + return rep_shape + else: + assert False, "TODO: support compute_rep_shape for layout " + str(type(layout)) + + +# This function gives a lower bound approximation of scratch buffer shape for convert_layout operation +def compute_scratch_buffer_shape(src_layout, dst_layout, shape): + src_rep_shape = compute_rep_shape(src_layout) + dst_rep_shape = compute_rep_shape(dst_layout) + full_scratch_shape = np.maximum(src_rep_shape, dst_rep_shape) + return np.minimum(full_scratch_shape, shape) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("interm_layout", intermediate_layouts) +@pytest.mark.parametrize("dst_layout", layouts) +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): + if (M == 1 or N == 1) and interm_layout: + # TODO(jlebar): These OOB accesses don't even hit an assert in the + # compiler, and some of them return the wrong result instead of + # crashing! + pytest.skip("Out of bound access when maxPhase > 1") + if str(src_layout) == str(dst_layout): + pytest.skip() + if is_hip(): + try: + scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) + except AssertionError: + pytest.skip("Can't compute scratch buffer size") + lds_size = 65536 + # consider int32 dtype in scratch buffer size, + # because it is the largest dtype used in convert_layout in this test + int32_size = 4 + # skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding + if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size: + pytest.skip("Scratch buffer is too large") + + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ if interm_layout is None else f""" + #src = {src_layout} + #interm = {interm_layout} + #dst = {dst_layout} + """ + + conversion = f""" + %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ if interm_layout is None else f""" + %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm> + %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm> -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm> + %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm> -> tensor<{M}x{N}xf16, #src> + + %12 = triton_gpu.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + +mma_pairs = [ + [ + MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + # Mma -> mma support is TODO on Hopper (and Volta) + # [ + # MmaLayout((3, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], +] + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("mma_pair", mma_pairs) +def test_convertmma2mma(M, N, mma_pair, dtype, device): + if is_hip(): + pytest.skip("test_mma2mma is not supported in HIP") + + src_layout, _ = mma_pair + num_warps = np.cumprod(src_layout.warps_per_cta)[-1] + + def do_test(src_layout, dst_layout): + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ + + conversion = f""" + %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} + }} + """ + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + do_test(mma_pair[0], mma_pair[1]) + do_test(mma_pair[1], mma_pair[0]) + + +@pytest.mark.interpreter +def test_load_scalar_with_mask(device): + + @triton.jit + def kernel(Input, Index, Out, N: int): + index = tl.load(Index) + scalar = tl.load(Input + index, mask=index < N, other=0) + tl.store(Out, scalar, mask=index < N) + + Index = torch.tensor([0], dtype=torch.int32, device=device) + Input = torch.tensor([0], dtype=torch.int32, device=device) + Out = torch.empty_like(Index, device=device) + kernel[(1, )](Input, Index, Out, Index.numel()) + assert Out.data[0] == 0 + + +# This test is used to test our own PTX codegen for float16 and int16 conversions +# maybe delete it later after ptxas has been fixed +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("dtype_str", ['float16', 'int16']) +def test_ptx_cast(dtype_str, device): + + @triton.jit + def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x0 = xindex + _tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype) + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r1 = rindex + tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype) + tmp1 = 2 + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(dtype) + tmp5 = _tmp4 < tmp3 + _tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4) + tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask) + + torch.manual_seed(123) + if dtype_str == 'int16': + torch_dtype = torch.int16 + triton_dtype = tl.int32 + else: + torch_dtype = torch.float16 + triton_dtype = tl.float32 + + s0 = 4 + buf11 = -torch.ones((6 * s0, 197, 197), device=device, dtype=torch_dtype) + buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) + kernel[(4728, )](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) + assert buf14.to(torch.float32).mean() == -2.0 + + +# ----------------------- +# test fp8 -> fp32 dot +# ----------------------- + + +def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + low_precision_acc: tl.constexpr, # + num_pipeline_stages: tl.constexpr = 3 # +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15']) +@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) +def test_fp8_dot_acc(in_type_str, low_precision_acc, device): + if is_hip(): + pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc[0] >= 9 and in_type_str == "float8e4b15": + pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") + check_type_supported(in_type_str, device) + M, N, K = 128, 256, 256 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128 + A = numpy_random((M, K), dtype_str=in_type_str) + B = numpy_random((K, N), dtype_str=in_type_str) + C = torch.empty((M, N), dtype=torch.float32, device=device) + num_warps = 8 + a = to_triton(A, device=device, dst_type=in_type_str) + b = to_triton(B, device=device, dst_type=in_type_str) + grid = (triton.cdiv(M, BLOCK_M), 1) + matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) + torch_a = torch.from_numpy(A).to(device=device) + th_a = f8_to_f16(torch_a, in_type_str) + torch_b = torch.from_numpy(B).to(device=device) + th_b = f8_to_f16(torch_b, in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': + torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) + elif low_precision_acc > 32: + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + else: + torch.testing.assert_close(ref_out, C) + + +# ----------------------- +# test enable_fp_fusion +# ----------------------- + + +@pytest.mark.parametrize("enable_fp_fusion", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion, device): + if is_hip(): + pytest.skip( + 'test_enable_fp_fusion for HIP currently broken in https://github.com/triton-lang/triton. Use https://github.com/ROCmSoftwarePlatform/triton' + ) + + # Sequential multiply add can be fused by backend + @triton.jit + def mul_add(data): + ptrs = data + tl.arange(0, 128) + tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) + + if not is_cuda(): + return + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion + + +# ----------------------- +# test propagate_nan +# ----------------------- + + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) +def test_propagate_nan(dtype, propagate_nan, func, device): + + @triton.jit + def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): + if func == 'clamp': + tl.store( + C, + getattr(tl, func)(tl.load(A), -tl.load(B), tl.load(B), + propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + else: + tl.store(C, + getattr(tl, func)(tl.load(A), tl.load(B), propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + + for mode in ['A', 'B', 'both']: + if func == 'clamp' and mode == 'B': + # clamp does not guarantee propagation from 'min' and 'max' args + continue + A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'A' or mode == 'both': A[0] = torch.nan + B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'B' or mode == 'both': B[0] = torch.nan + C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) + kernel[(1, )](A, B, C, propagate_nan, func) + + if mode == 'both' or propagate_nan == 'ALL': + assert torch.isnan(C[0]) + else: + assert not torch.isnan(C[0]) + + +# ----------------------- +# test clamp +# ----------------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp(dtype, device): + + @triton.jit + def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + min = tl.load(min_ptr + off, mask=mask) + max = tl.load(max_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, min, max), mask=mask) + ref_val = tl.minimum(tl.maximum(x, min), max) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + a = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + b = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + min = torch.min(a, b) + max = torch.max(a, b) + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, min, max, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# Test for symmetric clamp(x, -limit, limit), as it may go through optimized +# codegen in the backends +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp_symmetric(dtype, device): + + @triton.jit + def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + limit = tl.load(limit_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, -limit, limit), mask=mask) + ref_val = tl.minimum(tl.maximum(x, -limit), limit) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + limit = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)).abs() + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, limit, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# ----------------------- +# test iterators +# ----------------------- + + +@pytest.mark.interpreter +def test_static_range(device): + + @triton.jit + def loop_kernel(Z, N: tl.constexpr, step: tl.constexpr): + acc = 0 + for i in tl.static_range(0, N, step=step): + acc += i + tl.store(Z, acc) + + N = 100 + step = 7 + Out = torch.empty(1, dtype=torch.int32, device=device) + loop_kernel[(1, )](Out, N, step) + Acc = torch.tensor([0], dtype=torch.int32, device=device) + for i in range(0, N, step): + Acc += i + assert (Out == Acc).all(), (Out, Acc) + + +@pytest.mark.skip(reason="Random error") +@pytest.mark.interpreter +def test_tl_range(device): + if is_hip(): + pytest.skip("test_tl_range is not supported in HIP") + M, N, K = 64, 64, 512 + BLOCK_M, BLOCK_N, BLOCK_K = M, N, 64 + a = torch.randn((M, K), device=device, dtype=torch.float16) + b = torch.randn((K, N), device=device, dtype=torch.float16) + c = torch.empty((M, N), dtype=torch.float32, device=device) + pgm = matmul_kernel[ + 1, + ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, 0, num_pipeline_stages=5) + ref_out = torch.matmul(a, b).to(torch.float32) + if is_interpreter(): + # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. + # Thus we use a higher tolerance + torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) + else: + torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group 0x6' in ptx + + +@triton.jit(noinline=True) +def maxnreg_noinline1(X): + tl.store(X, 0) + + +@triton.jit(noinline=True) +def maxnreg_noinline2(X): + tl.store(X, 0) + + +@pytest.mark.skip(reason="TO FIX") +def test_maxnreg(device): + assert not is_interpreter(), "this test won't work with the interpreter" + if is_hip(): + pytest.skip('maxnreg only works on CUDA') + + # triton kernel + @triton.jit + def kernel(X): + maxnreg_noinline1(X) + tl.store(X, 0) + maxnreg_noinline2(X) + + X = torch.empty(1, dtype=torch.int32, device=device) + k = kernel[(1, )](X, maxnreg=42) + + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise + + +@pytest.mark.interpreter +def test_temp_var_in_loop(device): + + @triton.jit + def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): + acc = tl.full((BLOCK, ), 0, dtype=tl.int32) + for i in range(N): + if i == 0: + temp = tl.full((BLOCK, ), 2, dtype=tl.int32) + acc = temp + else: + acc += tl.full((BLOCK, ), 1, dtype=tl.int32) + # re-use the temp variable and make sure to check that it isn't creating incorrect IR. + temp = tl.full((BLOCK, ), 1, dtype=tl.int32) + acc += temp + z = Z + tl.arange(0, BLOCK) + tl.store(z, acc) + + N = 10 + BLOCK = 32 + out = torch.empty((BLOCK, ), dtype=torch.int32, device=device) + temp_in_loop[(1, )](out, N, BLOCK) + acc = torch.full((BLOCK, ), 0, dtype=torch.int32, device=device) + for i in range(N): + if i == 0: + temp = torch.full((BLOCK, ), 2, dtype=torch.int32, device=device) + acc = temp + else: + acc += torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + acc += temp + assert (acc == out).all() diff --git a/third_party/mthreads/python/test/unit/language/test_decorator.py b/third_party/mthreads/python/test/unit/language/test_decorator.py new file mode 100644 index 000000000..531197a78 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_decorator.py @@ -0,0 +1,49 @@ +import torch + +import triton +import triton.language as tl +import pytest + + +@pytest.mark.skip(reason="TO FIX") +def test_decorator_with_def(device): + + def triton_heuristics_pointwise(**kwargs): + + def decorator(func): + return func + + return decorator + + # "def" might appear in a decorator call, e.g. a hash string argument. + # This test makes sure the compiler can find the right position of function + # definition. + @triton_heuristics_pointwise(inductor_meta={'backend_hash': 'def0aeffabe53b3f8'}, ) + @triton.jit + def kernel(): + pass + + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + except Exception as e: + pytest.fail(f"triton compile failed with error: {e}") + + +def test_triton_heuristic(device): + N = 1023 + src = torch.empty(N, device=device) + dst = torch.zeros(N, device=device) + + @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1) + @triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs + @triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr, EVEN_N: tl.constexpr, EVEN_src: tl.constexpr): + tl.store(dst, EVEN_N) + tl.store(dst + 1, EVEN_src) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + assert dst[0].item() == 0.0 + assert dst[1].item() == 1.0 + assert _kernel.base_fn.__name__ == "_kernel" diff --git a/third_party/mthreads/python/test/unit/language/test_line_info.py b/third_party/mthreads/python/test/unit/language/test_line_info.py new file mode 100644 index 000000000..6421c7309 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_line_info.py @@ -0,0 +1,171 @@ +import subprocess +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def kernel_single(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def device_inline(x): + return x + x + + +@triton.jit +def kernel_call(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = device_inline(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit(noinline=True) +def device_noinline(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = x + x + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_call_noinline(X, Y, BLOCK: tl.constexpr): + device_noinline(X, Y, BLOCK) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK": 128}, num_warps=4), + ], + key=[], +) +@triton.jit +def kernel_autotune(X, Y, SIZE: tl.constexpr, BLOCK: tl.constexpr): + for i in range(0, SIZE, BLOCK): + x = tl.load(X + i + tl.arange(0, BLOCK)) + tl.store(Y + i + tl.arange(0, BLOCK), x) + + +# AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +# Since the + symbol will take effect in the dot op after combination, +# it seems making sense to annotate with the same line as dot. +@triton.jit +def kernel_dot_combine(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + a = (tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :]).to(tl.int8) + d = tl.dot(a, a) + d = d + c + tl.device_print("", d) + + +def get_disassembler_command_and_debug_line_format(): + """Gets backend specific disassembler information. + + Returns a tuple: (object file kind, disassembler tool command, + debug line anchor, debug line file and line number separator). + """ + backend = triton.runtime.driver.active.get_current_target().backend + + if backend == "cuda": + from triton.backends.nvidia.compiler import _path_to_binary + nvdisasm, _ = _path_to_binary("nvdisasm") + return ("cubin", [nvdisasm, "-g"], "## File", ",") + + if backend == "hip": + import shutil + # Try to find llvm-objdump from the current PATH to disassmble hsaco. + tool = shutil.which("llvm-objdump") + if tool is not None: + return ("hsaco", [tool, "-D", "-l", "--arch=amdgcn"], ";", ":") + raise RuntimeError("llvm-objdump not found in PATH") + + raise RuntimeError(f"unknown backend {backend}") + + +def extract_file_lines(command, anchor, separator, asm): + fd, path = tempfile.mkstemp() + with open(fd, 'wb') as cubin: + cubin.write(asm) + asm = subprocess.check_output(command + [path]).decode("utf-8") + file_lines = [] + lines = asm.splitlines() + for line in lines: + # We are looking for an anchor string and a separator between the file name and line number. + if anchor in line and separator in line: + entries = line[line.index(anchor):].split(separator) + if len(entries) == 2 and all(len(e) != 0 for e in entries): + file_lines.append((entries[0].strip(), entries[1].strip())) + return file_lines + + +def check_file_lines(file_lines, file_name, lineno, should_contain=True): + """ + Check if the file name and line number is in the file_lines + + Args: + file_lines: list of (file_name, line_number) + file_name: file name + lineno: line number, -1 means do not check line number + should_contain: whether the file name and line number should be in the file_lines + """ + for file, line in file_lines: + if lineno == -1: + if file_name in file: + return True + if file_name in file and str(lineno) in line: + return should_contain + return not should_contain + + +func_types = ["single", "call", "call_noinline", "autotune", "dot_combine"] + + +@pytest.mark.parametrize("func", func_types) +def test_line_info(func: str): + try: + obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() + except BaseException: + pytest.skip("disassembler is not available") + + shape = (128, ) + kernel_info = {} + if func == "single": + kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "call": + kernel_info = kernel_call.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "call_noinline": + kernel_info = kernel_call_noinline.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "autotune": + kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1,))[0] + elif func == "dot_combine": + kernel_info = kernel_dot_combine.warmup(20, grid=(1,)) + + file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) + if func == "single": + assert (check_file_lines(file_lines, "test_line_info.py", 15)) + assert (check_file_lines(file_lines, "test_line_info.py", 16)) + elif func == "call": + assert (check_file_lines(file_lines, "test_line_info.py", 28)) + assert (check_file_lines(file_lines, "test_line_info.py", 21)) + assert (check_file_lines(file_lines, "test_line_info.py", 30)) + elif func == "call_noinline": + assert (check_file_lines(file_lines, "test_line_info.py", 42)) + assert (check_file_lines(file_lines, "test_line_info.py", 35)) + assert (check_file_lines(file_lines, "test_line_info.py", 36)) + assert (check_file_lines(file_lines, "test_line_info.py", 37)) + elif func == "autotune": + assert (check_file_lines(file_lines, "test_line_info.py", 53)) + assert (check_file_lines(file_lines, "test_line_info.py", 54)) + assert (check_file_lines(file_lines, "test_line_info.py", 55)) + elif func == "dot_combine": + assert (check_file_lines(file_lines, "test_line_info.py", 65)) + assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False)) diff --git a/third_party/mthreads/python/test/unit/language/test_random.py b/third_party/mthreads/python/test/unit/language/test_random.py new file mode 100644 index 000000000..e0e59b069 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_random.py @@ -0,0 +1,255 @@ +import numpy as np +import pytest +import scipy.stats +import torch + +import triton +import triton.language as tl + +##################################### +# Reference Philox Implementation +##################################### + + +class PhiloxConfig: + + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + while len(res) < pad: + res.append(np.array(n, dtype=self._dtype)) + n >>= (np.dtype(self._dtype).itemsize * 8) + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +# Unit Tests +##################################### + +BLOCK: tl.constexpr = 1024 + +# test generation of random uint32 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in ['10', '4,53', '400'] + for seed in [0, 42, 124, 54, 0xffffffff, 0x0000000fcafeb0ba] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randint(size, seed, device, dtype, const_seed): + size = list(map(int, size.split(','))) + torch_dtype = getattr(torch, dtype) + numpy_dtype = getattr(np, f"u{dtype}") + config = {'int32': PHILOX_32, 'int64': PHILOX_64}[dtype] + + @triton.jit + def kernel(X, N, seed): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch_dtype, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed) + else: + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(numpy_dtype).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=config) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + + +# test uniform PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_rand(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert all((x >= 0) & (x <= 1)) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + + +# test normal PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randn(size, seed, dtype, device, const_seed): + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('dtype', ['int32', 'int64']) +def test_rand_limits(dtype, device): + + @triton.jit + def kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = tl.random.uint_to_uniform_float(x) + tl.store(output + idx, y) + + torch_dtype = getattr(torch, dtype) + min_max_int = torch.tensor([ + torch.iinfo(torch_dtype).min, + torch.iinfo(torch_dtype).max, + ], dtype=torch_dtype, device=device) + output = torch.empty(2, dtype=torch.float32, device=device) + kernel[(1, )](min_max_int, output, 2) + + assert output[0] == output[1] + assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/third_party/mthreads/python/test/unit/language/test_reproducer.py b/third_party/mthreads/python/test/unit/language/test_reproducer.py new file mode 100644 index 000000000..a045e8f30 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_reproducer.py @@ -0,0 +1,42 @@ +import os +import shutil + +import pytest + +import torch +import triton +import re + + +@triton.jit +def triton_(): + return + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") +def test_reproducer(): + tmpdir = ".tmp" + reproducer = 'triton-reproducer.mlir' + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) + os.environ["TRITON_CACHE_DIR"] = tmpdir + os.environ["TRITON_REPRODUCER_PATH"] = reproducer + triton_[(1, )]() + foundPipeline = "" + with open(reproducer, 'r') as f: + line = f.read() + if 'pipeline:' in line: + foundPipeline = line + if 0 == len(foundPipeline): + raise Exception("Failed to find pipeline info in reproducer file.") + + ttgir_to_llvm_pass = re.compile("convert-triton-{{.*}}gpu-to-llvm") + if ttgir_to_llvm_pass.search(foundPipeline): + raise Exception("Failed to find triton passes in pipeline") + # cleanup + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) diff --git a/third_party/mthreads/python/test/unit/language/test_standard.py b/third_party/mthreads/python/test/unit/language/test_standard.py new file mode 100644 index 000000000..d0ef00118 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_standard.py @@ -0,0 +1,76 @@ +import triton +import pytest +import torch +import triton.language as tl + +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random + +# --------------- +# test maximum/minimum ops +# --------------- + + +# TODO: Tests with unsigned integers failed at compilation stage. +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", int_dtypes + uint_dtypes + float_dtypes + ["bfloat16"]) +@pytest.mark.parametrize("op", ["maximum", "minimum"]) +def test_maximum_minium(dtype, op, device): + expr = f'tl.{op}(x, y)' + numpy_expr = f'np.{op}(x, y)' + _test_binary(dtype, dtype, expr, numpy_expr, device=device) + + +# --------------- +# test sort op +# --------------- + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_sort(M, N, descending, dtype_str, device): + + @triton.jit + def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.sort(x, descending=descending) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.sort(x, descending=descending)[0] + z = torch.empty_like(x) + sort_kernel[(1, )](x, z, N, M, descending, num_warps=8) + assert (y == z).all(), (y, z) + + +# --------------- +# test flip op +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_flip(M, N, dtype_str, device): + + @triton.jit + def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.flip(x) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.flip(x, (1, )) + z = torch.empty_like(x, device=device) + flip_kernel[(1, )](x, z, N, M, num_warps=8) + assert (y == z).all(), (y, z) diff --git a/third_party/mthreads/python/test/unit/language/test_subprocess.py b/third_party/mthreads/python/test/unit/language/test_subprocess.py new file mode 100644 index 000000000..8afdc8973 --- /dev/null +++ b/third_party/mthreads/python/test/unit/language/test_subprocess.py @@ -0,0 +1,161 @@ +import itertools +import os +import subprocess +import sys +from collections import Counter + +import pytest + +dir_path = os.path.dirname(os.path.realpath(__file__)) +print_path = os.path.join(dir_path, "print_helper.py") +assert_path = os.path.join(dir_path, "assert_helper.py") + +# TODO: bfloat16 after LLVM-15 +assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"] +nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]] +torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] + +pytestmark = pytest.mark.skip(reason="Skipping entire test file") + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +# TODO: Print with multiple operands + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + ("no_arg_print", "int32"), + ("print_no_arg", "int32"), + ("device_print_large", "int32"), + ("print_multiple_args", "int32"), + ("device_print_multiple_args", "int32"), + ("device_print_hex", "int16"), + ("device_print_hex", "int32"), + ("device_print_hex", "int64"), + ("device_print_pointer", "int32"), +]) +def test_print(func_type: str, data_type: str, device: str): + proc = subprocess.Popen([sys.executable, print_path, func_type, data_type, device], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) + outs, err = proc.communicate() + assert proc.returncode == 0 + + if is_interpreter() and func_type != "static_assert": + # Interpreter uses a different format for device_print + # Only check if there's no error + assert err == b'' + return + + outs = [line for line in outs.decode("UTF-8").split("\n") if line] + # The total number of elements in the 1-D tensor to print. + N = 128 + + # Format is + # pid (, , ) idx (, , ...) (operand ) + expected_lines = Counter() + if func_type == "print" or func_type == "device_print": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: {i}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = 1 + elif func_type == "device_print_hex": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: 0x" + if data_type == "int16": + line += f"{i:04x}" + if data_type == "int32": + line += f"{i:08x}" + if data_type == "int64": + line += f"{i:016x}" + expected_lines[line] = 1 + elif func_type == "static_print": + expected_lines[f" int32[constexpr[{N}]]"] = 1 + elif func_type == "no_arg_print": + expected_lines["pid (0, 0, 0) idx (): 0"] = N + elif func_type == "print_no_arg": + expected_lines["pid (0, 0, 0) no arg"] = N + elif func_type == "device_print_large": + for i, j, k in itertools.product(range(2), range(64), range(N)): + expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1 + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1 + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1 + elif func_type == "device_print_pointer": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}) ptr: 0x"] = 1 + + actual_lines = Counter() + for line in outs: + # Trim the exact pointer address in the output--they can change per run. + line = (line.split(':')[0] + ": 0x") if func_type == "device_print_pointer" else line + actual_lines[line] += 1 + + diff = Counter(actual_lines) + diff.subtract(expected_lines) + for line, delta in diff.items(): + if delta == 0: + continue + print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') + assert all(delta == 0 for delta in diff.values()) + + +@pytest.mark.parametrize("func_type", assert_types) +def test_assert(func_type: str, device: str): + # The total number of elements in the 1-D tensor to assert on. + N = 128 + + os.environ["TRITON_DEBUG"] = "1" + proc = subprocess.Popen([sys.executable, assert_path, func_type, device], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) + _, errs = proc.communicate() + errs = errs.splitlines() + num_errs = 0 + for err in errs: + if "x != 0" in err.decode("utf-8", errors="ignore"): + num_errs += 1 + + # Check for segfaults. + assert all("segmentation fault" not in line.decode("utf-8", errors="ignore").lower() for line in errs) + + os.environ["TRITON_DEBUG"] = "0" + if func_type == "static_assert" or func_type == "device_assert_passes": + assert num_errs == 0 + else: + assert num_errs == N - 1 + + +@pytest.mark.parametrize("caller_type, callee_type", nested_types) +def test_assert_nested(caller_type, callee_type, device): + # The total number of elements in the 1-D tensor to assert on. + N = 128 + + proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type, device], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) + _, errs = proc.communicate() + errs = errs.splitlines() + num_errs = 0 + for err in errs: + if "x != 0" in err.decode("utf-8", errors="ignore"): + num_errs += 1 + if caller_type == "none": + if callee_type == "true": + assert num_errs == N - 1 + else: + assert num_errs == 0 + elif caller_type == "true": + if callee_type == "false": + assert num_errs == 0 + else: + assert num_errs == N - 1 + elif caller_type == "false": + if callee_type == "true": + assert num_errs == N - 1 + else: + assert num_errs == 0 diff --git a/third_party/mthreads/python/test/unit/operators/conftest.py b/third_party/mthreads/python/test/unit/operators/conftest.py new file mode 100644 index 000000000..091f9ea41 --- /dev/null +++ b/third_party/mthreads/python/test/unit/operators/conftest.py @@ -0,0 +1,5 @@ +# content of conftest.py + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") diff --git a/third_party/mthreads/python/test/unit/operators/test_blocksparse.py b/third_party/mthreads/python/test/unit/operators/test_blocksparse.py new file mode 100644 index 000000000..f30582e52 --- /dev/null +++ b/third_party/mthreads/python/test/unit/operators/test_blocksparse.py @@ -0,0 +1,235 @@ +import pytest +import torch + +import triton +import triton.ops + + +def is_hip_mi200(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == 'hip' and target.arch == 'gfx90a' + + +def sparsify_tensor(x, mask, block): + ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) + for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): + ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] + return ret + + +def make_pair(shape, device="musa", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32): + if data is None: + data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device) + ref_ret = data + ref_ret = ref_ret * alpha + beta + ref_ret = ref_ret.half().to(dtype) + if trans: + ref_ret = ref_ret.t().requires_grad_() + ref_ret = ref_ret.detach().requires_grad_() + tri_ret = ref_ret.clone().detach().requires_grad_() + return ref_ret, tri_ret + + +def mask_tensor(x, mask, block, value=0): + ret = x.clone() + for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): + ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value + return ret + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) +@pytest.mark.parametrize("TRANS_A", [False, True]) +@pytest.mark.parametrize("TRANS_B", [False, True]) +@pytest.mark.parametrize("BLOCK", [16, 32, 64]) +@pytest.mark.parametrize("DTYPE", [torch.float16]) +def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, device, Z=3, H=2, M=512, N=384, K=256): + seed = 0 + torch.manual_seed(seed) + is_sdd = MODE == "sdd" + is_dsd = MODE == "dsd" + is_dds = MODE == "dds" + do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK) + do_mask = lambda x: mask_tensor(x, layout, BLOCK) + # create inputs + # create op + a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K) + b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N) + c_shape = (Z, H, M, N) + shape = { + "sdd": (M, N), + "dsd": (a_shape[2], a_shape[3]), + "dds": (b_shape[2], b_shape[3]), + }[MODE] + layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # create data + a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE) + b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE) + dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE) + # compute [torch] + dc_ref = do_mask(dc_ref) if is_sdd else dc_ref + a_ref = do_mask(a_ref) if is_dsd else a_ref + b_ref = do_mask(b_ref) if is_dds else b_ref + a_ref.retain_grad() + b_ref.retain_grad() + c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, b_ref.transpose(2, 3) if TRANS_B else b_ref) + c_ref.backward(dc_ref) + c_ref = do_sparsify(c_ref) if is_sdd else c_ref + da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad + db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad + # triton result + dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri + a_tri = do_sparsify(a_tri) if is_dsd else a_tri + b_tri = do_sparsify(b_tri) if is_dds else b_tri + a_tri.retain_grad() + b_tri.retain_grad() + op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device=device) + c_tri = op(a_tri, b_tri) + c_tri.backward(dc_tri) + da_tri = a_tri.grad + db_tri = b_tri.grad + + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and + # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + tol = {'atol': 1e-3, 'rtol': 0} if is_hip_mi200() else {} + + # compare + torch.testing.assert_close(c_ref, c_tri, **tol) + torch.testing.assert_close(da_ref, da_tri, **tol) + torch.testing.assert_close(db_ref, db_tri, **tol) + + +configs = [ + (16, 256), + (32, 576), + (64, 1871), + (128, 2511), +] + + +@pytest.mark.parametrize("is_dense", [False, True]) +@pytest.mark.parametrize("BLOCK, WIDTH", configs) +def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale=0.4): + # set seed + torch.random.manual_seed(0) + Z, H, M, N = 2, 3, WIDTH, WIDTH + # initialize layout + # make sure each row has at least one non-zero element + layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) + if is_dense: + layout[:] = 1 + else: + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # initialize data + a_shape = (Z, H, M, N) + a_ref, a_tri = make_pair(a_shape) + dout_ref, dout_tri = make_pair(a_shape) + # compute [torch] + a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf")) + a_ref.retain_grad() + at_mask = torch.ones((M, N), device=device) + if is_causal: + at_mask = torch.tril(at_mask) + M = at_mask[None, None, :, :] + torch.zeros_like(a_ref) + a_ref[M == 0] = float("-inf") + out_ref = torch.softmax(a_ref * scale, -1) + out_ref.backward(dout_ref) + out_ref = sparsify_tensor(out_ref, layout, BLOCK) + da_ref = sparsify_tensor(a_ref.grad, layout, BLOCK) + # compute [triton] + a_tri = sparsify_tensor(a_tri, layout, BLOCK) + a_tri.retain_grad() + dout_tri = sparsify_tensor(dout_tri, layout, BLOCK) + op = triton.ops.blocksparse.softmax(layout, BLOCK, device=device, is_dense=is_dense) + out_tri = op(a_tri, scale=scale, is_causal=is_causal) + out_tri.backward(dout_tri) + da_tri = a_tri.grad + # compare + torch.testing.assert_close(out_tri, out_ref, equal_nan=True) + torch.testing.assert_close(da_tri, da_ref, equal_nan=True) + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize("block", [16, 32, 64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_attention_fwd_bwd( + block, + dtype, + device, + input_scale=1.0, + scale=1 / 8.0, + n_ctx=256, + batch_size=2, + n_heads=2, +): + # inputs + qkv_shape = (batch_size, n_heads, n_ctx, 64) + qkvs = [ + torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).musa() for _ in range(3) + ] + + # Triton: + n_blocks = n_ctx // block + layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) + query, key, value = [x.clone() for x in qkvs] + query.retain_grad() + key.retain_grad() + value.retain_grad() + attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) + # ad hoc loss + loss = (attn_out**2).mean() + loss.backward() + grads = [query.grad, key.grad, value.grad] + + # Torch version: + torch_q, torch_k, torch_v = [x.clone() for x in qkvs] + attn_mask = torch.ones([n_ctx, n_ctx], device=device, dtype=dtype) + attn_mask = torch.tril(attn_mask, diagonal=0) + attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).musa())) + torch_q.retain_grad() + torch_k.retain_grad() + torch_v.retain_grad() + scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) + scores = scores + attn_mask + probs = torch.softmax(scores, dim=-1) + torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) + # ad hoc loss + torch_loss = (torch_attn_out**2).mean() + torch_loss.backward() + torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] + + # comparison + # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") + torch.testing.assert_close(loss, torch_loss, atol=1e-3, rtol=0) + + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and + # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + tol = {'atol': 1e-3, 'rtol': 0} if is_hip_mi200() else {} + for g1, g2 in zip(grads, torch_grads): + torch.testing.assert_close(g1, g2, **tol) + + +@pytest.mark.parametrize("block", [16, 32, 64]) +def triton_attention( + layout, + block: int, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +): + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, + device=value.device) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, + device=value.device) + sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) + + w = sparse_dot_sdd_nt(query, key) + w = sparse_softmax(w, scale=scale, is_causal=True) + a = sparse_dot_dsd_nn(w, value) + return a diff --git a/third_party/mthreads/python/test/unit/operators/test_cross_entropy.py b/third_party/mthreads/python/test/unit/operators/test_cross_entropy.py new file mode 100644 index 000000000..5182510d7 --- /dev/null +++ b/third_party/mthreads/python/test/unit/operators/test_cross_entropy.py @@ -0,0 +1,38 @@ +import pytest +import torch + +import triton +import triton.ops + + +@pytest.mark.parametrize("M, N, dtype, mode", [ # + (M, N, dtype, mode) + for M in [1024, 821] + for N in [512, 857, 1871, 2089, 8573, 31000] + for dtype in ['float16', 'float32'] + for mode in ['forward', 'backward'] +]) +def test_op(M, N, dtype, mode, device): + dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype] + # create inputs + x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True) + idx = 4 + torch.ones(M, dtype=torch.int64, device=device) + # forward pass + tt_y = triton.ops.cross_entropy(x, idx) + th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) + if mode == 'forward': + torch.testing.assert_close(th_y, tt_y) + # backward pass + elif mode == 'backward': + dy = torch.randn_like(tt_y) + # triton backward + tt_y.backward(dy) + tt_dx = x.grad.clone() + # torch backward + x.grad = None + th_y.backward(dy) + th_dx = x.grad.clone() + if dtype == torch.float16: + torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001) + else: + torch.testing.assert_close(th_dx, tt_dx) diff --git a/third_party/mthreads/python/test/unit/operators/test_flash_attention.py b/third_party/mthreads/python/test/unit/operators/test_flash_attention.py new file mode 100644 index 000000000..88599464e --- /dev/null +++ b/third_party/mthreads/python/test/unit/operators/test_flash_attention.py @@ -0,0 +1,115 @@ +import pytest +import torch +import os + +import triton +import triton.ops + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.interpreter +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ # + (2, 4, 512, 16), + (2, 4, 512, 32), + (2, 4, 512, 64), + (2, 4, 512, 128), +]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('seq_par', [True, False]) +def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device): + if dtype == torch.bfloat16 and os.environ.get("TRITON_INTERPRET", "0") == "1": + pytest.skip("Flash attention bfloat16 not supported in interpreter mode") + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_() + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device=device)) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if causal: + p[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float(), dim=-1).to(dtype) + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # # triton implementation + tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # compare + atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_out), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_out), dim=0), atol=atol, rtol=0) + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dv), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_dv), dim=0), atol=atol, rtol=0) + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dk), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_dk), dim=0), atol=atol, rtol=0) + torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dq), dim=0), + torch.nn.functional.normalize(torch.flatten(tri_dq), dim=0), atol=atol, rtol=0) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 +# vary seq length for fixed head and batch=4 +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], x_vals=[2**i for i in range(10, 14)], line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}', args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'casual': casual, + 'seq_par': seq_par, + }) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False] +] + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, device, dtype=torch.float16): + assert mode in ['fwd', 'bwd'] + warmup = 25 + rep = 100 + sm_scale = 1.3 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=device, requires_grad=True) + if provider == "triton": + fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual) + if mode == 'bwd': + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +# only works on post-Ampere GPUs right now +# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/third_party/mthreads/python/test/unit/operators/test_inductor.py b/third_party/mthreads/python/test/unit/operators/test_inductor.py new file mode 100644 index 000000000..a638cb633 --- /dev/null +++ b/third_party/mthreads/python/test/unit/operators/test_inductor.py @@ -0,0 +1,198 @@ +import pytest +import torch + +import triton +import triton.language as tl + + +def test_normalization_with_remat(device): + + @triton.jit + def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr): + xnumel = 512 + rnumel = 4096 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x3 = xindex + x0 = xindex % 64 + tmp1 = tl.load(in_ptr0 + (x0), xmask) + tmp3 = tl.load(in_ptr1 + (x0), xmask) + tmp11 = tl.load(in_ptr2 + (x0), xmask) + tmp13 = tl.load(in_ptr3 + (x0), xmask) + _tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r2 = rindex + tmp0 = tl.load(in_out_ptr0 + (r2 + (4096 * x3)), rmask & xmask, eviction_policy='evict_last', other=0) + tmp2 = tmp0 - tmp1 + tmp4 = 1e-05 + tmp5 = tmp3 + tmp4 + tmp6 = tl.sqrt(tmp5) + tmp7 = 1 / tmp6 + tmp8 = 1.0 + tmp9 = tmp7 * tmp8 + tmp10 = tmp2 * tmp9 + tmp12 = tmp10 * tmp11 + tmp14 = tmp12 + tmp13 + _tmp17 = tl.where(rmask & xmask, _tmp17 + tmp14, _tmp17) + tl.store(in_out_ptr0 + (r2 + (4096 * x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp14, rmask & xmask) + tmp17 = tl.sum(_tmp17, 1)[:, None] + tmp18 = 4096.0 + tmp19 = tmp17 / tmp18 + tl.store(in_out_ptr1 + (x3 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask) + + torch.manual_seed(123) + + buf14 = torch.rand(8, 64, 64, 64, device=device) + buf16 = torch.rand(8, 1, 64, device=device) + arg114_1 = torch.rand(64, device=device) + arg115_1 = torch.rand(64, device=device) + arg8_1 = torch.rand(64, device=device) + arg9_1 = torch.rand(64, device=device) + triton_[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) + torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) + + +def test_avg_pool_bw(device): + + @triton.jit + def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + x1 = (xindex // 8) % 8 + x0 = xindex % 8 + x2 = (xindex // 64) + x5 = xindex + tmp0 = (-1) + x1 + tmp1 = (-1) + x0 + tmp2 = 2 + x1 + tmp3 = 2 + x0 + tmp4 = 0 + tmp5 = tl.where(tmp0 != tmp0, tmp0, tl.where(tmp0 > tmp4, tmp0, tmp4)) + tmp6 = tl.where(tmp1 != tmp1, tmp1, tl.where(tmp1 > tmp4, tmp1, tmp4)) + tmp7 = 8 + tmp8 = tl.where(tmp2 != tmp2, tmp2, tl.where(tmp2 < tmp7, tmp2, tmp7)) + tmp9 = tl.where(tmp3 != tmp3, tmp3, tl.where(tmp3 < tmp7, tmp3, tmp7)) + tmp10 = tmp5 + tmp4 + tmp11 = tmp6 + tmp4 + tmp12 = 1 + tmp13 = tmp8 - tmp12 + tmp14 = tl.where(tmp10 != tmp10, tmp10, tl.where(tmp10 < tmp13, tmp10, tmp13)) + tmp15 = tmp9 - tmp12 + tmp16 = tl.where(tmp11 != tmp11, tmp11, tl.where(tmp11 < tmp15, tmp11, tmp15)) + tmp17 = tl.load(in_ptr0 + (tmp16 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp18 = tmp17 / 9 + tmp19 = tmp10 < tmp8 + tmp20 = tmp11 < tmp9 + tmp21 = tmp19 & tmp20 + tmp22 = 0.0 + tmp23 = tl.where(tmp21, tmp18, tmp22) + tmp24 = tmp6 + tmp12 + tmp25 = tl.where(tmp24 != tmp24, tmp24, tl.where(tmp24 < tmp15, tmp24, tmp15)) + tmp26 = tl.load(in_ptr0 + (tmp25 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp27 = tmp26 / 9 + tmp28 = tmp24 < tmp9 + tmp29 = tmp19 & tmp28 + tmp30 = tmp23 + tmp27 + tmp31 = tl.where(tmp29, tmp30, tmp23) + tmp32 = 2 + tmp33 = tmp6 + tmp32 + tmp34 = tl.where(tmp33 != tmp33, tmp33, tl.where(tmp33 < tmp15, tmp33, tmp15)) + tmp35 = tl.load(in_ptr0 + (tmp34 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp36 = tmp35 / 9 + tmp37 = tmp33 < tmp9 + tmp38 = tmp19 & tmp37 + tmp39 = tmp31 + tmp36 + tmp40 = tl.where(tmp38, tmp39, tmp31) + tmp41 = tmp5 + tmp12 + tmp42 = tl.where(tmp41 != tmp41, tmp41, tl.where(tmp41 < tmp13, tmp41, tmp13)) + tmp43 = tl.load(in_ptr0 + (tmp16 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp44 = tmp43 / 9 + tmp45 = tmp41 < tmp8 + tmp46 = tmp45 & tmp20 + tmp47 = tmp40 + tmp44 + tmp48 = tl.where(tmp46, tmp47, tmp40) + tmp49 = tl.load(in_ptr0 + (tmp25 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp50 = tmp49 / 9 + tmp51 = tmp45 & tmp28 + tmp52 = tmp48 + tmp50 + tmp53 = tl.where(tmp51, tmp52, tmp48) + tmp54 = tl.load(in_ptr0 + (tmp34 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp55 = tmp54 / 9 + tmp56 = tmp45 & tmp37 + tmp57 = tmp53 + tmp55 + tmp58 = tl.where(tmp56, tmp57, tmp53) + tmp59 = tmp5 + tmp32 + tmp60 = tl.where(tmp59 != tmp59, tmp59, tl.where(tmp59 < tmp13, tmp59, tmp13)) + tmp61 = tl.load(in_ptr0 + (tmp16 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp62 = tmp61 / 9 + tmp63 = tmp59 < tmp8 + tmp64 = tmp63 & tmp20 + tmp65 = tmp58 + tmp62 + tmp66 = tl.where(tmp64, tmp65, tmp58) + tmp67 = tl.load(in_ptr0 + (tmp25 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp68 = tmp67 / 9 + tmp69 = tmp63 & tmp28 + tmp70 = tmp66 + tmp68 + tmp71 = tl.where(tmp69, tmp70, tmp66) + tmp72 = tl.load(in_ptr0 + (tmp34 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp73 = tmp72 / 9 + tmp74 = tmp63 & tmp37 + tmp75 = tmp71 + tmp73 + tmp76 = tl.where(tmp74, tmp75, tmp71) + tl.store(out_ptr0 + (x5 + tl.zeros([XBLOCK], tl.int32)), tmp76, None) + + inp = torch.ones(8, 2048, 8, 8, device=device, dtype=torch.half) + out = torch.ones_like(inp) * 3 + numel = inp.numel() + triton_[(numel // 1024, )](inp, out, 1024) + out_ref = torch.ones_like(inp) + out_ref[:, :, 1:7, 0::7] = 2 / 3 + out_ref[:, :, 0::7, 1:7] = 2 / 3 + out_ref[:, :, 0::7, 0::7] = 4 / 9 + torch.testing.assert_close(out, out_ref) + + +@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128]) +@pytest.mark.parametrize("num_warps", [1, 4]) +def test_scan2d_broadcast(RBLOCK, num_warps, device): + + @triton.jit(debug=True) + def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + rindex = tl.arange(0, RBLOCK)[None, :] + xindex = tl.arange(0, XBLOCK)[:, None] + data = tl.load(in_ptr + rindex) + scan = tl.cumsum(data, 1) + expected_max = tl.sum(data, 1) + tl.device_assert(scan <= expected_max) + tl.store(out_ptr + xindex * RBLOCK + rindex, scan) + + XBLOCK = 4 + input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device=device) + output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device=device) + fn[(1, )](input, output, XBLOCK, RBLOCK, num_warps=num_warps) + ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)) + torch.testing.assert_close(output, ref) + + +def test_scan2d_for(device): + + @triton.jit + def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): + rbase = tl.arange(0, RBLOCK)[None, :] + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + tmp3 = tl.where(rmask, 1, 0) + tmp6 = tl.cumsum(tmp3, 1) + tl.store(out_ptr0 + rindex, tmp6, rmask) + + RBLOCK = 8 + out0 = torch.empty(RBLOCK, device=device, dtype=torch.int64) + fn[(1, )](out0, RBLOCK, RBLOCK) + ref = torch.arange(RBLOCK, device=device, dtype=torch.int64) + 1 + torch.testing.assert_close(out0, ref) diff --git a/third_party/mthreads/python/test/unit/operators/test_matmul.py b/third_party/mthreads/python/test/unit/operators/test_matmul.py new file mode 100644 index 000000000..eb746ff99 --- /dev/null +++ b/third_party/mthreads/python/test/unit/operators/test_matmul.py @@ -0,0 +1,193 @@ +import itertools + +import pytest +import torch + +import triton +import triton.language as tl +import triton.ops + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@pytest.mark.skip(reason="TO FIX") +@pytest.mark.parametrize( + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, INPUT_PRECISION, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE", + itertools.chain( + *[[ + # 1 warp + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # 2 warp + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # 4 warp + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # 8 warp + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, None, True, None, None), + # variable input + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]], + # n-stage + *[[ + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, None, True, None, None), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, None, True, None, None), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, None, True, None, None), + ] + for DTYPE in ["float16", "bfloat16", "float32"] + for AT in [False, True] + for BT in [False, True] + for STAGES in [4]], + # tf32x3 + *[[ + (16, 16, 16, 1, 1, 2, 32, 32, 80, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (64, 32, 64, 1, 2, 2, 128, 64, 128, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (128, 64, 16, 1, 4, 2, 256, 128, 80, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (256, 128, 32, 1, 8, 2, 512, 256, 160, AT, BT, "float32", "float32", "tf32x3", True, None, None), + (128, 128, 32, 1, 4, 2, 256, 256, 160, AT, BT, "float32", "float32", "tf32x3", True, None, None), + ] for AT in [False, True] for BT in [False, True]], + # mixed-precision + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, None, FASTACCUM, None, None), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float8e5"), + ("float8e4nv", "float8e4nv"), + ("float8e5", "float8e4nv"), + ("float8e5", "float8e5"), + ("float8e4b15", "float8e4b15"), + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("int8", "bfloat16"), + ("float16", "int8"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]], + # mixed-precision block layout + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, True, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, None, True, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, None, True, None, None), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True]], + # acc-out-dtype and output_dtype + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, False, False, "float16", "float16", None, True, ACC_DTYPE, + OUTPUT_DTYPE), + (128, 256, 32, 1, 8, 2, None, None, None, False, False, "float16", "float16", None, True, ACC_DTYPE, + OUTPUT_DTYPE), + ] for ACC_DTYPE in [None, "float16", "float32"] for OUTPUT_DTYPE in [None, "float16", "float32"]], + ), +) +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, INPUT_PRECISION, + F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE, device): + torch.manual_seed(0) + # nuke kernel decorators -- will set meta-parameters manually + kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} + pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_() + configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)] + kernel = triton.ops._matmul.kernel + kernel.configs = configs + # kernel.run = kernel.run.run.run + + # get matrix shape + M = BLOCK_M if M is None else M + N = BLOCK_N if N is None else N + K = BLOCK_K * SPLIT_K if K is None else K + + def is_fp8(dtype): + return "float8" in dtype + + def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty_strided(x.shape, x.stride(), dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + def upcast_if_fp8(x, dtype): + if is_fp8(dtype): + return f8_to_f16(x, dtype) + return x + + def init_input(m, n, dtype, acc_dtype): + if 'float8' in dtype: + ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype] + sign = torch.randint(2, size=(m, n), device=device, dtype=torch.int8) * 128 + val = torch.randint(2**3 - 1, size=(m, n), device=device, dtype=torch.int8) << 7 - ewidth + return sign | val + if dtype == "int8": + return torch.randint(-128, 127, (m, n), device=device, dtype=torch.int8) + # Use small range of values to prevent numerical issues. + min_exp = -4 if acc_dtype == "float16" else -10 + exponents = torch.randint(min_exp, 0, size=(m, n)) + ret = (2.**exponents).to(getattr(torch, dtype)).to(device) + return ret + + if is_hip(): + if INPUT_PRECISION == 'tf32x3' or is_fp8(ADTYPE) or is_fp8(BDTYPE): + pytest.skip("fp8 inputs or tf32x3 precison does not have native support on hip") + # allocate/transpose inputs + a = init_input(M, K, ADTYPE, ACC_DTYPE) + b = init_input(K, N, BDTYPE, ACC_DTYPE) + a = a if not AT else a.T.contiguous().T + b = b if not BT else b.T.contiguous().T + # run test + th_a = upcast_if_fp8(a, ADTYPE) + th_b = upcast_if_fp8(b, BDTYPE) + ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) + acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype + output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype + th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype)) + try: + if is_fp8(ADTYPE): + a = triton.reinterpret(a, getattr(tl, ADTYPE)) + if is_fp8(BDTYPE): + b = triton.reinterpret(b, getattr(tl, BDTYPE)) + tt_c = triton.ops.matmul(a, b, acc_dtype if ACC_DTYPE else None, INPUT_PRECISION, F8_FASTACCUM, output_dtype) + torch.testing.assert_close(th_c, tt_c) + except triton.OutOfResources as e: + pytest.skip(str(e)) diff --git a/third_party/mthreads/python/test/unit/runtime/test_autotuner.py b/third_party/mthreads/python/test/unit/runtime/test_autotuner.py new file mode 100644 index 000000000..679782a32 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_autotuner.py @@ -0,0 +1,132 @@ +import torch + +import triton +import triton.language as tl +import pytest + + +@pytest.mark.parametrize('use_cuda_graph', [False, True]) +def test_kwargs(use_cuda_graph: bool, device: str): + N = 1024 + src = torch.empty(N, device=device) + dst = torch.empty(N, device=device) + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N) + _kernel[grid](dst=dst, src=src, N=N) + + +def test_restore(device: str): + N = 1024 + src = torch.zeros(N, device=device) + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1) + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](src, N) + triton.testing.assert_close(src, torch.ones_like(src)) + + +def test_hooks(device: str): + # Autotuner's pre- and post- hooks should be called the same number of times + N = 4096 + src = torch.zeros(N, device=device) + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 4096}), triton.Config(kwargs={'BLOCK_SIZE': 32})] + + values = {"counter": 0, "has_exception": False} + + def _pre_hook(*args, **kwargs): + values["counter"] += 1 + + def _post_hook(*args, exception): + values["counter"] -= 1 + if exception is not None: + values["has_exception"] = True + assert values["counter"] == 0 + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook) + @triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4}) + @triton.jit + def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + max_iters = tl.cdiv(N, BLOCK_SIZE) + for _ in tl.range(max_iters, num_stages=N_STAGES): + x = tl.load(src + offsets, mask=offsets < N) + tl.store(src + offsets, x, mask=offsets < N) + offsets += BLOCK_SIZE + + _kernel[(1, )](src, N) + + # On NVIDIA GPUs: + # The tunning knob `num_stages` can be set by users. + # This will cause out of resources when N_STAGES = 100 + # shared memory bytes = N_STAGES * BLOCK_SIZE * sizeof(float) + # On AMD GPUs: + # `num_stages` is a fixed value of 2, so it won't cause out of resources + if triton.runtime.driver.active.get_current_target().backend == "cuda": + assert values["has_exception"] is True + else: + assert values["has_exception"] is False + + +@pytest.mark.parametrize('with_perf_model', [False, True]) +def test_prune_configs(with_perf_model: bool, device: str): + N = 1024 + src = torch.empty(N, device=device) + dst = torch.empty(N, device=device) + records = {} + + def early_config_prune(configs, named_args, **kwargs): + records['run_early_config_prune'] = True + if "N" in kwargs and kwargs["N"] == 1024: + records['capture_kwargs'] = True + if "dst" in named_args and "src" in named_args and len(named_args) == 2: + records['capture_named_args'] = True + return [configs[0]] + + def perf_model(*args, **kwargs): + records['run_perf_model'] = True + return kwargs['BLOCK_SIZE'] + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + if with_perf_model: + prune_configs_by = {'perf_model': perf_model, 'top_k': 1} + else: + prune_configs_by = {'early_config_prune': early_config_prune} + + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + torch.testing.assert_close(src, dst) + if with_perf_model: + assert len(records) == 1 + assert records['run_perf_model'] + else: + assert len(records) == 3 + assert records['run_early_config_prune'] + assert records['capture_kwargs'] + assert records['capture_named_args'] diff --git a/third_party/mthreads/python/test/unit/runtime/test_bindings.py b/third_party/mthreads/python/test/unit/runtime/test_bindings.py new file mode 100644 index 000000000..ea9a13087 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_bindings.py @@ -0,0 +1,81 @@ +import triton +import triton.language as tl + +import torch + + +@triton.jit +def add_helper(x, y): + return x + y + + +@triton.jit +def add_kernel( + in_ptr0, + in_ptr1, + n_elements, + out_ptr, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = add_helper(x, y) + tl.store(out_ptr + offsets, output, mask=mask) + + +def test_module_walk(device: str): + """ + Test the MLIR bindings exposed for the out-ot-tree walk. + """ + + def walk_fn(op): + name = op.get_name() + for i in range(op.get_num_results()): + op.get_result(i).id() + for i in range(op.get_num_operands()): + op.get_operand(i).id() + for i in range(op.get_num_regions()): + op.get_region(i).id() + block = op.get_block() + if block is not None: + block.id() + for i in range(block.get_num_arguments()): + block.get_argument(i) + if name == "tt.func": + op.get_str_attr("sym_name") + if name == "tt.call": + op.get_flat_symbol_ref_attr("callee") + + kernel = add_kernel + args = [ + torch.empty((32, 32), device=device), # in_ptr0 + torch.empty((32, 32), device=device), # in_ptr1 + 1024, # n_elements + torch.empty((32, 32), device=device), # out_ptr + 16, # BLOCK_SIZE + ] + src = triton.compiler.compiler.ASTSource( + fn=kernel, + signature={i: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args) + if i not in kernel.constexprs}, + constants={i: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + attrs=kernel._get_config(*args, ), + ) + + context = triton._C.libtriton.ir.context() + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options(dict()) + codegen_fns = dict() + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + ttir_module = src.make_ir(options, codegen_fns, context) + ttir_module.walk(walk_fn) diff --git a/third_party/mthreads/python/test/unit/runtime/test_cache.py b/third_party/mthreads/python/test/unit/runtime/test_cache.py new file mode 100644 index 000000000..ed7be83e9 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_cache.py @@ -0,0 +1,536 @@ +import importlib.util +import itertools +import os +import shutil +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl +from triton.runtime.jit import JITFunction + +tmpdir = ".tmp" + + +@triton.jit +def function_1(i): + i = i + 1 + i = function_2(i) + return i + + +@triton.jit +def function_2(i): + i = i + 1 + return i + + +@triton.jit +def combine_fn(a, b): + return COMBINE_OP # noqa: F821 + + +@triton.jit +def kernel(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize=["i"]) +def kernel_nospec(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit +def kernel_with_combine_fn(X, BLOCK: tl.constexpr): + i = tl.arange(0, BLOCK) + i = REDUCE_OR_SCAN(i, 0, combine_fn) # noqa: F821 + tl.store(X, i) + + +def apply_src_change(target, old, new): + kernel.hash = None + function_1.hash = None + function_2.hash = None + function_1.src = function_1.src.replace(old, new) + target.src = target.src.replace(old, new) + ret = target.cache_key + target.src = target.src.replace(new, old) + return ret + + +def test_nochange(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 1') + assert baseline == updated + + +def test_toplevel_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2') + assert baseline != updated + + +def test_nested1_change(): + baseline = kernel.cache_key + updated = apply_src_change(function_1, 'i + 1', 'i + 2') + assert baseline != updated + + +def test_combine_fn_change(): + # Test that tl.reduce and associative_scan calls include + # the combine_fn in the hash + + orig_combine_fn_src = combine_fn.src + orig_kernel_src = kernel_with_combine_fn.src + seen_keys = set() + + for reduce_or_scan, combine_op in itertools.product( + ["tl.reduce", "tl.associative_scan"], + ["a + b", "a * b"], + ): + combine_fn.src = orig_combine_fn_src.replace("COMBINE_OP", combine_op) + kernel_with_combine_fn.src = orig_kernel_src.replace("REDUCE_OR_SCAN", reduce_or_scan) + try: + key = kernel_with_combine_fn.cache_key + finally: + combine_fn.src = orig_combine_fn_src + kernel_with_combine_fn.src = orig_kernel_src + + kernel_with_combine_fn.hash = None + combine_fn.hash = None + + assert key not in seen_keys + seen_keys.add(key) + + +def write_and_load_module(code, num_extra_lines): + with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: + f.write(('# extra line\n' * num_extra_lines) + code) + f.flush() + spec = importlib.util.spec_from_file_location("module.name", f.name) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_changed_line_numbers_invalidate_cache(): + from textwrap import dedent + code = dedent(""" + import triton + @triton.jit + def test_kernel(i): + i = i + 1 + """) + orig_mod = write_and_load_module(code, 0) + orig_cache_key = orig_mod.test_kernel.cache_key + + updated_mod = write_and_load_module(code, 1) + updated_cache_key = updated_mod.test_kernel.cache_key + assert orig_cache_key != updated_cache_key + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + # https://stackoverflow.com/questions/303200/how-do-i-remove-delete-a-folder-that-is-not-empty + shutil.rmtree(tmpdir, ignore_errors=True) + + +def test_reuse(device: str): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device=device) + for i in range(10): + kernel[(1, )](x, 1, BLOCK=1024) + assert counter == 1 + + +@pytest.mark.parametrize('mode', ['enable', 'disable']) +def test_specialize(mode, device: str): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device=device) + function = {'enable': kernel, 'disable': kernel_nospec}[mode] + target = {'enable': 3, 'disable': 1}[mode] + for i in [1, 2, 4, 8, 16, 32]: + function[(1, )](x, i, BLOCK=512) + assert counter == target + + +def test_annotation(device: str): + + @triton.jit + def kernel(X, i: tl.int32): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device=device) + + device = eval(f'torch.{device}.current_device()') + kernel[(1, )](x, 1) + kernel[(1, )](x, 8) + kernel[(1, )](x, 16) + kernel[(1, )](x, 17) + assert len(kernel.cache[device]) == 3 + + +GLOBAL_DEFAULT_ARG = 1 + + +def test_kernel_default_arg(device: str): + global GLOBAL_DEFAULT_ARG + + @triton.jit + def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device=device) + kernel[(1, )](x) + assert x == torch.ones_like(x) + + # Changing the global variable should not change the default argument in + # `kernel`. That value gets set at the time the function is declared. + GLOBAL_DEFAULT_ARG = 2 + kernel[(1, )](x) + assert x == torch.ones_like(x) + + device = eval(f'torch.{device}.current_device()') + assert len(kernel.cache[device]) == 1 + + +GLOBAL_VAR: tl.constexpr = 1 + + +def test_kernel_global_var_change(device: str): + global GLOBAL_VAR + + @triton.jit + def kernel(X): + tl.store(X, GLOBAL_VAR) + + x = torch.empty(1, dtype=torch.int32, device=device) + kernel[(1, )](x) + assert x == torch.ones_like(x) + + GLOBAL_VAR = 2 + with pytest.raises(RuntimeError) as e: + kernel[(1, )](x) + + assert "global variable" in str(e.value).lower() + + +GLOBAL = 42 # noqa + + +def test_local_shadows_global(): + global GLOBAL + + @triton.jit + def kernel(): + _, GLOBAL = 0, 0 # noqa + a = GLOBAL # noqa + + # No error because the `GLOBAL` we're modifying is not the same `GLOBAL` as + # inside the kernel. + GLOBAL = 42 + kernel[(1, )]() + GLOBAL = 43 + kernel[(1, )]() + + +CONSTEXPR_GLOBAL: tl.constexpr = 42 + + +def test_local_does_not_shadow_global(): + global CONSTEXPR_GLOBAL + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + _, CONSTEXPR_GLOBAL = 0, 0 # noqa + + CONSTEXPR_GLOBAL = 42 + kernel[(1, )]() + CONSTEXPR_GLOBAL = 43 + + # Error because the `CONSTEXPR_GLOBAL` we're modifying is the same + # `CONSTEXPR_GLOBAL` that's read inside `kernel`. (Alternatively, we could + # make this kernel an error altogether, as it is if it's a pure Python + # function -- the fact that we store to `CONSTEXPR_GLOBAL` inside the kernel + # makes the first read a read of the local variable, which doesn't exist + # yet.) + with pytest.raises(RuntimeError): + kernel[(1, )]() + + +CONFLICTING_GLOBAL: tl.constexpr = 0 + + +@triton.jit +def conflicting_global_inner(): + a = CONFLICTING_GLOBAL # noqa + + +def test_conflicting_global_in_inner_function(): + global CONFLICTING_GLOBAL + + @triton.jit + def kernel1(): + a = CONFLICTING_GLOBAL # noqa + conflicting_global_inner() + + @triton.jit + def kernel2(): + a = CONFLICTING_GLOBAL #noqa + conflicting_global_inner() + + kernel1[(1, )]() + + # This should be an error because kernel2 calls conflicting_global_inner, + # which saw a value for 42 for the global when it was first compiled. + CONFLICTING_GLOBAL = 1 + + with pytest.raises(RuntimeError) as e: + kernel2[(1, )]() + + assert "Global variable CONFLICTING_GLOBAL has value" in str(e.value) + + +def test_use_builtin(): + + @triton.jit + def kernel(): + a = float(0) # noqa + + # No error about the value of `float` changing. + kernel[(1, )]() + kernel[(1, )]() + + +def test_no_cache_module_as_global(): + + @triton.jit + def kernel(): + tl.arange(0, 16) + + kernel[(1, )]() + # `tl` should not be entered into used_global_vals + assert not kernel.used_global_vals + + +BUILTIN_AS_GLOBAL = tl.int32 + + +def test_cache_builtin_as_global(): + global BUILTIN_AS_GLOBAL + + @triton.jit + def kernel(): + x = BUILTIN_AS_GLOBAL # noqa + + kernel[(1, )]() + + BUILTIN_AS_GLOBAL = tl.int64 + with pytest.raises(RuntimeError) as e: + kernel[(1, )]() + + assert "global variable" in str(e.value).lower() + + +@triton.jit +def no_cache_callable_inner(): + pass + + +def test_no_cache_callable(): + + @triton.jit + def kernel(): + no_cache_callable_inner() + + kernel[(1, )]() + # `no_cache_callable_inner` should not be entered into used_global_vals. + assert not kernel.used_global_vals + + +def test_constexpr_not_callable(device: str) -> None: + + @triton.jit + def kernel(X, c: tl.constexpr): + tl.store(X, 2) + + x = torch.empty(1, dtype=torch.int32, device=device) + error = False + try: + kernel[(1, )](x, c="str") + except BaseException: + error = True + assert error is False + # try and catch + try: + kernel[(1, )](x, c=tl.abs) + except BaseException: + error = True + assert error is True + + +def test_jit_warmup_cache(device: str) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), + torch.randn(32, dtype=torch.float32, device=device), + 32, + ] + device = eval(f'torch.{device}.current_device()') + assert len(kernel_add.cache[device]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + + +def test_jit_debug(device: str) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + device = eval(f'torch.{device}.current_device()') + assert len(kernel_add.cache[device]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.debug = False + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 2 + kernel_add.debug = True + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 3 + bins = list(kernel_add.cache[device].values()) + assert bins[2].asm['ttir'] != bins[1].asm['ttir'] + + +@triton.jit +def add_fn(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + +@pytest.mark.skip(reason="TO FIX") +def test_jit_noinline(device: str) -> None: + + @triton.jit + def kernel_add_device(a, b, o, N: tl.constexpr): + add_fn(a, b, o, N) + + device = eval(f'torch.{device}.current_device()') + assert len(kernel_add_device.cache[device]) == 0 + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + inline_ttir = bins[0].asm['ttir'] + add_fn.noinline = True + add_fn.hash = None + kernel_add_device.hash = None + kernel_add_device.cache[device].clear() + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + noinline_ttir = bins[0].asm['ttir'] + assert inline_ttir != noinline_ttir + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + +@pytest.mark.skip(reason="TO FIX") +def test_preload(device: str) -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) + + device = eval(f'torch.{device}.current_device()') + + # get the serialized specialization data + specialization_data = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + pre_compile = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + hash = pre_compile.hash + assert specialization_data is not None + + # clear the cache + reset_tmp_dir() + kernel_add.cache[device].clear() + + # preload the kernel + kernel_preload = kernel_add.preload(specialization_data) + assert kernel_preload.hash == hash + assert len(kernel_add.cache[device]) == 1 + + # we should hit the cache and not compile anything + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + JITFunction.cache_hook = None + assert counter == 0 + assert len(kernel_add.cache[device]) == 1 + assert final_kernel.hash == hash + + # test that we can't preload a mismatched kernel + with pytest.raises(RuntimeError, match="Specialization data is for"): + kernel_sub.preload(specialization_data) diff --git a/third_party/mthreads/python/test/unit/runtime/test_driver.py b/third_party/mthreads/python/test/unit/runtime/test_driver.py new file mode 100644 index 000000000..de00082f5 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_driver.py @@ -0,0 +1,14 @@ +import sys + +import triton + + +def test_is_lazy(): + from importlib import reload + reload(sys.modules["triton.runtime.driver"]) + reload(sys.modules["triton.runtime"]) + mod = sys.modules[triton.runtime.driver.__module__] + assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy")) + assert triton.runtime.driver.active._obj is None + utils = triton.runtime.driver.active.utils # noqa: F841 + assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase")) diff --git a/third_party/mthreads/python/test/unit/runtime/test_jit.py b/third_party/mthreads/python/test/unit/runtime/test_jit.py new file mode 100644 index 000000000..5892494c4 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_jit.py @@ -0,0 +1,42 @@ +import itertools +import pytest +import torch + +import triton +import triton.language as tl + + +def test_pre_call_hooks(device): + + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + class MyTensor(torch.Tensor): + pass + + def my_hook(*args, **kwargs): + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, MyTensor): + raise Exception("MyTensor is not allowed") + + add_kernel.add_pre_run_hook(my_hook) + + x = torch.randn(4, device=device) + y = MyTensor(x) + out = torch.zeros_like(x) + with pytest.raises(Exception): + add_kernel[(4, )](x, y, out, 4, 4) diff --git a/third_party/mthreads/python/test/unit/runtime/test_launch.py b/third_party/mthreads/python/test/unit/runtime/test_launch.py new file mode 100644 index 000000000..2a52ce285 --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_launch.py @@ -0,0 +1,134 @@ +import gc +# import importlib +# import os +# import sys +# import tempfile +# import textwrap +# import time +import tracemalloc + +import torch + +import triton +import triton.language as tl + +# from typing import Tuple + + +def test_metadata() -> None: + + used_hook = False + + def _launch_metadata(grid, kernel, args): + ret = dict() + ret["grid"] = grid + ret["value"] = args["x"] + return ret + + def hook(launch_metadata): + nonlocal used_hook + metadata = launch_metadata.get() + assert metadata["grid"] == (1, 3, 2) + assert metadata["value"] == 6 + used_hook = True + + @triton.jit(launch_metadata=_launch_metadata) + def kernel(x): + pass + + # launch kernel + triton.compiler.CompiledKernel.launch_enter_hook = hook + kernel[(1, 3, 2)](6) + triton.compiler.CompiledKernel.launch_enter_hook = None + assert used_hook + + +def test_memory_leak(device: str) -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + tracemalloc.start() + try: + inp = torch.randn(10, device=device) + out = torch.randn(10, device=device) + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + begin, _ = tracemalloc.get_traced_memory() + for _ in range(100): + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + end, _ = tracemalloc.get_traced_memory() + assert end - begin < 30000 + finally: + tracemalloc.stop() + + +# LATENCY_THRESHOLD_US = 46 + +# def test_kernel_launch_latency() -> None: +# def define_kernel(kernel_name: str, num_tensor_args: int) -> str: +# arg_str = ",".join([f"arg{i}: torch.Tensor" for i in range(num_tensor_args)]) +# arg_str += ", n_elements: int, BLOCK_SIZE: tl.constexpr" +# func_str = f""" +# import torch + +# import triton +# import triton.language as tl + +# @triton.jit +# def {kernel_name}({arg_str}): +# pass +# """ +# with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py", delete=False) as temp_file: +# temp_file.write(textwrap.dedent(func_str)) +# temp_file_path = temp_file.name + +# return temp_file_path + +# def import_kernel(file_path, kernel_name): +# directory, filename = os.path.split(file_path) +# module_name, _ = os.path.splitext(filename) +# sys.path.insert(0, directory) + +# module = importlib.import_module(module_name) +# kernel = getattr(module, kernel_name) +# return kernel + +# def empty(*kernel_args: Tuple[torch.Tensor]): +# first_arg = kernel_args[0] +# n_elements = first_arg.numel() +# grid = (triton.cdiv(n_elements, 1024),) +# device = torch.cuda.current_device() +# # Warmup +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# torch.cuda.synchronize() +# # Measure launch overhead at steady state +# num_runs = 1000 +# start_time = time.time() +# for i in range(num_runs): +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# end_time = time.time() +# latency_us = (end_time - start_time) / num_runs * 1e6 + +# assert latency_us < LATENCY_THRESHOLD_US, "Kernel launch time has increased!" + +# num_tensor_args = 40 +# kernel_name = 'empty_kernel' +# file_path = define_kernel(kernel_name, num_tensor_args) +# empty_kernel = import_kernel(file_path, kernel_name) + +# # Initialize random tensors for the empty_kernel +# torch.manual_seed(0) +# size = 1024 +# kernel_args = (torch.rand(size, device='cuda') for i in range(num_tensor_args)) + +# # Run empty, which would run empty_kernel internally +# empty(*kernel_args) diff --git a/third_party/mthreads/python/test/unit/runtime/test_subproc.py b/third_party/mthreads/python/test/unit/runtime/test_subproc.py new file mode 100644 index 000000000..03ce44efc --- /dev/null +++ b/third_party/mthreads/python/test/unit/runtime/test_subproc.py @@ -0,0 +1,76 @@ +import multiprocessing +import os +import shutil + +import torch +import pytest + +import triton +import triton.language as tl +from triton.compiler import ASTSource + +tmpdir = ".tmp" + +target = triton.runtime.driver.active.get_current_target() + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + + +def compile_fn(attrs, capability): + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + + src = ASTSource( + fn=kernel_sub, + constants={3: 32}, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + attrs=attrs, + ) + triton.compile(src=src, target=target) + + +@pytest.mark.skip(reason="TO FIX") +def test_compile_in_subproc(device: str) -> None: + major, minor = eval(f'torch.{device}.get_device_capability(0)') + cc = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(4)), ()) + + multiprocessing.set_start_method('fork') + proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_fn_dot(attrs, capability): + + @triton.jit + def kernel_dot(Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + tl.store(Z + offs, z) + + src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) + triton.compile(src=src, target=target) + + +@pytest.mark.skip(reason="TO FIX") +def test_compile_in_forked_subproc(device: str) -> None: + reset_tmp_dir() + major, minor = eval(f'torch.{device}.get_device_capability(0)') + capability = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) + + assert multiprocessing.get_start_method() == 'fork' + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) + proc.start() + proc.join() + assert proc.exitcode == 0 diff --git a/third_party/mthreads/python/triton/_C/include b/third_party/mthreads/python/triton/_C/include new file mode 120000 index 000000000..b85a40983 --- /dev/null +++ b/third_party/mthreads/python/triton/_C/include @@ -0,0 +1 @@ +../../../include/ \ No newline at end of file diff --git a/third_party/mthreads/python/triton/__init__.py b/third_party/mthreads/python/triton/__init__.py new file mode 100644 index 000000000..f96017c1e --- /dev/null +++ b/third_party/mthreads/python/triton/__init__.py @@ -0,0 +1,75 @@ +"""isort:skip_file""" +__version__ = '3.1.0' + +# --------------------------------------- +# Note: import order is significant here. + +# submodules +from .runtime import ( + autotune, + Config, + heuristics, + JITFunction, + KernelInterface, + reinterpret, + TensorWrapper, + OutOfResources, + InterpreterError, + MockTensor, +) +from .runtime.jit import jit +from .compiler import compile, CompilationError +from .errors import TritonError + +from . import language +from . import testing +from . import tools +from .backends.mthreads import musa_testing + +__all__ = [ + "autotune", + "cdiv", + "CompilationError", + "compile", + "Config", + "heuristics", + "impl", + "InterpreterError", + "jit", + "JITFunction", + "KernelInterface", + "language", + "MockTensor", + "next_power_of_2", + "ops", + "OutOfResources", + "reinterpret", + "runtime", + "TensorWrapper", + "TritonError", + "testing", + "musa_testing", + "tools", +] + +# ------------------------------------- +# misc. utilities that don't fit well +# into any specific module +# ------------------------------------- + + +def cdiv(x: int, y: int): + return (x + y - 1) // y + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/third_party/mthreads/python/triton/backends b/third_party/mthreads/python/triton/backends new file mode 120000 index 000000000..19987ee14 --- /dev/null +++ b/third_party/mthreads/python/triton/backends @@ -0,0 +1 @@ +../../../../python/triton/backends/ \ No newline at end of file diff --git a/third_party/mthreads/python/triton/compiler/__init__.py b/third_party/mthreads/python/triton/compiler/__init__.py new file mode 100644 index 000000000..ce0cfedfc --- /dev/null +++ b/third_party/mthreads/python/triton/compiler/__init__.py @@ -0,0 +1,4 @@ +from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict +from .errors import CompilationError + +__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/third_party/mthreads/python/triton/compiler/code_generator.py b/third_party/mthreads/python/triton/compiler/code_generator.py new file mode 100644 index 000000000..6903052ca --- /dev/null +++ b/third_party/mthreads/python/triton/compiler/code_generator.py @@ -0,0 +1,1302 @@ +import ast +import inspect +import re +import sys +import warnings +import os +import textwrap +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from .. import language +from .._C.libtriton import ir +from ..language import constexpr, tensor, str_to_ty +from ..runtime.jit import _normalize_ty +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) +from types import ModuleType + + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + assert False, "Unsupported type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return isinstance(o, constexpr) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _unwrap_if_constexpr(o: Any): + return o.value if isinstance(o, constexpr) else o + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +def _get_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + for s in body: + if self.visit(s): + return True + return False + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) == ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + self.context = context + self.builder = ir.builder(context) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + self.gscope = gscope + self.lscope = dict() + self.attributes = attributes + self.constants = constants + self.jit_fn = jit_fn + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.debug = options.debug if debug is None else debug + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + if a := self.gscope.get("__annotations__", {}).get(name): + return _normalize_ty(a) == "constexpr" + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if (val is absent # + or name in self.builtin_namespace # + or type(val) == ModuleType # + or isinstance(val, JITFunction) # + or getattr(val, "__triton_builtin__", False) # + or getattr(val, "__module__", "").startswith("triton.language") # + or isinstance(val, language.dtype) # + or self._is_constexpr_global(name) # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + or self.visiting_arg_default_value # + or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are annotated as constexpr (`x: triton.language.constexpr = 42` + or `x = triton.language.constexpr(42)`). Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + # ret_block = self.builder.create_block() + # post_ret_block = self.builder.create_block() + # self.builder.create_branch(ret_block) + # self.builder.set_insertion_point_to_end(ret_block) + if ret_value is None: + self.builder.ret([]) + ret_ty = language.void + elif isinstance(ret_value, tuple): + ret_values = [language.core._to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + ret_ty = tuple(ret_types) + else: + ret = language.core._to_tensor(ret_value, self.builder) + self.builder.ret([ret.handle]) + ret_ty = ret.type + # self.builder.create_branch(post_ret_block) + # self.builder.set_insertion_point_to_end(post_ret_block) + + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + for name, value in self.attributes[i]: + self.fn.set_arg_attr(idx, name, value) + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + # finalize function + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + # update return type + if isinstance(self.ret_type, tuple): + self.prototype.ret_types = list(self.ret_type) + self.fn.reset_type(self.prototype.to_ir(self.builder)) + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + # Remove dead code + self.fn.finalize() + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not _is_constexpr(value): + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise self._unsupported(node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not _is_list_like(names): + names = [names] + if not _is_list_like(values): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_tensor(value) and \ + not isinstance(value, native_nontensor_types): + value = language.core._to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + ret_types = [] + ir_ret_types = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + assert defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name].type}, '\ + f'but the {block_name} block redefines it as {defs[name].type}' + if name in then_defs or name in else_defs: + names.append(name) + ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in then_defs.keys() & else_defs.keys(): + if name in names: + continue + then_ty = then_defs[name].type + else_ty = else_defs[name].type + assert then_ty == else_ty, \ + f'mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + ret_types.append(then_ty) + ir_ret_types.append(then_defs[name].handle.get_type()) + + return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + + def visit_if_top_level(self, cond, node): + has_endif_block = True + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create basic-block after conditional + endif_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # then terminator + self.builder.set_insertion_point_to_end(then_block) + if then_block.has_return() and else_block.has_return(): + has_endif_block = False + endif_block.erase() + if not then_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + if not else_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + if has_endif_block: + for ty in ir_ret_types: + endif_block.add_argument(ty) + if has_endif_block: + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + # update values + for i, name in enumerate(names): + new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) + self.set_value(name, new_tensor) + + def visit_If(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if self.scf_stack and contains_return: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + elif self.scf_stack or not contains_return: + self.visit_if_scf(cond, node) + else: + self.visit_if_top_level(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.orelse) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = language.core._to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = language.core._to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) == ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) == ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_builder=self.builder) + try: + return getattr(operand, fn)() + except AttributeError: + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + assert _is_triton_tensor(loop_defs[name]), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop' + assert loop_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {loop_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + # these are loop-carried values + names.append(name) + ret_types.append(loop_defs[name].type) + init_args.append(liveins[name]) + + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + self.builder.set_insertion_point_to_start(before_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + yields.append(loop_defs[name]) + self.builder.create_yield_op([y.handle for y in yields]) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = language.core._to_tensor(lb, self.builder) + ub = language.core._to_tensor(ub, self.builder) + step = language.core._to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_undef(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' + assert _is_triton_tensor(liveins[name]) + assert self.local_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {self.local_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + names.append(name) + init_args.append(language.core._to_tensor(liveins[name], self.builder)) + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + if num_stages is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + + self.scf_stack.append(node) + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create YieldOp + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + if not self.debug: + return + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + # Convert assert to triton's device_assert which happens on the device + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = fn.__globals__ + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = _get_fn_file_line(fn) + debug = self.debug if fn.debug is None else fn.debug + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + if fn is language.core.device_assert: # TODO: this should not be so hardcoded + if not self.debug: + return + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = dict(_builder=self.builder) + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + return fn(*args, **extra_kwargs, **kws) + except Exception as e: + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, None) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + return fn(*args, **kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise self._unsupported( + node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + if sys.version_info < (3, 8): + + def visit_NameConstant(self, node): + return constexpr(node.value) + + def visit_Num(self, node): + return constexpr(node.n) + + def visit_Str(self, node): + return constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs): + if node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisible_by_16: + suffix += 'd' + return suffix + + +def ast_to_ttir(fn, specialization, context, options, codegen_fns): + attrs = specialization.attrs + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in specialization.constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = fn.repr(specialization) + tys = list(specialization.signature.values()) + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1} + new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16} + + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + file_name, begin_line = _get_fn_file_line(fn) + + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns) + generator.visit(fn.parse()) + + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/third_party/mthreads/python/triton/compiler/compiler.py b/third_party/mthreads/python/triton/compiler/compiler.py new file mode 100644 index 000000000..226360fd7 --- /dev/null +++ b/third_party/mthreads/python/triton/compiler/compiler.py @@ -0,0 +1,475 @@ +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir +from ..backends import backends +from ..backends.compiler import GPUTarget +from .. import __version__ +from ..runtime.jit import JITFunction +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.driver import driver +# TODO: this shouldn't be here +from dataclasses import dataclass +from .code_generator import ast_to_ttir +from pathlib import Path +import re +import functools +import os + + +@dataclass +class AttrsDescriptor: + divisible_by_16: set = None + equal_to_1: set = None + + def __post_init__(self): + if self.divisible_by_16 is None: + self.divisible_by_16 = set() + if self.equal_to_1 is None: + self.equal_to_1 = set() + + def to_dict(self): + return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)} + + @staticmethod + def from_dict(data): + return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])), + equal_to_1=set(data.get('equal_to_1', []))) + + def hash(self): + key = str([sorted(x) for x in self.__dict__.values()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+),?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if + # e.g. someone has an instruction (not module) attribute named "num-warps". + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + num_warps = int(num_warps_matches[0]) + return num_warps + + +class ASTSource: + + def __init__(self, fn, signature, constants=None, attrs=None) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.attrs = attrs + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + if self.constants is None: + self.constants = dict() + if self.attrs is None: + self.attrs = AttrsDescriptor() + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + # Note - we stringify the keys here to allow sorting to work for cases + # where constants have mixed int/str keys. + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def parse_options(self): + if self.ext == "ttgir": + return {'num_warps': _get_num_warps_from_ir_str(self.src)} + return dict() + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx": + return Path(full_name).read_text() + if ext == "cubin": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +def adaptConfigToTorch_2_2(fn, **kwargs): + # TODO: Get correct cc here. + try: + from torch._dynamo.device_interface import CudaInterface + device = CudaInterface.device + cc = CudaInterface.get_compute_capability(device) + except: + cc = kwargs["cc"] + attrs = AttrsDescriptor( + divisible_by_16=kwargs["configs"][0].divisible_by_16, + equal_to_1=kwargs["configs"][0].equal_to_1, + ) + + src = ASTSource(fn, kwargs["signature"], kwargs["constants"], attrs) + target = GPUTarget(kwargs["device_type"], cc, 128 + # TODO: Warp size is hard-code now. + # rocm_warp_size if torch.version.hip else 32, + ) + options = { + "num_warps": kwargs["num_warps"], + "num_stages": kwargs["num_stages"], + "debug": kwargs["debug"], + } + + return src, target, options + + +def compile(fn, **kwargs): + # Adapt Triton3.0.0 to torch2.2.0. + if isinstance(fn, JITFunction): + src, target, options = adaptConfigToTorch_2_2(fn, **kwargs) + else: + src = fn + target = kwargs["target"] + options = kwargs["options"] + + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + metadata_filename = f"{src.name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" + if not always_compile and metadata_path is not None: + # cache hit! + metadata = json.loads(Path(metadata_path).read_text()) + return CompiledKernel(src, metadata_group, hash) + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() + try: + module = src.make_ir(options, codegen_fns, context) + except Exception as e: + filter_traceback(e) + raise + use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1" + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + ir_filename = f"{src.name}.{ext}" + # Adapt to musa backend. + if ext == "mubin": + metadata_group[ir_filename] = next_module[1] # Get mubin path. + else: + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)): + print(f"\nOverriding kernel with file {ir_filename}") + full_name = fn_override_manager.get_file(ir_filename) + next_module = parse(full_name, ext, context) + # use an env variable to parse ttgir from file + if use_ttgir_loc and ext == "ttgir": + ttgir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ttgir_full_name) + print(f"Create new locations for {ttgir_full_name}") + module = next_module + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target): + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self) -> None: + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = { + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + } + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + # Adapt to torch2.2. + self.num_warps = self.packed_metadata[0] + self.num_ctas = self.packed_metadata[1] + self.shared = self.packed_metadata[2] + self.clusterDims = self.packed_metadata[3:6] + + def _init_handles(self): + + def adaptToTorch_2_2Wrapper(fn): + + def wrapper(grid_0, grid_1, grid_2, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared, + stream, cu_function, launch_metadata, launch_enter_hook, launch_exit_hook, *args): + + metadata = (num_warps, num_ctas, shared, clusterDimX, clusterDimY, clusterDimZ) + + return fn(grid_0, grid_1, grid_2, stream, cu_function, metadata, launch_metadata, launch_enter_hook, + launch_exit_hook, *args) + + return wrapper + + if self.module is not None: + return + device = driver.active.get_current_device() + # create launcher + self._run = driver.active.launcher_cls(self.src, self.metadata) + self.c_wrapper = adaptToTorch_2_2Wrapper(self._run) + self.run = self._run + # not enough shared memory to run the kernel + max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + + self.cu_function = self.function + + def __getattribute__(self, name): + if name in ["c_wrapper", "run"]: + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if CompiledKernel.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + if i in self.src.fn.constexprs: + arg_dict[arg_name] = self.src.constants[arg_name] + else: + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) + + return runner diff --git a/third_party/mthreads/python/triton/compiler/errors.py b/third_party/mthreads/python/triton/compiler/errors.py new file mode 100644 index 000000000..39e6c4dfb --- /dev/null +++ b/third_party/mthreads/python/triton/compiler/errors.py @@ -0,0 +1,51 @@ +import ast +from typing import Optional +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/third_party/mthreads/python/triton/compiler/make_launcher.py b/third_party/mthreads/python/triton/compiler/make_launcher.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/mthreads/python/triton/errors.py b/third_party/mthreads/python/triton/errors.py new file mode 100644 index 000000000..3a0a86355 --- /dev/null +++ b/third_party/mthreads/python/triton/errors.py @@ -0,0 +1,5 @@ +"""Base class for all errors raised by Triton""" + + +class TritonError(Exception): + ... diff --git a/third_party/mthreads/python/triton/language/__init__.py b/third_party/mthreads/python/triton/language/__init__.py new file mode 100644 index 000000000..168dccfea --- /dev/null +++ b/third_party/mthreads/python/triton/language/__init__.py @@ -0,0 +1,284 @@ +"""isort:skip_file""" +# Import order is significant here. + +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + sigmoid, + softmax, + sort, + sum, + swizzle2d, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + _experimental_descriptor_load, + _experimental_descriptor_store, + advance, + arange, + associative_scan, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + const, + const_pointer_type, + constexpr, + debug_barrier, + device_assert, + device_print, + dot, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + function_type, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + max_constancy, + max_contiguous, + maximum, + minimum, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + program_id, + range, + reduce, + reshape, + split, + static_assert, + static_print, + static_range, + store, + tensor, + trans, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) + +__all__ = [ + "PropagateNan", + "TRITON_MAX_TENSOR_NUMEL", + "_experimental_descriptor_load", + "_experimental_descriptor_store", + "abs", + "advance", + "arange", + "argmax", + "argmin", + "associative_scan", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "block_type", + "broadcast", + "broadcast_to", + "builtin", + "cat", + "cast", + "cdiv", + "ceil", + "clamp", + "const", + "const_pointer_type", + "constexpr", + "cos", + "cumprod", + "cumsum", + "debug_barrier", + "device_assert", + "device_print", + "div_rn", + "dot", + "dtype", + "erf", + "exp", + "exp2", + "expand_dims", + "extra", + "fdiv", + "flip", + "float16", + "float32", + "float64", + "float8e4b15", + "float8e4nv", + "float8e4b8", + "float8e5", + "float8e5b16", + "floor", + "fma", + "full", + "function_type", + "histogram", + "inline_asm_elementwise", + "interleave", + "int1", + "int16", + "int32", + "int64", + "int8", + "ir", + "join", + "load", + "log", + "log2", + "make_block_ptr", + "math", + "max", + "max_constancy", + "max_contiguous", + "maximum", + "min", + "minimum", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "permute", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "range", + "ravel", + "reduce", + "reshape", + "rsqrt", + "sigmoid", + "sin", + "softmax", + "sort", + "split", + "sqrt", + "sqrt_rn", + "static_assert", + "static_print", + "static_range", + "store", + "sum", + "swizzle2d", + "tensor", + "trans", + "triton", + "uint16", + "uint32", + "uint64", + "uint8", + "uint_to_uniform_float", + "umulhi", + "view", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] + + +def str_to_ty(name): + if name[0] == "*": + name = name[1:] + if name[0] == "k": + name = name[1:] + ty = str_to_ty(name) + return const_pointer_type(ty) + ty = str_to_ty(name) + return pointer_type(ty) + tys = { + "fp8e4nv": float8e4nv, + "fp8e4b8": float8e4b8, + "fp8e5": float8e5, + "fp8e5b16": float8e5b16, + "fp8e4b15": float8e4b15, + "fp16": float16, + "bf16": bfloat16, + "fp32": float32, + "fp64": float64, + "i1": int1, + "i8": int8, + "i16": int16, + "i32": int32, + "i64": int64, + "u1": int1, + "u8": uint8, + "u16": uint16, + "u32": uint32, + "u64": uint64, + "B": int1, + } + return tys[name] diff --git a/third_party/mthreads/python/triton/language/core.py b/third_party/mthreads/python/triton/language/core.py new file mode 100644 index 000000000..f2d3266e9 --- /dev/null +++ b/third_party/mthreads/python/triton/language/core.py @@ -0,0 +1,2621 @@ +from __future__ import annotations + +from warnings import warn +from contextlib import contextmanager +from enum import Enum +from functools import partial, wraps +import typing +from typing import Union, Callable, List, Sequence, TypeVar, Optional +import builtins +from ..runtime.jit import jit +import inspect +import os + +from .._C.libtriton import ir +from . import semantic + +T = TypeVar('T') + +TRITON_MAX_TENSOR_NUMEL = 1048576 + +TRITON_BUILTIN = "__triton_builtin__" + +PropagateNan = ir.PROPAGATE_NAN + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def _tensor_member_fn(fn: T) -> T: + """Decorator that adds this free function as a member fn on class tensor. + + When called as a member function on class tensor, the first argument to `fn` + is `self`, i.e. the tensor object. + + If there are multiple decorators on a function, you probably want this one + to be the highest one (i.e. furthest from the function's `def`), so it's + applied last. + + Unfortunately you still need to add a type stub to the body of class tensor + in order for pytype to know about it. + """ + assert callable(fn) + orig_sig = inspect.signature(fn) + # Does fn take args other than _builder, _generator, and the tensor itself? + has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1 + + if not fn.__doc__: + fn.__doc__ = "" + fn.__doc__ += f""" + This function can also be called as a member function on :py:class:`tensor`, + as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of + :code:`{fn.__name__}(x{", ..." if has_args else ""})`. + """ + + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Match the signature of `fn`, but change the first arg to `self` so the + # docs are a little less weird. + new_params = list(orig_sig.parameters.values()) + new_params[0] = new_params[0].replace(name='self') + new_sig = orig_sig.replace(parameters=new_params) + wrapper.__signature__ = new_sig + wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function" + # If fn is a builtin, mark the wrapper as a builtin too. + if is_builtin(fn): + setattr(wrapper, TRITON_BUILTIN, True) + + setattr(tensor, fn.__name__, wrapper) + return fn + + +def _unwrap_iterable(x): + """Returns x[0] if x has one element and x[0] is iterable.""" + if len(x) == 1: + # Determine whether x[0] is iterable. + # + # You might want to use collections.abc.Iterable instead of this + # try/except block. Unfortunately, this doesn't work with constexpr. + # + # The problem is that abc.Iterable checks for __iter__ on the *class*. + # But we want constexpr to expose an __iter__ method if and only if the + # wrapped *object* (i.e. self.value) is iterable. Therefore there's no + # right answer for whether the class constexpr defines __iter__, and + # abc.Iterable doesn't work (at least not without some metaclass magic). + try: + iter(x[0]) + return x[0] + except TypeError: + pass + + return x + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +@builtin +def to_tensor(x, _builder=None): + return _to_tensor(x, _builder) + + +def _to_tensor(x, builder): + if isinstance(x, bool): + return tensor(builder.get_int1(x), int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + return tensor(builder.get_int32(x), int32) + elif 2**31 <= x < 2**32: + return tensor(builder.get_uint32(x), uint32) + elif -2**63 <= x < 2**63: + return tensor(builder.get_int64(x), int64) + elif 2**63 <= x < 2**64: + return tensor(builder.get_uint64(x), uint64) + else: + raise RuntimeError(f'Nonrepresentable integer {x}.') + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + return tensor(builder.get_fp32(x), float32) + else: + return tensor(builder.get_fp64(x), float64) + + elif isinstance(x, constexpr): + return _to_tensor(x.value, builder) + elif isinstance(x, tensor): + return x + assert False, f"cannot convert {x} of type {type(x)} to tensor" + + +class dtype: + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + def __init__(self, name): + if hasattr(name, 'value'): + name = name.value + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 53 + self.primitive_bitwidth = 64 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + +# Some functions have a param named `dtype`, which shadows the `dtype` class. +# We can't change the param name because it is part of function's public API. +# Declare an alias so those functions can still reference the dtype class. +_DtypeClass = dtype + + +class pointer_type(dtype): + + def __init__(self, element_ty: dtype, address_space: int = 1): + if not isinstance(element_ty, dtype): + raise TypeError(f'element_ty is a {type(element_ty).__name__}.') + self.element_ty = element_ty + self.address_space = address_space + + self.name = f'pointer<{element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class const_pointer_type(pointer_type): + + def __init__(self, element_ty: dtype, address_space: int = 1): + super().__init__(element_ty, address_space) + + def __str__(self): + return f'const_pointer<{self.element_ty}>' + + def is_const(self): + return True + + def __eq__(self, other) -> bool: + if not isinstance(other, const_pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + + # shape can be empty ([]) when an input is a 0D tensor. + if not shape: + raise TypeError('0d block_type is forbidden') + if isinstance(shape[0], constexpr): + shape = [s.value for s in shape] + + self.shape = shape + self.numel = 1 + for s in self.shape: + self.numel *= s + if self.numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: + self.ret_types = ret_types + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_types}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(ir_param_types, ret_types) + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') +float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') +float8e4b15 = dtype('fp8e4b15') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _constexpr_to_value + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _constexpr_to_value(other)) + + def __radd__(self, other): + return constexpr(_constexpr_to_value(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _constexpr_to_value(other)) + + def __rsub__(self, other): + return constexpr(_constexpr_to_value(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _constexpr_to_value(other)) + + def __mod__(self, other): + return constexpr(self.value % _constexpr_to_value(other)) + + def __rmul__(self, other): + return constexpr(_constexpr_to_value(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _constexpr_to_value(other)) + + def __rtruediv__(self, other): + return constexpr(_constexpr_to_value(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _constexpr_to_value(other)) + + def __rfloordiv__(self, other): + return constexpr(_constexpr_to_value(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _constexpr_to_value(other)) + + def __rgt__(self, other): + return constexpr(_constexpr_to_value(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _constexpr_to_value(other)) + + def __rge__(self, other): + return constexpr(_constexpr_to_value(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _constexpr_to_value(other)) + + def __rlt__(self, other): + return constexpr(_constexpr_to_value(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _constexpr_to_value(other)) + + def __rle__(self, other): + return constexpr(_constexpr_to_value(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _constexpr_to_value(other)) + + def __ne__(self, other): + return constexpr(self.value != _constexpr_to_value(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _constexpr_to_value(other)) + + def logical_and(self, other): + return constexpr(self.value and _constexpr_to_value(other)) + + def __or__(self, other): + return constexpr(self.value | _constexpr_to_value(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _constexpr_to_value(other)) + + def logical_or(self, other): + return constexpr(self.value or _constexpr_to_value(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_constexpr_to_value(other)) + + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _constexpr_to_value(other)) + + def __lshift__(self, other): + return constexpr(self.value << _constexpr_to_value(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +CONSTEXPR_0 = constexpr(0) + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +class tensor: + """Represents an N-dimensional array of values or pointers. + + :code:`tensor` is the fundamental data structure in Triton programs. Most + functions in :py:mod:`triton.language` operate on and return tensors. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + :code:`tensor` also defines most of the magic/dunder methods, so you can + write :code:`x+y`, :code:`x << 2`, etc. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, type: dtype): + """Not called by user code.""" + # IR handle + self.handle = handle + # Block shape + self.shape = type.shape if type.is_block() else () + self.numel = 1 + for s in self.shape: + self.numel *= s + self.numel = constexpr(self.numel) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = [constexpr(s) for s in self.shape] + + def __str__(self) -> str: + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.add(self, other, _builder) + + @builtin + def __radd__(self, other, _builder=None): + return self.__add__(other, _builder=_builder) + + @builtin + def __sub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(self, other, _builder) + + @builtin + def __rsub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(other, self, _builder) + + @builtin + def __mul__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mul(self, other, _builder) + + @builtin + def __rmul__(self, other, _builder=None): + return self.__mul__(other, _builder=_builder) + + @builtin + def __truediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(self, other, _builder) + + @builtin + def __rtruediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(other, self, _builder) + + @builtin + def __floordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(self, other, _builder) + + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(other, self, _builder) + + @builtin + def __mod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(self, other, _builder) + + @builtin + def __rmod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(other, self, _builder) + + # unary operators + @builtin + def __neg__(self, _builder=None): + return semantic.minus(self, _builder) + + @builtin + def __invert__(self, _builder=None): + return semantic.invert(self, _builder) + + # bitwise operators + + @builtin + def __and__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(self, other, _builder) + + @builtin + def __rand__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(other, self, _builder) + + @builtin + def __or__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(self, other, _builder) + + @builtin + def __ror__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(other, self, _builder) + + @builtin + def __xor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(self, other, _builder) + + @builtin + def __rxor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(other, self, _builder) + + @builtin + def __lshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + return semantic.shl(self, other, _builder) + + @builtin + def __rlshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + return semantic.shl(other, self, _builder) + + @builtin + def __rshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + + @builtin + def __rrshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(other, self, _builder) + else: + return semantic.lshr(other, self, _builder) + + # > + @builtin + def __gt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) + + @builtin + def __rgt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) + + # >= + @builtin + def __ge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + + @builtin + def __rge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) + + # < + @builtin + def __lt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) + + @builtin + def __rlt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(other, self, _builder) + + # <= + @builtin + def __le__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) + + @builtin + def __rle__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) + + # == + @builtin + def __eq__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(self, other, _builder) + + @builtin + def __req__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(other, self, _builder) + + @builtin + def __ne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) + + @builtin + def __rne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(other, self, _builder) + + @builtin + def logical_and(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_and(self, other, _builder) + + @builtin + def logical_or(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_or(self, other, _builder) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _builder=None): + return semantic.not_(self, _builder) + + @builtin + def __getitem__(self, slices, _builder=None): + if isinstance(slices, (slice, constexpr)) or slices is None: + slices = [slices] + ret = self + for dim, sl in enumerate(slices): + if sl is None or isinstance(sl, constexpr) and sl.value is None: + ret = semantic.expand_dims(ret, dim, _builder) + elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + pass + else: + raise ValueError(f"unsupported tensor index: {sl}") + return ret + + @property + def T(self): + """Transposes a 2D tensor.""" + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Alias for :py:func:`tensor.cast`. + """ + # Triton doesn't like core functions calling other core functions, so we + # just copy-paste the implementation of cast here. It's not too bad. + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder, fp_downcast_rounding) + + # Type stubs for functions added by the _tensor_member_fn decorator. + # (Unfortunately these can't be created automatically.) + # + # We couldn't write these definitions out even if we wanted to, because some + # of these functions are defined in standard.py. + def broadcast_to(self, *shape) -> tensor: + ... + + def trans(self, *dims) -> tensor: + ... + + def permute(self, *dims) -> tensor: + ... + + def split(self) -> tuple[tensor, tensor]: + ... + + def view(self, *shape) -> tensor: + ... + + def reshape(self, *shape) -> tensor: + ... + + def expand_dims(self, axis) -> tensor: + ... + + def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor: + ... + + def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor: + ... + + def advance(self, offsets) -> tensor: + ... + + def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor: + ... + + def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def exp(self) -> tensor: + ... + + def log(self) -> tensor: + ... + + def cos(self) -> tensor: + ... + + def sin(self) -> tensor: + ... + + def sqrt(self) -> tensor: + ... + + def rsqrt(self) -> tensor: + ... + + def abs(self) -> tensor: + ... + + def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: + ... + + def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: + ... + + def histogram(self, num_bins) -> tensor: + ... + + def cdiv(self, div) -> tensor: + ... + + def sigmoid(self) -> tensor: + ... + + def softmax(self, ieee_rounding=False) -> tensor: + ... + + def ravel(self) -> tensor: + ... + + def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def xor_sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def cumsum(self, axis=0, reverse=False) -> tensor: + ... + + def cumprod(self, axis=0, reverse=False) -> tensor: + ... + + def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor: + ... + + def flip(self, dim=None) -> tensor: + ... + + +def get_bool_env_var(var_name): + v = os.getenv(var_name, "0") + return v == "1" or v == "true" or v == "on" + + +# ----------------------- +# SPMD Programming Model +# ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v + + +@builtin +def program_id(axis, _builder=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(0, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) + + +@builtin +def num_programs(axis, _builder=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _builder=None): + """ + Returns contiguous values within the half-open interval :code:`[start, + end)`. :code:`end - start` must be less than or equal to + :code:`TRITON_MAX_TENSOR_NUMEL = 131072` + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two greater than + :code:`start`. + :type end: int32 + """ + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) + + +def _shape_check_impl(shape): + shape = _constexpr_to_value(shape) + for i, d in enumerate(shape): + if isinstance(d, int): + d = constexpr(d) + if not isinstance(d, constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + if d.value & (d.value - 1) != 0: + raise ValueError(f"Shape element {i} must be a power of 2") + return [_constexpr_to_value(x) for x in shape] + + +@builtin +def full(shape, value, dtype, _builder=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :value value: A scalar value to fill the array with + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + shape = _shape_check_impl(shape) + value = _constexpr_to_value(value) + dtype = _constexpr_to_value(dtype) + return semantic.full(shape, value, dtype, _builder) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _builder=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return semantic.broadcast_impl_value(input, other, _builder) + + +@_tensor_member_fn +@builtin +def broadcast_to(input, *shape, _builder=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + broadcast_to(x, (32, 32)) + broadcast_to(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.broadcast_impl_shape(input, shape, _builder) + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + If no permutation is specified, tries to do a (1,0) permutation, i.e. tries + to transpose a 2D tensor. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + if not dims: + dims = (1, 0) + return semantic.permute(input, dims, _builder) + + +@_tensor_member_fn +@builtin +def permute(input, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + :param input: The input tensor. + :type input: Block + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + permute(x, (2, 1, 0)) + permute(x, 2, 1, 0) + + :py:func:`trans` is equivalent to this function, except when + :code:`dims` is empty, it tries to do a (1,0) permutation. + """ + dims = _unwrap_iterable(dims) + return semantic.permute(input, dims, _builder) + + +@builtin +def cat(input, other, can_reorder=False, _builder=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: + :param other: The second input tensor. + :type other: + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops) + """ + return semantic.cat(input, other, can_reorder, _builder) + + +@builtin +def join(a, b, _builder=None): + """ + Join the given tensors in a new, minor dimension. + + For example, given two tensors of shape (4,8), produces a new tensor of + shape (4,8,2). Given two scalars, returns a tensor of shape (2). + + The two inputs are broadcasted to be the same shape. + + If you want to join more than two elements, you can use multiple calls to + this function. This reflects the constraint in Triton that tensors must + have power-of-two sizes. + + join is the inverse of split. + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + return semantic.join(a, b, _builder) + + +@jit +def _take_first(a, b): + return a + + +@_tensor_member_fn +@builtin +def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: + """ + Split a tensor in two along its last dim, which must have size 2. + + For example, given a tensor of shape (4,8,2), produces two tensors of shape + (4,8). Given a tensor of shape (2), returns two scalars. + + If you want to split into more than two pieces, you can use multiple calls + to this function (probably plus calling reshape). This reflects the + constraint in Triton that tensors must have power-of-two sizes. + + split is the inverse of join. + + :param a: The tensor to split. + :type a: Tensor + """ + # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars. + # But semantic.split can only handle returning tensors. Work around this by + # expanding the input to shape [1,2] and then reducing the result. + was_rank_1 = len(a.shape) == 1 + if was_rank_1: + a = semantic.expand_dims(a, 0, _builder) + + out_lhs, out_rhs = semantic.split(a, _builder) + + if was_rank_1: + # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar. + out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator)) + out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator)) + + return out_lhs, out_rhs + + +@_tensor_member_fn +@builtin +def view(input, *shape, _builder=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + view(x, (32, 32)) + view(x, 32, 32) + """ + warn("view is deprecated, please use reshape with can_reorder being true.") + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder=True, builder=_builder) + + +@_tensor_member_fn +@builtin +def reshape(input, *shape, can_reorder=False, _builder=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: Block + :param shape: The new shape. + + :code:`shape ` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + reshape(x, (32, 32)) + reshape(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder, _builder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@_tensor_member_fn +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + input = _to_tensor(input, _builder) + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, Sequence) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + """ + input = _to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + return semantic.cast(input, dtype, _builder, fp_downcast_rounding) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _builder=None): + """ + Returns the matrix product of two blocks. + + The two blocks must be two-dimensional and have compatible inner dimensions. + + :param input: The first tensor to be multiplied. + :type input: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input_precision: How to exercise the Tensor Cores for f32 x f32. If + the device does not have Tensor Cores or the inputs are not of dtype f32, + this option is ignored. For devices that do have tensor cores, the + default precision is tf32. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + Only one of :code:`input_precision` and :code:`allow_tf32` can be + specified (i.e. at least one must be :code:`None`). + """ + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + if input_precision is None: + supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions + default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) + + input_precision = _constexpr_to_value(input_precision) + out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, _builder=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be None, and + - `boundary_check` and `padding_option` can be specified to control + the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + # `mask` and `other` can be constexpr + mask = _constexpr_to_value(mask) + other = _constexpr_to_value(other) + if mask is not None: + mask = _to_tensor(mask, _builder) + if other is not None: + other = _to_tensor(other, _builder) + padding_option = _constexpr_to_value(padding_option) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile, _builder) + + +@builtin +def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None): + """ + Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This loads a tensor of data based on the descriptor and offsets. + """ + type = block_type(dtype, shape) + return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) + + +@builtin +def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): + """ + Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This stores a tensor of data based on the descriptor and offsets. + """ + return semantic.descriptor_store(desc_pointer, value, offsets, _builder) + + +@_tensor_member_fn +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): + """ + Store a tensor of data into memory locations defined by `pointer`. + + (1) If `pointer` is a single element pointer, a scalar is stored. In + this case: + + - `mask` must also be scalar, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional block is stored. In this case: + + - `mask` is implicitly broadcast to `pointer.shape`, and + - `boundary_check` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block + of data is stored. In this case: + + - `mask` must be None, and + - `boundary_check` can be specified to control the behavior of out-of-bound access. + + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + """ + # `value` can be constexpr + value = _to_tensor(value, _builder) + mask = _constexpr_to_value(mask) + if mask is not None: + mask = _to_tensor(mask, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder) + + +@_tensor_member_fn +@builtin +def advance(base, offsets, _builder=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return semantic.advance(base, offsets, _builder) + + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = f""" + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default), + "ACQUIRE", "RELEASE", or "RELAXED") + :type sem: str + :param scope: Scope of threads that observe synchronizing effect of the + atomic operation ("GPU" (default), "CTA", or "SYSTEM") + :type scope: str + """ + func.__doc__ = docstr + return func + + return _decorator + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("compare-and-swap", has_cmp=True) +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_add(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_max(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_min(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_and(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_or(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, _builder=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = _to_tensor(condition, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.where(condition, x, y, _builder) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.minimum(x, y, propagate_nan, _builder) + + +@builtin +def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.maximum(x, y, propagate_nan, _builder) + + +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + min = _to_tensor(min, _builder) + max = _to_tensor(max, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + min = _promote_bfloat16_to_float32(min, _builder=_builder) + max = _promote_bfloat16_to_float32(max, _builder=_builder) + + propagate_nan = _constexpr_to_value(propagate_nan) + + return semantic.clamp(x, min, max, propagate_nan, _builder) + + +# ----------------------- +# Reductions +# ----------------------- + + +def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the reduction should be done + :param keep_dims: if true, keep the reduced dimensions with length 1""" + if return_indices_arg is not None: + docstr += f""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value""" + if tie_break_arg is not None: + docstr += f""" + :param {tie_break_arg}: if true, return the left-most indices in case of ties for values that aren't NaN""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@_tensor_member_fn +@builtin +def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param keep_dims: if true, keep the reduced dimensions with length 1 + + """ + if isinstance(input, tensor): + return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) + + def expand_ndims(t, ndims): + for _ in builtins.range(ndims): + t = expand_dims(t, 0, _builder=_builder) + return t + + axis = _constexpr_to_value(axis) + keep_dims = _constexpr_to_value(keep_dims) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + ret = semantic.reduction(input, axis, make_combine_region, _builder) + if keep_dims: + if axis is not None: + ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret) + else: + ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + return ret + + +@builtin +def _promote_bfloat16_to_float32(t, _builder=None): + scalar_ty = t.type.scalar + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) + return t + + +@builtin +def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, + _generator=_generator) + return rvalue, rindices + + +# ----------------------- +# Scans +# ----------------------- + + +def _add_scan_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the scan should be done""" + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@_tensor_member_fn +@builtin +def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None): + """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param reverse: apply the associative scan in the reverse direction along axis. + + """ + if isinstance(input, tensor): + return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(scan_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = scan_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_scan_ret(*handles) + + axis = _constexpr_to_value(axis) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder) + + +@_tensor_member_fn +@builtin +def histogram(input, num_bins, _builder=None, _generator=None): + """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. + + :param input: the input tensor + :param num_bins: number of histogram bins + + """ + num_bins = _constexpr_to_value(num_bins) + return semantic.histogram(input, num_bins, _builder) + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_builder=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return semantic.debug_barrier(_builder) + + +@builtin +def multiple_of(input, values, _builder=None): + """ + Let the compiler know that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) + + +@builtin +def max_constancy(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are constant. + + e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal, + for example [0, 0, 0, 0, 1, 1, 1, 1]. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_constancy(input, values) + + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + + NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`, + which has special requirements for the arguments. + + .. highlight:: python + .. code-block:: python + + tl.static_print(f"{BLOCK_SIZE=}") + ''' + pass + + +@builtin +def static_assert(cond, msg="", _builder=None): + ''' + Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable + is set. + + .. highlight:: python + .. code-block:: python + + tl.static_assert(BLOCK_SIZE == 1024) + ''' + pass + + +@builtin +def device_print(prefix, *args, hex=False, _builder=None): + ''' + Print the values at runtime from the device. String formatting does not work for runtime values, so you should + provide the values you want to print as arguments. The first value must be a string, all following values must + be scalars or tensors. + + Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match + this function (not the normal requirements for :code:`print`). + + .. highlight:: python + .. code-block:: python + + tl.device_print("pid", pid) + print("pid", pid) + + On CUDA, printfs are streamed through a buffer of limited size (on one host, + we measured the default as 6912 KiB, but this may not be consistent across + GPUs and CUDA versions). If you notice some printfs are being dropped, you + can increase the buffer size by calling + + .. highlight:: python + .. code-block:: python + + triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes) + + CUDA may raise an error if you try to change this value after running a + kernel that uses printfs. The value set here may only affect the current + device (so if you have multiple GPUs, you'd need to call it multiple times). + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + :param hex: print all values as hex instead of decimal + ''' + import string + prefix = _constexpr_to_value(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_to_tensor(arg, _builder)) + return semantic.device_print(prefix, new_args, hex, _builder) + + +@builtin +def device_assert(cond, msg="", _builder=None): + ''' + Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG` + is set to a value besides :code:`0` in order for this to have any effect. + + Using the Python :code:`assert` statement is the same as calling this function, except that the second argument + must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must + be set for this :code:`assert` statement to have any effect. + + .. highlight:: python + .. code-block:: python + + tl.device_assert(pid == 0) + assert pid == 0, f"pid != 0" + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _constexpr_to_value(msg) + import inspect + frame = inspect.currentframe() + module = inspect.getmodule(frame) + # The triton function module doesn't have the name attribute. + # We use this trick to find the caller. + while hasattr(module, "__name__"): + frame = frame.f_back + module = inspect.getmodule(frame) + lineno = 0 + func_name = 'unknown' + file_name = 'unknown' + if frame is not None and frame.f_back is not None: + func_name = frame.f_code.co_name + file_name = frame.f_back.f_code.co_filename + # TODO: The line number currently indicates the line + # where the triton function is called but not where the + # device_assert is called. Need to enhance this. + lineno = frame.f_back.f_lineno + return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder) + + +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]], + is_pure: bool, pack: int, _builder=None): + ''' + Execute inline assembly over a tensor. Essentially, this is :code:`map` + where the function is inline assembly. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + :code:`dtype` can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes :code:`pack` elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty :code:`dtype` -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Example using + [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + assembly: + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + :param asm: assembly to run. Must match target's assembly format. + :param constraints: asm constraints in + [LLVM format](https://llvm.org/docs/LangRef.html#inline-asm-constraint-string) + :param args: the input tensors, whose values are passed to the asm block + :param dtype: the element type(s) of the returned tensor(s) + :param is_pure: if true, the compiler assumes the asm block has no side-effects + :param pack: the number of elements to be processed by one instance of inline assembly + :param _builder: the builder + :return: one tensor or a tuple of tensors of the given dtypes + ''' + asm = _constexpr_to_value(asm) + constraints = _constexpr_to_value(constraints) + pack = _constexpr_to_value(pack) + is_pure = _constexpr_to_value(is_pure) + + # Wrap `dtype` in a tuple if it's not already. + try: + iter(dtype) # type: ignore + has_multiple_outputs = True + except TypeError: + has_multiple_outputs = False + dtype = (dtype, ) # type: ignore + + dtype = typing.cast(Sequence[_DtypeClass], dtype) + + res_tys = dtype + if dispatch_args := [_to_tensor(arg, _builder) for arg in args]: + bin_op_type_checking = partial( + semantic.binary_op_type_checking_impl, + builder=_builder, + arithmetic_check=False, + allow_lhs_ptr=True, + allow_rhs_ptr=True, + ) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = bin_op_type_checking(item, broadcast_arg) + if broadcast_arg.shape: + # Change the shape of each argument based on the broadcast shape + for i, item in enumerate(dispatch_args): + dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg) + res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype] + handles = [t.handle for t in dispatch_args] + call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack) + + if not has_multiple_outputs: + return tensor(call.get_result(0), res_tys[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr) + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr) + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + assert isinstance(arg2, constexpr) + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +class range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + +# ----------------------- +# Extern functions +# ----------------------- + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _builder=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = block_type(ret_type, ret_shape) + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) + + +@builtin +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _builder=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + ret_shape = None + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _to_tensor(dispatch_args[i], _builder) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if len(arg_types) > 0: + arg_types = tuple(arg_types) + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in builtins.range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + if not all_scalar: + ret_shape = broadcast_arg.shape + func = _builder.create_extern_elementwise + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder) + + +def binary_op_type_legalization(lhs, rhs, builder): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs, builder) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) diff --git a/third_party/mthreads/python/triton/language/extra/__init__.py b/third_party/mthreads/python/triton/language/extra/__init__.py new file mode 100644 index 000000000..14e1778d2 --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/__init__.py @@ -0,0 +1,4 @@ +from . import cuda +from . import hip + +__all__ = ['cuda', 'hip'] diff --git a/python/triton/language/extra/cuda/__init__.py b/third_party/mthreads/python/triton/language/extra/cuda/__init__.py similarity index 100% rename from python/triton/language/extra/cuda/__init__.py rename to third_party/mthreads/python/triton/language/extra/cuda/__init__.py diff --git a/third_party/mthreads/python/triton/language/extra/cuda/libdevice.py b/third_party/mthreads/python/triton/language/extra/cuda/libdevice.py new file mode 100644 index 000000000..3490e6b0e --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/cuda/libdevice.py @@ -0,0 +1,1629 @@ +from triton.language import core + + +@core.extern +def clz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def popc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("__nv_umul64hi", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul24(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def brev(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sad(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp64h(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def saturatef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__nv_dmul_ru", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__nv_fmul_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hiloint2double(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2loint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2hiint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_uint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def longlong_as_double(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double_as_longlong(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_sinf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_cosf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log2f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_logf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_tanf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_exp10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_powf(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_nearbyint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isnand", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_signbitd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinpi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_norm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_rnorm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def yn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def jn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def scalbn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def remainder(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_pow", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def logb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isfinited(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/mthreads/python/triton/language/extra/cuda/utils.py b/third_party/mthreads/python/triton/language/extra/cuda/utils.py new file mode 100644 index 000000000..01bc040b2 --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/cuda/utils.py @@ -0,0 +1,109 @@ +from triton.language import core + + +@core.extern +def globaltimer(_builder=None): + return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1, + _builder=_builder) + + +@core.extern +def smid(_builder=None): + return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, + _builder=_builder) + + +@core.builtin +def num_threads(_builder=None): + return core.constexpr(_builder.options.num_warps * 32) + + +@core.builtin +def num_warps(_builder=None): + return core.constexpr(_builder.options.num_warps) + + +# ----- FP8E4M3B15 ------ +# This data-type is a variant of the standard FP8E4M3 format. +# It was designed for fast software conversion to FP16 on +# nvidia GPUs that do not support it natively. +# This is the same format as FP8E4M3Nv, but: +# - the exponent bias is 15 instead of 7 +# - 0xff and 0x7f are mapped to +-1.750 instead of +-nan +@core.builtin +def convert_fp8e4b15_to_float16(arg, _builder=None): + return core.inline_asm_elementwise( + "{ \n" + ".reg .b32 a<2>, b<2>; \n" + "prmt.b32 a0, 0, $2, 0x5746; \n" + "and.b32 b0, a0, 0x7f007f00; \n" + "and.b32 b1, a0, 0x00ff00ff; \n" + "and.b32 a1, a0, 0x00800080; \n" + "shr.b32 b0, b0, 1; \n" + "add.u32 b1, b1, a1; \n" + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" + "shl.b32 $1, b1, 7; \n" + "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None): + asm = """{ + .reg .pred p<4>; + .reg .b32 a<2>, b<2>; + .reg .b16 c<4>; + .reg .b16 max_val_f16; + .reg .b32 max_val_f16x2; + mov.b16 max_val_f16, 0x3F00; + mov.b32 max_val_f16x2, 0x3F003F00; + and.b32 a0, $1, 0x7fff7fff; + and.b32 a1, $2, 0x7fff7fff;""" + if has_minx2: + asm += """min.f16x2 a0, a0, max_val_f16x2; + min.f16x2 a1, a1, max_val_f16x2;""" + else: + asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2; + setp.lt.f16x2 p2|p3, a1, max_val_f16x2; + mov.b32 {c0, c1}, a0; + mov.b32 {c2, c3}, a1; + selp.b16 c0, c0, max_val_f16, p0; + selp.b16 c1, c1, max_val_f16, p1; + selp.b16 c2, c2, max_val_f16, p2; + selp.b16 c3, c3, max_val_f16, p3; + mov.b32 a0, {c0, c1}; + mov.b32 a1, {c2, c3};""" + asm += """mad.lo.u32 a0, a0, 2, 0x00800080; + mad.lo.u32 a1, a1, 2, 0x00800080; + lop3.b32 b0, $1, 0x80008000, a0, 0xea; + lop3.b32 b1, $2, 0x80008000, a1, 0xea; + prmt.b32 $0, b0, b1, 0x7531; + }""" + return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _builder=None): + if arg.type.scalar.is_fp8e4b15(): + upcast_val = convert_fp8e4b15_to_float16(arg, _builder=_builder) + if dst_ty.scalar.is_fp32(): + upcast_val = upcast_val.to(core.float32, _builder=_builder) + return upcast_val + + assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32() + downcast_val = arg + if arg.type.scalar.is_fp32(): + downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _builder=_builder) + downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _builder=_builder) + return downcast_val + + +@core.builtin +def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _builder=_builder) + + +@core.builtin +def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _builder=_builder) diff --git a/third_party/mthreads/python/triton/language/extra/hip/__init__.py b/third_party/mthreads/python/triton/language/extra/hip/__init__.py new file mode 100644 index 000000000..229b57d87 --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/hip/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/third_party/mthreads/python/triton/language/extra/hip/libdevice.py b/third_party/mthreads/python/triton/language/extra/hip/libdevice.py new file mode 100644 index 000000000..02e5d2d0b --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/hip/libdevice.py @@ -0,0 +1,468 @@ +from triton.language import core + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")), + (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")), + (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/mthreads/python/triton/language/extra/libdevice.py b/third_party/mthreads/python/triton/language/extra/libdevice.py new file mode 100644 index 000000000..1acbab109 --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/libdevice.py @@ -0,0 +1,1216 @@ +from .cuda import libdevice as cuda_libdevice +from .hip import libdevice as hip_libdevice +from .musa import libdevice as musa_libdevice +from triton.language import core +from functools import wraps +from typing import TypeVar + +T = TypeVar('T') + + +def dispatch(fn: T) -> T: + """Dispatch a function to a correct implementation.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + _backend = kwargs["_builder"].options.backend_name + if _backend == 'cuda': + _curr_libdevice_module = cuda_libdevice + elif _backend == 'hip': + _curr_libdevice_module = hip_libdevice + elif _backend == 'musa': + _curr_libdevice_module = musa_libdevice + else: + raise RuntimeError('unknown backend') + + try: + _impl = getattr(_curr_libdevice_module, fn.__name__) + except AttributeError: + raise RuntimeError(f'`{_backend}` does not provide support for `{fn.__name__}` extra function') + + return _impl(*args, **kwargs) + + return wrapper + + +@core.extern +@dispatch +def clz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def popc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def byte_perm(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def mulhi(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul24(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def brev(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sad(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def abs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def floor(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp64h(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ceil(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def trunc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def saturatef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rn(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rz(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rd(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_ru(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fast_dividef(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def add_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hiloint2double(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2loint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2hiint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_int(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_uint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def longlong_as_double(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double_as_longlong(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_sinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_cosf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log2f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_logf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_expf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_tanf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_exp10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_powf(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def hadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ffs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llrint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nearbyint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isnan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def signbit(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def copysign(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def finitef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nextafter(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinpi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cospi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atan2(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def atan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log1p(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def expm1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def norm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def cbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def yn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def jn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcx(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def lgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ldexp(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def scalbn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fmod(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def remainder(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fma(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def pow(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def tgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def round(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llround(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fdim(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def ilogb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def logb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isfinited(arg0, _builder=None): + ... diff --git a/third_party/mthreads/python/triton/language/extra/musa/__init__.py b/third_party/mthreads/python/triton/language/extra/musa/__init__.py new file mode 100644 index 000000000..2ce22e700 --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/musa/__init__.py @@ -0,0 +1,11 @@ +from . import libdevice + +# TODO: Maybe add some codegen funcitons for musa. +# from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80) + +# __all__ = [ +# "libdevice", "globaltimer", "num_threads", "num_warps", "smid", "convert_custom_float8_sm70", +# "convert_custom_float8_sm80" +# ] + +__all__ = ["libdevice"] diff --git a/third_party/mthreads/python/triton/language/extra/musa/libdevice.py b/third_party/mthreads/python/triton/language/extra/musa/libdevice.py new file mode 100644 index 000000000..07b47e06d --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/musa/libdevice.py @@ -0,0 +1,1803 @@ +from enum import Enum +from triton.language import core + + +class RoundingMode(Enum): + rn = 0 # rte + rz = 1 # rtz + rd = 2 # rtn + ru = 3 # rtp + reserve0 = 4 + reserve1 = 5 + + +@core.extern +def clz(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def popc(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + raise NotImplementedError + + +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + ): ("__mt_mulhi", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_umulhi", core.dtype("uint32")), + ( + core.dtype("int64"), + core.dtype("int64"), + ): ("__mt_mul64hi", core.dtype("int64")), + ( + core.dtype("uint64"), + core.dtype("uint64"), + ): ("__mt_umul64hi", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul24(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + ): ("__mt_mul24", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_umul24", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def brev(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def sad(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + arg2, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + core.dtype("uint32"), + ): ("__mt_sad", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_usad", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__mt_abs_i32", core.dtype("int32")), + # FIXME mtcc not support abs(int64) + # (core.dtype("int64"),): ("__nv_llabs", core.dtype("int64")), + ( + core.dtype("fp32"), ): ("__mt_fabs_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_fabs_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_floor_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_floor_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp64h(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__mt_trunc_f64", core.dtype("fp64")), # FIXME: maybe bad perf + (core.dtype("fp32"), ): ("__mt_trunc_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], {(core.dtype("fp32"), ): ("__mt_exp2_f32", core.dtype("fp32")), (core.dtype("fp64"), ): + ("__mt_exp2_f64", core.dtype("fp64")), # FIXME: maybe bad perf + }, is_pure=True, _builder=_builder) + + +@core.extern +def saturatef(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + arg2, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_fmaf_rn_f32", core.dtype("fp32")), + # FIXME mtcc not support __mt_fmaf_rn_f64 + # (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + raise NotImplementedError + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + raise NotImplementedError + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + raise NotImplementedError + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_fast_fdivide_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_div_rte_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_div_rte_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_div_rtz_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_div_rtz_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_div_rtn_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_div_rtn_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + # FIXME mtcc not yet support __mt_div_rtp_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_div_rtp_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_div_rtp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def rcp_rz(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def rcp_rd(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def rcp_ru(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def sqrt_rn(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def sqrt_rz(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def sqrt_rd(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def sqrt_ru(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_sqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_sqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_add_rte_f64", core.dtype("fp64")), + # FIXME mtcc not yet support __mt_add_rte_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_add_rte_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_add_rtz_f64", core.dtype("fp64")), + # FIXME mtcc not yet support __mt_add_rtz_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_add_rtz_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_add_rtn_f64", core.dtype("fp64")), + # FIXME mtcc not yet support __mt_add_rtn_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_add_rtn_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_add_rtp_f64", core.dtype("fp64")), + # FIXME mtcc not yet support __mt_add_rtp_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_add_rtp_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_mul_rte_f64", core.dtype("fp64")), + # FIXME mtcc not yet support __mt_mul_rte_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_mul_rte_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_mul_rtz_f64", core.dtype("fp64")), + # FIXME mtcc not yet support __mt_mul_rtz_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_mul_rtz_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_mul_rtn_f64", core.dtype("fp64")), + # FIXME mtcc not yet support __mt_mul_rtn_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_mul_rtn_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_mul_rtp_f64", core.dtype("fp64")), + # FIXME mtcc not yet support __mt_mul_rtp_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_mul_rtp_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__mt_i32_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint32"), ): ("__mt_ui32_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + raise NotImplementedError + # TODO make sure __mt_f32_to_i32 eq to __nv_float2int_rn, which rounds to nearest. + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_i32", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + raise NotImplementedError + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("xxx", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_i32_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_i32_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def float2uint_rz(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def float2uint_rd(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def float2uint_ru(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def int2float_rn(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def int2float_rz(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def int2float_rd(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def int2float_ru(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def uint2float_rn(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def uint2float_rz(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def uint2float_rd(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def uint2float_ru(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def hiloint2double(arg0, arg1, _builder=None): + raise NotImplementedError + + +@core.extern +def double2loint(arg0, _builder=None): + # FIXME(lingfeng.qiu): It seems like this function is missed in libdevice.bc of musa. + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2hiint(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_f32_to_ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int64"), ): ("__mt_ll_to_f32_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int64"), ): ("__mt_ll_to_f32_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int64"), ): ("__mt_ll_to_f32_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint64"), ): ("__mt_ull_to_f32_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint64"), ): ("__mt_ull_to_f32_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint64"), ): ("__mt_ull_to_f32_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rn], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_i64_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rz], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_i64_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rd], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_i64_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.ru], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_i64_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rn], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_ui64_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rz], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_ui64_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.rd], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_ui64_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0, RoundingMode.ru], { + (core.dtype("int64"), core.dtype("int8")): ("__mt_ui64_to_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__mt_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("uint32"), ): ("__mt_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_uint(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_float_as_uint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def longlong_as_double(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("int64"), ): ("__mt_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double_as_longlong(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp64"), ): ("__mt_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +# TODO It seems lack of fast_math in mtcc. + + +@core.extern +def fast_sinf(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def fast_cosf(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def fast_log2f(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def fast_logf(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def fast_expf(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def fast_tanf(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def fast_exp10f(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def fast_log10f(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def fast_powf(arg0, arg1, _builder=None): + raise NotImplementedError + + +@core.extern +def hadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + ): ("__mt_hadd", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_uhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("int32"), + core.dtype("int32"), + ): ("__mt_rhadd", core.dtype("int32")), + ( + core.dtype("uint32"), + core.dtype("uint32"), + ): ("__mt_urhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_fsub_rn_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_sub_rte_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + # FIXME mtcc not yet support __mt_sub_rtz_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_sub_rtz_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_sub_rtz_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + # FIXME mtcc not yet support __mt_sub_rtn_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_sub_rtn_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_sub_rtn_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + # FIXME mtcc not yet support __mt_sub_rtp_f32. + # (core.dtype("fp32"), core.dtype("fp32"),): ("__mt_sub_rtp_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_sub_rtp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_rsqrt_rn_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__mt_ffs_i32", core.dtype("int32")), + (core.dtype("int64"), ): ("__mt_ffsll_i64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_isnan_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__mt_isnan_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_signbit_f32", core.dtype("int1")), + (core.dtype("fp64"), ): ("__mt_signbit_f64", core.dtype("int1")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_isfinite_f32", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_isinf_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__mt_isinf_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_nextafter_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_nextafter_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinpi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_sinpi_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_sinpi_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + # FIXME mtcc not yet support __mt_cospi_f32. + # (core.dtype("fp32"),): ("__mt_cospi_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), ): ("__mt_cospi_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_tan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_tan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + # FIXME mtcc not yet support __mt_exp10_f32. + # (core.dtype("fp32"),): ("__mt_exp10_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), ): ("__mt_exp10_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_sinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_sinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], + {(core.dtype("fp32"), ): + ("__mt_tanh_f32", + core.dtype("fp32")), # FIXME: mtcc should wrap the libdevice func to support mp_32 hw supported tanhf + (core.dtype("fp64"), ): ("__mt_tanh_f64", core.dtype("fp64")), # FIXME: maybe bad perf + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_log_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_log_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_log10_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_log10_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_log1p_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_log1p_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhypot(arg0, arg1, _builder=None): + raise NotImplementedError + + +@core.extern +def norm3d(arg0, arg1, arg2, _builder=None): + raise NotImplementedError + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + raise NotImplementedError + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + raise NotImplementedError + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + raise NotImplementedError + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_cbrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_cbrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcbrt(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def j0(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def j1(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def y0(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def y1(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def yn(arg0, arg1, _builder=None): + raise NotImplementedError + + +@core.extern +def jn(arg0, arg1, _builder=None): + raise NotImplementedError + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], {(core.dtype("fp32"), ): ("__mt_erf_f32", core.dtype("fp32")), (core.dtype("fp64"), ): + ("__mt_erf_f64", core.dtype("fp64")), # FIXME: maybe bad perf + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_erfinv_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_erfinv_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_erfc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_erfc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def erfcinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_erfcinv_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_erfcinv_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdfinv(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def normcdf(arg0, _builder=None): + raise NotImplementedError + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_lgamma_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_lgamma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], {(core.dtype("fp32"), core.dtype("int32")): ("__mt_ldexp_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): + ("__mt_ldexp_f64", core.dtype("fp64")), # FIXME: maybe bad perf + }, is_pure=True, _builder=_builder) + + +@core.extern +def scalbn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("int32"), + ): ("__mt_scalbn_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("int32"), + ): ("__mt_scalbn_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_fmod_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_fmod_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def remainder(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_remainder_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_remainder_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp32"), + core.dtype("int32"), + ): ("__mt_pown_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("int32"), + ): ("__mt_pown_f64", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__mt_pow_f32", core.dtype("fp32")), + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__mt_pow_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_tgamma_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_tgamma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_round_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_round_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_llround_f32", core.dtype("int64")), + (core.dtype("fp64"), ): ("__mt_llround_f64", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def logb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__mt_logb_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__mt_logb_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isfinited(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp64"), ): ("__mt_isfinite_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +# device capability >= 31 +@core.extern +def fast_gelu(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__mt_tt_gelu_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +# device capability >= 31 +@core.extern +def fast_tanh(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__mt_tt_tanh_f32", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/mthreads/python/triton/language/extra/musa/utils.py b/third_party/mthreads/python/triton/language/extra/musa/utils.py new file mode 100644 index 000000000..7ed910ea7 --- /dev/null +++ b/third_party/mthreads/python/triton/language/extra/musa/utils.py @@ -0,0 +1,101 @@ +from triton.language import core + +# TODO: Maybe add some codegen funcitons for musa. +# @core.extern +# def globaltimer(_builder=None): +# return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1, +# _builder=_builder) + +# @core.extern +# def smid(_builder=None): +# return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, +# _builder=_builder) + +# @core.builtin +# def num_threads(_builder=None): +# return core.constexpr(_builder.options.num_warps * 32) + +# @core.builtin +# def num_warps(_builder=None): +# return core.constexpr(_builder.options.num_warps) + +# # ----- FP8E4M3B15 ------ +# # This data-type is a variant of the standard FP8E4M3 format. +# # It was designed for fast software conversion to FP16 on +# # nvidia GPUs that do not support it natively. +# # This is the same format as FP8E4M3Nv, but: +# # - the exponent bias is 15 instead of 7 +# # - 0xff and 0x7f are mapped to +-1.750 instead of +-nan +# @core.builtin +# def convert_fp8e4b15_to_float16(arg, _builder=None): +# return core.inline_asm_elementwise( +# "{ \n" +# ".reg .b32 a<2>, b<2>; \n" +# "prmt.b32 a0, 0, $2, 0x5746; \n" +# "and.b32 b0, a0, 0x7f007f00; \n" +# "and.b32 b1, a0, 0x00ff00ff; \n" +# "and.b32 a1, a0, 0x00800080; \n" +# "shr.b32 b0, b0, 1; \n" +# "add.u32 b1, b1, a1; \n" +# "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" +# "shl.b32 $1, b1, 7; \n" +# "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4, +# _builder=_builder) + +# @core.builtin +# def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None): +# asm = """{ +# .reg .pred p<4>; +# .reg .b32 a<2>, b<2>; +# .reg .b16 c<4>; +# .reg .b16 max_val_f16; +# .reg .b32 max_val_f16x2; +# mov.b16 max_val_f16, 0x3F00; +# mov.b32 max_val_f16x2, 0x3F003F00; +# and.b32 a0, $1, 0x7fff7fff; +# and.b32 a1, $2, 0x7fff7fff;""" +# if has_minx2: +# asm += """min.f16x2 a0, a0, max_val_f16x2; +# min.f16x2 a1, a1, max_val_f16x2;""" +# else: +# asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2; +# setp.lt.f16x2 p2|p3, a1, max_val_f16x2; +# mov.b32 {c0, c1}, a0; +# mov.b32 {c2, c3}, a1; +# selp.b16 c0, c0, max_val_f16, p0; +# selp.b16 c1, c1, max_val_f16, p1; +# selp.b16 c2, c2, max_val_f16, p2; +# selp.b16 c3, c3, max_val_f16, p3; +# mov.b32 a0, {c0, c1}; +# mov.b32 a1, {c2, c3};""" +# asm += """mad.lo.u32 a0, a0, 2, 0x00800080; +# mad.lo.u32 a1, a1, 2, 0x00800080; +# lop3.b32 b0, $1, 0x80008000, a0, 0xea; +# lop3.b32 b1, $2, 0x80008000, a1, 0xea; +# prmt.b32 $0, b0, b1, 0x7531; +# }""" +# return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4, +# _builder=_builder) + +# @core.builtin +# def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _builder=None): +# if arg.type.scalar.is_fp8e4b15(): +# upcast_val = convert_fp8e4b15_to_float16(arg, _builder=_builder) +# if dst_ty.scalar.is_fp32(): +# upcast_val = upcast_val.to(core.float32, _builder=_builder) +# return upcast_val + +# assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32() +# downcast_val = arg +# if arg.type.scalar.is_fp32(): +# downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _builder=_builder) +# downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _builder=_builder) +# return downcast_val + +# @core.builtin +# def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _builder=None): +# return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _builder=_builder) + +# @core.builtin +# def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _builder=None): +# return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _builder=_builder) diff --git a/third_party/mthreads/python/triton/language/math.py b/third_party/mthreads/python/triton/language/math.py new file mode 100644 index 000000000..d15bcd514 --- /dev/null +++ b/third_party/mthreads/python/triton/language/math.py @@ -0,0 +1,251 @@ +from . import core +from . import semantic +from functools import wraps +from typing import List +from .extra.libdevice import * + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x` and :code:`y`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + :param z: the input values + :type z: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + + +@core.builtin +@_add_math_1arg_docstr("absolute value") +@core._tensor_member_fn +def abs(x, _builder=None): + x = core._to_tensor(x, _builder) + dtype = x.dtype + if dtype.is_fp8e4b15(): + mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder) + return core.tensor(_builder.create_and(x.handle, mask.handle), x.type) + elif dtype.is_floating(): + return core.tensor(_builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return core.tensor(_builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +@core.builtin +@_add_math_2arg_docstr("fast division") +def fdiv(x, y, ieee_rounding=False, _builder=None): + ieee_rounding = core._constexpr_to_value(ieee_rounding) + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + return semantic.fdiv(x, y, ieee_rounding, _builder) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest)") +def div_rn(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) + + +@core.builtin +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + z = core._to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) diff --git a/third_party/mthreads/python/triton/language/random.py b/third_party/mthreads/python/triton/language/random.py new file mode 100644 index 000000000..430aeb09e --- /dev/null +++ b/third_party/mthreads/python/triton/language/random.py @@ -0,0 +1,207 @@ +from ..runtime.jit import jit +from . import core as tl +from . import math + +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +# ------------------- +# randint +# ------------------- + + +@jit +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + if c0.dtype == tl.uint32: + PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 + PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 + PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 + else: + tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl") + PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B + PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93 + PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157 + + for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = math.umulhi(B, _c2) ^ c1 ^ k0 + c2 = math.umulhi(A, _c0) ^ c3 ^ k1 + c1 = B * _c2 + c3 = A * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A + k1 = k1 + PHILOX_KEY_B + return c0, c1, c2, c3 + + +@jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = tl.to_tensor(seed) + c0 = tl.to_tensor(c0) + c1 = tl.to_tensor(c1) + c2 = tl.to_tensor(c2) + c3 = tl.to_tensor(c3) + seed = seed.to(tl.uint64) + if tl.constexpr(c0.dtype.primitive_bitwidth) == 32: + int_dtype = tl.uint32 + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + else: + tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox") + int_dtype = tl.uint64 + seed_hi = tl.full((1, ), 0, dtype=int_dtype) + seed_lo = seed + c0 = c0.to(int_dtype, bitcast=True) + c1 = c1.to(int_dtype, bitcast=True) + c2 = c2.to(int_dtype, bitcast=True) + c3 = c3.to(int_dtype, bitcast=True) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +@jit +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offset: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +@jit +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + # _0 = tl.zeros(offset.shape, offset.dtype) + _0 = offset * 0 + return philox(seed, offset, _0, _0, _0, n_rounds) + + +# ------------------- +# rand +# ------------------- + +# @jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + + +@jit +def uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + # TODO: fix frontend issues and cleanup + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32): + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + x = x.to(tl.int32, bitcast=True) + scale = 4.6566127342e-10 + else: + tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)) + x = x.to(tl.int64, bitcast=True) + scale = 1.0842020432385337e-19 + x = tl.where(x < 0, -x - 1, x) + return x * scale + + +@jit +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset, n_rounds) + return uint_to_uniform_float(source) + + +@jit +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + u3 = uint_to_uniform_float(i3) + u4 = uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +# ------------------- +# randn +# ------------------- + + +@jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = math.sqrt(-2.0 * math.log(u1)) + return r * math.cos(th), r * math.sin(th) + + +@jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/third_party/mthreads/python/triton/language/semantic.py b/third_party/mthreads/python/triton/language/semantic.py new file mode 100644 index 000000000..d0e60afd1 --- /dev/null +++ b/third_party/mthreads/python/triton/language/semantic.py @@ -0,0 +1,1622 @@ +from __future__ import annotations # remove after python 3.11 + +from typing import List, Optional, Sequence, Tuple, TypeVar + +from .._C.libtriton import ir +from . import core as tl +from . import math + +T = TypeVar('T') + + +class IncompatibleTypeErrorImpl(Exception): + + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") + + +def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 + if not a_ty.is_int() or not b_ty.is_int(): + raise TypeError(f"unexpected type {a_ty} and {b_ty}") + # 5 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False, + allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) + lhs = cast(lhs, ret_sca_ty, builder) + rhs = cast(rhs, ret_sca_ty, builder) + return lhs, rhs + + +def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise TypeError("cannot add pointers together") + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # * int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to the highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + raise TypeError(f"unexpected type {input_scalar_ty}") + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise TypeError("both operands of fdiv must have floating scalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + # input - input.div(other, rounding_mode="floor") * other + ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder) + return ret + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +############## +# other arithmetic ops +############## + + +def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + min, max = binary_op_type_checking_impl(min, max, builder) + x, min = binary_op_type_checking_impl(x, min, builder) + x, max = binary_op_type_checking_impl(x, max, builder) + + dtype = x.dtype + if dtype.is_floating(): + return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") + + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return and_(input, other, builder) + + +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return or_(input, other, builder) + + +def not_(input: tl.tensor, builder: ir.builder): + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + return invert(input, builder) + + +def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, builder) + + +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + if (range & (range - 1)) != 0: + raise ValueError("arange's range must be a power of 2") + shape = [range] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) + + +def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + value = cast(value, dtype, builder) + else: + # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + value = tl.tensor(value, dtype) + + return splat(value, shape, builder) + + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + + +def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("reshape() cannot change total number of elements in tensor") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) + + +def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type) + + +def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor: + a, b = broadcast_impl_value(a, b, builder) + + # The IR can't handle joining two scalars, so upcast them to 1D tensors, + # then downcast the result. + was_rank_1 = a.shape == [] + if was_rank_1: + a = expand_dims(a, 0, builder) + b = expand_dims(b, 0, builder) + + if isinstance(a.shape[-1], tl.constexpr): + two = tl.constexpr(2) + else: + two = 2 + new_shape = a.shape + [two] + + ret_type = tl.block_type(a.type.scalar, new_shape) + ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type) + + if was_rank_1: + ret = reshape(ret, [2], can_reorder=False, builder=builder) + + return ret + + +def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + assert (len(a.shape) > 0) + assert (tl._constexpr_to_value(a.shape[-1]) == 2) + + new_shape = a.shape[:-1] + ret_type = tl.block_type(a.type.scalar, new_shape) + outLHS, outRHS = builder.create_split(a.handle) + return ( + tl.tensor(outLHS, ret_type), + tl.tensor(outRHS, ret_type), + ) + + +def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor: + if len(input.shape) != len(dims): + raise ValueError("permute dims must have the same length as input shape") + if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))): + raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") + + ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) + return tl.tensor(builder.create_trans(input.handle, dims), ret_type) + + +def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for _ in range(len(lhs_shape), len(rhs_shape)): + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for _ in range(len(rhs_shape), len(lhs_shape)): + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + + +####### +# cast +####### + + +def _str_to_rounding_mode(rounding_mode: Optional[str]): + if rounding_mode is None: + return None + if rounding_mode == 'rtne': + return ir.ROUNDING_MODE.RTNE + if rounding_mode == 'rtz': + return ir.ROUNDING_MODE.RTZ + raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.") + + +def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, + fp_downcast_rounding: Optional[str] = None) -> tl.tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + + if (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): + assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" + + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def _str_to_load_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_store_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + elif cache_modifier == ".wt": + cache = ir.CACHE_MODIFIER.WT + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + +def _str_to_padding_option(padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + +def _str_to_sem(sem_option): + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + if sem_option: + if sem_option == "acquire": + sem = ir.MEM_SEMANTIC.ACQUIRE + elif sem_option == "release": + sem = ir.MEM_SEMANTIC.RELEASE + elif sem_option == "acq_rel": + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + elif sem_option == "relaxed": + sem = ir.MEM_SEMANTIC.RELAXED + else: + raise ValueError(f"Memory semantic {sem_option} not supported") + return sem + + +def _str_to_scope(scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + +def _canonicalize_boundary_check(boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return () + + +def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask is not None or other is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + + # Build IR + return tl.tensor( + builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) + + +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other is not None: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast `other` into `ele_ty` type + if other is not None: + other = cast(other, elt_ty, builder) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + return tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) + + +def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, + padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, + builder: ir.builder) -> tl.tensor: + # Cache, eviction and padding options + cache = _str_to_load_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + padding = _str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + + +def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type, + builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder), + _str_to_load_cache_modifier(cache_modifier), + _str_to_eviction_policy(eviction_policy)) + return tl.tensor(x, type) + + +def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void) + + +def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = broadcast_impl_shape(val, block_shape, builder) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, block_shape) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), + tl.void) + + +def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) + + +def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: + # Cache and eviction options + cache = _str_to_store_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder) + + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]: + raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val is not None: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_val.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_ptr.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + +def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + +def _str_to_dot_input_precision(input_precision, builder): + assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + return getattr(ir.INPUT_PRECISION, input_precision) + + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + + def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): + # TODO(lingfeng.qiu): Disable dot with fp8 for backend musa. + if (not options.allow_fp8e4nv) or (options.backend_name == "musa"): + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( + ), "Dot op does not support fp8e4nv on CUDA arch < 90" + if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): + return + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + else: + if lhs_dtype.is_int() or rhs_dtype.is_int(): + assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8( + ), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): + if options.allow_fp8e4b15: + allowed_types = ['fp8e4nv', 'fp8e5', 'fp8e4b15'] + else: + allowed_types = ['fp8e4nv', 'fp8e5'] + + def _validate_dtype(dtype, allowed_types, operand_name): + if not any(getattr(dtype, f'is_{dtype_name}')() for dtype_name in allowed_types): + supported_types = ', '.join(allowed_types) + raise AssertionError(f"Only supports {supported_types}. {operand_name} ({dtype})") + + _validate_dtype(lhs_dtype, allowed_types, "First operand") + _validate_dtype(rhs_dtype, allowed_types, "Second operand") + else: + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1( + ), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1( + ), f"Unsupported dtype {rhs_dtype}" + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + + assert lhs.type.is_block() and rhs.type.is_block() + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + lhs = cast(lhs, tl.float16, builder) + rhs = cast(rhs, tl.float16, builder) + + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + input_precision = _str_to_dot_input_precision(input_precision, builder) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert lhs.shape[-2].value >= 16 and lhs.shape[-1].value >= 16 \ + and rhs.shape[-1].value >= 16, \ + f"All non-batch values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + # TODO: This is CUDA specific, check if ROCm has the same limitation + assert lhs.shape[1].value >= 32, "small blocks not supported!" + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if max_num_imprecise_acc is None: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), + ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + + +def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + condition = cast(condition, tl.int1, builder) + if condition.type.is_block(): + condition, x = broadcast_impl_value(condition, x, builder) + x, y = broadcast_impl_value(x, y, builder) + condition, x = broadcast_impl_value(condition, x, builder) + + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + if not condition.type.is_block(): + condition, _ = broadcast_impl_value(condition, x, builder) + ret_ty = x.type + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + +def wrap_tensor(x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + +def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: + if axis is None: + inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + assert axis < rank, f"reduction axis must be < inputs rank ({rank})" + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Associative Scan +# ===----------------------------------------------------------------------=== + + +def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool, + builder: ir.builder) -> Tuple[tl.tensor, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + scan_op.verify() + + return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Histogram +# ===----------------------------------------------------------------------=== + + +def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor: + assert len(input.shape) == 1, "histogram only supports 1D input" + assert input.dtype.is_int(), "histogram only supports integer input" + return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, ))) + + +## + + +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if max(1, len(x.shape)) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_constancy does not match the length of values") + x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(), tl.void) + + +def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + + new_args = [arg.handle for arg in args] + return tl.tensor(builder.create_print(prefix, hex, new_args), tl.void) + + +def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: + cond_ty = cond.type + if not cond_ty.is_block(): + cond_ty = tl.block_type(cond_ty.scalar, (1, )) + cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) + return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) + + +def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) + elif elem.dtype != tl.int32 and not require_i64: + assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ + "add a `.to(tl.int32)` or use regular indexing for 64 bit support" + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + +def _convert_to_ir_values(builder, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like] + return [_convert_elem_to_ir_value(builder, list_like, require_i64)] + + +def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = _convert_to_ir_values(builder, shape) + strides = _convert_to_ir_values(builder, strides) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) + + +def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + # Convert dynamic offsets to IR values + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return tl.tensor(builder.create_advance(base.handle, offsets), base.type) diff --git a/third_party/mthreads/python/triton/language/standard.py b/third_party/mthreads/python/triton/language/standard.py new file mode 100644 index 000000000..de30cf260 --- /dev/null +++ b/third_party/mthreads/python/triton/language/standard.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +from ..runtime.jit import jit +from . import core +from . import math + +# constexpr utilities (triton metaprogramming sucks) + + +def _unwrap_if_constexpr(o): + return o.value if isinstance(o, core.constexpr) else o + + +def _log2(i: core.constexpr): + log2 = 0 + n = i.value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + +def _is_power_of_two(i: core.constexpr): + n = i.value + return core.constexpr((n & (n - 1)) == 0 and n != 0) + + +# ----------------------- +# Standard library +# ----------------------- + + +@core._tensor_member_fn +@jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type x: Block + :param div: the divisor + :param div: Block + """ + return (x + div - 1) // div + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("softmax") +def softmax(x, ieee_rounding=False): + z = x - max(x, 0) + num = math.exp(z) + den = sum(num, 0) + return math.fdiv(num, den, ieee_rounding) + + +@core._tensor_member_fn +@jit +def ravel(x): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.reshape(x, [x.numel], can_reorder=True) + + +@jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms indices of a row-major :code:`size_i * size_j` matrix into those + of one where the indices are col-major for each group of :code:`size_g` + rows. + + For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will + transform :: + + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + + into :: + + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # new row and column indices + new_i = off_i + (ij % size_g) + new_j = (ij % size_gj) // size_g + return new_i, new_j + + +@jit +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) + + +@jit +def zeros_like(input): + """ + Creates a tensor of zeros with the same shape and type as a given tensor. + """ + return zeros(input.shape, input.dtype) + + +# max and argmax + + +@jit +def _argmax_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + gt = value1 > value2 or tie + v_ret = core.where(gt, value1, value2) + i_ret = core.where(gt, index1, index2) + return v_ret, i_ret + + +@jit +def _argmax_combine_tie_break_left(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, True) + + +@jit +def _argmax_combine_tie_break_fast(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_max(a, b): + return core.maximum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") +def argmax(input, axis, tie_break_left=True, keep_dims=False): + (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +# min and argmin + + +@jit +def _argmin_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + lt = value1 < value2 or tie + value_ret = core.where(lt, value1, value2) + index_ret = core.where(lt, index1, index2) + return value_ret, index_ret + + +@jit +def _argmin_combine_tie_break_left(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, True) + + +@jit +def _argmin_combine_tie_break_fast(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_min(a, b): + return core.minimum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") +def argmin(input, axis, tie_break_left=True, keep_dims=False): + _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +@jit +def _sum_combine(a, b): + return a + b + + +# sum + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("sum") +def sum(input, axis=None, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims) + + +@jit +def _xor_combine(a, b): + return a ^ b + + +# xor sum + + +@core._tensor_member_fn +@core.builtin +@core._add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = core._promote_bfloat16_to_float32(input, _builder=_builder) + return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator) + + +# cumsum + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumsum") +def cumsum(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _sum_combine, reverse) + + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _prod_combine, reverse) + + +# sort + + +@jit +def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): + n_outer: core.constexpr = x.numel >> n_dims + shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = core.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = core.arange(0, 2)[None, :, None] + left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape) + right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) + left = core.reshape(left, x.shape) + right = core.reshape(right, x.shape) + # actual compare-and-swap + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + ret = ix ^ core.where((left > right) ^ flip, ileft ^ iright, zeros_like(ix)) + return ret.to(x.dtype, bitcast=True) + + +@jit +def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + n_outer: core.constexpr = x.numel >> n_dims + core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims) + return x + + +@core._tensor_member_fn +@jit +def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: core.constexpr = _log2(x.shape[_dim]) + for i in core.static_range(1, n_dims + 1): + x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims) + return x + + +# flip + + +def _get_flip_dim(dim, shape): + dim = _unwrap_if_constexpr(dim) + shape = _unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + assert dim == len(shape) - 1, "Currently only support flipping the last dimension" + return core.constexpr(dim) + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along (currently only final dimension supported) + :type dim: int + """ + core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)])) + core.static_assert(_is_power_of_two(x.numel)) + # # reshape the tensor to have all dimensions be 2. + # # TODO: We shouldn't have to change the dimensions not sorted. + steps: core.constexpr = _log2(x.numel) + start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)]) + y = core.reshape(x, [2] * steps) + y = core.expand_dims(y, start) + flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2)) + for i in core.static_range(start, steps): + flip2 = flip + for j in core.static_range(0, steps + 1): + if j != i and j != i + 1: + flip2 = core.expand_dims(flip2, j) + y = sum(y * flip2, i + 1, keep_dims=True) + x = core.reshape(y, x.shape) + return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. + + The two tensors must have the same shape. + + Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])` + """ + c = core.join(a, b) + + assert isinstance(c.shape, list) + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) diff --git a/python/triton/ops/__init__.py b/third_party/mthreads/python/triton/ops/__init__.py similarity index 100% rename from python/triton/ops/__init__.py rename to third_party/mthreads/python/triton/ops/__init__.py diff --git a/third_party/mthreads/python/triton/ops/blocksparse/__init__.py b/third_party/mthreads/python/triton/ops/blocksparse/__init__.py new file mode 100644 index 000000000..6b24b5377 --- /dev/null +++ b/third_party/mthreads/python/triton/ops/blocksparse/__init__.py @@ -0,0 +1,7 @@ +from .matmul import matmul +from .softmax import softmax + +__all__ = [ + "matmul", + "softmax", +] diff --git a/third_party/mthreads/python/triton/ops/blocksparse/matmul.py b/third_party/mthreads/python/triton/ops/blocksparse/matmul.py new file mode 100644 index 000000000..098e15438 --- /dev/null +++ b/third_party/mthreads/python/triton/ops/blocksparse/matmul.py @@ -0,0 +1,432 @@ +import torch + +from ... import cdiv, heuristics, jit +from ... import language as tl + +# ******************************************************** +# -------------------------------------------------------- +# Sparse = Dense x Dense (SDD) +# This operation uses super-blocking to make sure that +# it's done efficiently when small blocks can be grouped +# together +# -------------------------------------------------------- +# ******************************************************** + + +@heuristics({ + 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, +}) +@jit +def _sdd_kernel(A, B, C, # + stride_za, stride_ha, stride_ma, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_nb, # + stride_zc, stride_hc, stride_mc, stride_nc, # + K, grid_offset, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + BLOCK: tl.constexpr, EVEN_K: tl.constexpr # + ): + # ------------ # + # - Prologue - # + # ------------ # + block_id = tl.program_id(0) + grid_offset + lut += block_id * 3 + # offsets + off_z = tl.program_id(2) # batch + off_h = tl.load(lut + 0) # head + + # initialize pointers to A + start_am = tl.load(lut + 1) + offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) + offs_ak = tl.arange(0, TILE_K) + a_ptrs = A \ + + off_z * stride_za \ + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B + start_bn = tl.load(lut + 2) + offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) + offs_bk = tl.arange(0, TILE_K) + b_ptrs = B \ + + off_z * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_nb \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for k in range(K, 0, -TILE_K): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.) + b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.) + acc += tl.dot(a, b, out_dtype=tl.float32) + a_ptrs += TILE_K * stride_ak + b_ptrs += TILE_K * stride_bk + c = acc.to(C.dtype.element_ty) + # ---------------- # + # Epilogue # + # ---------------- # + offs_cm = tl.arange(0, TILE_M) % BLOCK + offs_cn = tl.arange(0, TILE_N) % BLOCK + pc = C \ + + off_z * stride_zc \ + + block_id * stride_hc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc + tl.store(pc, c, mask=True) + + +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # (A * B)^T = B^T * A^T + if trans_c: + a, b = b, a + trans_a, trans_b = not trans_b, not trans_a + # shape constraints + a_dim = -2 if trans_a else -1 + b_dim = -1 if trans_b else -2 + Ka, Kb = a.shape[a_dim], b.shape[b_dim] + if Ka != Kb: + raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") + # allocate output + if out is None: + c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device) + else: + assert out.shape == (a.shape[0], lut.shape[0], block, block) + c = out + grid = [c.shape[1], 1, c.shape[0]] + _sdd_kernel[grid]( + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(2), c.stride(3), # + Ka, 0, lut, # + TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, # + num_warps=4 # + ) + return c + + +def sdd_lut(layout, block, device): + lut = layout.nonzero(as_tuple=False).to(device).int() + lut = lut.contiguous() + return lut, None + + +# ----------------------------- +# Dense = Sparse x Dense (DSD) +# This operation uses a look-up table that contains pre-computed pointer increments +# in order to minimize computations in the inner loop of the matmul kernel. +# ----------------------------- + + +@jit +def _dsd_kernel(A, B, C, # + stride_az, stride_ha, stride_am, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_bn, # + stride_zc, stride_hc, stride_cm, stride_cn, # + DS0, DS1, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr # + ): + # ------------ # + # - Prologue - # + # ------------ # + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) + pidz = tl.program_id(2) + header = lut + pid_n * 4 + offset = tl.load(header + 0) + K = tl.load(header + 1) + column = tl.load(header + 2) + off_h = tl.load(header + 3) + pinc = lut + offset + # initialize pointers to A (sparse) + block_id = tl.load(pinc + 1) + block_id = tl.multiple_of(block_id, 8) # compiler hint + offs_am = tl.arange(0, TILE_M) + offs_ak = tl.arange(0, TILE_K) + pa = A + pidz * stride_az \ + + block_id * stride_ha \ + + offs_am[:, None] * stride_am \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B (dense) + offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) + start_bk = tl.load(pinc) + start_bk = tl.multiple_of(start_bk, 8) # compiler hint + offs_bk = start_bk + tl.arange(0, TILE_K) + pb = B + pidz * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + for k in range(K, 0, -TILE_K): + a = tl.load(pa) + b = tl.load(pb) + acc += tl.dot(a, b, out_dtype=tl.float32) + pa += inc_a + pb += inc_b * stride_bk + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + c = acc.to(C.dtype.element_ty) + # initialize pointers to C + offs_cm = column * TILE_M + tl.arange(0, TILE_M) + offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) + pc = C \ + + off_h * stride_hc \ + + pidz * stride_zc \ + + offs_cm[:, None] * stride_cm \ + + offs_cn[None, :] * stride_cn + tl.store(pc, c, mask=offs_cn[None, :] < DS0) + + +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # shapes / dtypes + AS1 = block * spdims[2 if trans_a else 1] + BS0 = b.size(0) + BS1 = b.size(1) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + # allocate output + CS0 = BS0 + CS1 = BS1 + CS2 = BS3 if trans_c else AS1 + CS3 = AS1 if trans_c else BS3 + if out is None: + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + else: + assert out.shape == (CS0, CS1, CS2, CS3) + c = out + # meta-parameter heuristics + TILE_N = 128 + # compute output + grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0] + _dsd_kernel[grid]( + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), # + BS3, AS1, lut, # + TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, # + num_warps=4, GROUP_SIZE_M=4 # + ) + # exit() + return c + + +def dsd_lut(layout, block, step, trans, device): + """ + Generates the look-up table for incrementing pointers in the DSD/DDS matmul. + Example (BLOCK=32, STEP=16) + [[1, 0, 0, 1, 0], + [0, 1, 1, 0, 1], + [1, 0, 1, 0, 0]] + + Then the offsets for A are + [0 , 16, 32, 48] <- row 0 + \\----/ \\----/ + col=0 col=3 + [64, 80, 96, 112, 128, 144] <- row 1 + \\----/ \\----/ \\------/ + col=1 col=2 col=3 + [160, 176, 192, 208] + which leads to increments table + [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16] + + Because B is dense, the offsets are + [0, 16, 96, 112] <- row 0 + [32, 48, 64, 80] <- row 1 + [0, 16, 64, 80] <- row 2 + """ + sizes = torch.sum(layout, 2 if trans else 1) + head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True) + sizes = sizes.flatten() + segments = sizes * step + # pointer increments + if trans: + nnz = layout.nonzero(as_tuple=False) + else: + nnz = layout.transpose(1, 2).nonzero(as_tuple=False) + num_blocks = nnz.size(0) + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets)) + # ------------------------------- + # dense input pointer increments + # ------------------------------- + # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K) + # that is smaller than the block size, so we need to do a bit of extra work + # to handle this case + B_idx = nnz[:, 2] * block + B_incs = B_idx.clone() + B_incs[1:] -= B_idx[:-1] + div = block // step + B_incs = B_incs.view(-1, 1).repeat(1, div) + B_incs[:, 1:] = step + B_incs[:, 0] -= (div - 1) * step + # first increment for each reduction is actually the offset + B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]] + B_incs = B_incs.view(-1) + # ------------------------------- + # sparse input pointer increments + # ------------------------------- + # same as above, except that the increments are in the sparse memory layout + if trans: + A_idx = torch.arange(num_blocks, device=layout.device) + else: + A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) + current_offset = 0 + for z in range(layout.size(0)): + layoutw = layout[z, :, :].clone().long() + msum = layoutw.sum() + layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device) + A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) + current_offset += msum + A_incs = A_idx * block * block + A_incs[1:] -= A_idx[:-1] * block * block + A_incs = A_incs.view(-1, 1).repeat(1, div) + if trans: + A_incs[:, 1:] = step + A_incs[:, 0] -= (div - 1) * step + else: + A_incs[:, 1:] = step * block + A_incs[:, 0] -= (div - 1) * step * block + A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] + A_incs = A_incs.view(-1) + # create header + width = col_id.size(0) + offsets = offsets * 2 * div + 4 * width + segments = segments * div + header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() + # create increments + incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() + # pad by a factor 2*MAX_NUM_STAGES + # to accommodate pre-fetching inside the kernel + pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) + incs = torch.cat((incs, pad)) + # create lut + lut = torch.cat((header, incs)) + lut = lut.type(torch.int32).to(device) + # create locks + return lut, width + + +# ----------------------------- +# Dense = Dense x Sparse (DDS) +# ----------------------------- +# AB = (B^T A^T)^T + + +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out) + + +############## +# MAIN API # +############## + + +class _matmul(torch.autograd.Function): + + fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} + + @staticmethod + def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut, + db_width, out): + c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) + # save for backward + ctx.save_for_backward(a, b) + ctx.da_lut = da_lut + ctx.da_width = da_width + ctx.db_lut = db_lut + ctx.db_width = db_width + ctx.mode = mode + ctx.spdims = spdims + ctx.block = block + ctx.trans_a = trans_a + ctx.trans_b = trans_b + ctx.trans_c = trans_c + ctx.has_out = out is not None + return c + + @staticmethod + def backward(ctx, dc): + # saved for backward + a, b = ctx.saved_tensors + da, db = None, None + mode = ctx.mode + # gradients w.r.t. a + if ctx.needs_input_grad[0]: + mode_da = mode[1] + mode[0] + mode[2] + da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, + ctx.da_lut, ctx.da_width) + # gradients w.r.t. b + if ctx.needs_input_grad[1]: + mode_db = mode[2] + mode[1] + mode[0] + db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, + ctx.db_lut, ctx.db_width) + dout = dc if ctx.has_out else None + return da, db, None, None, None, \ + None, None, None, None, \ + None, None, None, None, None, dout + + +class matmul: + + def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False): + if mode not in ['sdd', 'dsd', 'dds']: + raise NotImplementedError('Supported modes are: sdd, dsd, dds') + self.block = block + self.mode = mode + self.trans_a = trans_a + self.trans_b = trans_b + self.trans_c = trans_c + self.layout = layout + self.spdims = layout.shape + step = min(block, 32) + if self.mode == 'sdd': + self.c_lut, self.c_width = sdd_lut(layout, block, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device) + if self.mode == 'dsd': + self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device) + self.da_lut, self.da_width = sdd_lut(layout, block, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) + if self.mode == 'dds': + self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) + self.db_lut, self.db_width = sdd_lut(layout, block, device) + + def __call__(self, a, b, out=None): + c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, # + self.c_lut, self.c_width, # + self.da_lut, self.da_width, # + self.db_lut, self.db_width, # + out) + return c diff --git a/third_party/mthreads/python/triton/ops/blocksparse/softmax.py b/third_party/mthreads/python/triton/ops/blocksparse/softmax.py new file mode 100644 index 000000000..bcffff26b --- /dev/null +++ b/third_party/mthreads/python/triton/ops/blocksparse/softmax.py @@ -0,0 +1,228 @@ +import torch + +from ... import jit +from ... import language as tl +from ... import next_power_of_2 + + +def num_warps(n): + if n <= 128: + return 1 + if n <= 256: + return 2 + if n <= 512: + return 4 + if n <= 4096: + return 8 + return 16 + + +@jit +def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, # + R, extent, stride_zr, stride_hr, # relative attention + scale, is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr # + ): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # pointer offset + off_a = z * stride_xz + off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx + off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load X + mask = block_n < size + a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf")) + a = a.to(tl.float32) + # compute + out = a + out *= scale + # apply relative attention + if R is not None: + R += z * stride_zr + R += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) + rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0) + out += rel_logits + out = out.to(tl.float32) + # apply causal mask + out = tl.where((ns > m) & is_causal, -float("inf"), out) + # computation + out = tl.softmax(out) + # write-back + tl.store(Out + off_a + lane_n, out, mask=mask) + + +@jit +def _blocksparse_softmax_bwd(DA, stride_zdx, # + DOut, stride_zdout, # + Out, stride_zout, # + scale, # + LUT, # + DR, extent, stride_zr, stride_hr, stride_er, # + is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # row-col offset + off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE + off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE + mask = block_n < size + # pointers + As = Out + z * stride_zout + off_mn + DOuts = DOut + z * stride_zdout + off_mn + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load data + a = tl.load(As + lane_n, mask=mask, other=0.0) + a = a.to(tl.float32) + dout = tl.load(DOuts + lane_n, mask=mask, other=0.0) + dout = dout.to(tl.float32) + # compute + a = tl.where((ns > m) & is_causal & (a == a), 0., a) + da = a * (dout - tl.sum(a * dout, 0)) + # apply relative attention + if DR is not None: + DR += z * stride_zr + DR += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) & mask + tl.store(DR + m * extent + off_lo, da, mask=mask_lo) + da = da * scale + # convert da + # write-back + DAs = DA + z * stride_zdx + off_mn + tl.store(DAs + lane_n, da, mask=mask) + + +class _softmax(torch.autograd.Function): + + @staticmethod + def make_lut(layout, block, device): + _empty = torch.tensor([], dtype=torch.int64, device=layout.device) + sizes = _empty.clone() + # sizes along rows + for h in range(layout.shape[0]): + sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) + total_sizes = sizes * block + # offsets in block format + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + # block indices + columns = layout.nonzero(as_tuple=False)[:, 2] + header = torch.stack((sizes, offsets), dim=1).view(-1) + lut = torch.cat((header, columns)).type(torch.int32).to(device) + return lut, int(total_sizes.max()) + + @staticmethod + def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense): + if scale is not None and isinstance(scale, torch.Tensor): + assert scale.device.type == "cpu" + scale = scale.item() + M = a.shape[0] + grid = [spdims[0], spdims[1] * block, M] + rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape + rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride() + # enqueue kernel + out = torch.empty_like(a) + _blocksparse_softmax_fwd[grid]( + out, a, a.stride(0), lut, # + rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn# + scale, # + is_causal, # + BLOCK_SIZE=block, # + ROW_SIZE=next_power_of_2(maxlut), # + IS_DENSE=is_dense, # + num_warps=num_warps(maxlut) # + ) + # save to context + # ctx.mark_dirty(x) + ctx.save_for_backward(out, lut) + ctx.spdims = spdims + ctx.block = block + ctx.maxlut = maxlut + ctx.scale = scale + ctx.rel_shape = rel_shape + ctx.rel_strides = rel_strides + ctx.rel_dtype = a.dtype + ctx.is_dense = is_dense + ctx.is_causal = is_causal + return out + + @staticmethod + def backward(ctx, dout): + # retrieve from context + out, lut = ctx.saved_tensors + # relative logits gradients + dr = None + if ctx.needs_input_grad[3]: + dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device) + # run kernel + M = out.shape[0] + grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) + da = torch.empty_like(dout) + _blocksparse_softmax_bwd[grid]( + da, da.stride(0), # + dout, dout.stride(0), # + out, out.stride(0), # + ctx.scale, # + lut, # + dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], # + ctx.is_causal, # + BLOCK_SIZE=ctx.block, # + ROW_SIZE=next_power_of_2(ctx.maxlut), # + IS_DENSE=ctx.is_dense, # + num_warps=num_warps(ctx.maxlut) # + ) + return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None) + + +class softmax: + + def __init__(self, layout, block, device, is_dense=False): + self.spdims = layout.shape + self.layout = layout + self.block = block + self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device) + self.is_dense = is_dense + + def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): + if rel_logits is not None and rel_logits.dtype != a.dtype: + raise ValueError(f"relative position embedding must be {a.dtype}") + a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut, + self.is_dense) + return a diff --git a/third_party/mthreads/python/triton/ops/cross_entropy.py b/third_party/mthreads/python/triton/ops/cross_entropy.py new file mode 100644 index 000000000..88e8dae50 --- /dev/null +++ b/third_party/mthreads/python/triton/ops/cross_entropy.py @@ -0,0 +1,96 @@ +import torch + +from .. import heuristics, jit +from .. import language as tl +from .. import next_power_of_2 + + +def num_warps(N): + if N < 2048: + return 4 + elif N < 8192: + return 8 + return 16 + + +@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@jit +def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to logit and probs + LOGITS = LOGITS + row * N + cols + WRIT_PROBS = PROBS + row * N + cols + READ_PROBS = PROBS + row * N + idx + # write-back negative log-probs + logits = tl.load(LOGITS, mask=cols < N, other=-float('inf')) + logits = logits.to(tl.float32) + logits = logits - tl.max(logits, 0) + probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits + tl.store(WRIT_PROBS, probs, mask=cols < N) + # There is a bug in the compiler, which fails to insert a barrier here. + # We add it explicitly for now. Will be fixed soon. + tl.debug_barrier() + # write-back loss + probs = tl.load(READ_PROBS) + tl.store(LOSS + row, probs) + + +@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@jit +def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to probs + PROBS = PROBS + row * N + cols + # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + # and we have -log(p[k]) stored in PROBS, so this is easy + probs = -tl.load(PROBS, mask=cols < N, other=float('inf')) + probs = tl.exp(probs.to(tl.float32)) + delta = cols == idx + # write result in-place in PROBS + dout = tl.load(DPROBS + row) + din = (probs - delta) * dout + tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N) + + +class _cross_entropy(torch.autograd.Function): + + @classmethod + def forward(cls, ctx, logits, indices): + # make sure we can use triton + assert (indices.dtype == torch.int64), "Indices are expected to be of type long." + # make kernel + device, dtype = logits.device, logits.dtype + n_cols = logits.shape[-1] + # run the kernel + result = torch.empty_like(indices, dtype=dtype, device=device) + neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device) + grid = lambda opt: (logits.numel() // n_cols, ) + _forward[grid](logits, neg_logprobs, indices, result, n_cols) + # save for backward + ctx.save_for_backward(neg_logprobs, indices) + return result + + @classmethod + def backward(cls, ctx, dneg_logprobs): + """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + so we initialize the gradient as neg_logprobs, so we can just exponentiate + to get p[k], which is most of what we need... neg_logprobs will be + modified in place to become the gradient we want + """ + # load saved tensors + neg_logprobs, indices = ctx.saved_tensors + # run the kernel + # neg_logprobs will be modified in place to become our gradient: + n_cols = neg_logprobs.shape[-1] + grid = lambda opt: (neg_logprobs.numel() // n_cols, ) + _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols) + return neg_logprobs, None + + +cross_entropy = _cross_entropy.apply diff --git a/python/triton/ops/flash_attention.py b/third_party/mthreads/python/triton/ops/flash_attention.py similarity index 100% rename from python/triton/ops/flash_attention.py rename to third_party/mthreads/python/triton/ops/flash_attention.py diff --git a/python/triton/ops/matmul.py b/third_party/mthreads/python/triton/ops/matmul.py similarity index 100% rename from python/triton/ops/matmul.py rename to third_party/mthreads/python/triton/ops/matmul.py diff --git a/python/triton/ops/matmul_perf_model.py b/third_party/mthreads/python/triton/ops/matmul_perf_model.py similarity index 100% rename from python/triton/ops/matmul_perf_model.py rename to third_party/mthreads/python/triton/ops/matmul_perf_model.py diff --git a/third_party/mthreads/python/triton/runtime/__init__.py b/third_party/mthreads/python/triton/runtime/__init__.py new file mode 100644 index 000000000..0b3979d28 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/__init__.py @@ -0,0 +1,23 @@ +from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics) +from .cache import RedisRemoteCacheBackend, RemoteCacheBackend +from .driver import driver +from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret +from .errors import OutOfResources, InterpreterError + +__all__ = [ + "autotune", + "Autotuner", + "Config", + "driver", + "Heuristics", + "heuristics", + "InterpreterError", + "JITFunction", + "KernelInterface", + "MockTensor", + "OutOfResources", + "RedisRemoteCacheBackend", + "reinterpret", + "RemoteCacheBackend", + "TensorWrapper", +] diff --git a/third_party/mthreads/python/triton/runtime/autotuner.py b/third_party/mthreads/python/triton/runtime/autotuner.py new file mode 100644 index 000000000..a6746ce22 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/autotuner.py @@ -0,0 +1,378 @@ +from __future__ import annotations + +import builtins +import os +import time +import inspect +from typing import Dict + +from ..backends.mthreads.driver import MusaDriver +if MusaDriver.is_active(): + from ..backends.mthreads.musa_testing import do_bench +else: + from ..testing import do_bench +from ..testing import do_bench_cudagraph +from .jit import KernelInterface +from .errors import OutOfResources + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + use_cuda_graph=False, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] + + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args, exception: 0 + if pre_hook: + self.pre_hook = pre_hook + elif (len(self.reset_idx) > 0 or len(self.restore_idx) > 0): + + def _pre_hook(args, reset_only=False): + for i in self.reset_idx: + args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + elif len(self.restore_idx) > 0: + + def _post_hook(args, exception): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + self.num_warmups = warmup + self.num_reps = rep + import torch + self.use_cuda_graph = use_cuda_graph and torch.cuda.is_available() + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(args, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(args, exception=None) + + try: + if self.use_cuda_graph: + import torch + with torch.cuda.stream(torch.cuda.Stream()): + bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median") + return bench_res + return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure): + return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + for arg in _args: + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + # prune configs + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.pre_hook(args, reset_only=True) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + if config.pre_hook is not None: + config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.maxnreg = maxnreg + self.pre_hook = pre_hook + try: + import torch_musa + device = torch_musa.current_device() + capability = torch_musa.get_device_capability(device) + if capability[0] < 3 and self.num_warps > 8: + import logging + logging.warning( + f"QY1/QY2 has 128 threads per warp, one should limit number of warps not exceed 8, so decrease self.num_warps: {self.num_warps} to 8" + ) + self.num_warps = 8 + except ImportError: + pass + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("maxnreg", self.maxnreg), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=25, rep=100, use_cuda_graph=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :type warmup: int + :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :type rep: int + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/third_party/mthreads/python/triton/runtime/build.py b/third_party/mthreads/python/triton/runtime/build.py new file mode 100644 index 000000000..d7baeb286 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/build.py @@ -0,0 +1,78 @@ +import contextlib +import sys +import io +import sysconfig +import os +import shutil +import subprocess +import setuptools + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +def _build(name, src, srcdir, library_dirs, include_dirs, libraries): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + include_dirs = include_dirs + [srcdir, py_include_dir] + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + ret = subprocess.check_call(cc_cmd) + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) + return so diff --git a/third_party/mthreads/python/triton/runtime/cache.py b/third_party/mthreads/python/triton/runtime/cache.py new file mode 100644 index 000000000..bd3c29b99 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/cache.py @@ -0,0 +1,281 @@ +import importlib +import json +import os +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Optional +import hashlib + + +def default_cache_dir(): + return os.path.join(Path.home(), ".triton", "cache") + + +def default_override_dir(): + return os.path.join(Path.home(), ".triton", "override") + + +def default_dump_dir(): + return os.path.join(Path.home(), ".triton", "dump") + + +class CacheManager(ABC): + + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use tempfile to be robust against program interruptions + temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + return filepath + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, key: str): + pass + + @abstractmethod + def get(self, filenames: List[str]) -> Dict[str, bytes]: + pass + + @abstractmethod + def put(self, filename: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + + def __init__(self, key): + import redis + self._key = key + self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}") + self._redis = redis.Redis( + host=os.environ.get("TRITON_REDIS_HOST", "localhost"), + port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + ) + + def _get_key(self, filename: str) -> str: + return self._key_fmt.format(key=self._key, filename=filename) + + def get(self, filenames: List[str]) -> Dict[str, str]: + results = self._redis.mget([self._get_key(f) for f in filenames]) + return {filename: result for filename, result in zip(filenames, results) if result is not None} + + def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + self._redis.set(self._get_key(filename), data) + + +class RemoteCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. + remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"] + module_path, clz_nme = remote_cache_manager.split(":") + module = importlib.import_module(module_path) + remote_cache_cls = getattr(module, clz_nme) + self._backend = remote_cache_cls(key) + + self._override = override + self._dump = dump + + # Use a `FileCacheManager` to materialize remote cache paths locally. + self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) + + def _materialize(self, filename: str, data: bytes): + # We use a backing `FileCacheManager` to provide the materialized data. + return self._file_cache_manager.put(data, filename, binary=True) + + def get_file(self, filename: str) -> Optional[str]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_file(filename) + + # We always check the remote cache backend -- even if our internal file- + # based cache has the item -- to make sure LRU accounting works as + # expected. + results = self._backend.get([filename]) + if len(results) == 0: + return None + (_, data), = results.items() + return self._materialize(filename, data) + + def put(self, data, filename: str, binary=True) -> str: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put(data, filename, binary=binary) + + if not isinstance(data, bytes): + data = str(data).encode("utf-8") + self._backend.put(filename, data) + return self._materialize(filename, data) + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_group(filename) + + grp_filename = f"__grp__{filename}" + grp_filepath = self.get_file(grp_filename) + if grp_filepath is None: + return None + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + + result = None + + # Found group data. + if child_paths is not None: + result = {} + for child_path, data in self._backend.get(child_paths).items(): + result[child_path] = self._materialize(child_path, data) + + return result + + def put_group(self, filename: str, group: Dict[str, str]): + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put_group(filename, group) + + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename) + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(key) + + +def get_override_manager(key) -> CacheManager: + return __cache_cls(key, override=True) + + +def get_dump_manager(key) -> CacheManager: + return __cache_cls(key, dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + return key diff --git a/third_party/mthreads/python/triton/runtime/driver.py b/third_party/mthreads/python/triton/runtime/driver.py new file mode 100644 index 000000000..c3b97a764 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/driver.py @@ -0,0 +1,60 @@ +from ..backends import backends +from ..backends import DriverBase + + +def _create_driver(): + actives = [x.driver for x in backends.values() if x.driver.is_active()] + if len(actives) != 1: + raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.") + return actives[0]() + + +class LazyProxy: + + def __init__(self, init_fn): + self._init_fn = init_fn + self._obj = None + + def _initialize_obj(self): + if self._obj is None: + self._obj = self._init_fn() + + def __getattr__(self, name): + self._initialize_obj() + return getattr(self._obj, name) + + def __setattr__(self, name, value): + if name in ["_init_fn", "_obj"]: + super().__setattr__(name, value) + else: + self._initialize_obj() + setattr(self._obj, name, value) + + def __delattr__(self, name): + self._initialize_obj() + delattr(self._obj, name) + + def __repr__(self): + if self._obj is None: + return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>" + return repr(self._obj) + + def __str__(self): + self._initialize_obj() + return str(self._obj) + + +class DriverConfig: + + def __init__(self): + self.default = LazyProxy(_create_driver) + self.active = self.default + + def set_active(self, driver: DriverBase): + self.active = driver + + def reset_active(self): + self.active = self.default + + +driver = DriverConfig() diff --git a/third_party/mthreads/python/triton/runtime/errors.py b/third_party/mthreads/python/triton/runtime/errors.py new file mode 100644 index 000000000..4dce91767 --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/errors.py @@ -0,0 +1,26 @@ +from ..errors import TritonError +from typing import Optional + + +class InterpreterError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + return self.error_message or "" + + +class OutOfResources(TritonError): + + def __init__(self, required, limit, name): + self.required = required + self.limit = limit + self.name = name + + def __str__(self) -> str: + return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help." + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) diff --git a/third_party/mthreads/python/triton/runtime/interpreter.py b/third_party/mthreads/python/triton/runtime/interpreter.py new file mode 100644 index 000000000..a82832ecf --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/interpreter.py @@ -0,0 +1,1127 @@ +import inspect +from typing import Tuple + +import math +import numpy as np + +import triton +import triton.language as tl +from dataclasses import dataclass +from .errors import InterpreterError +from functools import partial +from .._C.libtriton import interpreter as _interpreter +from .._C.libtriton import ir as _ir + + +class TensorHandle: + + def __init__(self, data, dtype): + ''' + data: numpy array + dtype: triton type, either pointer_type or scalar_type. + we don't store block_type here because the shape information is already availale in the data field + attr: a dictionary of attributes + ''' + self.data = data + self.dtype = dtype + self.attr = {} + + def __bool__(self): + return bool(self.data.all()) + + def get_element_ty(self): + dtype = self.dtype + while hasattr(dtype, "element_ty"): + dtype = dtype.element_ty + return dtype + + def clone(self): + return TensorHandle(self.data.copy(), self.dtype) + + def set_attr(self, key, value): + self.attr[key] = value + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, tensor_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.tensor_shape = tensor_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.get_element_ty() + n_bytes = dtype_tt.primitive_bitwidth // 8 + tensor_shape = self.tensor_shape + ptrs = np.broadcast_to(self.base.data, self.tensor_shape) + masks = np.ones(self.tensor_shape, dtype=bool) + for dim in range(len(tensor_shape)): + bcast_dims = [1] * len(tensor_shape) + bcast_dims[dim] = tensor_shape[dim] + off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = np.logical_and(masks, off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +@dataclass(frozen=True) +class InterpreterOptions: + extern_libs: dict = None + debug: bool = False + arch: str = None + allow_fp8e4nv: bool = True + allow_fp8e4b15: bool = True + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: int = 0 + + +def _get_signed_np_dtype(dtype): + if dtype == np.uint8: + return np.int8 + if dtype == np.uint16: + return np.int16 + if dtype == np.uint32: + return np.int32 + if dtype == np.uint64: + return np.int64 + return dtype + + +def _get_np_dtype(tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.int1: np.dtype(bool), + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + # bfloat16 types are stored as uint16 + tl.bfloat16: np.dtype(np.uint16), + # float8 types are stored as uint8 + tl.float8e5: np.dtype(np.uint8), + tl.float8e5b16: np.dtype(np.uint8), + tl.float8e4nv: np.dtype(np.uint8), + tl.float8e4b8: np.dtype(np.uint8), + tl.float8e4b15: np.dtype(np.uint8), + } + if isinstance(tt_dtype, tl.block_type): + if isinstance(tt_dtype.element_ty, tl.pointer_type): + return np.dtype(np.uint64) + return np_types[tt_dtype.element_ty] + return np_types[tt_dtype] + + +def _convert_float(input, input_dtype, output_dtype, rounding_mode): + input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}") + output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}") + input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype) + sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01 + input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1 + output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1 + significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1) + bias_input = input_dtype.exponent_bias + bias_output = output_dtype.exponent_bias + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + subnormal_index = exponent == 0 + if np.any(subnormal_index): + # Credit to Phil: phil@openai.com + # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0)) + bit_pos = np.zeros_like(input_bin, dtype=np.int32) + # Find the most significant bit of the mantissa in the significand + for i in range(input_dtype.fp_mantissa_width): + bit_index = ((significand >> i) & 0x01) + # pos should be >= 1 + bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i + zero_significand_index = significand == 0 + exponent[subnormal_index] = 1 - bit_pos[subnormal_index] + # 0 significand and subnormal should be treated as 0 + exponent[zero_significand_index & subnormal_index] = bias_input - bias_output + significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( + (1 << input_dtype.fp_mantissa_width) - 1) + # Prevent overflow and underflow + exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_output = exponent_output.astype(output_unint_dtype) + sign_output = sign.astype(output_unint_dtype) + if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast + significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even + # find the cut-off bit + cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1)) + significand_output = significand_output + (cut_off > 0) + significand_output = significand_output.astype(output_unint_dtype) + else: # Upcast + significand_output = (significand.astype(output_unint_dtype) << + (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + subnormal_index = exponent_output == 0 + if np.any(subnormal_index): # underflow + # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # shift = (1 - exp_bias_output) - (exp - exp_bias_input) + # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift)) + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + non_zero_exponent_index = exponent != 0 + # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa + subnormal_index = subnormal_index & non_zero_exponent_index + shift = np.zeros_like(input_bin, dtype=np.int32) + shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) + significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( + 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( + exponent_output << output_dtype.fp_mantissa_width) | significand_output + return output.reshape(input.shape) + + +def _erf(x): + # Numpy does not support erf + return math.erf(x) + + +def _umulhi_64(a, b): + # Numpy does not support 128-bit multiplication + # So we have to implement it manually + return (int(a) * int(b)) >> 64 + + +np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32]) +np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64]) +np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64]) + + +class ExtraFunctions: + + @staticmethod + def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder): + return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty) + + +class InterpreterBuilder: + ir_sem_to_interpreter_sem = { + _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE, + _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE, + _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED, + _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE, + } + + ir_rmw_op_to_interpreter_rmw_op = { + _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD, + _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD, + _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN, + _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN, + _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX, + _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX, + _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND, + _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR, + _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR, + _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG, + } + + def __init__(self) -> None: + self.arch = None + self.options = InterpreterOptions() + self.codegen_fns = {} + self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + + def set_grid_idx(self, x, y, z): + if not x < self.grid_dim[0]: + raise ValueError("x >= grid_dim[0]") + if not y < self.grid_dim[1]: + raise ValueError("y >= grid_dim[1]") + if not z < self.grid_dim[2]: + raise ValueError("z >= grid_dim[2]") + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + # constants + + def get_half_ty(self): + return tl.float16 + + def get_bf16_ty(self): + return tl.bfloat16 + + def get_float_ty(self): + return tl.float32 + + def get_double_ty(self): + return tl.float64 + + def get_int8_ty(self): + return tl.int8 + + def get_uint8_ty(self): + return tl.uint8 + + def get_int16_ty(self): + return tl.int16 + + def get_uint16_ty(self): + return tl.uint16 + + def get_int32_ty(self): + return tl.int32 + + def get_uint32_ty(self): + return tl.uint32 + + def get_int64_ty(self): + return tl.int64 + + def get_uint64_ty(self): + return tl.uint64 + + def get_fp8e4nv_ty(self): + return tl.float8e4nv + + def get_fp8e4b15_ty(self): + return tl.float8e4b15 + + def get_fp8e4b8_ty(self): + return tl.float8e4b8 + + def get_fp8e5_ty(self): + return tl.float8e5 + + def get_fp8e5b16_ty(self): + return tl.float8e5b16 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.block_type(dtype, shape) + + def get_int1(self, value): + return TensorHandle(np.array([value], dtype=np.bool_), tl.int1) + + def get_uint8(self, value): + return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8) + + def get_int8(self, value): + return TensorHandle(np.array([value], dtype=np.int8), tl.int8) + + def get_uint16(self, value): + return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16) + + def get_int16(self, value): + return TensorHandle(np.array([value], dtype=np.int16), tl.int16) + + def get_uint32(self, value): + return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_uint64(self, value): + return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_fp64(self, value): + return TensorHandle(np.array([value], dtype=np.float64), tl.float64) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + if self.grid_idx is None: + raise ValueError("grid_idx is None") + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \ + (src_element_type == tl.float32 and dst_element_type == tl.bfloat16): + data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + else: + return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type, rounding_mode): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + + def create_idiv(self, lhs, rhs): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar) + + def create_ashr(self, lhs, rhs): + # Triton's rshift operator depends on the signedness of the left operand + lhs_dtype = _get_signed_np_dtype(lhs.data.dtype) + rhs_dtype = _get_signed_np_dtype(rhs.data.dtype) + lhs.data = lhs.data.astype(lhs_dtype) + rhs.data = rhs.data.astype(rhs_dtype) + return self.binary_op(lhs, rhs, np.right_shift) + + def create_umulhi(self, lhs, rhs): + dtype = lhs.data.dtype + if dtype == np.int64 or dtype == np.uint64: + return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar) + else: + compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}") + lhs_data = lhs.data.astype(compute_dtype) + rhs_data = rhs.data.astype(compute_dtype) + ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8) + return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar) + + create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + def create_fma(self, x, y, z): + return TensorHandle(x.data * y.data + z.data, z.dtype.scalar) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype.scalar) + + def create_fabs(self, arg): + # Mask out the sign bit based on the primitive length + dtype_tt = arg.dtype + mask_bitwidth = dtype_tt.primitive_bitwidth - 1 + np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}") + data = arg.data.view(np_uint_dtype) + mask = (1 << mask_bitwidth) - 1 + ret = (data & mask).view(_get_np_dtype(dtype_tt)) + return TensorHandle(ret, arg.dtype.scalar) + + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + create_floor = lambda self, arg: self.unary_op(arg, np.floor) + create_ceil = lambda self, arg: self.unary_op(arg, np.ceil) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_log2 = lambda self, arg: self.unary_op(arg, np.log2) + create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + + def create_erf(self, arg): + ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data) + return TensorHandle(ret, arg.dtype.scalar) + + def create_rsqrt(self, arg): + return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar) + + # tensor operators + create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar) + + def create_trans(self, arg, perm): + return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar) + + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): + a_data = a.data + b_data = b.data + if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \ + (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()): + a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16) + b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16) + return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + def create_histogram(self, data, bins): + return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.get_element_ty() + element_bitwidth = dtype_tt.primitive_bitwidth + # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic + element_bytewidth = max(1, element_bitwidth // 8) + return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if padding_option is None: + other = None + elif padding_option == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding_option == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding option {padding_option}") + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) + + def create_int_to_ptr(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_ptr_to_int(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_cat(self, lhs, rhs): + return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) + + def create_join(self, lhs, rhs): + # Triton only supports joining two original tensors into a new one along the last axis + return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar) + + def create_split(self, val): + # Triton only supports splitting the original tensor into two along the last axis + return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar)) + + def create_splat(self, arg, shape): + if isinstance(arg.dtype, tl.block_type): + return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + else: # scalar + return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + def create_atomic_cas(self, ptr, cmp, val, sem, scope): + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar) + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope): + if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op: + raise ValueError(f"unsupported rmwOp {rmwOp}") + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp] + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar) + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + raise NotImplementedError("extern_elementwise not supported in interpreter mode") + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + raise NotImplementedError("inline_asm not supported in interpreter mode") + + def create_print(self, prefix, hex, values): + # Interpreter's device_print function has a different format than Triton's device_print + msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" + if prefix: + msg += f" {prefix}" + if hex: + np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"}) + for value in values: + print(msg + f" {value.data}") + if hex: + np.set_printoptions(formatter=None) + + def create_assert(self, condition, message, fileName, funcName, lineNo): + # Interpreter's device_assert function has a different format than Triton's device_assert + assert condition, f"{message} in {fileName}:{funcName}:{lineNo}" + + def create_barrier(self): + # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in offsets] + return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order) + + def create_advance(self, ptr, offsets): + if len(ptr.offsets) != len(offsets): + raise ValueError("len(ptr.offsets) != len(offsets)") + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in ptr.offsets] + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + def get_all_ones_value(self, type): + np_type = _get_np_dtype(type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar) + else: + raise TypeError(f"unsupported type {type}") + + +def _patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_builder"}, _builder=builder)) + setattr(obj, name, new_member) + + +def _patch_builtin(pkg, builder): + for name, member in inspect.getmembers(pkg): + if tl.core.is_builtin(member): + _patch_attr(pkg, name, member, builder) + + +def _patch_lang_tensor(tensor): + + def _get_bool(self): + data = self.handle.data + # in triton, only scalars can be converted to booleans + # here we need this hack because all scalars are tensors + return bool(data) if data.size == 1 else True + + def _get_transpose(self): + return tl.core.tensor(TensorHandle(np.transpose(self.handle.data), self.handle.dtype), self.dtype.scalar) + + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: _get_bool(self) + tensor.__repr__ = lambda self: repr(self.handle.data) + tensor.__str__ = lambda self: str(self.handle.data) + tensor.T = property(_get_transpose) + + +class ReduceScanOpIneterface: + + def __init__(self, axis, combine_fn): + self.axis = axis + self.combine_fn = combine_fn + + def check_axis(self, shape, axis): + if axis is not None and axis >= len(shape): + raise ValueError(f"axis {axis} out of bounds for shape {shape}") + + def check_tensor(self, input): + for arg in input: + if not isinstance(arg, tl.core.tensor): + raise ValueError(f"input must be a tensor, got {type(arg)}") + self.check_axis(arg.shape, self.axis) + + def to_tensor(self, ret, dtype): + if hasattr(ret, "shape") and ret.shape: + ret_type = tl.block_type(dtype, ret.shape) + else: + ret = np.array([ret], dtype=_get_np_dtype(dtype)) + ret_type = dtype + return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) + + def apply(self, input): + if not isinstance(input, tuple): + input = (input, ) + self.check_tensor(input) + return self.apply_impl(input) + + def apply_impl(self, input): + raise NotImplementedError("apply_impl not implemented") + + +class ReduceOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, keep_dims): + super().__init__(axis, combine_fn) + self.keep_dims = keep_dims + + def unravel(self, input, axis): + ret = [] + for data in input: + if axis is not None: + ret.append(data) + else: + axis = 0 + ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype)) + return tuple(ret), axis + + def generic_reduce(self, input): + original_axis = self.axis + input, axis = self.unravel(input, self.axis) + input_data = [] + output_data = [] + input_shape = input[0].handle.data.shape + output_shape = input_shape[0:axis] + input_shape[axis + 1:] + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype)) + # Reduce on axis + for i in range(input_data[0].size): + # Recover input_index from i using input_shape + input_index = np.unravel_index(i, input_shape) + output_index = input_index[0:axis] + input_index[axis + 1:] + input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data)) + if input_index[axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][output_index] = input_tuple[j].handle.data.item() + else: + acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + if self.keep_dims: + if original_axis is not None: + data = np.expand_dims(data, axis) + else: + for _ in range(len(input_shape)): + data = np.expand_dims(data, 0) + + elif original_axis is None: + # Take a scalar + data = data.item() + ret.append(self.to_tensor(data, input[i].dtype)) + return ret[0] if len(ret) == 1 else tuple(ret) + + def min_max(self, input, val_reduce_op, idx_reduce_op=None): + # If input is a tuple, it must be (val, index), and we only take val + input = input[0] if isinstance(input, tuple) else input + val = None + idx = None + if val_reduce_op: + val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + if idx_reduce_op: + idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32) + if val is not None and idx is not None: + return val, idx + elif val is not None: + return val + elif idx is not None: + return idx + else: + raise ValueError("val_reduce_op and idx_reduce_op are both None") + + def sum(self, input): + return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + + def apply_impl(self, input): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + elif self.combine_fn == tl.standard._elementwise_max: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input) + + +class ScanOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, reverse): + super().__init__(axis, combine_fn) + self.reverse = reverse + + def cumsum(self, input): + return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def cumprod(self, input): + return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def generic_scan(self, input): + input_data = [] + output_data = [] + shape = input[0].handle.data.shape + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype)) + # Scan on axis + for i in range(input_data[0].size): + # Recover index from i using shape + index = np.unravel_index(i, shape) + data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data)) + if index[self.axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][index] = data[j].handle.data.item() + else: + prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index))) + acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def apply_impl(self, input): + new_input = [] + if self.reverse: + for arg in input: + new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype)) + else: + new_input = input + if self.combine_fn == tl.standard._sum_combine: + ret = self.cumsum(new_input[0]) + elif self.combine_fn == tl.standard._prod_combine: + ret = self.cumprod(new_input[0]) + else: + # Fall back to the slow mode + ret = self.generic_scan(new_input) + if self.reverse: + for arg in ret: + arg.handle.data = np.flip(arg.handle.data, axis=self.axis) + return len(ret) == 1 and ret[0] or tuple(ret) + + +def _patch_reduce_scan(): + # Because interpreter doesn't support region_builder_fn, we cannot patch the builder + # to use the new reduce and scan functions. + # Instead, we need to patch reduce and reduce functions in tl and tl.core + def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs): + return ReduceOps(axis, combine_fn, keep_dims).apply(input) + + def _new_scan(input, axis, combine_fn, reverse=False, **kwargs): + return ScanOps(axis, combine_fn, reverse).apply(input) + + tl.reduce = _new_reduce + tl.associative_scan = _new_scan + tl.core.reduce = _new_reduce + tl.core.associative_scan = _new_scan + + +def _patch_lang_core(lang): + + def _new_to_ir(self, builder): + # We need to specify signedness for integer types in the numpy mode + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8': + return builder.get_int8_ty() + elif self.name == 'uint8': + return builder.get_uint8_ty() + elif self.name == 'int16': + return builder.get_int16_ty() + elif self.name == 'uint16': + return builder.get_uint16_ty() + elif self.name == 'int32': + return builder.get_int32_ty() + elif self.name == 'uint32': + return builder.get_uint32_ty() + elif self.name == 'int64': + return builder.get_int64_ty() + elif self.name == 'uint64': + return builder.get_uint64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + # can't just map lang.static_range to `range`, because `tl.static_range` + # can get `step` passed by keyword + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_static_assert(cond, msg=""): + assert cond, msg + + def _set_attr(input, values, name): + # skip non tensor types. This may happen for induction variables. + if not isinstance(input, tl.tensor): + return input + # Unwrap constexpr + values = [values] if not isinstance(values, (list, tuple)) else values + values = [v.value if isinstance(v, tl.constexpr) else v for v in values] + if len(values) != max(1, len(input.shape)): + raise ValueError(f"len(values) != len(input.shape) for {name}") + input.handle.set_attr(name, values) + return input + + lang.range = _new_range + lang.static_range = _new_range + lang.static_assert = _new_static_assert + lang.static_print = print + lang.dtype.to_ir = _new_to_ir + lang.multiple_of = partial(_set_attr, name="tt.divisiblity") + lang.max_contiguous = partial(_set_attr, name="tt.contiguity") + lang.max_constancy = partial(_set_attr, name="tt.constancy") + + _patch_reduce_scan() + + +def _patch_lang(fn): + lang = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_builtin(lang[0], interpreter_builder) + _patch_builtin(lang[0].tensor, interpreter_builder) + if lang[0] == tl: + _patch_builtin(lang[0].math, interpreter_builder) + _patch_lang_tensor(lang[0].tensor) + _patch_lang_core(lang[0]) + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + dtype = np.int32 + if -2**31 <= arg < 2**31: + dtype = np.int32 + elif 2**31 <= arg < 2**32: + dtype = np.uint32 + elif -2**63 <= arg < 2**63: + dtype = np.int64 + elif 2**63 <= arg < 2**64: + dtype = np.uint64 + else: + raise ValueError(f"Unsupported integer value {arg}") + handle = TensorHandle(np.array([arg], dtype=dtype), ty) + return tl.tensor(handle, ty) + if hasattr(arg, "data_ptr"): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + return arg + + +interpreter_builder = InterpreterBuilder() + +# These keywords are not supported by the interpreter +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] + + def _init_args_hst(self, args_dev, kwargs): + args_hst = [] + for arg in args_dev: + if hasattr(arg, "data_ptr"): + args_hst.append(arg.cpu()) + else: + args_hst.append(arg) + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + if hasattr(value, "data_ptr"): + kwargs_hst[key] = value.cpu() + else: + kwargs_hst[key] = value + return args_hst, kwargs_hst + + def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): + for arg_dev, arg_hst in zip(args_dev, args_hst): + if hasattr(arg_dev, "data_ptr"): + arg_dev.data.copy_(arg_hst.to(arg_dev.device).data) + + # Restore keyword arguments + for key, kwarg_dev in kwargs.items(): + kwarg_hst = kwargs_hst[key] + if hasattr(kwarg_dev, "data_ptr"): + kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data) + + def __call__(self, *args_dev, **kwargs): + # removes reserved keywords from kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + if kwargs.pop("warmup", False): + return + # copy arguments to the host + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + # remaps core language functions to interpreted ones + _patch_lang(self.fn) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3, "grid must have at most 3 dimensions" + grid = grid + (1, ) * (3 - len(grid)) + interpreter_builder.set_grid_dim(*grid) + try: + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + interpreter_builder.set_grid_idx(x, y, z) + self.fn(**args) + except Exception as e: + raise InterpreterError(repr(e)) from e + # copy arguments back to propagate side-effects + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + + +class InterpretedFunction: + + def __init__(self, fn) -> None: + self.fn = fn + + def run(*args, **kwargs): + grid = kwargs["grid"] + return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + @property + def __name__(self): + return self.fn.__name__ + + def __getitem__(self, grid): + return GridExecutor(self.fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + # This is a device function call + _patch_lang(self.fn) + try: + return self.fn(*args, **kwargs) + except Exception as e: + raise InterpreterError(repr(e)) from e diff --git a/third_party/mthreads/python/triton/runtime/jit.py b/third_party/mthreads/python/triton/runtime/jit.py new file mode 100644 index 000000000..8a88d76be --- /dev/null +++ b/third_party/mthreads/python/triton/runtime/jit.py @@ -0,0 +1,973 @@ +from __future__ import annotations, division +import ast +import hashlib +import inspect +import itertools +import os +import re +import textwrap +from collections import defaultdict +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple +from ..runtime.driver import driver +from types import ModuleType + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + + +def get_cuda_stream(idx=None): + if idx is None: + idx = get_current_device() + try: + from torch._C import _cuda_getCurrentRawStream + return _cuda_getCurrentRawStream(idx) + except ImportError: + import torch + return torch.cuda.current_stream(idx).cuda_stream + + +def get_current_device(): + import torch + return torch.cuda.current_device() + + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + val = self.globals.get(node.id, None) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) != ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins # + ): + self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + + def is_triton_builtin(func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + func = self.visit(node.func) + assert func is None or is_triton_builtin(func) or isinstance( + func, JITFunction + ), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this' + + # Traverse arguments as well as node.func so we can find JITFunctions + # passed to tl.reduce or tl.associative_scan as the combine_fn + for obj in itertools.chain( + (func, ), + map(self.visit, node.args), + (self.visit(kw.value) for kw in node.keywords), + ): + if not isinstance(obj, JITFunction): + continue + if is_triton_builtin(obj): + continue + + func_cache_key = obj.cache_key + + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & obj.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = obj.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + + self.used_global_vals.update(obj.used_global_vals) + + noinline = str(getattr(obj, "noinline", False)) + + key = func_cache_key + noinline + self.hasher.update(key.encode("utf-8")) + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + if isinstance(ty, type): + return ty.__name__ + elif isinstance(ty, str): + return ty + return repr(ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self): + annotation = self.annotation + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +def compute_spec_key(v): + + if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): + return "D" + elif isinstance(v, int): + # bool is a subclass of int, so we don't check explicitly above. + if (v % 16 == 0): + return "D" + elif v == 1: + return "1" + return "N" + + +dtype2str = {} + + +def mangle_type(arg, is_const=False): + + if arg is None: + return "none" + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + else: + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] + dtype2str[dsk] = res + return res + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': + options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + + assert len(sig.parameters) == len(kparams) + + # Create the function argument list and the dict entries for the return statement + func_args = [] + dict_entries = [] + constexpr_vals = [] + non_constexpr_vals = [] + signature_types = [] + specialisations = [] + + for ((name, sp), kp) in zip(sig.parameters.items(), kparams): + if sp.default is inspect.Parameter.empty: + func_args.append(name) + dict_entries.append(f"'{name}': {name}") + else: + func_args.append(f"{name}=default_{name}") + dict_entries.append(f"'{name}': {name}") + if kp.is_constexpr: + constexpr_vals.append(name) + else: + non_constexpr_vals.append(name) + if not kp.do_not_specialize: + specialisations.append('compute_spec_key(%s)' % name) + if kp.annotation_type: + signature_types.append('"%s"' % kp.annotation_type) + else: + signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False')) + + cache_key = ''.join([x + ', ' for x in signature_types + specialisations]) + constexpr_vals = ''.join([x + ', ' for x in constexpr_vals]) + non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals]) + + func_args.append('**excess_kwargs') + + # Join all arguments into a function definition string + args_str = ', '.join(func_args) + dict_str = ', '.join(dict_entries) + func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % ( + args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals) + + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace['mangle_type'] = mangle_type + func_namespace['compute_spec_key'] = compute_spec_key + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +type_canonicalisation_dict = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +class JITFunction(KernelInterface[T]): + # Hook for inspecting compiled functions and modules + cache_hook = None + divisibility = 16 + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif arg is None: + return None + else: + raise TypeError(f"Unsupported type {type(arg)} for {arg}") + + @staticmethod + def _spec_of(arg): + if hasattr(arg, "data_ptr"): + return arg.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(arg, int): + return (arg % 16 == 0, arg == 1) + return (arg is None, ) + + def _get_config(self, *args): + from ..compiler import AttrsDescriptor + + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(x, int): + return x % JITFunction.divisibility == 0 + if x is None: + return True + return False + + divisible_by_16 = { + param.num + for param, arg in zip(self.params, args) + if is_divisible_by_16(arg) and not param.do_not_specialize + } + equal_to_1 = { + param.num + for param, arg in zip(self.params, args) + if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize + } + # folded equal_to_1 and None + # TODO: method to collect all folded args + return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, + # equal_to_1) + + @staticmethod + def _type_of(key, is_const=False): + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + elif isinstance(key, str): + return key + + dtype_str = str(key).split(".")[-1] + dtype_str = type_canonicalisation_dict[dtype_str] + const_str = "*k" if is_const else "*" + return const_str + dtype_str + + def _make_constants(self, constexpr_key): + constants = dict(zip(self.constexprs, constexpr_key)) + return constants + + def _call_hook( + self, + key, + signature, + device, + constants, + options, + configs, + ): + if JITFunction.cache_hook is None: + return False + + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" + + class JitFunctionInfo: + + def __init__(self, module, name, jit_function): + self.module = module + self.name = name + self.jit_function = jit_function + pass + + specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + } + + return JITFunction.cache_hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=False, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + self.make_backend = make_backend + self.binder = create_function_from_signature(self.signature, self.params) + self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] + self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] + self.specialised_indices = [ + i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr) + ] + + def run(self, *args, grid, warmup, **kwargs): + # parse options + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + kwargs["debug"] = self.debug + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + if self.binder is None: + self.create_binder() + + bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) + + # compute cache key + key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) + kernel = self.cache[device].get(key, None) + + if kernel is None: + # Kernel is not cached; we have to compile. + target = driver.active.get_current_target() + backend = self.make_backend(target) + options = backend.parse_options(kwargs) + + # deprecated arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: + if k not in options.__dict__: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + + bound_vals = tuple(bound_args.values()) + + # `None` is nullptr. Implicitly convert to *i8. This needs to be + # done here rather than when we build the signature as otherwise + # the kernel cache key could not distinguish between byte pointers + # and None arguments, resulting in a downstream mismatch: + sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigvals = sig_and_spec[:len(sigkeys)] + signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + + configs = (self._get_config(*bound_vals), ) + constants = { + p.name: v + for (v, p) in zip(bound_vals, self.params) + if p.is_constexpr or p.num in configs[0].equal_to_1 or v is None + } + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + + if self._call_hook(key, signature, device, constants, options, configs): + return None + # compile the kernel + src = self.ASTSource(self, signature, constants, configs[0]) + kernel = self.compile( + src, + target=target, + options=options.__dict__, + ) + self.cache[device][key] = kernel + + # Check that used global values have not changed. + not_present = object() + for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) + return kernel + + def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None, repr=None, + launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.starting_line_number = inspect.getsourcelines(fn)[1] + self.repr = lambda _: fn.__name__ if repr is None else repr(_) + self.launch_metadata = launch_metadata + + self.binder = None + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize) + self.params.append(KernelParam(i, param, dns)) + + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():] + # cache of just-in-time compiled kernels + self.cache = defaultdict(dict) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler import AttrsDescriptor, compile, ASTSource + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self.fn.__name__: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in deserialized_obj['constants'].items() + } + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.cache[device][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == "src": + self.hash = None + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from .interpreter import InterpretedFunction + return InterpretedFunction(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, i): + return self.base.stride(i) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") diff --git a/third_party/mthreads/python/triton/testing.py b/third_party/mthreads/python/triton/testing.py new file mode 100644 index 000000000..0c8d4bcea --- /dev/null +++ b/third_party/mthreads/python/triton/testing.py @@ -0,0 +1,496 @@ +import functools +import os +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + """ + import torch + assert return_mode in ["min", "max", "mean", "median"] + + if torch.cuda.current_stream() == torch.cuda.default_stream(): + raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.") + # warmup + fn() + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for i in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for i in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + times = torch.tensor(ret) + return getattr(torch, return_mode)(times).item() + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", + device_type="cuda"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + """ + assert return_mode in ["min", "max", "mean", "median"] + import torch + + di = torch._dynamo.device_interface.get_interface_for_device(device_type) + + fn() + di.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device_type) + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device=device_type) + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + color=None, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + html.close() + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/mthreads/python/triton/tools/__init__.py b/third_party/mthreads/python/triton/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/mthreads/python/triton/tools/build_extern.py b/third_party/mthreads/python/triton/tools/build_extern.py new file mode 100644 index 000000000..8f0168d59 --- /dev/null +++ b/third_party/mthreads/python/triton/tools/build_extern.py @@ -0,0 +1,365 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + + +class Symbol: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) + + @property + def name(self) -> str: + return self._name + + @property + def op_name(self) -> str: + return self._op_name + + @property + def ret_type(self) -> str: + return self._ret_type + + @property + def arg_names(self) -> List[str]: + return self._arg_names + + @property + def arg_types(self) -> List[str]: + return self._arg_types + + +def convert_type(type_str) -> Optional[str]: + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str) -> str: + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = format + self._grouping = grouping + + @property + def name(self) -> str: + return self._name + + @property + def path(self) -> str: + return self._path + + @property + def symbols(self) -> Dict[str, Symbol]: + return self._symbols + + @property + def grouping(self) -> bool: + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file) -> None: + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir) -> None: + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + self.is_pure = True + + @staticmethod + def _extract_symbol(line) -> Optional[Symbol]: + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self) -> None: + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file) -> None: + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self) -> str: + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return core.extern_elementwise("libdevice", , , , _builder) + import_str = "from . import core\n" + + header_str = "" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@core.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += f", is_pure={self.is_pure}" + return_str += ", _builder=_builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + _path: str + _ll_file: str + + def __init__(self, path) -> None: + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path: str) -> None: + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self) -> str: + return self._ll_file + + @property + def path(self) -> str: + return self._path + + +extern_libs = ["libdevice"] + + +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/third_party/mthreads/python/triton/tools/compile.c b/third_party/mthreads/python/triton/tools/compile.c new file mode 100644 index 000000000..971bf6191 --- /dev/null +++ b/third_party/mthreads/python/triton/tools/compile.c @@ -0,0 +1,67 @@ +/* clang-format off */ +#include +#include +#include +#include +#include + + +// helpers to check for cuda errors +#define CUDA_CHECK(ans) {{\ + gpuAssert((ans), __FILE__, __LINE__);\ + }}\ + +static inline void gpuAssert(CUresult code, const char *file, int line) {{ + if (code != CUDA_SUCCESS) {{ + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + printf("%s\\n", err); + exit(code); + }} +}} + +// globals +#define CUBIN_NAME {kernel_name}_cubin +CUmodule {kernel_name}_mod = NULL; +CUfunction {kernel_name}_func = NULL; +unsigned char CUBIN_NAME[{bin_size}] = {{ {bin_data} }}; + + +void unload_{kernel_name}(void) {{ + CUDA_CHECK(cuModuleUnload({kernel_name}_mod)); +}} + +// TODO: some code duplication with `runtime/backend/cuda.c` +void load_{kernel_name}() {{ + int dev = 0; + void *bin = (void *)&CUBIN_NAME; + int shared = {shared}; + CUDA_CHECK(cuModuleLoadData(&{kernel_name}_mod, bin)); + CUDA_CHECK(cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}")); + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev)); + if (shared > 49152 && shared_optin > 49152) {{ + CUDA_CHECK(cuFuncSetCacheConfig({kernel_name}_func, CU_FUNC_CACHE_PREFER_SHARED)); + CUDA_CHECK(cuFuncSetAttribute({kernel_name}_func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)) + }} +}} + +/* +{kernel_docstring} +*/ +CUresult {kernel_name}(CUstream stream, {signature}) {{ + if ({kernel_name}_func == NULL) + load_{kernel_name}(); + unsigned int gX = {gridX}; + unsigned int gY = {gridY}; + unsigned int gZ = {gridZ}; + void *args[{num_args}] = {{ {arg_pointers} }}; + // TODO: shared memory + if(gX * gY * gZ > 0) + return cuLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * 32, 1, 1, {shared}, stream, args, NULL); +}} diff --git a/third_party/mthreads/python/triton/tools/compile.h b/third_party/mthreads/python/triton/tools/compile.h new file mode 100644 index 000000000..d98b7063b --- /dev/null +++ b/third_party/mthreads/python/triton/tools/compile.h @@ -0,0 +1,14 @@ +#ifndef TT_KERNEL_INCLUDES +#define TT_KERNEL_INCLUDES + +#include +#include +#include +#include + +#endif + +void unload_{kernel_name}(void); +void load_{kernel_name}(void); +// tt-linker: {kernel_name}:{full_signature}:{algo_info} +CUresult{_placeholder} {kernel_name}(CUstream stream, {signature}); diff --git a/third_party/mthreads/python/triton/tools/compile.py b/third_party/mthreads/python/triton/tools/compile.py new file mode 100644 index 000000000..872332b03 --- /dev/null +++ b/third_party/mthreads/python/triton/tools/compile.py @@ -0,0 +1,145 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from pathlib import Path +from typing import List + +import triton +from triton.compiler.code_generator import kernel_suffix +from triton.backends.nvidia.driver import ty_to_cpp + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + +if __name__ == "__main__": + + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) + parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") + parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") + parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) + args = parser.parse_args() + + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {k: v for k, v in hints.items() if v is not None} + constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + divisible_by_16 = [i for i, h in hints.items() if h == 16] + equal_to_1 = [i for i, h in hints.items() if h == 1] + attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + for i in equal_to_1: + constants.update({i: 1}) + src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, options=opts) + arg_names = [] + arg_types = [] + for i in signature.keys(): + if i not in equal_to_1: + arg_names += [kernel.arg_names[i]] + arg_types += [signature[i]] + + # dump C stub code + suffix = kernel_suffix(signature.values(), attrs) + func_name = '_'.join([out_name, sig_hash, suffix]) + hex_ = str(binascii.hexlify(ccinfo.asm["cubin"]))[2:-1] + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(hex_), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), + "num_args": len(arg_names), + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": '_'.join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + } + for ext in ['h', 'c']: + template_path = Path(__file__).parent / f"compile.{ext}" + with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp: + fp.write(Path(template_path).read_text().format(**params)) diff --git a/third_party/mthreads/python/triton/tools/disasm.py b/third_party/mthreads/python/triton/tools/disasm.py new file mode 100644 index 000000000..1e309a2e4 --- /dev/null +++ b/third_party/mthreads/python/triton/tools/disasm.py @@ -0,0 +1,142 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import functools +import os +import re +import subprocess +import tempfile + +from ..common.backend import path_to_cuobjdump, path_to_nvdisasm + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) is not None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +@functools.lru_cache() +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + +def extract(file_path, fun): + cuobjdump, _ = path_to_cuobjdump() + nvdisasm, _ = path_to_nvdisasm() + os.environ["NVDISASM_PATH"] = nvdisasm + if fun is None: + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) + else: + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + + # Looking for new function header (function: ) + while FNAME_RE.match(line) is None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) is not None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convention: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret diff --git a/third_party/mthreads/python/triton/tools/link.py b/third_party/mthreads/python/triton/tools/link.py new file mode 100644 index 000000000..75a1157a5 --- /dev/null +++ b/third_party/mthreads/python/triton/tools/link.py @@ -0,0 +1,322 @@ +from collections import defaultdict +from pathlib import Path +from typing import Sequence, Union + +from dataclasses import dataclass + + +def _exists(x): + return x is not None + + +class LinkerError(Exception): + pass + + +@dataclass +class KernelLinkerMeta: + orig_kernel_name: str + arg_names: Sequence[str] + arg_ctypes: Sequence[str] + sizes: Sequence[Union[int, None]] + sig_hash: str + triton_suffix: str + suffix: str + num_specs: int + """ number of specialized arguments """ + + +class HeaderParser: + + def __init__(self) -> None: + import re + + # [kernel_name, c signature] + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + # [name, hash, suffix] + self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") + # [(type, name)] + self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") + # [d|c] + self.arg_suffix = re.compile("[c,d]") + + self.kernels = defaultdict(list) + + def extract_linker_meta(self, header: str): + for ln in header.splitlines(): + if ln.startswith("//"): + m = self.linker_directives.match(ln) + if _exists(m): + ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3) + name, sig_hash, suffix = self._match_name(ker_name) + c_types, arg_names = self._match_c_sig(c_sig) + num_specs, sizes = self._match_suffix(suffix, c_sig) + self._add_kernel( + "_".join([name, algo_info]), + KernelLinkerMeta( + orig_kernel_name=name, + arg_names=arg_names, + arg_ctypes=c_types, + sizes=sizes, + sig_hash=sig_hash, + triton_suffix=suffix, + suffix=suffix, + num_specs=num_specs, + ), + ) + + def _match_name(self, ker_name: str): + m = self.kernel_name.match(ker_name) + if _exists(m): + name, sig_hash, suffix = m.group(1), m.group(2), m.group(3) + return name, sig_hash, suffix + raise LinkerError(f"{ker_name} is not a valid kernel name") + + def _match_c_sig(self, c_sig: str): + m = self.c_sig.findall(c_sig) + if len(m): + tys, args = [], [] + for ty, arg_name in m: + tys.append(ty) + args.append(arg_name) + return tys, args + + raise LinkerError(f"{c_sig} is not a valid argument signature") + + def _match_suffix(self, suffix: str, c_sig: str): + args = c_sig.split(",") + s2i = {"c": 1, "d": 16} + num_specs = 0 + sizes = [] + # scan through suffix, first find the index, + # then see if it is followed by d or c + for i in range(len(args)): + pos = suffix.find(str(i)) + if pos == -1: + raise LinkerError(f"{suffix} is not a valid kernel suffix") + pos += len(str(i)) + if self.arg_suffix.match(suffix, pos): + num_specs += 1 + sizes.extend([None] * (i - len(sizes))) + sizes.append(s2i[suffix[pos]]) + pos += 1 + if i < len(args) - 1: + suffix = suffix[pos:] + else: + sizes.extend([None] * (len(args) - len(sizes))) + return num_specs, sizes + + def _add_kernel(self, name: str, ker: KernelLinkerMeta): + if name in self.kernels: + last: KernelLinkerMeta = self.kernels[name][-1] + + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): + if cur != new_: + raise LinkerError( + f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" + ) + + self.kernels[name].append(ker) + + +def gen_signature_with_full_args(m): + return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)]) + + +def gen_signature(m): + arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1] + arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1] + sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)]) + return sig + + +# generate declarations of kernels with meta-parameter and constant values +def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + return f""" +CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}); +void load_{name}(); +void unload_{name}(); + """ + + +# generate declarations of kernels with meta-parameter and constant values +def make_global_decl(meta: KernelLinkerMeta) -> str: + return f""" +CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}); +CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id); +void load_{meta.orig_kernel_name}(); +void unload_{meta.orig_kernel_name}(); + """ + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different integer value hints +def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + src = f"// launcher for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" + src += "\n" + + src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{") + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + cond_fn = ( # + lambda val, hint: f"({val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required + arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] + src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" + src += "\n" + src += " return CUDA_ERROR_INVALID_VALUE;\n" + src += "}\n" + + for mode in ["load", "unload"]: + src += f"\n// {mode} for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f"void {mode}_{name}() {{" + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: + src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" + src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" + src += "}\n" + return src + + +# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values +def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: + # the table of hint dispatchers + src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n" + src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" + for name in names: + src += f" {name},\n" + src += "};\n" + return src + + +# generate definition for load/unload functions for kernels with different meta-parameter and constant values +def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str: + src = "" + for mode in ["load", "unload"]: + src += f"void {mode}_{meta.orig_kernel_name}(void){{\n" + for name in names: + src += f" {mode}_{name}();\n" + src += "}\n\n" + return src + + +def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void);" + return src + + +def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n" + src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n" + src += "}\n" + return src + + +desc = """ +Triton ahead-of-time linker: + +This program takes in header files generated by compile.py, and generates a +single entry-point responsible for dispatching the user's input to the right +kernel given the specializations that were compiled. + +Example usage: +python link.py /path/to/headers/*.h -o kernel_name +""" + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser(description=desc) + parser.add_argument( + "headers", + nargs="+", + help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", + ) + parser.add_argument("--out", "-o", type=Path, help="Out filename") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) + args = parser.parse_args() + + # metadata + parser = HeaderParser() + includes = [] + for header in args.headers: + h_path = Path(header) + h_str = h_path.read_text() + includes.append(h_path.name) + parser.extract_linker_meta(h_str) + + # generate headers + algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()] + meta_lists = [meta for name, meta in parser.kernels.items()] + meta = meta_lists[0][0] + get_num_algos_decl = make_get_num_algos_decl(meta) + global_decl = make_global_decl(meta) + with args.out.with_suffix(".h").open("w") as fp: + out = "#include \n" + out += "\n".join(algo_decls) + out += "\n" + out += get_num_algos_decl + out += "\n" + out += global_decl + fp.write(out) + + # generate source + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + names = [name for name in parser.kernels.keys()] + func_pointers_def = make_func_pointers(names, meta) + meta_const_def = make_kernel_meta_const_dispatcher(meta) + load_unload_def = make_kernel_load_def(names, meta) + get_num_algos_def = make_get_num_algos_def(meta) + default_algo_kernel = make_default_algo_kernel(meta) + with args.out.with_suffix(".c").open("w") as fp: + out = "" + out += "#include \n" + out += "#include \n" + out += "#include \n" + out += "\n" + out += "\n".join(defs) + out += "\n" + out += func_pointers_def + out += "\n" + out += get_num_algos_def + out += "\n" + out += meta_const_def + out += "\n" + out += load_unload_def + out += "\n" + out += default_algo_kernel + fp.write(out) diff --git a/third_party/nvidia/CMakeLists.txt b/third_party/nvidia/CMakeLists.txt index d36a88272..75f98fa8f 100644 --- a/third_party/nvidia/CMakeLists.txt +++ b/third_party/nvidia/CMakeLists.txt @@ -5,3 +5,6 @@ add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonNVIDIA ${CMAKE_CURRENT_SOURCE_DIR}/triton_nvidia.cc LINK_LIBS TritonNVIDIAGPUToLLVM NVGPUToLLVM) endif() +if(TRITON_BUILD_UT) + add_subdirectory(unittest) +endif() diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 6d7994923..ac5318893 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -3,7 +3,8 @@ from dataclasses import dataclass import functools -from typing import Any, Tuple, Optional +from typing import Any, Dict, Tuple, Optional +from types import ModuleType import hashlib import re import tempfile @@ -13,6 +14,10 @@ from pathlib import Path +def min_dot_size(target: GPUTarget): + return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16) + + @functools.lru_cache() def _path_to_binary(binary: str): paths = [ @@ -44,12 +49,37 @@ def ptx_get_version(cuda_version) -> int: assert isinstance(cuda_version, str) major, minor = map(int, cuda_version.split('.')) if major == 12: - return 80 + minor + if minor < 6: + return 80 + minor + elif minor == 6: + return 85 if major == 11: return 70 + minor if major == 10: return 63 + minor - raise RuntimeError("Triton only support CUDA 10.0 or higher") + raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version) + + +def get_ptx_version_from_options(options): + ptx_version = options.ptx_version + if ptx_version is None: + _, cuda_version = _path_to_binary("ptxas") + ptx_version = ptx_get_version(cuda_version) + return ptx_version + + +@functools.lru_cache() +def get_features(options): + ptx_version = get_ptx_version_from_options(options) + + # PTX 8.3 is the max version supported by llvm 3a83162168. + # + # To check if a newer PTX version is supported, increase this value + # and run a test. If it's not supported, LLVM will print a warning + # like "+ptx8.4 is not a recognized feature for this target". + llvm_ptx_version = min(83, ptx_version) + features = f'+ptx{llvm_ptx_version}' + return features @functools.lru_cache(None) @@ -63,20 +93,25 @@ class CUDAOptions: num_warps: int = 4 num_ctas: int = 1 num_stages: int = 3 + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 # maxnreg corresponds to the ptx parameter .maxnreg, which controls the # maximum number of 32-bit registers used by one thread. maxnreg: Optional[int] = None cluster_dims: tuple = (1, 1, 1) ptx_version: int = None enable_fp_fusion: bool = True - allow_fp8e4nv: bool = False - allow_fp8e4b15: bool = False + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15") + deprecated_fp8_dtypes: Tuple[str] = () default_dot_input_precision: str = "tf32" allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") max_num_imprecise_acc_default: bool = None extern_libs: dict = None debug: bool = False backend_name: str = 'cuda' + sanitize_overflow: bool = True def __post_init__(self): default_libdir = Path(__file__).parent / 'lib' @@ -108,8 +143,18 @@ def __init__(self, target: GPUTarget) -> None: def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts} - args["allow_fp8e4nv"] = self.capability >= 89 - args["allow_fp8e4b15"] = self.capability < 90 + if "supported_fp8_dtypes" not in args: + supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes) + if self.capability >= 89: + supported_fp8_dtypes.add("fp8e4nv") + args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + + if "deprecated_fp8_dtypes" not in args: + if self.capability >= 90: + args["deprecated_fp8_dtypes"] = ("fp8e4b15", ) + + if "enable_fp_fusion" not in args: + args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0 return CUDAOptions(**args) @@ -127,10 +172,15 @@ def get_codegen_implementation(self): import triton.language.extra.cuda as cuda codegen_fns = { "convert_custom_types": - cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70 + cuda.convert_custom_float8_sm80 if self.capability >= 80 else cuda.convert_custom_float8_sm70, + "min_dot_size": min_dot_size(self.target) } return codegen_fns + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.cuda import libdevice + return {"triton.language.extra.libdevice": libdevice} + def load_dialects(self, ctx): nvidia.load_dialects(ctx) @@ -146,6 +196,7 @@ def make_ttir(mod, metadata, opt): passes.common.add_cse(pm) passes.common.add_licm(pm) passes.common.add_symbol_dce(pm) + passes.ttir.add_loop_unroll(pm) pm.run(mod) return mod @@ -156,6 +207,11 @@ def make_ttgir(mod, metadata, opt, capability): cluster_info.clusterDimX = opt.cluster_dims[0] cluster_info.clusterDimY = opt.cluster_dims[1] cluster_info.clusterDimZ = opt.cluster_dims[2] + # Set up Diagnostic + if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": + srcMgr = llvm.source_mgr() + diag = ir.source_mgr_diag(srcMgr, mod.context) + mod.context.printOpOnDiagnostic(True) # TTIR -> TTGIR pm = ir.pass_manager(mod.context) pm.enable_debug() @@ -173,8 +229,15 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.common.add_cse(pm) if capability // 10 >= 8: + passes.ttgpuir.add_optimize_accumulator_init(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.ttgpuir.add_ws_task_partition(pm, opt.num_consumer_groups) + passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups, + opt.reg_dec_producer, opt.reg_inc_consumer) passes.ttgpuir.add_pipeline(pm, opt.num_stages) + passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.ttgpuir.add_remove_layout_conversions(pm) @@ -192,6 +255,8 @@ def make_ttgir(mod, metadata, opt, capability): @staticmethod def make_llir(src, metadata, options, capability): + ptx_version = get_ptx_version_from_options(options) + # warp-specialization mutates num_warps num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") if num_warp_groups is not None: @@ -200,12 +265,17 @@ def make_llir(src, metadata, options, capability): # TritonGPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + # Set up Diagnostic + if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1": + srcMgr = llvm.source_mgr() + diag = ir.source_mgr_diag(srcMgr, mod.context) + mod.context.printOpOnDiagnostic(True) nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) passes.ttgpuir.add_allocate_shared_memory(pm) - nvidia.passes.ttgpuir.add_to_llvmir(pm, capability) + nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) passes.convert.add_arith_to_llvmir(pm) passes.common.add_canonicalizer(pm) @@ -217,7 +287,12 @@ def make_llir(src, metadata, options, capability): # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + proc = 'sm_90a' if capability == 90 else f'sm_{capability}' + features = get_features(options) + triple = 'nvptx64-nvidia-cuda' + llvm.attach_datalayout(llvm_mod, triple, proc, features) nvidia.set_nvvm_reflect_ftz(llvm_mod) # Set maxnreg on all kernels, if it was provided. @@ -241,21 +316,11 @@ def make_llir(src, metadata, options, capability): @staticmethod def make_ptx(src, metadata, opt, capability): - ptx_version = opt.ptx_version - if ptx_version is None: - _, cuda_version = _path_to_binary("ptxas") - ptx_version = ptx_get_version(cuda_version) - - # PTX 8.3 is the max version supported by llvm 3a83162168. - # - # To check if a newer PTX version is supported, increase this value - # and run a test. If it's not supported, LLVM will print a warning - # like "+ptx8.4 is not a recognized feature for this target". - llvm_ptx_version = min(83, ptx_version) + ptx_version = get_ptx_version_from_options(opt) triple = 'nvptx64-nvidia-cuda' proc = 'sm_90a' if capability == 90 else f'sm_{capability}' - features = f'+ptx{llvm_ptx_version}' + features = get_features(opt) ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False) # Find kernel names (there should only be one) names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret) @@ -280,31 +345,35 @@ def make_cubin(src, metadata, opt, capability): fsrc.flush() fbin = fsrc.name + '.o' - line_info = '' if os.environ.get('TRITON_DISABLE_LINE_INFO') else ' -lineinfo' - fmad = '' if opt.enable_fp_fusion else ' --fmad=false' - suffix = 'a ' if capability == 90 else ' ' - if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1": - cmd = f'{ptxas}{line_info}{fmad} -v --opt-level 0 --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' - else: - cmd = f'{ptxas}{line_info}{fmad} -v --gpu-name=sm_{capability}{suffix}{fsrc.name} -o {fbin} 2> {flog.name}' - + line_info = [] if os.environ.get('TRITON_DISABLE_LINE_INFO') else ['-lineinfo'] + fmad = [] if opt.enable_fp_fusion else ['--fmad=false'] + suffix = 'a' if capability == 90 else '' + opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else [] + ptxas_cmd = [ + ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin + ] try: - subprocess.run(cmd, shell=True, check=True) + subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog) + if os.path.exists(fsrc.name): + os.remove(fsrc.name) + if os.path.exists(flog.name): + os.remove(flog.name) except subprocess.CalledProcessError as e: with open(flog.name) as log_file: log = log_file.read() + if os.path.exists(flog.name): + os.remove(flog.name) + if e.returncode == 255: - raise RuntimeError(f'Internal Triton PTX codegen error: \n{log}') + error = 'Internal Triton PTX codegen error' elif e.returncode == 128 + signal.SIGSEGV: - raise RuntimeError( - f'Please run `ptxas {fsrc.name}` to confirm that this is a bug in `ptxas`\n{log}') + error = '`ptxas` raised SIGSEGV' else: - raise RuntimeError(f'`ptxas` failed with error code {e.returncode}: \n{log}') - finally: - if os.path.exists(fsrc.name): - os.remove(fsrc.name) - if os.path.exists(flog.name): - os.remove(flog.name) + error = f'`ptxas` failed with error code {e.returncode}' + + raise RuntimeError(f'{error}\n' + f'`ptxas` stderr:\n{log}\n' + f'Repro command: {" ".join(ptxas_cmd)}\n') with open(fbin, 'rb') as f: cubin = f.read() diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 44524da27..bb0d86888 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -274,6 +274,7 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { } Py_END_ALLOW_THREADS; + Py_INCREF(Py_None); return Py_None; } @@ -284,12 +285,11 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { uint64_t dim; uint32_t tensorDim; int elementSize; - Py_buffer desc_buffer; - if (!PyArg_ParseTuple(args, "KKiiy*", &global_address, &dim, &tensorDim, - &elementSize, &desc_buffer)) { + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, + &elementSize, &desc_address)) { return NULL; } - char *desc = (char *)desc_buffer.buf; uint64_t dims[1] = {dim}; uint64_t globalStrides[1] = {dim * elementSize}; uint32_t boxDim[1] = {tensorDim}; @@ -307,18 +307,19 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { break; default: PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); + return NULL; } assert((elementSize * tensorDim) >= 32 && "block size too small."); int rank = 1; static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, getCuTensorMapEncodeTiledHandle); - CUresult result = cuTensorMapEncodeTiled( - (CUtensorMap *)desc, type, rank, (void *)global_address, dims, + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - assert(result == CUDA_SUCCESS); + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); return Py_None; } @@ -329,13 +330,12 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { uint64_t dims[2]; uint32_t tensorDims[2]; int elementSize; - Py_buffer desc_buffer; - if (!PyArg_ParseTuple(args, "KKKiiiy*", &global_address, &dims[1], &dims[0], + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0], &tensorDims[1], &tensorDims[0], &elementSize, - &desc_buffer)) { + &desc_address)) { return NULL; } - char *desc = (char *)desc_buffer.buf; uint64_t globalStrides[2] = {dims[0] * elementSize, dims[0] * dims[1] * elementSize}; uint32_t elementStrides[2] = {1, 1}; @@ -377,12 +377,12 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, getCuTensorMapEncodeTiledHandle); - CUresult result = cuTensorMapEncodeTiled( - (CUtensorMap *)desc, type, rank, (void *)global_address, dims, + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - assert(result == CUDA_SUCCESS); + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); return Py_None; } diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 90f71138b..0a62c378c 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -110,6 +110,7 @@ def ty_to_cpp(ty): "fp32": "float", "f32": "float", "fp64": "double", + "nvTmaDesc": "CUtensorMap", }[ty] @@ -121,6 +122,9 @@ def make_launcher(constants, signature, ids): def _extracted_type(ty): if ty[0] == '*': return "PyObject*" + if ty == "nvTmaDesc": + return "PyObject*" + return ty_to_cpp(ty) def format_of(ty): @@ -143,6 +147,16 @@ def format_of(ty): format = "iiiKKOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + internal_args_list = [] + for i, ty in signature.items(): + if ty[0] == "*": + internal_args_list.append(f"ptr_info{i}.dev_ptr") + elif ty == "nvTmaDesc": + # Note: we have to dereference the pointer + internal_args_list.append(f"*tma_ptr{i}") + else: + internal_args_list.append(f"_arg{i}") + # generate glue code params = [i for i in signature.keys() if i not in constants] src = f""" @@ -261,6 +275,9 @@ def format_of(ty): PyErr_Format(PyExc_ValueError, "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); ptr_info.valid = false; + }} else if (status != CUDA_SUCCESS) {{ + CUDA_CHECK(status); // Catch any other cuda API errors + ptr_info.valid = false; }} ptr_info.dev_ptr = dev_ptr; Py_DECREF(ret); // Thanks ChatGPT! @@ -271,7 +288,68 @@ def format_of(ty): return ptr_info; }} +static inline CUtensorMap* getTmaDesc(PyObject *obj) {{ + if (sizeof(CUtensorMap*) != 8) {{ + PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); + return NULL; + }} + + PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr"); + if (!method_handle) {{ + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist"); + return NULL; + }} + + PyObject *empty_tuple = PyTuple_New(0); + if (!empty_tuple) {{ + Py_DECREF(method_handle); + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; + }} + PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(method_handle); + if (!method_ret) {{ + PyErr_SetString(PyExc_SystemError, "Internal Python error!"); + return NULL; + }} + + if (!PyLong_Check(method_ret)) {{ + PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int"); + Py_DECREF(method_ret); + return NULL; + }} + + uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); + Py_DECREF(method_ret); + if (!ptr_as_uint) {{ + PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()"); + return NULL; + }} + if (ptr_as_uint % 64 != 0) {{ + PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned"); + return NULL; + }} + + return (CUtensorMap*)(ptr_as_uint); +}} + +static void ensureCudaContext() {{ + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) {{ + // Ensure device context. + CUdevice device; + CUDA_CHECK(cuDeviceGet(&device, 0)); + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + }} +}} + static PyObject* launch(PyObject* self, PyObject* args) {{ + // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes + ensureCudaContext(); + int gridX, gridY, gridZ; uint64_t _stream; uint64_t _function; @@ -302,9 +380,10 @@ def format_of(ty): }} // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); Py_END_ALLOW_THREADS; if (PyErr_Occurred()) {{ return NULL; @@ -379,7 +458,24 @@ def get_current_target(self): warp_size = 32 return GPUTarget("cuda", capability, warp_size) + def get_device_interface(self): + import torch + return torch.cuda + @staticmethod def is_active(): import torch return torch.cuda.is_available() and (torch.version.hip is None) + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 cache + # doesn't contain any input data before the run + cache_size = 256 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda') diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index ca9d18873..840e0714c 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -39,6 +39,7 @@ def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> { let assemblyFormat = "attr-dict"; } + def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> { let assemblyFormat = "attr-dict"; } @@ -52,6 +53,32 @@ def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", let assemblyFormat = "$input attr-dict `:` type($input)"; } +def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType", + "mbarrier arrive type, either 'normal', 'expect_tx', 'cp_async'", + [ + I32EnumAttrCase<"normal", 0>, + I32EnumAttrCase<"cp_async", 1>, + I32EnumAttrCase<"expect_tx", 2>, + I32EnumAttrCase<"remote", 3>, + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> { + let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, Optional:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr:$txCount); + let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)"; +} + +def NVGPU_NamedBarrierArriveOp : NVGPU_Op<"bar_arrive", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def NVGPU_NamedBarrierWaitOp : NVGPU_Op<"bar_wait", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout", "wgmma layout, either 'row' or 'col'", [ @@ -79,36 +106,12 @@ def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType", def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">; def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { - let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, Optional:$opC, + let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, I1:$useC, Optional:$opC, I32Attr:$m, I32Attr:$n, I32Attr:$k, WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); let results = (outs LLVM_AnyStruct:$res); - let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; -} - -def NVGPU_LoadDSmemOp : NVGPU_Op<"load_dsmem", [MemoryEffects<[MemRead]>]> { - let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, I32Attr:$bitwidth, I32Attr:$vec); - let builders = [ - OpBuilder<(ins "Type":$resultTy, "Value":$addr, "Value":$ctaId)>, - OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth, "unsigned":$vec)>, - OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth)> - ]; - let results = (outs LLVM_LoadableType:$result); - let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; -} - -def NVGPU_StoreDSmemOp : NVGPU_Op<"store_dsmem", [MemoryEffects<[MemWrite]>]> { - let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, - Variadic:$values, I1:$pred); - let builders = [ - OpBuilder<(ins "Value":$addr, "Value":$ctaId, "Value":$value, "Value":$pred)>, - ]; - let assemblyFormat = "operands attr-dict `:` type(operands)"; - let extraClassDeclaration = [{ - unsigned getBitwidth(); - unsigned getVec(); - }]; + let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> { @@ -136,4 +139,19 @@ def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { let assemblyFormat = "attr-dict"; } +def NVGPU_CanonicalWarpIdOp : NVGPU_Op<"canonical_warp_id", [Pure]> { + let results = (outs I32:$result); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_RegDeallocOp : NVGPU_Op<"reg_dealloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "attr-dict"; +} + #endif diff --git a/third_party/nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h b/third_party/nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h index e4b91550c..12ac194a8 100644 --- a/third_party/nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h +++ b/third_party/nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h @@ -2,6 +2,14 @@ #define TRITON_CONVERSION_NVGPU_TO_LLVM_PASS_H #include +#include +#include +#include + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" namespace mlir { @@ -10,6 +18,19 @@ template class OperationPass; namespace triton { +namespace nvgpu { + +using Constraints = std::vector; +using OperandsAndConstraints = std::vector>; + +LogicalResult +rewriteAsPtxAsm(mlir::Operation *op, mlir::PatternRewriter &rewriter, + std::string ptxAsm, + const OperandsAndConstraints &operandsAndConstraints = {}, + const Constraints &outputConstraints = {}); + +} // namespace nvgpu + std::unique_ptr> createConvertNVGPUToLLVMPass(); } // namespace triton diff --git a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h index 30bfaea7d..8cd8a180c 100644 --- a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h +++ b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h @@ -26,6 +26,8 @@ createDecomposeUnsupportedConversionsPass(); std::unique_ptr> createConvertTritonGPUToLLVMPass(); std::unique_ptr> createConvertTritonGPUToLLVMPass(int32_t computeCapability); +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int32_t computeCapability, int32_t ptxVersion); #define GEN_PASS_REGISTRATION #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h.inc" diff --git a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td index 07624c72d..9f942dd53 100644 --- a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td +++ b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td @@ -30,6 +30,9 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" Option<"computeCapability", "compute-capability", "int32_t", /*default*/"80", "device compute capability">, + Option<"ptxVersion", "ptx-version", + "int32_t", /*default*/"80", + "PTX version">, ]; } diff --git a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h new file mode 100644 index 000000000..6d1c3c06a --- /dev/null +++ b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Utility.h @@ -0,0 +1,17 @@ +#ifndef TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H +#define TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H + +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace triton { +namespace NVIDIA { + +/// Return true if we can skip a barrier synchronization between two operations +/// even if they access the same shared memory. +bool canSkipBarSync(Operation *before, Operation *after); +} // namespace NVIDIA +} // namespace triton +} // namespace mlir + +#endif // TRITONGPU_CONVERSION_TRITONNVIDIAGPUTOLLVM_UTILITY_H diff --git a/third_party/nvidia/include/cublas_instance.h b/third_party/nvidia/include/cublas_instance.h new file mode 100644 index 000000000..d79d4d76b --- /dev/null +++ b/third_party/nvidia/include/cublas_instance.h @@ -0,0 +1,213 @@ +#ifndef TRITON_CUBLAS_INSTANCE_H +#define TRITON_CUBLAS_INSTANCE_H + +#include "cublas_types.h" +#include +#include +#include + +class CublasLtInstance { + // Typedefs for cublas functions + typedef cublasStatus_t (*cublasLtCreate_t)(cublasLtHandle_t *); + typedef cublasStatus_t (*cublasLtDestroy_t)(cublasLtHandle_t); + typedef cublasStatus_t (*cublasLtMatmulDescCreate_t)(cublasLtMatmulDesc_t *, + cublasComputeType_t, + cudaDataType_t); + typedef cublasStatus_t (*cublasLtMatmulDescDestroy_t)(cublasLtMatmulDesc_t); + typedef cublasStatus_t (*cublasLtMatmulDescSetAttribute_t)( + cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, const void *, + size_t); + typedef cublasStatus_t (*cublasLtMatrixLayoutCreate_t)( + cublasLtMatrixLayout_t *, cudaDataType_t, uint64_t, uint64_t, int64_t); + typedef cublasStatus_t (*cublasLtMatrixLayoutDestroy_t)( + cublasLtMatrixLayout_t); + typedef cublasStatus_t (*cublasLtMatmulPreferenceCreate_t)( + cublasLtMatmulPreference_t *); + typedef cublasStatus_t (*cublasLtMatmulPreferenceDestroy_t)( + cublasLtMatmulPreference_t); + typedef cublasStatus_t (*cublasLtMatmulPreferenceSetAttribute_t)( + cublasLtMatmulPreference_t, cublasLtMatmulPreferenceAttributes_t, + const void *, size_t); + typedef cublasStatus_t (*cublasLtMatmulAlgoGetHeuristic_t)( + cublasLtHandle_t, cublasLtMatmulDesc_t, cublasLtMatrixLayout_t, + cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, cublasLtMatrixLayout_t, + cublasLtMatmulPreference_t, int, cublasLtMatmulHeuristicResult_t *, + int *); + typedef cublasStatus_t (*cublasLtMatmul_t)( + cublasLtHandle_t, cublasLtMatmulDesc_t, const void *, const void *, + const cublasLtMatrixLayout_t, const void *, const cublasLtMatrixLayout_t, + const void *, const void *, const cublasLtMatrixLayout_t, void *, + const cublasLtMatrixLayout_t, const cublasLtMatmulAlgo_t *, void *, + size_t, cudaStream_t); + + static constexpr const char *name = "libcublas.so"; + + cublasLtCreate_t cublasLtCreate; + cublasLtDestroy_t cublasLtDestroy; + cublasLtMatmulDescCreate_t cublasLtMatmulDescCreate; + cublasLtMatmulDescDestroy_t cublasLtMatmulDescDestroy; + cublasLtMatmulDescSetAttribute_t cublasLtMatmulDescSetAttribute; + cublasLtMatrixLayoutCreate_t cublasLtMatrixLayoutCreate; + cublasLtMatrixLayoutDestroy_t cublasLtMatrixLayoutDestroy; + cublasLtMatmulPreferenceCreate_t cublasLtMatmulPreferenceCreate; + cublasLtMatmulPreferenceDestroy_t cublasLtMatmulPreferenceDestroy; + cublasLtMatmulPreferenceSetAttribute_t cublasLtMatmulPreferenceSetAttribute; + cublasLtMatmulAlgoGetHeuristic_t cublasLtMatmulAlgoGetHeuristic; + cublasLtMatmul_t cublasLtMatmul; + + void *dylibHandle = nullptr; + cublasLtHandle_t ltHandle; + + void *workspace = nullptr; + size_t workspaceSize = 0; + + cublasLtMatmulPreference_t preference = NULL; + + void loadCublasDylib() { + if (dylibHandle == nullptr) { + // First reuse the existing handle + dylibHandle = dlopen(name, RTLD_NOLOAD); + } + if (dylibHandle == nullptr) { + // If not found, try to load it + dylibHandle = dlopen(name, RTLD_LOCAL | RTLD_LAZY); + } + if (dylibHandle == nullptr) { + throw std::runtime_error("Could not find `" + std::string(name) + + "`. Make sure it is in your " + "LD_LIBRARY_PATH."); + } + dlerror(); // Clear any existing error + + cublasLtCreate = (cublasLtCreate_t)dlsym(dylibHandle, "cublasLtCreate"); + cublasLtDestroy = (cublasLtDestroy_t)dlsym(dylibHandle, "cublasLtDestroy"); + cublasLtMatmulDescCreate = (cublasLtMatmulDescCreate_t)dlsym( + dylibHandle, "cublasLtMatmulDescCreate"); + cublasLtMatmulDescDestroy = (cublasLtMatmulDescDestroy_t)dlsym( + dylibHandle, "cublasLtMatmulDescDestroy"); + cublasLtMatmulDescSetAttribute = (cublasLtMatmulDescSetAttribute_t)dlsym( + dylibHandle, "cublasLtMatmulDescSetAttribute"); + cublasLtMatrixLayoutCreate = (cublasLtMatrixLayoutCreate_t)dlsym( + dylibHandle, "cublasLtMatrixLayoutCreate"); + cublasLtMatrixLayoutDestroy = (cublasLtMatrixLayoutDestroy_t)dlsym( + dylibHandle, "cublasLtMatrixLayoutDestroy"); + cublasLtMatmulPreferenceCreate = (cublasLtMatmulPreferenceCreate_t)dlsym( + dylibHandle, "cublasLtMatmulPreferenceCreate"); + cublasLtMatmulPreferenceDestroy = (cublasLtMatmulPreferenceDestroy_t)dlsym( + dylibHandle, "cublasLtMatmulPreferenceDestroy"); + cublasLtMatmulPreferenceSetAttribute = + (cublasLtMatmulPreferenceSetAttribute_t)dlsym( + dylibHandle, "cublasLtMatmulPreferenceSetAttribute"); + cublasLtMatmulAlgoGetHeuristic = (cublasLtMatmulAlgoGetHeuristic_t)dlsym( + dylibHandle, "cublasLtMatmulAlgoGetHeuristic"); + cublasLtMatmul = (cublasLtMatmul_t)dlsym(dylibHandle, "cublasLtMatmul"); + + const char *dlsym_error = dlerror(); + if (dlsym_error) { + throw std::runtime_error("Could not load symbol from `" + + std::string(name) + + "`: " + std::string(dlsym_error)); + } + } + + void unloadCublasDylib() { dlclose(dylibHandle); } + + void successOrExit(cublasStatus_t status) { + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("cuBLAS Error: " + std::to_string(status) + + "\n"); + } + } + + // Simple wrapper around the cublasLtMatmul function + void matmul_impl(int m, int n, int k, uint64_t A, uint64_t B, uint64_t D, + cudaDataType_t dtype) { + cublasLtMatmulDesc_t matmulDesc = NULL; + + cublasOperation_t transa = CUBLAS_OP_T; + cublasOperation_t transb = CUBLAS_OP_N; + + int8_t fastAccum = 1; + + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, + Ddesc = NULL; + + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + + successOrExit( + cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F)); + successOrExit(cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa))); + successOrExit(cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transb))); + if (dtype == CUDA_R_8F_E4M3) { + successOrExit(cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccum, + sizeof(fastAccum))); + } + + successOrExit(cublasLtMatrixLayoutCreate(&Adesc, dtype, k, m, k)); + successOrExit(cublasLtMatrixLayoutCreate(&Bdesc, dtype, k, n, k)); + successOrExit(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16F, m, n, m)); + successOrExit(cublasLtMatrixLayoutCreate(&Ddesc, dtype, m, n, m)); + + successOrExit(cublasLtMatmulAlgoGetHeuristic( + ltHandle, matmulDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, + &heuristicResult, &returnedResults)); + if (returnedResults == 0) { + throw std::runtime_error( + "No valid algorithm found by cublasLtMatmulAlgoGetHeuristic"); + } + + float alpha = 1.0f; + float beta = 0.0f; + successOrExit(cublasLtMatmul(ltHandle, matmulDesc, &alpha, (void *)A, Adesc, + (void *)B, Bdesc, &beta, nullptr, Cdesc, + (void *)D, Ddesc, &heuristicResult.algo, + (void *)workspace, workspaceSize, 0)); + if (Ddesc) + successOrExit(cublasLtMatrixLayoutDestroy(Ddesc)); + if (Cdesc) + successOrExit(cublasLtMatrixLayoutDestroy(Cdesc)); + if (Bdesc) + successOrExit(cublasLtMatrixLayoutDestroy(Bdesc)); + if (Adesc) + successOrExit(cublasLtMatrixLayoutDestroy(Adesc)); + if (matmulDesc) + successOrExit(cublasLtMatmulDescDestroy(matmulDesc)); + } + +public: + CublasLtInstance(uint64_t workspace, size_t workspaceSize) + : workspace((void *)workspace), workspaceSize(workspaceSize) { + loadCublasDylib(); + cublasLtCreate(<Handle); + + successOrExit(cublasLtMatmulPreferenceCreate(&preference)); + successOrExit(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, + sizeof(workspaceSize))); + } + ~CublasLtInstance() { + if (preference) + successOrExit(cublasLtMatmulPreferenceDestroy(preference)); + + cublasLtDestroy(ltHandle); + unloadCublasDylib(); + } + + // C = A * B + // Matrix B needs to be transposed, while matrix A does not. The function + // *will-not* transpose the matrices, so the caller is responsible for + // ensuring that the matrices are in the correct format and have the correct + // dimensions. + void matmul(int m, int n, int k, uint64_t A, uint64_t B, uint64_t C, + cudaDataType_t dtype) { + // CUDA is column-major, while triton is row-major, therefore we need to + // reverse the order of the matrices ( A * B = (B^T * A^T)^T ). + matmul_impl(n, m, k, B, A, C, dtype); + } +}; + +#endif // TRITON_CUBLAS_INSTANCE_H diff --git a/third_party/nvidia/include/cublas_types.h b/third_party/nvidia/include/cublas_types.h new file mode 100644 index 000000000..9972c3585 --- /dev/null +++ b/third_party/nvidia/include/cublas_types.h @@ -0,0 +1,152 @@ +#ifndef TRITON_CUBLAS_TYPES_H +#define TRITON_CUBLAS_TYPES_H + +// Forward declarations of cuBLAS types and functions. + +#include "backend/include/cuda.h" +#include "backend/include/driver_types.h" + +/* CUBLAS status type returns */ +typedef enum { + CUBLAS_STATUS_SUCCESS = 0, + CUBLAS_STATUS_NOT_INITIALIZED = 1, + CUBLAS_STATUS_ALLOC_FAILED = 3, + CUBLAS_STATUS_INVALID_VALUE = 7, + CUBLAS_STATUS_ARCH_MISMATCH = 8, + CUBLAS_STATUS_MAPPING_ERROR = 11, + CUBLAS_STATUS_EXECUTION_FAILED = 13, + CUBLAS_STATUS_INTERNAL_ERROR = 14, + CUBLAS_STATUS_NOT_SUPPORTED = 15, + CUBLAS_STATUS_LICENSE_ERROR = 16 +} cublasStatus_t; + +typedef enum { + CUBLAS_COMPUTE_16F = 64, /* half - default */ + CUBLAS_COMPUTE_16F_PEDANTIC = 65, /* half - pedantic */ + CUBLAS_COMPUTE_32F = 68, /* float - default */ + CUBLAS_COMPUTE_32F_PEDANTIC = 69, /* float - pedantic */ + CUBLAS_COMPUTE_32F_FAST_16F = + 74, /* float - fast, allows down-converting inputs to half or TF32 */ + CUBLAS_COMPUTE_32F_FAST_16BF = + 75, /* float - fast, allows down-converting inputs to bfloat16 or TF32 */ + CUBLAS_COMPUTE_32F_FAST_TF32 = + 77, /* float - fast, allows down-converting inputs to TF32 */ + CUBLAS_COMPUTE_64F = 70, /* double - default */ + CUBLAS_COMPUTE_64F_PEDANTIC = 71, /* double - pedantic */ + CUBLAS_COMPUTE_32I = 72, /* signed 32-bit int - default */ + CUBLAS_COMPUTE_32I_PEDANTIC = 73, /* signed 32-bit int - pedantic */ +} cublasComputeType_t; + +typedef enum { + CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0, + CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1, + CUBLASLT_MATMUL_DESC_POINTER_MODE = 2, + CUBLASLT_MATMUL_DESC_TRANSA = 3, + CUBLASLT_MATMUL_DESC_TRANSB = 4, + CUBLASLT_MATMUL_DESC_TRANSC = 5, + CUBLASLT_MATMUL_DESC_FILL_MODE = 6, + CUBLASLT_MATMUL_DESC_EPILOGUE = 7, + CUBLASLT_MATMUL_DESC_BIAS_POINTER = 8, + CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE = 10, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER = 11, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD = 12, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE = 13, + CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE = 14, + CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET = 15, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER = 17, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER = 18, + CUBLASLT_MATMUL_DESC_C_SCALE_POINTER = 19, + CUBLASLT_MATMUL_DESC_D_SCALE_POINTER = 20, + CUBLASLT_MATMUL_DESC_AMAX_D_POINTER = 21, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE = 22, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER = 23, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER = 24, + CUBLASLT_MATMUL_DESC_FAST_ACCUM = 25, + CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE = 26, + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS = 27, + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS = 28, + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER = 29, + CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER = 30, +} cublasLtMatmulDescAttributes_t; + +typedef enum { + CUBLAS_OP_N = 0, + CUBLAS_OP_T = 1, + CUBLAS_OP_C = 2, + CUBLAS_OP_HERMITAN = 2, /* synonym if CUBLAS_OP_C */ + CUBLAS_OP_CONJG = + 3 /* conjugate, placeholder - not supported in the current release */ +} cublasOperation_t; + +typedef enum { + CUBLASLT_MATMUL_PREF_SEARCH_MODE = 0, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1, + CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK = 3, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES = 5, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES = 6, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES = 7, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES = 8, + CUBLASLT_MATMUL_PREF_MAX_WAVES_COUNT = 9, + CUBLASLT_MATMUL_PREF_IMPL_MASK = 12, +} cublasLtMatmulPreferenceAttributes_t; +typedef struct { + uint64_t data[8]; +} cublasLtMatrixLayoutOpaque_t; +typedef cublasLtMatrixLayoutOpaque_t *cublasLtMatrixLayout_t; + +typedef struct { + uint64_t data[8]; +} cublasLtMatmulPreferenceOpaque_t; +typedef cublasLtMatmulPreferenceOpaque_t *cublasLtMatmulPreference_t; + +typedef struct { + uint64_t data[8]; +} cublasLtMatmulAlgo_t; + +typedef struct { + cublasLtMatmulAlgo_t algo; + size_t workspaceSize; + cublasStatus_t state; + float wavesCount; + int reserved[4]; +} cublasLtMatmulHeuristicResult_t; + +typedef enum cudaDataType_t { + CUDA_R_16F = 2, /* real as a half */ + CUDA_C_16F = 6, /* complex as a pair of half numbers */ + CUDA_R_16BF = 14, /* real as a nv_bfloat16 */ + CUDA_C_16BF = 15, /* complex as a pair of nv_bfloat16 numbers */ + CUDA_R_32F = 0, /* real as a float */ + CUDA_C_32F = 4, /* complex as a pair of float numbers */ + CUDA_R_64F = 1, /* real as a double */ + CUDA_C_64F = 5, /* complex as a pair of double numbers */ + CUDA_R_4I = 16, /* real as a signed 4-bit int */ + CUDA_C_4I = 17, /* complex as a pair of signed 4-bit int numbers */ + CUDA_R_4U = 18, /* real as a unsigned 4-bit int */ + CUDA_C_4U = 19, /* complex as a pair of unsigned 4-bit int numbers */ + CUDA_R_8I = 3, /* real as a signed 8-bit int */ + CUDA_C_8I = 7, /* complex as a pair of signed 8-bit int numbers */ + CUDA_R_8U = 8, /* real as a unsigned 8-bit int */ + CUDA_C_8U = 9, /* complex as a pair of unsigned 8-bit int numbers */ + CUDA_R_16I = 20, /* real as a signed 16-bit int */ + CUDA_C_16I = 21, /* complex as a pair of signed 16-bit int numbers */ + CUDA_R_16U = 22, /* real as a unsigned 16-bit int */ + CUDA_C_16U = 23, /* complex as a pair of unsigned 16-bit int numbers */ + CUDA_R_32I = 10, /* real as a signed 32-bit int */ + CUDA_C_32I = 11, /* complex as a pair of signed 32-bit int numbers */ + CUDA_R_32U = 12, /* real as a unsigned 32-bit int */ + CUDA_C_32U = 13, /* complex as a pair of unsigned 32-bit int numbers */ + CUDA_R_64I = 24, /* real as a signed 64-bit int */ + CUDA_C_64I = 25, /* complex as a pair of signed 64-bit int numbers */ + CUDA_R_64U = 26, /* real as a unsigned 64-bit int */ + CUDA_C_64U = 27, /* complex as a pair of unsigned 64-bit int numbers */ + CUDA_R_8F_E4M3 = 28, /* real as a nv_fp8_e4m3 */ + CUDA_R_8F_E5M2 = 29, /* real as a nv_fp8_e5m2 */ +} cudaDataType; + +struct cublasContext; +typedef struct cublasLtContext *cublasLtHandle_t; +struct cublasLtMatmulDescOpaque_t; +typedef cublasLtMatmulDescOpaque_t *cublasLtMatmulDesc_t; + +#endif // TRITON_CUBLAS_TYPES_H diff --git a/third_party/nvidia/language/cuda/__init__.py b/third_party/nvidia/language/cuda/__init__.py new file mode 100644 index 000000000..9fffa216b --- /dev/null +++ b/third_party/nvidia/language/cuda/__init__.py @@ -0,0 +1,13 @@ +from . import libdevice + +from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80) + +from ._experimental_tma import * # noqa: F403 +from ._experimental_tma import __all__ as _tma_all + +__all__ = [ + "libdevice", "globaltimer", "num_threads", "num_warps", "smid", "convert_custom_float8_sm70", + "convert_custom_float8_sm80", *_tma_all +] + +del _tma_all diff --git a/third_party/nvidia/language/cuda/_experimental_tma.py b/third_party/nvidia/language/cuda/_experimental_tma.py new file mode 100644 index 000000000..567781019 --- /dev/null +++ b/third_party/nvidia/language/cuda/_experimental_tma.py @@ -0,0 +1,108 @@ +from typing import Sequence + +from triton.language import core +from triton.language import semantic +from triton._C.libtriton import ir + +__all__ = [ + "experimental_device_tensormap_create1d", + "experimental_device_tensormap_create2d", + "experimental_tensormap_fenceproxy_acquire", +] + + +def _determine_elem_type(element_ty: core.dtype): + if element_ty.primitive_bitwidth == 8: + return 0 + elif element_ty.primitive_bitwidth == 16: + return 1 + elif element_ty.primitive_bitwidth == 32: + return 2 + else: + raise ValueError("element_ty must be a primitive of size 1, 2, or 4 bytes but got") + + +@core.builtin +def experimental_device_tensormap_create1d( + desc_ptr: core.tensor, + global_address: core.tensor, + load_size: core.tensor, + global_size: core.tensor, + element_ty: core.dtype, + _builder: ir.builder, +): + load_size = core._constexpr_to_value(load_size) + global_size = semantic.to_tensor(global_size, _builder) + element_ty = core._constexpr_to_value(element_ty) + element_stride = [core.full([], 1, core.int32, _builder=_builder)] + + semantic.tensormap_create( + desc_ptr=desc_ptr, + global_address=global_address, + box_dim=[semantic.to_tensor(load_size, _builder)], + global_dim=[global_size], + global_stride=[], + element_stride=element_stride, + elem_type=_determine_elem_type(element_ty), + interleave_layout=0, + swizzle_mode=0, + fill_mode=0, + builder=_builder, + ) + + +@core.builtin +def experimental_device_tensormap_create2d( + desc_ptr: core.tensor, + global_address: core.tensor, + load_size: Sequence[core.constexpr], + global_size: Sequence[core.tensor], + element_ty: core.dtype, + _builder: ir.builder, +): + assert len(load_size) == 2 + assert len(global_size) == 2 + load_size = [core._constexpr_to_value(x) for x in load_size] + global_size = [semantic.to_tensor(x, _builder) for x in global_size] + + element_size = element_ty.primitive_bitwidth // 8 + element_size_t = core.full([], element_size, core.int64, _builder=_builder) + global_stride = semantic.mul(element_size_t, global_size[-1], True, _builder) + # Undocumented, but global_stride seems to be divided by 16 + global_stride = semantic.ashr(global_stride, semantic.to_tensor(4, _builder), _builder) + + contig_dim_size_in_bytes = element_size * load_size[-1] + if contig_dim_size_in_bytes > 128: + load_size[-1] = 128 // element_size + + elem_stride = core.full([], 1, core.int32, _builder=_builder) + + semantic.tensormap_create( + desc_ptr=desc_ptr, + global_address=global_address, + box_dim=[semantic.to_tensor(x, _builder) for x in load_size[::-1]], + global_dim=global_size[::-1], + global_stride=[global_stride], + element_stride=[elem_stride, elem_stride], + elem_type=_determine_elem_type(element_ty), + interleave_layout=0, + swizzle_mode=_determine_swizzle_mode_2d(contig_dim_size_in_bytes, load_size), + fill_mode=0, + builder=_builder, + ) + + +def _determine_swizzle_mode_2d(contig_dim_size_in_bytes, load_size): + if contig_dim_size_in_bytes >= 128: + return 3 + elif contig_dim_size_in_bytes >= 64: + return 2 + elif contig_dim_size_in_bytes >= 32: + return 1 + else: + raise ValueError("block size too small") + + +@core.builtin +def experimental_tensormap_fenceproxy_acquire(desc_ptr: core.tensor, _builder: ir.builder): + semantic.tensormap_fenceproxy_acquire(desc_ptr, _builder) diff --git a/third_party/nvidia/language/cuda/libdevice.py b/third_party/nvidia/language/cuda/libdevice.py new file mode 100644 index 000000000..37e810bb1 --- /dev/null +++ b/third_party/nvidia/language/cuda/libdevice.py @@ -0,0 +1,1629 @@ +from triton.language import core + + +@core.extern +def clz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def popc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("__nv_umul64hi", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul24(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def brev(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sad(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp64h(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def saturatef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__nv_dmul_ru", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__nv_fmul_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hiloint2double(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2loint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2hiint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_uint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def longlong_as_double(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double_as_longlong(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_sinf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_cosf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log2f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_logf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_tanf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_exp10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_powf(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_nearbyint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isnand", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_signbitd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinpi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_norm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_rnorm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def yn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def jn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def scalbn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def remainder(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_pow", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def logb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isfinited(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")), + }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder) diff --git a/third_party/nvidia/language/cuda/utils.py b/third_party/nvidia/language/cuda/utils.py new file mode 100644 index 000000000..01bc040b2 --- /dev/null +++ b/third_party/nvidia/language/cuda/utils.py @@ -0,0 +1,109 @@ +from triton.language import core + + +@core.extern +def globaltimer(_builder=None): + return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1, + _builder=_builder) + + +@core.extern +def smid(_builder=None): + return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, + _builder=_builder) + + +@core.builtin +def num_threads(_builder=None): + return core.constexpr(_builder.options.num_warps * 32) + + +@core.builtin +def num_warps(_builder=None): + return core.constexpr(_builder.options.num_warps) + + +# ----- FP8E4M3B15 ------ +# This data-type is a variant of the standard FP8E4M3 format. +# It was designed for fast software conversion to FP16 on +# nvidia GPUs that do not support it natively. +# This is the same format as FP8E4M3Nv, but: +# - the exponent bias is 15 instead of 7 +# - 0xff and 0x7f are mapped to +-1.750 instead of +-nan +@core.builtin +def convert_fp8e4b15_to_float16(arg, _builder=None): + return core.inline_asm_elementwise( + "{ \n" + ".reg .b32 a<2>, b<2>; \n" + "prmt.b32 a0, 0, $2, 0x5746; \n" + "and.b32 b0, a0, 0x7f007f00; \n" + "and.b32 b1, a0, 0x00ff00ff; \n" + "and.b32 a1, a0, 0x00800080; \n" + "shr.b32 b0, b0, 1; \n" + "add.u32 b1, b1, a1; \n" + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" + "shl.b32 $1, b1, 7; \n" + "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None): + asm = """{ + .reg .pred p<4>; + .reg .b32 a<2>, b<2>; + .reg .b16 c<4>; + .reg .b16 max_val_f16; + .reg .b32 max_val_f16x2; + mov.b16 max_val_f16, 0x3F00; + mov.b32 max_val_f16x2, 0x3F003F00; + and.b32 a0, $1, 0x7fff7fff; + and.b32 a1, $2, 0x7fff7fff;""" + if has_minx2: + asm += """min.f16x2 a0, a0, max_val_f16x2; + min.f16x2 a1, a1, max_val_f16x2;""" + else: + asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2; + setp.lt.f16x2 p2|p3, a1, max_val_f16x2; + mov.b32 {c0, c1}, a0; + mov.b32 {c2, c3}, a1; + selp.b16 c0, c0, max_val_f16, p0; + selp.b16 c1, c1, max_val_f16, p1; + selp.b16 c2, c2, max_val_f16, p2; + selp.b16 c3, c3, max_val_f16, p3; + mov.b32 a0, {c0, c1}; + mov.b32 a1, {c2, c3};""" + asm += """mad.lo.u32 a0, a0, 2, 0x00800080; + mad.lo.u32 a1, a1, 2, 0x00800080; + lop3.b32 b0, $1, 0x80008000, a0, 0xea; + lop3.b32 b1, $2, 0x80008000, a1, 0xea; + prmt.b32 $0, b0, b1, 0x7531; + }""" + return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _builder=None): + if arg.type.scalar.is_fp8e4b15(): + upcast_val = convert_fp8e4b15_to_float16(arg, _builder=_builder) + if dst_ty.scalar.is_fp32(): + upcast_val = upcast_val.to(core.float32, _builder=_builder) + return upcast_val + + assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32() + downcast_val = arg + if arg.type.scalar.is_fp32(): + downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _builder=_builder) + downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _builder=_builder) + return downcast_val + + +@core.builtin +def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _builder=_builder) + + +@core.builtin +def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _builder=_builder) diff --git a/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp b/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp index ed87a588f..f623f50c6 100644 --- a/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp +++ b/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp @@ -32,59 +32,6 @@ using namespace mlir; using namespace mlir::triton::nvgpu; -void LoadDSmemOp::build(OpBuilder &builder, OperationState &state, - Type resultTy, Value addr, Value ctaId) { - unsigned vec, bitwidth; - if (auto structTy = dyn_cast(resultTy)) { - auto types = structTy.getBody(); - assert(types.size() > 0 && "Invalid result type of LoadDSmemOp"); - vec = types.size(); - for (unsigned i = 0; i < vec; ++i) - assert(types[0] == types[i]); - bitwidth = types[0].getIntOrFloatBitWidth(); - } else { - vec = 1; - bitwidth = resultTy.getIntOrFloatBitWidth(); - } - build(builder, state, resultTy, addr, ctaId, bitwidth, vec); -} - -void LoadDSmemOp::build(OpBuilder &builder, OperationState &state, Value addr, - Value ctaId, unsigned bitwidth, unsigned vec) { - Type resultTy = builder.getIntegerType(bitwidth); - if (vec > 1) { - SmallVector types(vec, resultTy); - resultTy = LLVM::LLVMStructType::getLiteral(builder.getContext(), types); - } - build(builder, state, resultTy, addr, ctaId, bitwidth, vec); -} - -void LoadDSmemOp::build(OpBuilder &builder, OperationState &state, Value addr, - Value ctaId, unsigned bitwidth) { - build(builder, state, addr, ctaId, bitwidth, /*vec*/ 1); -} - -void StoreDSmemOp::build(OpBuilder &builder, OperationState &state, Value addr, - Value ctaId, Value value, Value pred) { - SmallVector values = {value}; - build(builder, state, addr, ctaId, values, pred); -} - -unsigned StoreDSmemOp::getBitwidth() { - auto addrTy = getAddr().getType(); - assert(isa(addrTy) && "addr must be a pointer type"); - if (getValues().empty()) - return 0; - auto elemTy = getValues().back().getType(); - return elemTy.getIntOrFloatBitWidth(); -} - -unsigned StoreDSmemOp::getVec() { return getValues().size(); } - -static LogicalResult verify(mlir::triton::nvgpu::WGMMAOp op) { - return success(); -} - void mlir::triton::nvgpu::NVGPUDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index e19216520..5a461fb72 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -18,12 +18,11 @@ using namespace mlir::triton; namespace ttn = mlir::triton::nvgpu; using ::mlir::LLVM::NVIDIA::getSRegValue; +using ttn::Constraints; +using ttn::OperandsAndConstraints; namespace { -using OperandsAndConstraints = std::vector>; -typedef std::vector Constraints; - const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;"; const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;"; const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;"; @@ -39,6 +38,11 @@ const std::string Cluster_Cta_Id_Op = "{\n" "mad.lo.u32 a1, a2, a4, a1; \n" "mad.lo.u32 $0, a1, a3, a0; \n" "}"; +const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;"; +const std::string Reg_Dealloc_Op = "setmaxnreg.dec.sync.aligned.u32 #regCount;"; + +const std::string Named_Barrier_Arrive_Op = "bar.arrive $0, $1;"; +const std::string Named_Barrier_Wait_Op = "bar.sync $0, $1;"; const std::string Canonical_Warp_Id_Op = "{\n" ".reg .u32 a<5>; \n" @@ -63,7 +67,7 @@ bool isNumber(const std::string &s) { }) == s.end(); } -Type getTypeFromConstraint(char constraint, mlir::PatternRewriter &rewriter) { +Type getTypeFromConstraint(char constraint, PatternRewriter &rewriter) { Type ty; if (constraint == 'b') ty = IntegerType::get(rewriter.getContext(), 1); @@ -83,242 +87,185 @@ Type getTypeFromConstraint(char constraint, mlir::PatternRewriter &rewriter) { return ty; } -template -class NVGPUOpPatternBase : public mlir::RewritePattern { -public: - explicit NVGPUOpPatternBase(mlir::MLIRContext *context) - : mlir::RewritePattern(SourceOp::getOperationName(), 1, context) {} - - // Converts the given value to the type represented by the constraint - // E.g. if val is of type llvmptr and constraint is 'r', then we convert - // val to i32 using ptrtoint(i32_ty, val) - mlir::Value convertToType(mlir::Value val, std::string constraint, - Location &loc, - mlir::PatternRewriter &rewriter) const { - auto isConstraintNumber = isNumber(constraint); - if (!isConstraintNumber) { - auto ty = getTypeFromConstraint(constraint[0], rewriter); - if (isa(val.getType())) { - return ptrtoint(ty, val); - } else { - assert(val.getType().getIntOrFloatBitWidth() <= - ty.getIntOrFloatBitWidth() && - "Cannot convert to a smaller type"); - if (val.getType().getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth()) - return zext(ty, val); - } +// Converts the given value to the type represented by the constraint +// E.g. if val is of type llvmptr and constraint is 'r', then we convert +// val to i32 using ptrtoint(i32_ty, val) +Value convertToType(Value val, std::string constraint, Location loc, + PatternRewriter &rewriter) { + auto isConstraintNumber = isNumber(constraint); + if (!isConstraintNumber) { + auto ty = getTypeFromConstraint(constraint[0], rewriter); + if (isa(val.getType())) { + return ptrtoint(ty, val); + } else { + assert(val.getType().getIntOrFloatBitWidth() <= + ty.getIntOrFloatBitWidth() && + "Cannot convert to a smaller type"); + if (val.getType().getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth()) + return zext(ty, val); } - return val; } + return val; +} - SmallVector - getPtxOutputs(std::vector &outputConstraints, - PTXBuilder &ptxBuilder) const { - SmallVector ptxOutputs; - for (unsigned i = 0; i < outputConstraints.size(); i++) { - auto *ptxOutput = ptxBuilder.newOperand(outputConstraints[i]); - ptxOutputs.push_back(ptxOutput); - } - return ptxOutputs; +SmallVector +getPtxOutputs(const nvgpu::Constraints &outputConstraints, + PTXBuilder &ptxBuilder) { + SmallVector ptxOutputs; + for (unsigned i = 0; i < outputConstraints.size(); i++) { + auto *ptxOutput = ptxBuilder.newOperand(outputConstraints[i]); + ptxOutputs.push_back(ptxOutput); } + return ptxOutputs; +} - OperandsAndConstraints - unpackOperands(OperandsAndConstraints &operandsAndConstraints, - PTXBuilder &ptxBuilder, Location &loc, - mlir::PatternRewriter &rewriter) const { - OperandsAndConstraints unpackedOperands; - for (auto &[operand, constraint] : operandsAndConstraints) { - auto llvmStruct = llvm::dyn_cast(operand.getType()); - // if a constraint is a number, then we are doing input/output tying - // if the operand is a struct, then we need to unpack it, and - // add the constraint to each of the unpacked operands uses the constraint - // as an offset - auto isConstraintNumber = isNumber(constraint); - if (llvmStruct) { - for (unsigned i = 0; i < llvmStruct.getBody().size(); i++) { - if (isConstraintNumber) { - auto constraintInt = std::stoi(constraint) + i; - unpackedOperands.push_back( - {extract_val(llvmStruct.getBody()[i], operand, i), - std::to_string(constraintInt)}); - } else { - unpackedOperands.push_back( - {extract_val(llvmStruct.getBody()[i], operand, i), constraint}); - } +OperandsAndConstraints +unpackOperands(const OperandsAndConstraints &operandsAndConstraints, + PTXBuilder &ptxBuilder, Location loc, + PatternRewriter &rewriter) { + OperandsAndConstraints unpackedOperands; + for (const auto &[operand, constraint] : operandsAndConstraints) { + auto llvmStruct = llvm::dyn_cast(operand.getType()); + // if a constraint is a number, then we are doing input/output tying + // if the operand is a struct, then we need to unpack it, and + // add the constraint to each of the unpacked operands uses the constraint + // as an offset + auto isConstraintNumber = isNumber(constraint); + if (llvmStruct) { + for (unsigned i = 0; i < llvmStruct.getBody().size(); i++) { + if (isConstraintNumber) { + auto constraintInt = std::stoi(constraint) + i; + unpackedOperands.push_back( + {extract_val(llvmStruct.getBody()[i], operand, i), + std::to_string(constraintInt)}); + } else { + unpackedOperands.push_back( + {extract_val(llvmStruct.getBody()[i], operand, i), constraint}); } - } else { - unpackedOperands.push_back({operand, constraint}); } + } else { + unpackedOperands.push_back({operand, constraint}); } - return unpackedOperands; - } - - SmallVector - getPtxOperands(OperandsAndConstraints &operandsAndConstraints, - PTXBuilder &ptxBuilder, Location &loc, - mlir::PatternRewriter &rewriter) const { - SmallVector ptxOperands; - auto unpackedOperandsAndConstraints = - unpackOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); - for (auto &[operand, constraint] : unpackedOperandsAndConstraints) { - auto convertedOperand = convertToType(operand, constraint, loc, rewriter); - auto *ptxOperand = ptxBuilder.newOperand(convertedOperand, constraint); - ptxOperands.push_back(ptxOperand); - } - return ptxOperands; } + return unpackedOperands; +} - virtual std::vector getOutputConstraints(SourceOp op) const { - return {}; +SmallVector +getPtxOperands(const OperandsAndConstraints &operandsAndConstraints, + PTXBuilder &ptxBuilder, Location loc, + PatternRewriter &rewriter) { + SmallVector ptxOperands; + auto unpackedOperandsAndConstraints = + unpackOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); + for (auto &[operand, constraint] : unpackedOperandsAndConstraints) { + auto convertedOperand = convertToType(operand, constraint, loc, rewriter); + auto *ptxOperand = ptxBuilder.newOperand(convertedOperand, constraint); + ptxOperands.push_back(ptxOperand); } + return ptxOperands; +} - virtual OperandsAndConstraints getOperandsAndConstraints(SourceOp op) const { - return {}; +std::string patchPtxAsm(Operation *op, std::string ptxAsm) { + std::vector> patchLocations; + std::vector patchValues; + auto start = ptxAsm.find("#", 0); + while (start != std::string::npos) { + auto endIterator = + std::find_if(ptxAsm.begin() + start + 1, ptxAsm.end(), + [](unsigned char c) { return !std::isalnum(c); }); + + assert(endIterator != ptxAsm.end() && "unexpected asm format"); + + auto end = std::distance(ptxAsm.begin(), endIterator); + auto patchLocation = std::make_pair(start, end); + patchLocations.push_back(patchLocation); + auto patchValue = ptxAsm.substr(start + 1, end - start - 1); + patchValues.push_back(patchValue); + start = ptxAsm.find("#", end); } - - std::string patchPtxAsm(mlir::Operation *op, std::string ptxAsm) const { - std::vector> patchLocations; - std::vector patchValues; - auto start = ptxAsm.find("#", 0); - while (start != std::string::npos) { - auto endIterator = - std::find_if(ptxAsm.begin() + start + 1, ptxAsm.end(), - [](unsigned char c) { return !std::isalnum(c); }); - - assert(endIterator != ptxAsm.end() && "unexpected asm format"); - - auto end = std::distance(ptxAsm.begin(), endIterator); - auto patchLocation = std::make_pair(start, end); - patchLocations.push_back(patchLocation); - auto patchValue = ptxAsm.substr(start + 1, end - start - 1); - patchValues.push_back(patchValue); - start = ptxAsm.find("#", end); - } - assert(patchLocations.size() == patchValues.size() && - "patchLocations and patchValues should have the same size"); - if (patchLocations.size() == 0) { - return ptxAsm; - } - std::string res = ""; - size_t prevStart = 0; - unsigned i = 0; - for (auto &[start, end] : patchLocations) { - res += ptxAsm.substr(prevStart, start - prevStart); - auto integerAttr = op->getAttrOfType(patchValues[i]); - auto attr = integerAttr.getInt(); - res += std::to_string(attr); - prevStart = end; - i++; - } - if (prevStart < ptxAsm.size()) - res += ptxAsm.substr(prevStart, ptxAsm.size() - prevStart); - return res; + assert(patchLocations.size() == patchValues.size() && + "patchLocations and patchValues should have the same size"); + if (patchLocations.size() == 0) { + return ptxAsm; } - - LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto loc = op->getLoc(); - auto sourceOp = llvm::dyn_cast(op); - if (!sourceOp) - return mlir::failure(); - auto concrete = static_cast(this); - auto ptxAsm = concrete->getPtxAsm(sourceOp); - auto ptxAsmPatched = patchPtxAsm(sourceOp, ptxAsm); - auto hasSideEffects = !isMemoryEffectFree(sourceOp); - auto operandsAndConstraints = concrete->getOperandsAndConstraints(sourceOp); - auto outputConstraints = concrete->getOutputConstraints(sourceOp); - - PTXBuilder ptxBuilder; - auto ptxOutputs = getPtxOutputs(outputConstraints, ptxBuilder); - auto ptxOperands = - getPtxOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); - SmallVector outputsAndOperands = ptxOutputs; - outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end()); - auto &ptxInstr = *ptxBuilder.create(ptxAsmPatched); - ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true); - auto retTy = - op->getNumResults() == 0 ? void_ty(ctx) : op->getResult(0).getType(); - auto res = ptxBuilder.launch(rewriter, loc, retTy, - /*hasSideEffects*/ hasSideEffects); - if (op->getNumResults() == 0) { - rewriter.eraseOp(op); - } else { - rewriter.replaceOp(op, res); - } - - return mlir::success(); + std::string res = ""; + size_t prevStart = 0; + unsigned i = 0; + for (auto &[start, end] : patchLocations) { + res += ptxAsm.substr(prevStart, start - prevStart); + auto integerAttr = op->getAttrOfType(patchValues[i]); + auto attr = integerAttr.getInt(); + res += std::to_string(attr); + prevStart = end; + i++; } -}; + if (prevStart < ptxAsm.size()) + res += ptxAsm.substr(prevStart, ptxAsm.size() - prevStart); + return res; +} template -class NVGPUOpGenericPattern - : public NVGPUOpPatternBase> { +class NVGPUOpGenericPattern : public OpRewritePattern { public: - explicit NVGPUOpGenericPattern(mlir::MLIRContext *context, std::string ptxAsm, - std::vector outputConstraints, - std::vector inputConstraints) - : NVGPUOpPatternBase>(context), - ptxAsm(ptxAsm), outputConstraints(outputConstraints), + explicit NVGPUOpGenericPattern(MLIRContext *context, std::string ptxAsm, + Constraints outputConstraints, + Constraints inputConstraints) + : OpRewritePattern(context), ptxAsm(std::move(ptxAsm)), + outputConstraints(outputConstraints), inputConstraints(inputConstraints) {} - std::vector getOutputConstraints(SourceOp op) const { - return outputConstraints; - } - OperandsAndConstraints getOperandsAndConstraints(SourceOp op) const { + LogicalResult matchAndRewrite(SourceOp op, + PatternRewriter &rewriter) const override { OperandsAndConstraints operandsAndConstraints; for (unsigned i = 0; i < inputConstraints.size(); i++) { operandsAndConstraints.push_back( {op->getOperand(i), inputConstraints[i]}); } - return operandsAndConstraints; + return rewriteAsPtxAsm(op, rewriter, ptxAsm, operandsAndConstraints, + outputConstraints); } - std::string getPtxAsm(SourceOp op) const { return ptxAsm; } private: std::string ptxAsm; - std::vector outputConstraints; - std::vector inputConstraints; + Constraints outputConstraints; + Constraints inputConstraints; }; class FenceAsyncSharedOpPattern - : public NVGPUOpPatternBase { + : public OpRewritePattern { public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::FenceAsyncSharedOp op) const { - auto bCluster = op.getBCluster(); - if (bCluster) - return "fence.proxy.async.shared::cluster;"; - else - return "fence.proxy.async.shared::cta;"; + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::FenceAsyncSharedOp op, + PatternRewriter &rewriter) const override { + std::string ptxAsm = op.getBCluster() ? "fence.proxy.async.shared::cluster;" + : "fence.proxy.async.shared::cta;"; + return rewriteAsPtxAsm(op, rewriter, std::move(ptxAsm)); } }; -class ClusterArriveOpPattern - : public NVGPUOpPatternBase { +class ClusterArriveOpPattern : public OpRewritePattern { public: - using Base = NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::ClusterArriveOp op) const { - auto relaxed = op.getRelaxed(); - if (relaxed) - return "barrier.cluster.arrive.relaxed.aligned;"; - else - return "barrier.cluster.arrive.aligned;"; + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::ClusterArriveOp op, + PatternRewriter &rewriter) const override { + std::string ptxAsm = op.getRelaxed() + ? "barrier.cluster.arrive.relaxed.aligned;" + : "barrier.cluster.arrive.aligned;"; + return rewriteAsPtxAsm(op, rewriter, std::move(ptxAsm)); } }; -class StoreMatrixOpPattern - : public NVGPUOpPatternBase { +class StoreMatrixOpPattern : public OpRewritePattern { public: - using Base = NVGPUOpPatternBase; - using Base::Base; + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::StoreMatrixOp op, + PatternRewriter &rewriter) const override { + return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), + getOperandsAndConstraints(op)); + } OperandsAndConstraints getOperandsAndConstraints(ttn::StoreMatrixOp op) const { @@ -353,137 +300,94 @@ class StoreMatrixOpPattern } }; -class StoreDSmemOpPattern - : public NVGPUOpPatternBase { +class MBarrierArriveOpPattern : public OpRewritePattern { public: - using Base = NVGPUOpPatternBase; - using Base::Base; + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::MBarrierArriveOp op, + PatternRewriter &rewriter) const override { + return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), + getOperandsAndConstraints(op)); + } - OperandsAndConstraints getOperandsAndConstraints(ttn::StoreDSmemOp op) const { + OperandsAndConstraints + getOperandsAndConstraints(ttn::MBarrierArriveOp op) const { OperandsAndConstraints operandsAndTypes; - auto addr = op.getAddr(); - auto ctaId = op.getCtaId(); - auto values = op.getValues(); - auto pred = op.getPred(); - auto bitwidth = op.getBitwidth(); - operandsAndTypes.push_back({addr, "r"}); - operandsAndTypes.push_back({ctaId, "r"}); - operandsAndTypes.push_back({pred, "b"}); - std::string c = bitwidth == 16 ? "h" : (bitwidth == 32 ? "r" : "l"); - for (unsigned i = 0; i < values.size(); i++) { - operandsAndTypes.push_back({values[i], c}); + Value mbarrier = op.getMbarrier(); + Value pred = op.getPred(); + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + + switch (arriveType) { + case ttn::MBarriveType::normal: + case ttn::MBarriveType::cp_async: + case ttn::MBarriveType::expect_tx: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + case ttn::MBarriveType::remote: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; } return operandsAndTypes; } - std::string getPtxAsm(ttn::StoreDSmemOp op) const { - auto bitwidth = op.getBitwidth(); - auto vec = op.getVec(); - auto values = op.getValues(); - assert( - (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && - "invalid bitwidth"); - assert((vec == 1 || vec == 2 || vec == 4) && vec == values.size() && - "invalid vec size"); + std::string getPtxAsm(ttn::MBarrierArriveOp op) const { + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + uint32_t txCount = op.getTxCount(); std::string ptxAsm; - if (vec == 1) { - ptxAsm = "{ \n" - ".reg .u32 remoteAddr; \n" - "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" - ".reg .pred p; \n" - "mov.pred p, $2; \n" - "@p st.shared::cluster.u#bitwidth [remoteAddr], $3; \n" - "}\n"; - } - if (vec == 2) { - ptxAsm = "{ \n" - ".reg .u32 remoteAddr; \n" - "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" - ".reg .pred p; \n" - "mov.pred p, $2; \n" - "@p st.shared::cluster.v.u#bitwidth [remoteAddr], {$3, $4}; \n" - "}\n"; - } - if (vec == 4) { - ptxAsm = "{ \n" - ".reg .u32 remoteAddr; \n" - "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" - ".reg .pred p; \n" - "mov.pred p, $2; \n" - "@p st.shared::cluster.v.u#bitwidth [remoteAddr], {$3, $4, $5, " - "$6}; \n" - "}\n"; + switch (arriveType) { + case ttn::MBarriveType::normal: + ptxAsm = "@$1 mbarrier.arrive.shared.b64 _, [$0];"; + break; + case ttn::MBarriveType::cp_async: + ptxAsm = "@$1 cp.async.mbarrier.arrive.noinc.shared.b64 [$0];"; + break; + case ttn::MBarriveType::expect_tx: + assert(txCount > 0 && "txCount should be valid"); + ptxAsm = "@$1 mbarrier.arrive.expect_tx.shared.b64 _, [$0], " + + std::to_string(txCount) + ";"; + break; + case ttn::MBarriveType::remote: + assert(ctaId && "ctaId should have a valid value"); + ptxAsm = + " { .reg .b32 remAddr32; \n" + " @$2 mapa.shared::cluster.u32 remAddr32, $0, $1; \n" + " @$2 mbarrier.arrive.shared::cluster.b64 _, [remAddr32]; } \n"; + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; } return ptxAsm; } }; -class LoadDSmemOpPattern - : public NVGPUOpPatternBase { +class WGMMAWaitGroupOpPattern : public OpRewritePattern { public: - using Base = NVGPUOpPatternBase; - using Base::Base; - - std::vector getOutputConstraints(ttn::LoadDSmemOp op) const { - auto bitwidth = op.getBitwidth(); - std::string c = bitwidth == 16 ? "=h" : (bitwidth == 32 ? "=r" : "=l"); - auto vec = op.getVec(); - return std::vector(vec, c); - } - OperandsAndConstraints getOperandsAndConstraints(ttn::LoadDSmemOp op) const { - OperandsAndConstraints operandsAndTypes; - auto addr = op.getAddr(); - auto ctaId = op.getCtaId(); + using OpRewritePattern::OpRewritePattern; - operandsAndTypes.push_back({addr, "r"}); - operandsAndTypes.push_back({ctaId, "r"}); - return operandsAndTypes; + LogicalResult matchAndRewrite(ttn::WGMMAWaitGroupOp op, + PatternRewriter &rewriter) const override { + return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), + getOperandsAndConstraints(op), + getOutputConstraints(op)); } - std::string getPtxAsm(ttn::LoadDSmemOp op) const { - auto addr = op.getAddr(); - auto ctaId = op.getCtaId(); - auto bitwidth = op.getBitwidth(); - auto vec = op.getVec(); - - assert( - (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && - "invalid bitwidth"); - assert((vec == 1 || vec == 2 || vec == 4) && "invalid vec size"); - - std::string o1 = vec > 1 ? ".v.u" : ".u"; - std::string vecStr = vec == 1 ? "$0" - : vec == 2 ? "{$0, $1}" - : "{$0, $1, $2, $3}"; - unsigned argNum = vec == 1 ? 1 : vec == 2 ? 2 : 4; - auto ptxAsm = "{\n" - ".reg .u32 remoteAddr;\n" - "mapa.shared::cluster.u32 remoteAddr, $" + - std::to_string(argNum) + " , $" + std::to_string(argNum + 1) + - " ; \n" - "ld.shared::cluster" + - o1 + std::to_string(bitwidth) + " " + vecStr + - ", [remoteAddr];\n" - "}\n"; - return ptxAsm; - } -}; - -class WGMMAWaitGroupOpPattern - : public NVGPUOpPatternBase { -public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - - std::vector - getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { + Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { auto outputStructType = cast(op.getType()); uint32_t numOutputRegs = outputStructType.getBody().size(); std::string output = outputStructType.getBody().front().isF32() ? "=f" : "=r"; - return std::vector(numOutputRegs, output); + return Constraints(numOutputRegs, output); } OperandsAndConstraints @@ -508,10 +412,16 @@ class WGMMAWaitGroupOpPattern } }; -class WGMMAOpPattern : public NVGPUOpPatternBase { +class WGMMAOpPattern : public OpRewritePattern { public: - using Base = NVGPUOpPatternBase; - using Base::Base; + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::WGMMAOp op, + PatternRewriter &rewriter) const override { + return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), + getOperandsAndConstraints(op), + getOutputConstraints(op)); + } std::vector getOutputConstraints(ttn::WGMMAOp op) const { // TODO (zahi): Return type must always be a struct for wgmma, currently @@ -531,6 +441,7 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { auto opA = op.getOpA(); auto opB = op.getOpB(); auto opC = op.getOpC(); + auto opScaleD = op.getUseC(); auto typeA = opA.getType(); auto structTypeA = dyn_cast(typeA); @@ -547,6 +458,11 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { // Operand B (must be `desc`) operandsAndConstraints.push_back({opB, "l"}); + + // `scale-d` + if (op.getOpC()) + operandsAndConstraints.push_back({opScaleD, "b"}); + return operandsAndConstraints; } @@ -643,8 +559,11 @@ class WGMMAOpPattern : public NVGPUOpPatternBase { // Operand B (must be `desc`) args += "$" + std::to_string(asmOpIdx++) + ", "; - // `scale-d` is 1 if we have a C operand. - args += op.getOpC() ? "1" : "0"; + // `scale-d` + if (op.getOpC()) + args += "$" + std::to_string(asmOpIdx++); + else + args += "0"; // `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based // WGMMA @@ -681,17 +600,25 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { #define POPULATE_NVGPU_OP(SRC_OP, ASM) \ patterns.add>(context, ASM, Constraints(), \ Constraints()); + POPULATE_NVGPU_OP(ttn::RegAllocOp, Reg_Alloc_Op) POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op) POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op) POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op) + POPULATE_NVGPU_OP(ttn::RegDeallocOp, Reg_Dealloc_Op) #undef POPULATE_NVGPU_OP + patterns.add>( + context, Named_Barrier_Arrive_Op, Constraints(), + Constraints({"r", "r"})); + patterns.add>( + context, Named_Barrier_Wait_Op, Constraints(), Constraints({"r", "r"})); patterns.add>( context, Cluster_Cta_Id_Op, Constraints({"=r"}), Constraints()); + patterns.add>( + context, Canonical_Warp_Id_Op, Constraints({"=r"}), Constraints()); - patterns - .add( - context); + patterns.add(context); if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); @@ -703,6 +630,37 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { namespace mlir { namespace triton { +LogicalResult +nvgpu::rewriteAsPtxAsm(Operation *op, PatternRewriter &rewriter, + std::string ptxAsm, + const OperandsAndConstraints &operandsAndConstraints, + const Constraints &outputConstraints) { + auto ctx = rewriter.getContext(); + auto loc = op->getLoc(); + ptxAsm = patchPtxAsm(op, std::move(ptxAsm)); + auto hasSideEffects = !isMemoryEffectFree(op); + + PTXBuilder ptxBuilder; + auto ptxOutputs = getPtxOutputs(outputConstraints, ptxBuilder); + auto ptxOperands = + getPtxOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); + SmallVector outputsAndOperands = ptxOutputs; + outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end()); + auto &ptxInstr = *ptxBuilder.create(ptxAsm); + ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true); + auto retTy = + op->getNumResults() == 0 ? void_ty(ctx) : op->getResult(0).getType(); + auto res = ptxBuilder.launch(rewriter, loc, retTy, + /*hasSideEffects*/ hasSideEffects); + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, res); + } + + return success(); +} + std::unique_ptr> createConvertNVGPUToLLVMPass() { return std::make_unique<::ConvertNVGPUToLLVM>(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp index 746b910e1..268d1dbf6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp @@ -55,6 +55,77 @@ struct BarrierOpConversion } }; +// -------------------------------------------------------------------------- +// -- MBarrier related Ops lowering, to be moved to a separate file --------- +// -------------------------------------------------------------------------- +struct MBarrierArriveOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::MBarrierArriveOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::MBarrierArriveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto mbarrier = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getMbarrier(), + typeConverter->convertType(op.getMbarrier().getType().getElementType()), + rewriter); + + bool trackAsyncOp = op.getTrackAsyncOp(); + triton::nvgpu::MBarriveType type = triton::nvgpu::MBarriveType::normal; + uint32_t txCount = op.getTxCount(); + auto remoteCtaId = adaptor.getRemoteCtaId(); + if (trackAsyncOp) { + type = triton::nvgpu::MBarriveType::cp_async; + } else if (remoteCtaId) { + assert(txCount == 0 && + "remote arrive of transaction mbarrier is not implemented yet"); + type = triton::nvgpu::MBarriveType::remote; + } else if (txCount > 0) { + type = triton::nvgpu::MBarriveType::expect_tx; + } + Value pred = adaptor.getPred(); + if (pred == nullptr) { + pred = int_val(/*width*/ 1, 1); + } + rewriter.replaceOpWithNewOp( + op, mbarrier.getBase(), pred, remoteCtaId, type, txCount); + return success(); + } +}; + +struct NamedBarrierArriveOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierArriveOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierArriveOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + +struct NamedBarrierWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + struct FenceAsyncSharedOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -83,8 +154,18 @@ struct InitBarrierOpConversion typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + auto id = getThreadId(rewriter, loc); - auto pred = icmp_eq(id, i32_val(0)); + auto pred = icmp_eq(id, i32_val(executingThreadId)); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = "@$0 mbarrier.init.shared::cta.b64 [$1], " + std::to_string(op.getCount()) + ";"; @@ -112,8 +193,17 @@ struct InvalBarrierOpConversion typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } auto id = getThreadId(rewriter, loc); - Value pred = icmp_eq(id, i32_val(0)); + Value pred = icmp_eq(id, i32_val(executingThreadId)); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = "@$0 mbarrier.inval.shared::cta.b64 [$1];"; auto &barSyncOp = *ptxBuilder.create<>(ptx); @@ -140,8 +230,17 @@ struct BarrierExpectConversion typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } auto id = getThreadId(rewriter, loc); - Value pred = icmp_eq(id, i32_val(0)); + Value pred = icmp_eq(id, i32_val(executingThreadId)); pred = and_(pred, adaptor.getPred()); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = @@ -194,6 +293,9 @@ void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index 78aa1493d..6432ae305 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -10,13 +10,16 @@ add_triton_library(TritonNVIDIAGPUToLLVM LoadStoreOpToLLVM.cpp BarrierOpToLLVM.cpp TritonGPUToLLVM.cpp + TMAToLLVM.cpp DecomposeUnsupportedConversions.cpp SPMDOpToLLVM.cpp TensorPtrOpsToLLVM.cpp ClusterOpsToLLVM.cpp PTXAsmFormat.cpp Utility.cpp + UpcastMXFPToLLVM.cpp TargetInfo.cpp + RegReallocOpToLLVM.cpp DEPENDS TritonNVIDIAGPUConversionPassIncGen diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index f8ece0f1c..fbc5121d3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -21,7 +21,6 @@ using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; -using ::mlir::triton::gpu::isaDistributedLayout; using ::mlir::triton::gpu::SharedEncodingAttr; // Forward declarations @@ -140,74 +139,6 @@ struct LocalLoadOpConversion } }; -struct ConvertLayoutOpOptimizedConversion - : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern< - triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - RankedTensorType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - // forwarding on mma->mma shortcut, lower distributed->distributed otherwise - if (isa(srcLayout) && - isa(dstLayout)) { - if (isMmaToMmaShortcut(srcTy, dstTy)) { - return lowerMmaToMma(op, adaptor, rewriter); - } - } - return failure(); - } - -private: - // mma -> mma - LogicalResult lowerMmaToMma(triton::gpu::ConvertLayoutOp op, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - RankedTensorType srcTy = op.getSrc().getType(); - RankedTensorType dstTy = op.getType(); - if (triton::gpu::getTotalElemsPerThread(srcTy) == - triton::gpu::getTotalElemsPerThread(dstTy)) { - rewriter.replaceOp(op, adaptor.getSrc()); - return success(); - } - auto dstMmaLayout = cast(dstTy.getEncoding()); - auto srcMmaLayout = cast(srcTy.getEncoding()); - assert(dstMmaLayout.isHopper() && srcMmaLayout.isHopper() && - "only MMAV3 layout is supported"); - auto dstShape = dstTy.getShape(); - auto shapePerCTA = getShapePerCTA(dstMmaLayout, dstShape); - ArrayRef dstInstrShape = dstMmaLayout.getInstrShape(); - ArrayRef srcInstrShape = srcMmaLayout.getInstrShape(); - SmallVector retVals; - unsigned numBlockM = - ceil(shapePerCTA[0], getShapePerCTATile(dstMmaLayout)[0]); - unsigned numBlockN = - ceil(shapePerCTA[1], getShapePerCTATile(dstMmaLayout)[1]); - // Remap the values based on MMAV3 layout, there may be duplicated values in - // either the source or destination. - auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - for (unsigned i = 0; i < numBlockM; i++) { - for (unsigned j = 0; j < numBlockN; j++) { - for (unsigned k = 0; k < dstInstrShape[1] / 2; k++) { - int index = i * numBlockN * (srcInstrShape[1] / 2) + j + - (k % (srcInstrShape[1] / 2)); - retVals.push_back(vals[index]); - } - } - } - assert(retVals.size() == triton::gpu::getTotalElemsPerThread(dstTy)); - Value view = - packLLElements(loc, getTypeConverter(), retVals, rewriter, dstTy); - rewriter.replaceOp(op, view); - return success(); - } -}; - struct ConvertLayoutOpConversion : public ConvertOpToLLVMPattern { public: @@ -224,11 +155,14 @@ struct ConvertLayoutOpConversion RankedTensorType dstTy = op.getType(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) { + if (isa( + srcLayout) && + isa( + dstLayout)) { if (shouldUseDistSmem(srcLayout, dstLayout)) - return lowerDistToDistWithDistSmem(op, adaptor, rewriter); + return lowerDistToDistWithDistSmem(op, adaptor, rewriter, targetInfo); if (isLayoutMmaV1(srcLayout) || isLayoutMmaV1(dstLayout)) - return lowerDistributedToDistributed(op, adaptor, rewriter); + return lowerDistributedToDistributed(op, adaptor, rewriter, targetInfo); } if (isa(srcLayout) && isa(dstLayout)) { @@ -430,7 +364,9 @@ struct ConvertLayoutOpConversion LogicalResult lowerDistToDistWithDistSmem(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { + MLIRContext *ctx = rewriter.getContext(); auto loc = op.getLoc(); auto typeConverter = getTypeConverter(); auto srcTy = op.getSrc().getType(); @@ -446,7 +382,7 @@ struct ConvertLayoutOpConversion auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); Value smemBase = - LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); smemBase = bitcast(smemBase, elemPtrTy); auto smemShape = convertType(srcShapePerCTA); @@ -495,7 +431,8 @@ struct ConvertLayoutOpConversion Value localOffset = linearize(rewriter, loc, localCoord, smemShape); Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, localOffset); - outVals.push_back(load_dsmem(ptr, remoteCTAId, llvmElemTy)); + outVals.push_back(targetInfo.loadDShared( + rewriter, loc, ptr, remoteCTAId, llvmElemTy, /*pred=*/true_val())); } Value result = @@ -515,7 +452,8 @@ struct ConvertLayoutOpConversion LogicalResult lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) const { auto loc = op.getLoc(); auto typeConverter = getTypeConverter(); RankedTensorType srcTy = op.getSrc().getType(); @@ -524,7 +462,7 @@ struct ConvertLayoutOpConversion Attribute dstLayout = dstTy.getEncoding(); Value smemBase = - LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation()); auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); smemBase = bitcast(smemBase, elemPtrTy); auto shape = dstTy.getShape(); @@ -554,10 +492,12 @@ struct ConvertLayoutOpConversion // Potentially we need to store for multiple CTAs in this replication auto accumNumReplicates = product(numReplicates); auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - unsigned inVec = 0; - unsigned outVec = 0; - auto origRepShape = getRepShapeForCvtLayout(op); - auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); + auto scratchConfig = + getScratchConfigForCvt(op.getSrc().getType(), op.getType()); + unsigned inVec = scratchConfig.inVec; + unsigned outVec = scratchConfig.outVec; + const auto &origRepShape = scratchConfig.repShape; + const auto &paddedRepShape = scratchConfig.paddedRepShape; unsigned outElems = getTotalElemsPerThread(dstTy); auto outOrd = getOrder(dstLayout); @@ -567,7 +507,7 @@ struct ConvertLayoutOpConversion auto multiDimRepId = getMultiDimIndex(repId, numReplicates, outOrd); if (repId != 0) { - barrier(); + insertBarrier(rewriter, op); } if (isLayoutMmaV1(srcLayout)) @@ -579,7 +519,7 @@ struct ConvertLayoutOpConversion multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, vals, smemBase); - barrier(); + insertBarrier(rewriter, op); if (isLayoutMmaV1(dstLayout)) processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ false, dstTy, @@ -701,7 +641,6 @@ struct ConvertLayoutOpConversion // for the destination type, we need to pack values together // so they can be consumed by tensor core operations SmallVector vecVals; - SmallVector types; // For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer // instructions to pack & unpack sub-word integers. A workaround is to // store the results of ldmatrix in i32 @@ -715,37 +654,20 @@ struct ConvertLayoutOpConversion shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j)); val = or_(i32_ty, val, ext); } - vecVals.push_back(val); + vecVals.push_back(bitcast(val, i32_ty)); } - elems = elems / (32 / elemSize); - types = SmallVector(elems, i32_ty); } else { unsigned vecSize = std::max(32 / elemSize, 1); Type vecTy = vec_ty(elemTy, vecSize); - types = SmallVector(elems / vecSize, vecTy); for (unsigned i = 0; i < elems; i += vecSize) { Value packed = rewriter.create(loc, vecTy); for (unsigned j = 0; j < vecSize; j++) packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(packed); + vecVals.push_back(bitcast(packed, i32_ty)); } } - - // This needs to be ordered the same way that - // ldmatrix.x4 would order it - // TODO: this needs to be refactor so we don't - // implicitly depends on how emitOffsetsForMMAV2 - // is implemented - SmallVector reorderedVals; - for (unsigned i = 0; i < vecVals.size(); i += 4) { - reorderedVals.push_back(bitcast(vecVals[i], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty)); - reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty)); - } - - Value view = packLLElements(loc, getTypeConverter(), reorderedVals, - rewriter, dstTy); + Value view = + packLLElements(loc, getTypeConverter(), vecVals, rewriter, dstTy); rewriter.replaceOp(op, view); return success(); } @@ -787,25 +709,59 @@ struct LocalAllocOpConversion else return failure(); + auto *ctx = rewriter.getContext(); Location loc = op->getLoc(); + RankedTensorType srcTy = op.getSrc().getType(); - Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, op); - auto srcs = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector shape; - for (int64_t dim : srcTy.getShape()) - shape.push_back(dim); - bool loweredToStMatrix = targetInfo.processReplicaUsingStMatrix( - rewriter, loc, smemBase, srcs, srcTy, - getTypeConverter()->convertType(srcTy.getElementType()), shape, shape, - sharedLayout.getOrder(), 1, swizzleByteSize); - if (!loweredToStMatrix) + SmallVector shape = + convertType(srcTy.getShape()); + auto order = sharedLayout.getOrder(); + auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, + shape, order, swizzleByteSize); + if (!layout.has_value()) return failure(); + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + auto smemPtrTy = ptr_ty(ctx, 3); + + auto kRegister = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(layout->getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + + auto regBase = applyLinearLayout(loc, rewriter, *layout, + {{kRegister, i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})[0] + .second; + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcVec = layout->getNumConsecutiveInOut(); + Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); + for (int i = 0; i < srcVals.size(); i += srcVec) { + auto regIdx = + layout + ->apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] + .second; + Value offset = xor_(regBase, i32_val(regIdx)); + auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); + vecAddr.setInbounds(true); + SmallVector inValsVec; + for (int j = 0; j < srcVec; j++) + inValsVec.push_back(srcVals[i + j]); + Value valsVec = packLLVector(loc, inValsVec, rewriter); + targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + } + auto resultTy = cast(op.getType()); // Workaround for 3D tensors // TODO: we need to modify the pipeline pass to give a proper shared // encoding to 3D tensors - auto order = sharedLayout.getOrder(); SmallVector newOrder; if (resultTy.getShape().size() != order.size()) { for (auto i = 0; i < order.size(); ++i) @@ -814,7 +770,6 @@ struct LocalAllocOpConversion } else { newOrder = SmallVector(order.begin(), order.end()); } - auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, newOrder, loc, rewriter); @@ -832,7 +787,6 @@ struct LocalAllocOpConversion void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMOptimizedPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(typeConverter, benefit); patterns.add(typeConverter, targetInfo, benefit); } @@ -841,6 +795,9 @@ void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { // For now give ConvertLayoutOpConversion higher benefit, I can split before // merging + // + // TODO(jlebar): lowerDistributedToDistributed does not get hit in any + // testcases. Is this dead code? Does the benefit need to be increased? patterns.add(typeConverter, targetInfo, benefit); // Same default benefit patterns.add(typeConverter, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp index 75abe1145..6847c0550 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp @@ -10,7 +10,6 @@ using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; -using ::mlir::triton::gpu::isaDistributedLayout; using ::mlir::triton::gpu::SharedEncodingAttr; // Compute the offset of the matrix to load. @@ -150,8 +149,8 @@ static Value loadA(Value tensor, const SharedMemoryObject &smemObj, Type elemX2Ty = vec_ty(f16_ty, 2); Type elemTy = f16_ty; if (tensorTy.getElementType().isBF16()) { - elemX2Ty = vec_ty(i16_ty, 2); - elemTy = i16_ty; + elemX2Ty = vec_ty(bf16_ty, 2); + elemTy = bf16_ty; } // prepare arguments @@ -276,8 +275,8 @@ static Value loadB(Value tensor, const SharedMemoryObject &smemObj, Type elemTy = f16_ty; Type elemX2Ty = vec_ty(f16_ty, 2); if (tensorTy.getElementType().isBF16()) { - elemTy = i16_ty; - elemX2Ty = vec_ty(i16_ty, 2); + elemTy = bf16_ty; + elemX2Ty = vec_ty(bf16_ty, 2); } SmallVector ptrB(numPtrB); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 6977e0597..21c2bee58 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -14,7 +14,6 @@ using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; -using ::mlir::triton::gpu::isaDistributedLayout; using ::mlir::triton::gpu::SharedEncodingAttr; // Data loader for mma.16816 instruction. @@ -491,12 +490,17 @@ Type getSharedMemTy(Type argType) { if (argType.isF16()) return type::f16Ty(ctx); else if (argType.isBF16()) - return type::i16Ty(ctx); + return type::bf16Ty(ctx); else if (argType.isF32()) return type::f32Ty(ctx); else if (argType.getIntOrFloatBitWidth() == 8) return type::i8Ty(ctx); - else + else if (argType.isInteger(16) || argType.isInteger(32)) { + auto bitwidth = argType.getIntOrFloatBitWidth(); + auto signed_type = + argType.isSignedInteger() ? IntegerType::Signed : IntegerType::Unsigned; + return IntegerType::get(ctx, bitwidth, signed_type); + } else llvm::report_fatal_error("mma16816 data type not supported"); } @@ -509,8 +513,8 @@ Value composeValuesToDotOperandLayoutStruct( for (int m = 0; m < n0; ++m) for (int k = 0; k < n1; ++k) { elems.push_back(vals.at({b, 2 * m, 2 * k})); - elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); + elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); } assert(!elems.empty()); @@ -599,9 +603,9 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; - auto numRep = - mmaLayout.getMMAv2Rep(shapePerCTA, bitwidth, encoding.getOpIdx()); int kWidth = encoding.getKWidth(); + auto numRep = mmaLayout.getMMAv2RepForOperand(shapePerCTA, bitwidth, kWidth, + encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); auto order = triton::gpu::getOrder(mmaLayout); @@ -738,7 +742,8 @@ MemDescType getExpandedDesc(MemDescType descTy) { expandedShape[2] = shape[1]; auto encoding = descTy.getEncoding(); auto expandedEncoding = getExpandedEncoding(encoding); - auto expandedDesc = MemDescType::get(expandedShape, elTy, expandedEncoding); + auto expandedDesc = MemDescType::get(expandedShape, elTy, expandedEncoding, + descTy.getMemorySpace()); return expandedDesc; } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 4407a50bd..cf0ddc248 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -34,8 +34,8 @@ class DecomposeLocalLoadToDotOperand auto dstDotOp = dyn_cast( op.getType().getEncoding()); - auto sharedEncoding = - cast(op.getSrc().getType().getEncoding()); + MemDescType srcType = op.getSrc().getType(); + auto sharedEncoding = cast(srcType.getEncoding()); if (!dstDotOp || !sharedEncoding.getHasLeadingOffset()) return failure(); RankedTensorType type = op.getType(); @@ -55,7 +55,8 @@ class DecomposeLocalLoadToDotOperand triton::gpu::SharedEncodingAttr::get( op.getContext(), dstDotOp, type.getShape(), triton::gpu::getOrder(parentEnc), - triton::gpu::getCTALayout(parentEnc), type.getElementType())); + triton::gpu::getCTALayout(parentEnc), type.getElementType()), + srcType.getMemorySpace()); auto tmp = rewriter.create( op.getLoc(), newSharedDescTy, load); auto newConvert = @@ -71,8 +72,8 @@ struct DecomposeUnsupportedConversions void runOnOperation() override { ModuleOp mod = getOperation(); triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod); - triton::gpu::decomposeTensorCoreToDotLayoutConversion< - triton::gpu::NvidiaMmaEncodingAttr>(mod, isMmaToDotShortcut); + triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod, + isMmaToDotShortcut); triton::gpu::decomposeBlockedToDotLayoutConversion(mod); mlir::RewritePatternSet patterns(&getContext()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp index 374b9ec9e..3e915a577 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp @@ -23,15 +23,10 @@ LogicalResult convertMMA16816(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); -LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, +LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, + triton::nvidia_gpu::WarpGroupDotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Value thread); - -LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, - triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - Value thread); namespace { struct DotOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -59,9 +54,6 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { return convertMMA1688(op, adaptor, getTypeConverter(), rewriter); if (mmaLayout.isAmpere()) return convertMMA16816(op, adaptor, getTypeConverter(), rewriter); - if (mmaLayout.isHopper()) - return convertWGMMA(op, adaptor, getTypeConverter(), rewriter, - getThreadId(rewriter, loc)); llvm::report_fatal_error( "Unsupported MMA kind found when converting DotOp to LLVM."); @@ -76,13 +68,13 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { } }; -struct DotAsyncOpConversion - : public ConvertOpToLLVMPattern { +struct WarpGroupDotOpConversion + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< - triton::nvidia_gpu::DotAsyncOp>::ConvertOpToLLVMPattern; + triton::nvidia_gpu::WarpGroupDotOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::nvidia_gpu::DotAsyncOp op, OpAdaptor adaptor, + matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); // D = A * B + C @@ -100,26 +92,26 @@ struct DotAsyncOpConversion if (!isOuter && mmaLayout && supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) { if (mmaLayout.isHopper()) { - return convertAsyncWGMMA(op, adaptor, getTypeConverter(), rewriter, - getThreadId(rewriter, loc)); + return convertWGMMA(op, adaptor, getTypeConverter(), rewriter, + getThreadId(rewriter, loc)); } llvm::report_fatal_error( - "Unsupported MMA kind found when converting DotAsyncOp to LLVM."); + "Unsupported MMA kind found when converting WarpGroupDotOp to LLVM."); } llvm::report_fatal_error( - "Unsupported DotAsyncOp found when converting TritonGPU to LLVM."); + "Unsupported WarpGroupDotOp found when converting TritonGPU to LLVM."); } }; -struct DotWaitOpConversion - : public ConvertOpToLLVMPattern { +struct WarpGroupDotWaitOpConversion + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< - triton::nvidia_gpu::DotWaitOp>::ConvertOpToLLVMPattern; + triton::nvidia_gpu::WarpGroupDotWaitOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor, + matchAndRewrite(triton::nvidia_gpu::WarpGroupDotWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto pendings = op.getPendings(); Location loc = op.getLoc(); @@ -180,6 +172,6 @@ void mlir::triton::NVIDIA::populateDotOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 2d16dc19b..c2940a043 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -1,6 +1,9 @@ #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "Utility.h" #include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::triton; @@ -58,16 +61,89 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( const LLVMTypeConverter *typeConverter, Location loc, ConversionPatternRewriter &rewriter, Value value, int batch, int n0, int n1, RankedTensorType type) { - auto elems = unpackLLElements(loc, value, rewriter); int offset{}; ValueTableV2 vals; + + // FIXME [Dot LL] + // [ez] Generalize the logic below for kWidth * elemBitWidth > 32 + auto dot = cast(type.getEncoding()); + auto largeK = dot.getKWidth() == 8 && + cast(dot.getParent()).isAmpere(); + if (largeK) { + llvm::SmallVector si; + + // For kWidth = 8, split the mma into 4 mmas with "stride 4" along K + if (dot.getOpIdx() == 0) { + // Original register layout: + // + // [0, 1, 2, 3], [8, 9, 10, 11] + // [4, 5, 6, 7], [12, 13, 14, 15] + // + // Each element in the layout consists of two bf16 values. + // For example, the row [0, 1, 2, 3] expands to: + // + // [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]] + // + // Here, 0/0 refers to the first half of element 0, and 0/1 refers to the + // second half, matching kWidth = 8. + // + // To derive four independent MMA operations, a stride of 4 is applied to + // the original register layout: + // + // 1st MMA: [0, 4, 8, 12] + // 2nd MMA: [1, 5, 9, 13] + // 3rd MMA: [2, 6, 10, 14] + // 4th MMA: [3, 7, 11, 15] + si = llvm::SmallVector{0, 4, 8, 12, 1, 5, 9, 13, + 2, 6, 10, 14, 3, 7, 11, 15}; + } else { + // Original register layout: + // + // [0, 1, 2, 3]^T, [4, 5, 6, 7]^T + // + // A stride of 4 is applied to derive four independent MMA operations: + // + // 1st MMA: [0, 4] + // 2nd MMA: [1, 5] + // 3rd MMA: [2, 6] + // 4th MMA: [3, 7] + si = llvm::SmallVector{0, 4, 1, 5, 2, 6, 3, 7}; + } + + auto step = si.size(); + SmallVector perm(step); + for (auto i = 0; i < elems.size() / step; ++i) { + for (auto j = 0; j < step; ++j) { + perm[j] = elems[i * step + si[j]]; + } + std::copy(perm.begin(), perm.end(), elems.begin() + i * step); + } + + if (dot.getOpIdx() == 1) { + // there are kWidth * 2 elems packed as bf16x2 + int elemsInTile = dot.getKWidth(); + // n0 and n1 are unrolled in the legacy path + // Unrolling n1 makes some sense, but unrolling n0 makes absolutely no + // sense IMO + n0 *= 2; + n1 *= 2; + for (auto b = 0; b < batch; ++b) + for (auto j = 0; j < n1 / elemsInTile; ++j) + for (auto i = 0; i < n0; ++i) + for (auto k = 0; k < elemsInTile; ++k) { + vals[{b, i, elemsInTile * j + k}] = elems[offset++]; + } + return vals; + } + } + for (auto b = 0; b < batch; ++b) for (auto i = 0; i < n0; ++i) { for (auto j = 0; j < n1; j++) { vals[{b, 2 * i, 2 * j}] = elems[offset++]; - vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; vals[{b, 2 * i + 1, 2 * j}] = elems[offset++]; + vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++]; } } @@ -81,9 +157,9 @@ enum class TensorCoreType : uint8_t { FP32_TF32_TF32_FP32, FP16_FP16_FP16_FP16, FP32_FP8E5M2_FP8E5M2_FP32, - FP32_FP8E5M2_FP8E4M3FNUZ_FP32, - FP32_FP8E4M3FNUZ_FP8E5M2_FP32, - FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, + FP32_FP8E5M2_FP8E4M3FN_FP32, + FP32_FP8E4M3FN_FP8E5M2_FP32, + FP32_FP8E4M3FN_FP8E4M3FN_FP32, // integer tensor core instr INT32_INT1_INT1_INT32, // Not implemented INT32_INT4_INT4_INT32, // Not implemented @@ -112,9 +188,9 @@ Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) { case TensorCoreType::FP16_FP16_FP16_FP16: return fp16x2Pack2Ty; case TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32: - case TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32: - case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32: - case TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32: + case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32: + case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32: + case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32: return fp32x4Ty; case TensorCoreType::INT32_INT8_INT8_INT32: return i32x4Ty; @@ -140,14 +216,14 @@ TensorCoreType getMmaType(triton::DotOp op) { bTy.getElementType().isFloat8E5M2()) return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32; if (aTy.getElementType().isFloat8E5M2() && - bTy.getElementType().isFloat8E4M3FNUZ()) - return TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32; - if (aTy.getElementType().isFloat8E4M3FNUZ() && + bTy.getElementType().isFloat8E4M3FN()) + return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32; + if (aTy.getElementType().isFloat8E4M3FN() && bTy.getElementType().isFloat8E5M2()) - return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32; - if (aTy.getElementType().isFloat8E4M3FNUZ() && - bTy.getElementType().isFloat8E4M3FNUZ()) - return TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32; + return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32; + if (aTy.getElementType().isFloat8E4M3FN() && + bTy.getElementType().isFloat8E4M3FN()) + return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32; if (aTy.getElementType().isF32() && bTy.getElementType().isF32() && op.getInputPrecision() == InputPrecision::TF32) return TensorCoreType::FP32_TF32_TF32_FP32; @@ -193,11 +269,11 @@ inline static const std::map mmaInstrPtxAmpere = { {TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32"}, - {TensorCoreType::FP32_FP8E5M2_FP8E4M3FNUZ_FP32, + {TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32"}, - {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E5M2_FP32, + {TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32"}, - {TensorCoreType::FP32_FP8E4M3FNUZ_FP8E4M3FNUZ_FP32, + {TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32, "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32"}, }; @@ -318,19 +394,25 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); auto repA = cast(dotOpA.getParent()) - .getMMAv2Rep(aShapePerCTA, bitwidth, dotOpA.getOpIdx()); + .getMMAv2RepForOperand(aShapePerCTA, bitwidth, + dotOpA.getKWidth(), dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); auto repB = cast(dotOpB.getParent()) - .getMMAv2Rep(bShapePerCTA, bitwidth, dotOpB.getOpIdx()); + .getMMAv2RepForOperand(bShapePerCTA, bitwidth, + dotOpB.getKWidth(), dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); int repM = repA[1], repN = repB[2], repK = repA[2]; int repBatch = repA[0]; - // shape / shape_per_cta auto ha = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); + + // FIXME [Dot LL] + // max(repN / 2, 1) is wrong for repN = 1! + // This is also wrong in + // NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand auto hb = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedB, repBatch, std::max(repN / 2, 1), repK, bTensorTy); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 738f0fe04..1bb55373e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -58,7 +58,7 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) { return triton::nvgpu::WGMMAEltType::s8; } else if (aTy.isFloat8E5M2()) { return triton::nvgpu::WGMMAEltType::e5m2; - } else if (aTy.isFloat8E4M3FNUZ()) { + } else if (aTy.isFloat8E4M3FN()) { return triton::nvgpu::WGMMAEltType::e4m3; } else { llvm::report_fatal_error("Unsupported mma operand type found"); @@ -316,11 +316,6 @@ SmallVector unpackAccumulator(ConversionPatternRewriter &rewriter, return results; } -static bool isFP8(triton::nvgpu::WGMMAEltType eltType) { - return eltType == triton::nvgpu::WGMMAEltType::e5m2 || - eltType == triton::nvgpu::WGMMAEltType::e4m3; -} - static Value faddAccumulate(ConversionPatternRewriter &rewriter, Location loc, Value a, Value b) { int numEl = cast(a.getType()).getBody().size(); @@ -357,9 +352,10 @@ static SmallVector emitWait(ConversionPatternRewriter &rewriter, LogicalResult convertDot(const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Operation *op, Value a, Value b, Value c, Value d, - Value loadedA, Value loadedB, Value loadedC, - bool allowTF32, uint32_t maxNumImpreciseAcc, bool sync, - Value thread) { + Value useCOperand, Value loadedA, Value loadedB, + Value loadedC, bool allowTF32, + bool needsPartialAccumulator, + uint32_t maxNumImpreciseAcc, bool sync, Value thread) { auto aTensorTy = cast(a.getType()); auto bTensorTy = cast(b.getType()); auto dTensorTy = cast(d.getType()); @@ -420,10 +416,6 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, auto func = op->getParentOfType(); Operation *startSequence = rewriter.create(loc); - // WGMMA fp8 -> fp32 accumulates in lower precision than fp32. - bool needsPartialAccumulator = isFP8(eltTypeA) && - eltTypeC == triton::nvgpu::WGMMAEltType::f32 && - maxNumImpreciseAcc <= aTensorTy.getShape()[1]; SmallVector mmaResults; for (int m = 0; m < numRepM; ++m) { for (int n = 0; n < numRepN; ++n) { @@ -436,8 +428,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, auto accTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); Value d; - if (!zeroAcc) + Value useC = i1_val(0); + if (!zeroAcc) { d = packLLElements(loc, typeConverter, mmaOut, rewriter, accTy); + useC = i1_val(1); + } + if (useCOperand) + useC = and_(useC, useCOperand); uint32_t numLowPrecisionAcc = 0; Value partialAcc; for (int k = 0; k < numRepK; ++k) { @@ -463,8 +460,9 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, (numLowPrecisionAcc >= maxNumImpreciseAcc || k == numRepK - 1); Value mmaAcc = needsPartialAccumulator ? partialAcc : d; mmaAcc = rewriter.create( - loc, accTy, a, b, mmaAcc, M, N, K, eltTypeC, eltTypeA, eltTypeB, - layoutA, layoutB); + loc, accTy, a, b, useC, mmaAcc, M, N, K, eltTypeC, eltTypeA, + eltTypeB, layoutA, layoutB); + useC = i1_val(1); if (needsPartialAccumulator) partialAcc = mmaAcc; else @@ -500,35 +498,20 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, return success(); } -LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, +LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, + triton::nvidia_gpu::WarpGroupDotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Value thread) { auto AEnc = op.getA().getType().getEncoding(); auto BEnc = op.getB().getType().getEncoding(); - assert((mlir::isa(AEnc))); - assert(mlir::isa(BEnc) && - "Operand B should use Shared layout."); - return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // - op.getA(), op.getB(), op.getC(), op.getD(), // - adaptor.getA(), adaptor.getB(), adaptor.getC(), // - op.getInputPrecision() == InputPrecision::TF32, - op.getMaxNumImpreciseAcc(), true, thread); -} - -LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, - triton::nvidia_gpu::DotAsyncOp::Adaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter, - Value thread) { - auto AEnc = op.getA().getType().getEncoding(); - auto BEnc = op.getB().getType().getEncoding(); assert(mlir::isa(AEnc) || mlir::isa(AEnc)); assert(mlir::isa(BEnc) && "Operand B should use Shared layout."); - return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // - op.getA(), op.getB(), op.getC(), op.getD(), // - adaptor.getA(), adaptor.getB(), adaptor.getC(), + return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // + op.getA(), op.getB(), op.getC(), op.getD(), op.getUseC(), // + adaptor.getA(), adaptor.getB(), adaptor.getC(), // op.getInputPrecision() == InputPrecision::TF32, - op.getMaxNumImpreciseAcc(), false, thread); + op.needsPartialAccumulator(), op.getMaxNumImpreciseAcc(), + !op.getIsAsync(), thread); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index a87fc936d..ef69b96fc 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -5,6 +5,7 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" using namespace mlir::triton::gpu; @@ -234,6 +235,21 @@ static const std::string S8_to_Bf16 = "prmt.b32 $0, f0, f1, 0x7632; \n" // f32->bf16 + pack "prmt.b32 $1, f2, f3, 0x7632; \n" // "}"; +// Conversions have low throughput, rely on bit tricks instead of cvt +// instruction on Hopper and later GPUs. +static const std::string S8_to_Bf16_sm90 = + "{ \n" + ".reg .b32 l<3>; \n" + ".reg .b32 h<3>; \n" + "prmt.b32 l0, $2, 0x43, 0x4140; \n" // Unpack to shifted bf16. + "prmt.b32 h0, $2, 0x43, 0x4342; \n" + "and.b32 l1, l0, 0xff7fff7f; \n" // Zero the least exp bit. + "and.b32 h1, h0, 0xff7fff7f; \n" + "and.b32 l2, l0, 0xff80ff80; \n" // Zero the mantissa. + "and.b32 h2, h0, 0xff80ff80; \n" + "sub.bf16x2 $0, l1, l2; \n" // Subtract the offset. + "sub.bf16x2 $1, h1, h2; \n" + "}"; typedef std::function(Location, ConversionPatternRewriter &, const SmallVector &)> @@ -355,7 +371,7 @@ struct FpToFpOpConversion cvt(res, operand); // TODO: This is a hack to get the right type. We should be able to invoke // the type converter - return builder.launch(rewriter, loc, i16_ty, false); + return builder.launch(rewriter, loc, bf16_ty, false); } static Value convertFp32ToFp16(Location loc, @@ -386,7 +402,7 @@ struct FpToFpOpConversion std::pair getConversionFunc(Type srcTy, Type dstTy, std::optional roundingMode) const { - auto F8E4M3TyID = TypeID::get(); + auto F8E4M3TyID = TypeID::get(); auto F8E5M2TyID = TypeID::get(); auto F16TyID = TypeID::get(); auto BF16TyID = TypeID::get(); @@ -430,7 +446,7 @@ struct FpToFpOpConversion llvm::report_fatal_error("Unsupported rounding mode for conversion."); } if (computeCapability < 89 && - (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) { + (srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) { llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " "compute capability >= 89" << "\n"; @@ -452,7 +468,7 @@ struct FpToFpOpConversion auto dstElementType = getElementType(op.getResult()); auto roundingMode = op.getRounding(); - if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FNUZ()) { + if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { assert(roundingMode.has_value() && "Rounding mode must be specified for convertsions to fp8"); @@ -489,7 +505,7 @@ struct FpToFpOpConversion bool useFP16IntermediateSrc = srcElementType.isF32() && - (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() || + (!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() || dstElementType.isFloat8E5M2())) || roundingMode.value() == RoundingMode::RTZ); bool isDstFP32 = dstElementType.isF32(); @@ -574,7 +590,7 @@ struct FMulOpConversion auto lhs = builder.newOperand(operands[0][0], "h"); auto rhs = builder.newOperand(operands[0][1], "h"); fMul({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true); - return {builder.launch(rewriter, loc, i16_ty, false)}; + return {builder.launch(rewriter, loc, bf16_ty, false)}; } else { return {rewriter.create(loc, elemTy, operands[0][0], operands[0][1])}; @@ -604,7 +620,7 @@ struct FAddOpConversion auto lhs = builder.newOperand(operands[0][0], "h"); auto rhs = builder.newOperand(operands[0][1], "h"); fAdd({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true); - return {builder.launch(rewriter, loc, i16_ty, false)}; + return {builder.launch(rewriter, loc, bf16_ty, false)}; } else { return {rewriter.create(loc, elemTy, operands[0][0], operands[0][1])}; @@ -634,7 +650,7 @@ struct FSubOpConversion auto lhs = builder.newOperand(operands[0][0], "h"); auto rhs = builder.newOperand(operands[0][1], "h"); fSub({res, lhs, rhs}, /*onlyAttachMLIRArgs=*/true); - return {builder.launch(rewriter, loc, i16_ty, false)}; + return {builder.launch(rewriter, loc, bf16_ty, false)}; } else { return {rewriter.create(loc, elemTy, operands[0][0], operands[0][1])}; @@ -646,9 +662,15 @@ struct FSubOpConversion struct SIToFPOpConversion : ElementwiseOpConversionBase { using Base = ElementwiseOpConversionBase; - using Base::Base; using Adaptor = typename Base::OpAdaptor; + explicit SIToFPOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + int computeCapability, + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, @@ -657,21 +679,21 @@ struct SIToFPOpConversion Type outElemTy = getElementType(op.getOut()); if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) { auto cvtFunc = makeConverterFromPtx( - S8_to_Bf16, getTypeConverter()->convertType(inElemTy), + computeCapability >= 90 ? S8_to_Bf16_sm90 : S8_to_Bf16, + getTypeConverter()->convertType(inElemTy), getTypeConverter()->convertType(outElemTy)); SmallVector inVals = {operands[0][0], operands[1][0], operands[2][0], operands[3][0]}; auto outVals = cvtFunc(loc, rewriter, inVals); assert(outVals.size() == 4); return outVals; - } else if (outElemTy.isBF16()) { - auto value = rewriter.create(loc, f32_ty, operands[0][0]); - return {FpToFpOpConversion::convertFp32ToBf16(loc, rewriter, value, - RoundingMode::RTNE)}; } else { return {rewriter.create(loc, elemTy, operands[0][0])}; } } + +private: + int computeCapability; }; struct FPToSIOpConversion @@ -685,13 +707,7 @@ struct FPToSIOpConversion Type elemTy, MultipleOperandsRange operands, Location loc) const { auto inElemTy = getElementType(op.getIn()); - if (inElemTy.isBF16()) { - auto value = - FpToFpOpConversion::convertBf16ToFp32(loc, rewriter, operands[0][0]); - return {rewriter.create(loc, elemTy, value)}; - } else { - return {rewriter.create(loc, elemTy, operands[0][0])}; - } + return {rewriter.create(loc, elemTy, operands[0][0])}; } }; @@ -897,7 +913,7 @@ struct OpToExternCallConversion LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(rewriter, op, funcName, funcType); return { - rewriter.create(loc, funcOp, operands[0]).getResult()}; + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; } private: @@ -930,8 +946,9 @@ void mlir::triton::NVIDIA::populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, + computeCapability, benefit); patterns.add(typeConverter, axisInfoAnalysis, computeCapability, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 6b566a967..65746d013 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -6,7 +6,10 @@ #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" using namespace mlir; using namespace mlir::triton; @@ -96,6 +99,23 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, return mask; } +std::string getRegisterSizeCode(int size, bool is_float) { + switch (size) { + case 1: + return "b"; + case 16: + return "h"; + case 32: + return is_float ? "f" : "r"; + case 64: + return is_float ? "d" : "l"; + case 128: + return "q"; + default: + llvm_unreachable("Unsupported register size"); + } +} + // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(const NVIDIA::TargetInfo &targetInfo, @@ -164,6 +184,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(op.getType())); unsigned vec = getVectorSize(ptr); unsigned numElems = getTotalElemsPerThread(ptr.getType()); + unsigned vecOrig = vec; if (llMask) { LLVM_DEBUG(DBGS() << "vec = " << vec << " mask_alignment = " << getMaskAlignment(mask)); @@ -171,6 +192,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, LLVM_DEBUG(llvm::dbgs() << " vec = " << vec << '\n'); } + if (vec == 1 && numElems > 1) { + int maskValue = !llMask ? -1 : getMaskAlignment(mask); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " numElems = " << numElems << " mask is " << maskValue + << "\n"; + } // Get the LLVM values for pointers auto ptrElems = unpackLLElements(loc, llPtr, rewriter); assert(ptrElems.size() == numElems); @@ -376,6 +404,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, assert(ptrElems.size() == valueElems.size()); // Determine the vectorization size + unsigned vecOrig = vec; SmallVector maskElems; if (llMask) { Value mask = op.getMask(); @@ -386,6 +415,14 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, vec = std::min(vec, maskAlign); } + if (vec == 1 && elemsPerThread > 1) { + int mask = !llMask ? -1 : getMaskAlignment(op.getMask()); + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << " mask is " + << mask << "\n"; + } + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); @@ -467,10 +504,11 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, } }; -void createBarrier(ConversionPatternRewriter &rewriter, Location loc, +void createBarrier(ConversionPatternRewriter &rewriter, Operation *op, int numCTAs) { + auto loc = op->getLoc(); if (numCTAs == 1) { - barrier(); + insertBarrier(rewriter, op); } else { rewriter.create(loc, false); rewriter.create(loc); @@ -514,12 +552,18 @@ struct AtomicCASOpConversion auto elemsPerThread = getTotalElemsPerThread(op.getVal().getType()); // vec = 1 for scalar auto vec = getVectorSize(op.getPtr()); + auto vecOrig = vec; // tensor if (tensorTy) { auto valTy = cast(op.getVal().getType()); vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); } + if (vec == 1 && elemsPerThread > 1) + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " origin vec = " << vecOrig + << " elemsPerThread = " << elemsPerThread << "\n"; + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -561,9 +605,12 @@ struct AtomicCASOpConversion } } else { auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); - createBarrier(rewriter, loc, numCTAs); - Value atomPtr = - LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); + return success(); + } + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); // Only threads with mask = True store the result PTXBuilder ptxBuilderStore; @@ -574,9 +621,8 @@ struct AtomicCASOpConversion st(dstOprStore, valOprStore).predicate(mask); auto ASMReturnTy = void_ty(ctx); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); - createBarrier(rewriter, loc, numCTAs); + createBarrier(rewriter, op, numCTAs); Value ret = load(valueElemTy, atomPtr); - createBarrier(rewriter, loc, numCTAs); rewriter.replaceOp(op, {ret}); } } @@ -601,6 +647,18 @@ struct AtomicRMWOpConversion : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + bool supportsVectorized(RMWOp opType, Type elementType) const { + // vectorized atomics are only supported on hopper, + // and only for specific atomic ops (add, min, max). + // Note that "packed types" like f16x2 are supported sm60+. + if (!targetInfo.supportVectorizedAtomics()) { + return false; + } + + return opType == RMWOp::FADD && + (elementType.isF16() || elementType.isBF16() || elementType.isF32()); + } + LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -633,38 +691,81 @@ struct AtomicRMWOpConversion : valueTy; const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); - // vec = 1, numElements = 1 for scalar - auto vec = getVectorSize(ptr); - int numElems = 1; - // tensor + // packed: e.g. packed=2 for f16x2 + // vec: e.g. .v2, .v4, .v8 version of atom instruction. + unsigned vec, vecOrig; + int numElems, packed; if (tensorTy) { + vec = getVectorSize(ptr); + if (llMask) { + vec = std::min(vec, getMaskAlignment(op.getMask())); + } + vecOrig = vec; + packed = 1; auto valTy = cast(val.getType()); - vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); - // mask + if (!supportsVectorized(atomicRmwAttr, valTy.getElementType())) { + packed = + std::min(vecOrig, valTy.getElementType().isF16() ? 2 : 1); + vec = 1; + } numElems = tensorTy.getNumElements(); + } else { + // scalar + vec = 1; + vecOrig = 1; + numElems = 1; + packed = 1; } + assert((packed == 1 || vec == 1) && "packed or vec must be 1"); + + if (vec * packed == 1 && numElems > 1) + op->emitRemark() << "Warning: vectorization fails vec = " << vec + << " packed = " << packed << " origin vec = " << vecOrig + << " numElems = " << numElems; + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); - auto vecTy = vec_ty(valueElemTy, vec); + auto packedTy = vec_ty(valueElemTy, packed); SmallVector resultVals(elemsPerThread); - for (size_t i = 0; i < elemsPerThread; i += vec) { - Value rmwVal = undef(vecTy); - for (int ii = 0; ii < vec; ++ii) { - Value iiVal = createIndexAttrConstant( - rewriter, loc, getTypeConverter()->getIndexType(), ii); - rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal); - } - + for (size_t i = 0; i < elemsPerThread; i += vec * packed) { Value rmwPtr = ptrElements[i]; Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; std::string sTy; PTXBuilder ptxBuilderAtomicRMW; - std::string tyId = valueElemNBits * vec == 64 - ? "l" - : (valueElemNBits * vec == 32 ? "r" : "h"); - auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + // 16-bit -> "h", 32-bit -> "r", 64-bit -> "l" + std::string tyId = + getRegisterSizeCode(valueElemNBits * packed, /*is_float=*/false); + + PTXBuilder::Operand *dstOpr; + if (vec > 1) { + dstOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + dstOpr->listAppend( + ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true)); + } + } else { + dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true); + } + auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l"); - auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); + + PTXBuilder::Operand *valOpr; + if (vec > 1) { + valOpr = ptxBuilderAtomicRMW.newListOperand(); + for (unsigned ii = 0; ii < vec; ++ii) { + valOpr->listAppend( + ptxBuilderAtomicRMW.newOperand(valElements[i + ii], tyId)); + } + } else if (packed > 1) { + Value rmwVal = undef(packedTy); + for (int ii = 0; ii < packed; ++ii) { + rmwVal = insert_element(packedTy, rmwVal, valElements[i + ii], + i32_val(ii)); + } + valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId); + } else { + valOpr = ptxBuilderAtomicRMW.newOperand(valElements[i], tyId); + } auto scope = stringifyMemSyncScope(op.getScope()).str(); auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope); @@ -687,7 +788,7 @@ struct AtomicRMWOpConversion rmwOp = "add"; rmwOp += (valueElemNBits == 16 ? ".noftz" : ""); sTy = "f" + sBits; - sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : ""; + sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : ""; break; case RMWOp::MAX: sTy = "s" + sBits; @@ -712,25 +813,43 @@ struct AtomicRMWOpConversion std::string semStr; llvm::raw_string_ostream os(semStr); os << op.getSem(); - atom.o(semStr).o(rmwOp).o(sTy); + atom.o(semStr).o(rmwOp).v(vec).o(sTy); if (tensorTy) { atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); - auto retType = vec == 1 ? valueElemTy : vecTy; + Type retType; + if (vec > 1) { + SmallVector retTys(vec, valueElemTy); + retType = struct_ty(retTys); + } else if (packed > 1) { + retType = packedTy; + } else { + retType = valueElemTy; + } + auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType); - for (int ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = - vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii)); + + if (vec > 1) { + for (unsigned ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = extract_val(valueElemTy, ret, ii); + } + } else if (packed > 1) { + for (unsigned ii = 0; ii < packed; ++ii) { + resultVals[i + ii] = extract_element(valueElemTy, ret, i32_val(ii)); + } + } else { + resultVals[i] = ret; } + } else { auto ASMReturnTy = void_ty(ctx); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); - if (op->user_begin() == op->user_end()) { - rewriter.replaceOp(op, {old}); + if (!atomicNeedsSharedMemory(op.getResult())) { + rewriter.eraseOp(op); return success(); } - Value atomPtr = - LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + Value atomPtr = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, + op.getOperation()); atomPtr = bitcast(atomPtr, ptr_ty(ctx, 3)); // Only threads with rmwMask = True store the result PTXBuilder ptxBuilderStore; @@ -740,9 +859,8 @@ struct AtomicRMWOpConversion auto *valOpr = ptxBuilderStore.newOperand(old, tyId); storeShared(ptrOpr, valOpr).predicate(rmwMask); ptxBuilderStore.launch(rewriter, loc, void_ty(ctx)); - createBarrier(rewriter, loc, numCTAs); + createBarrier(rewriter, op, numCTAs); Value ret = load(valueElemTy, atomPtr); - createBarrier(rewriter, loc, numCTAs); rewriter.replaceOp(op, {ret}); } } @@ -807,70 +925,73 @@ struct AsyncCopyGlobalToLocalOpConversion // %other SmallVector otherElems; if (llOther) { - // FIXME(Keren): always assume other is 0 for now + // FIXME(Keren): assume other is 0 for now. + // // It's not necessary for now because the pipeline pass will skip // generating insert_slice_async if the load op has any "other" tensor. - // assert(false && "insert_slice_async: Other value not supported yet"); otherElems = unpackLLElements(loc, llOther, rewriter); assert(srcElems.size() == otherElems.size()); } - // We don't use getVec() here because we are copying from memory to memory. - // If contiguity > vector size, we can have one pointer maintaining the - // start of the vector and the other pointer moving to the next vector. - unsigned inVec = getContiguity(op.getSrc()); - unsigned outVec = resSharedLayout.getVec(); - unsigned minVec = inVec; - if (outVec > 1) - minVec = std::min(outVec, inVec); - unsigned numElems = getTotalElemsPerThread(srcTy); - unsigned perPhase = resSharedLayout.getPerPhase(); - unsigned maxPhase = resSharedLayout.getMaxPhase(); - SmallVector offsetVals = {smemObj.strides.size(), i32_val(0)}; - DenseMap sharedPtrs = getSwizzledSharedPtrs( - loc, targetInfo, inVec, srcTy, resSharedLayout, resElemTy, smemObj, - rewriter, offsetVals, smemObj.strides); - - // A sharedLayout encoding has a "vec" parameter. - // On the column dimension, if inVec > outVec, it means we have to divide - // single vector read into multiple ones - auto numVecCols = std::max(inVec / outVec, 1); - - for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { - // 16 * 8 = 128bits - auto maxBitWidth = - std::max(128, resElemTy.getIntOrFloatBitWidth()); - auto vecBitWidth = resElemTy.getIntOrFloatBitWidth() * minVec; - auto bitWidth = std::min(maxBitWidth, vecBitWidth); - auto numWords = vecBitWidth / bitWidth; - auto numWordElems = bitWidth / resElemTy.getIntOrFloatBitWidth(); - - // Tune CG and CA here. - auto byteWidth = bitWidth / 8; - CacheModifier srcCacheModifier = - byteWidth == 16 ? CacheModifier::CG : CacheModifier::CA; - assert(byteWidth == 16 || byteWidth == 8 || byteWidth == 4); - auto resByteWidth = resElemTy.getIntOrFloatBitWidth() / 8; - - Value basePtr = sharedPtrs[elemIdx]; - for (size_t wordIdx = 0; wordIdx < numWords; ++wordIdx) { + // We can load N elements at a time if: + // 1. Every group of N source pointers are contiguous. For example, if + // N=2, then the pointers should be [x, x+1, y, y+1, ...]. + // 2. The mask (if present) has "alignment" N, meaning that each group of N + // mask bits are the same. For example if N=2, the mask must be + // [x, x, y, y, ...]. + unsigned maxVec = getContiguity(op.getSrc()); + if (mask) { + maxVec = std::min(maxVec, getMaskAlignment(mask)); + } + + // Addresses to store into, one per `vecTy`. + VectorType vecTy; + SmallVector shmemAddrs; + bool ok = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides, loc, + rewriter, targetInfo, [&](VectorType vecTy_, Value shmemAddr) { + vecTy = vecTy_; + shmemAddrs.push_back(shmemAddr); + }); + assert(ok); + + int vecBytes = vecTy.getNumElements() * vecTy.getElementTypeBitWidth() / 8; + assert(llvm::isPowerOf2_32(vecBytes)); + if (vecBytes < 4) { + return emitError(loc, "cp.async does not support transfers smaller than " + "4 bytes; calculated this as ") + << vecBytes << " bytes"; + } + + for (int i = 0; i < shmemAddrs.size(); i++) { + // It's possible that vecTy is larger than 128 bits, in which case we have + // to use multiple cp.async instructions. + int wordBytes = std::min(vecBytes, 16); + int wordElems = wordBytes * 8 / vecTy.getElementTypeBitWidth(); + int numWordsInVec = std::max(1, vecBytes / wordBytes); + for (int j = 0; j < numWordsInVec; j++) { + int elemIdx = i * vecTy.getNumElements() + j * wordElems; + + // Tune CG and CA. + CacheModifier srcCacheModifier = + wordBytes == 16 ? CacheModifier::CG : CacheModifier::CA; + assert(wordBytes == 16 || wordBytes == 8 || wordBytes == 4); + PTXBuilder ptxBuilder; - auto wordElemIdx = wordIdx * numWordElems; auto ©AsyncOp = *ptxBuilder.create(srcCacheModifier); - auto *dstOperand = - ptxBuilder.newAddrOperand(basePtr, "r", wordElemIdx * resByteWidth); - auto *srcOperand = - ptxBuilder.newAddrOperand(srcElems[elemIdx + wordElemIdx], "l"); - auto *copySize = ptxBuilder.newConstantOperand(byteWidth); + auto *dstOperand = ptxBuilder.newAddrOperand(shmemAddrs[i], "r", + /*offset=*/j * wordBytes); + auto *srcOperand = ptxBuilder.newAddrOperand(srcElems[elemIdx], "l"); + auto *copySize = ptxBuilder.newConstantOperand(wordBytes); auto *srcSize = copySize; if (op.getMask()) { // We don't use predicate in this case, setting src-size to 0 // if there's any mask. cp.async will automatically fill the // remaining slots with 0 if cp-size > src-size. // XXX(Keren): Always assume other = 0 for now. - auto selectOp = select(maskElems[elemIdx + wordElemIdx], - i32_val(byteWidth), i32_val(0)); + auto selectOp = + select(maskElems[elemIdx], i32_val(wordBytes), i32_val(0)); srcSize = ptxBuilder.newOperand(selectOp, "r"); } @@ -948,6 +1069,13 @@ struct AsyncTMACopyGlobalToLocalOpConversion if (rank > 1) numCopies = ceil(contigDimSizeInByte, 128); + auto asyncTaskIds = getAsyncTaskIds(op); + int firstThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + firstThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + // The bounding box inner dimension must be less than or equal to the // swizzle size. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 @@ -957,8 +1085,9 @@ struct AsyncTMACopyGlobalToLocalOpConversion int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); if (numWarpsToCopy == 1) warpID = i32_val(0); - Value boxPred = - and_(pred, icmp_ult(id, i32_val(numWarpsToCopy * warpSize))); + Value boxPred = and_( + pred, + icmp_ult(id, i32_val(numWarpsToCopy * warpSize + firstThreadId))); ::mlir::triton::PTXBuilder ptxBuilderTMA; Type elemPtrTy = ptr_ty(rewriter.getContext(), 3); Value copyIdxVal = add(warpID, i32_val(copyIdx)); @@ -997,6 +1126,14 @@ struct AsyncTMACopyGlobalToLocalOpConversion } }; +int getWarpOffset(Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() > 0) { + return 4 * *std::min_element(asyncTaskIds.begin(), asyncTaskIds.end()); + } + return 0; +} + struct AsyncTMACopyLocalToGlobalOpConversion : public ConvertOpToLLVMPattern< triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp> { @@ -1042,6 +1179,9 @@ struct AsyncTMACopyLocalToGlobalOpConversion int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); if (numWarpsToCopy == 1) warpID = i32_val(0); + auto warpOffset = getWarpOffset(op); + warpID = sub(warpID, i32_val(warpOffset)); + id = sub(id, i32_val(warpOffset * warpSize)); Value boxPred = and_(pred, icmp_ult(id, i32_val(numWarpsToCopy * warpSize))); ::mlir::triton::PTXBuilder ptxBuilderTMA; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PTXAsmFormat.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PTXAsmFormat.cpp index 78a624768..2f4f03007 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PTXAsmFormat.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PTXAsmFormat.cpp @@ -172,10 +172,6 @@ std::string PTXInstrExecution::dump() const { std::string osStr; llvm::raw_string_ostream os(osStr); - std::string instrRepr = strJoin(instr->instrParts, "."); - if (onlyAttachMLIRArgs) - return instrRepr; - if (pred) { if (!pred->repr) os << "@" << pred->dump() << " "; @@ -183,6 +179,13 @@ std::string PTXInstrExecution::dump() const { os << pred->repr(pred->idx) << " "; } + std::string instrRepr = strJoin(instr->instrParts, "."); + if (onlyAttachMLIRArgs) { + os << instrRepr; + os.flush(); + return osStr; + } + llvm::SmallVector argReprs; for (auto *arg : argsInOrder) { argReprs.push_back(arg->dump()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h index 1013d5bc2..4060378fa 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -36,6 +36,11 @@ void populateElementwiseOpToLLVMPatterns( ModuleAxisInfoAnalysis &axisInfoAnalysis, int computeCapability, const TargetInfo &targetInfo, PatternBenefit benefit); +void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, @@ -46,6 +51,11 @@ void populateTensorPtrOpsToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateTMAToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/RegReallocOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/RegReallocOpToLLVM.cpp new file mode 100644 index 000000000..51c91c4af --- /dev/null +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/RegReallocOpToLLVM.cpp @@ -0,0 +1,47 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +struct RegAllocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::RegAllocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; + +struct RegDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::RegDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; +} // namespace + +void mlir::triton::populateRegReallocOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + return; +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp index 93ad46971..6bcb74436 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp @@ -1,5 +1,6 @@ #include "PatternTritonGPUOpToLLVM.h" #include "Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" namespace { @@ -33,10 +34,25 @@ struct GetNumProgramsOpConversion } }; +struct GetCanonicalWarpIdConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::GetCanonicalWarpIdOp>::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::nvidia_gpu::GetCanonicalWarpIdOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto warpIdOp = rewriter.create( + op->getLoc(), rewriter.getI32Type()); + rewriter.replaceOp(op, warpIdOp); + return success(); + } +}; } // namespace void mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp new file mode 100644 index 000000000..9f8ca5519 --- /dev/null +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp @@ -0,0 +1,294 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/TypeUtilities.h" + +#include "PatternTritonGPUOpToLLVM.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" + +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +constexpr int64_t TMA_SIZE_BYTES = 128; + +void tensormap_cp_fenceproxy(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, Value outPtr, + Value inPtr) { + PTXBuilder ptxBuilder; + + // prepare asm operands + auto *outAddrOpr = ptxBuilder.newAddrOperand(outPtr, "l"); + auto *inAddrOpr = ptxBuilder.newAddrOperand(inPtr, "l"); + auto *sizeOpr = ptxBuilder.newConstantOperand(TMA_SIZE_BYTES); + + // Define the instruction opcode + auto &cp = + *ptxBuilder.create<>("tensormap.cp_fenceproxy.global.shared::cta." + "tensormap::generic.release.gpu.sync.aligned"); + + // Execute collectively on first warp in block + constexpr int kWarpSize = 32; + Value threadId = getThreadId(rewriter, loc); + Value pred = icmp_slt(threadId, i32_val(kWarpSize)); + cp(outAddrOpr, inAddrOpr, sizeOpr).predicate(pred); + + ptxBuilder.launch(rewriter, loc, void_ty(ctx)); +}; + +void tensormap_replace_generic(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + std::string fieldName, Value descPtr, + int32_t newVal) { + PTXBuilder ptxBuilder; + + // prepare asm operands + auto *descAddrOpr = ptxBuilder.newAddrOperand(descPtr, "l"); + auto newValOpr = ptxBuilder.newConstantOperand(newVal); + + // Define the instruction opcode + auto &replace = ptxBuilder.create<>("tensormap.replace.tile") + ->o(fieldName) + .o("shared::cta") + .o("b1024") + .o("b32"); + + Value threadId = getThreadId(rewriter, loc); + Value pred = icmp_eq(threadId, i32_val(0)); + replace(descAddrOpr, newValOpr).predicate(pred); + + ptxBuilder.launch(rewriter, loc, void_ty(ctx)); +} + +void tensormap_replace_generic(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + std::string fieldName, Value descPtr, + Value newVal, + std::optional ord = std::nullopt) { + PTXBuilder ptxBuilder; + + auto newValTy = newVal.getType(); + int width = 0; + + // prepare asm operands + auto *descAddrOpr = ptxBuilder.newAddrOperand(descPtr, "l"); + PTXInstr::Operand *ordOpr = + ord ? ptxBuilder.newConstantOperand(*ord) : nullptr; + PTXInstr::Operand *newValOpr = nullptr; + if (mlir::isa(newValTy)) { + width = mlir::cast(newValTy).getWidth(); + } else { + assert(mlir::isa(newValTy)); + width = 64; + } + const char *constraint = width == 64 ? "l" : "r"; + newValOpr = ptxBuilder.newOperand(newVal, constraint); + + // Define the instruction opcode + auto &replace = ptxBuilder.create<>("tensormap.replace.tile") + ->o(fieldName) + .o("shared::cta") + .o("b1024") + .o("b32", width == 32) + .o("b64", width == 64); + + Value threadId = getThreadId(rewriter, loc); + Value pred = icmp_eq(threadId, i32_val(0)); + + if (ord) { + replace(descAddrOpr, ordOpr, newValOpr).predicate(pred); + } else { + replace(descAddrOpr, newValOpr).predicate(pred); + } + + ptxBuilder.launch(rewriter, loc, void_ty(ctx)); +} + +void tensormap_replace_global_address(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "global_address", descPtr, + newVal); +} + +void tensormap_replace_rank(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, Value descPtr, + int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "rank", descPtr, newVal); +} + +void tensormap_replace_box_dim(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t ord, Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "box_dim", descPtr, newVal, + ord); +} + +void tensormap_replace_global_dim(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t ord, Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "global_dim", descPtr, newVal, + ord); +} + +void tensormap_replace_global_stride(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t ord, Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "global_stride", descPtr, + newVal, ord); +} + +void tensormap_replace_element_stride(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t ord, + Value newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "element_stride", descPtr, + newVal, ord); +} + +void tensormap_replace_elemtype(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "elemtype", descPtr, newVal); +} + +void tensormap_replace_interleave_layout(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "interleave_layout", descPtr, + newVal); +} + +void tensormap_replace_swizzle_mode(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "swizzle_mode", descPtr, + newVal); +} + +void tensormap_replace_fill_mode(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + Value descPtr, int32_t newVal) { + tensormap_replace_generic(loc, ctx, rewriter, "fill_mode", descPtr, newVal); +} + +struct ExperimentalTensormapFenceproxyAcquireOpConversion + : public ConvertOpToLLVMPattern< + triton::ExperimentalTensormapFenceproxyAcquireOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ExperimentalTensormapFenceproxyAcquireOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + PTXBuilder ptxBuilder; + + // prepare asm operands + auto *descAddrOpr = ptxBuilder.newAddrOperand(adaptor.getDescPtr(), "l"); + auto *sizeOpr = ptxBuilder.newConstantOperand(TMA_SIZE_BYTES); + + // Define the instruction opcode + constexpr int kWarpSize = 32; + Value threadId = getThreadId(rewriter, loc); + Value pred = icmp_slt(threadId, i32_val(kWarpSize)); + auto &fence = + *ptxBuilder.create<>("fence.proxy.tensormap::generic.acquire.gpu"); + fence(descAddrOpr, sizeOpr).predicate(pred); + + ptxBuilder.launch(rewriter, loc, getVoidType()); + + // We run the fence on a single warp, then use a barrier to synchronize the + // rest. This ends up being faster than running the fence on each warp. + // TODO: Ideally we only emit one barrier after all fences are issued + insertBarrier(rewriter, op); + + rewriter.eraseOp(op); + return success(); + } +}; + +void zero_fill_tma(Location loc, MLIRContext *ctx, + ConversionPatternRewriter &rewriter, + const NVIDIA::TargetInfo &targetInfo, Value descPtr) { + // Write out zeros + constexpr int kWarpSize = 32; + Value threadId = getThreadId(rewriter, loc); + Value pred = icmp_slt(threadId, i32_val(kWarpSize)); + + auto fillVal = i32_val(0); + auto writeAddr = gep(descPtr.getType(), fillVal.getType(), descPtr, threadId); + targetInfo.storeShared(rewriter, loc, writeAddr, fillVal, pred); + + // Sync warp + PTXBuilder ptxBuilder; + auto &bar = *ptxBuilder.create<>("bar.warp.sync"); + auto *maskOpr = ptxBuilder.newConstantOperand(0xffffffff); + bar(maskOpr).predicate(pred); + ptxBuilder.launch(rewriter, loc, void_ty(ctx)); +} + +struct ExperimentalTensormapCreateOpConversion + : public ConvertOpToLLVMPattern { + const NVIDIA::TargetInfo &targetInfo; + + ExperimentalTensormapCreateOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ExperimentalTensormapCreateOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto ctx = getContext(); + + auto smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + + zero_fill_tma(loc, ctx, rewriter, targetInfo, smemBase); + tensormap_replace_global_address(loc, ctx, rewriter, smemBase, + adaptor.getGlobalAddress()); + tensormap_replace_rank(loc, ctx, rewriter, smemBase, op.getRank() - 1); + for (int i = 0; i < op.getRank(); ++i) { + tensormap_replace_box_dim(loc, ctx, rewriter, smemBase, i, + op.getBoxDim()[i]); + } + for (int i = 0; i < op.getRank(); ++i) { + tensormap_replace_global_dim(loc, ctx, rewriter, smemBase, i, + op.getGlobalDim()[i]); + } + for (int i = 0; i + 1 < op.getRank(); ++i) { + tensormap_replace_global_stride(loc, ctx, rewriter, smemBase, i, + op.getGlobalStride()[i]); + } + for (int i = 0; i < op.getRank(); ++i) { + tensormap_replace_element_stride(loc, ctx, rewriter, smemBase, i, + op.getElementStride()[i]); + } + tensormap_replace_elemtype(loc, ctx, rewriter, smemBase, op.getElemType()); + tensormap_replace_interleave_layout(loc, ctx, rewriter, smemBase, + op.getInterleaveLayout()); + tensormap_replace_swizzle_mode(loc, ctx, rewriter, smemBase, + op.getSwizzleMode()); + tensormap_replace_fill_mode(loc, ctx, rewriter, smemBase, op.getFillMode()); + tensormap_cp_fenceproxy(loc, ctx, rewriter, adaptor.getDescPtr(), smemBase); + rewriter.eraseOp(op); + return success(); + } +}; + +} // namespace + +void mlir::triton::NVIDIA::populateTMAToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, + targetInfo, benefit); + patterns.add( + typeConverter, benefit); +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 0afdf6fba..75f935410 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -3,8 +3,10 @@ #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "Utility.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/Support/MathExtras.h" using namespace mlir; @@ -13,134 +15,8 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; namespace { -Value computeStMatrixAddr(Value laneId, int matStride, Location loc, - ConversionPatternRewriter &rewriter, - int swizzleByteWidth) { - Value rowInMat = urem(laneId, i32_val(8)); // row in the 8x8 matrix - // linear index of the matrix in the 2x2 matrices - // Decompose matIndex => s_0, s_1, that is the coordinate in 2x2 matrices in - // a warp. - Value matIndex = udiv(laneId, i32_val(8)); - Value s0 = urem(matIndex, i32_val(2)); - Value s1 = udiv(matIndex, i32_val(2)); - if (swizzleByteWidth >= 32) - s1 = xor_(s1, and_(laneId, i32_val(1))); - Value mIndex = add(rowInMat, mul(s0, i32_val(8))); - int m8n8Stride = 8; - Value offset = - add(mul(mIndex, i32_val(matStride)), mul(s1, i32_val(m8n8Stride))); - return offset; -} - -void stMatrixm8n8x4(Value offset, ArrayRef vals, int indexOffset, - Value smemBase, Type elemTy, Location loc, - ConversionPatternRewriter &rewriter) { - SmallVector inputs; - auto prTy = ptr_ty(rewriter.getContext(), 3); - // Pack the input into 2xf16 - Type packedTy = vec_ty(vals[0].getType(), 2); - for (int i = 0; i < 4; i++) { - Value input = undef(packedTy); - for (int j = 0; j < 2; j++) { - input = insert_element(packedTy, input, vals[indexOffset + i * 2 + j], - i32_val(j)); - } - inputs.push_back(bitcast(input, i32_ty)); - } - Value addr = gep(smemBase.getType(), elemTy, smemBase, offset); - rewriter.create(loc, addr, inputs); -} -void storeDistributedToSharedWithStMatrix( - RankedTensorType tensorTy, Type elemTy, SmallVector &inVals, - Value smemBase, ArrayRef paddedRepShape, - ArrayRef origRepShape, Location loc, - ConversionPatternRewriter &rewriter, int swizzlingByteWidth) { - auto shapePerCTA = getShapePerCTA(tensorTy); - auto mmaLayout = mlir::cast(tensorTy.getEncoding()); - auto order = triton::gpu::getOrder(mmaLayout); - auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - auto shapePerCTATile = getShapePerCTATile(mmaLayout); - ArrayRef mmaShape = mmaLayout.getInstrShape(); - // 4xm8n8 matches exactly the size of 1 warp of wgmma layout for 16bit type - // and has a shape of 16x16. - int instrN = mmaShape[1] * warpsPerCTA[1]; - int instrM = mmaShape[0] * warpsPerCTA[0]; - std::array numRep = {ceil((int)origRepShape[0], instrM), - ceil((int)origRepShape[1], instrN)}; - int numBoxes = 1; - if (swizzlingByteWidth == 128) { - int contigDimSizeInByte = - origRepShape[1] * elemTy.getIntOrFloatBitWidth() / 8; - numBoxes = ceil(contigDimSizeInByte, 128); - } - SmallVector boxShape = {paddedRepShape[0], paddedRepShape[1]}; - boxShape[1] = boxShape[1] / numBoxes; - Value thread = getThreadId(rewriter, loc); - Value warp = udiv(thread, i32_val(32)); - Value lane = urem(thread, i32_val(32)); - - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warp, warpsPerCTA); - - // Compute the relative offset for each lane. - Value stMatrixLaneOffset = - computeStMatrixAddr(lane, boxShape[1], loc, rewriter, swizzlingByteWidth); - multiDimWarpId[0] = mul(multiDimWarpId[0], i32_val(mmaShape[0])); - multiDimWarpId[1] = mul(multiDimWarpId[1], i32_val(mmaShape[1])); - SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( - rewriter, loc, multiDimWarpId, boxShape, shapePerCTATile, shapePerCTA); - Value relativeOffset = - linearize(rewriter, loc, multiDimOffsetWrapped, boxShape, order); - relativeOffset = add(relativeOffset, stMatrixLaneOffset); - int indexOffset = 0; - int m8n8x4Stride = 16; - int numNChunk = mmaShape[1] / m8n8x4Stride; - unsigned totalNumElements = product(origRepShape); - numNChunk = numNChunk / numBoxes; - for (int m = 0; m < numRep[0]; m++) { - for (int n = 0; n < numRep[1]; n++) { - for (int box = 0; box < numBoxes; box++) { - for (int k = 0; k < numNChunk; k++) { - Value kOffset; - if (swizzlingByteWidth >= 64) { - int swizzleBits = swizzlingByteWidth == 128 ? 6 : 2; - Value o = lshr(and_(lane, i32_val(swizzleBits)), i32_val(1)); - Value kV = xor_(o, i32_val(k)); - kOffset = mul(kV, i32_val(m8n8x4Stride)); - } else { - kOffset = i32_val(k * m8n8x4Stride); - } - Value addr = add(relativeOffset, - i32_val(n * instrN + m * instrM * boxShape[1] + - box * (totalNumElements / numBoxes))); - addr = add(addr, kOffset); - - stMatrixm8n8x4(addr, inVals, indexOffset, smemBase, elemTy, loc, - rewriter); - indexOffset += 8; - } - } - } - } -} - -bool isStMatrixCompatible(RankedTensorType tensorTy, int swizzlingByteWidth) { - auto mmaLayout = - mlir::dyn_cast(tensorTy.getEncoding()); - if (!mmaLayout || !mmaLayout.isHopper()) - return false; - if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) - return false; - if (swizzlingByteWidth > 0 && mmaLayout.getInstrShape()[1] < 64) - return false; - if (swizzlingByteWidth != 0 && swizzlingByteWidth != 32 && - swizzlingByteWidth != 64 && swizzlingByteWidth != 128) - return false; - return true; -} - // declare vprintf(i8*, i8*) as external function -LLVM::LLVMFuncOp getVprintfDeclaration(ConversionPatternRewriter &rewriter) { +LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); StringRef funcName("vprintf"); Operation *funcOp = moduleOp.lookupSymbol(funcName); @@ -152,7 +28,7 @@ LLVM::LLVMFuncOp getVprintfDeclaration(ConversionPatternRewriter &rewriter) { SmallVector argsType{ptr_ty(context), ptr_ty(context)}; auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); return rewriter.create(UnknownLoc::get(context), funcName, @@ -161,8 +37,7 @@ LLVM::LLVMFuncOp getVprintfDeclaration(ConversionPatternRewriter &rewriter) { // extend integer to int32, extend float to float64 // this comes from vprintf alignment requirements. -std::pair printfPromoteValue(ConversionPatternRewriter &rewriter, - Value value) { +std::pair printfPromoteValue(RewriterBase &rewriter, Value value) { auto *context = rewriter.getContext(); auto type = value.getType(); Value newOp = value; @@ -186,7 +61,7 @@ std::pair printfPromoteValue(ConversionPatternRewriter &rewriter, return {newType, newOp}; } -LLVM::LLVMFuncOp getAssertfailDeclaration(ConversionPatternRewriter &rewriter) { +LLVM::LLVMFuncOp getAssertfailDeclaration(RewriterBase &rewriter) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); StringRef funcName("__assertfail"); { @@ -200,7 +75,7 @@ LLVM::LLVMFuncOp getAssertfailDeclaration(ConversionPatternRewriter &rewriter) { SmallVector argsType{ptr_ty(ctx), ptr_ty(ctx), i32_ty, ptr_ty(ctx), rewriter.getIntegerType(sizeof(size_t) * 8)}; auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); auto funcOp = rewriter.create(UnknownLoc::get(ctx), funcName, funcType); @@ -260,70 +135,312 @@ Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { rewriter.getI32Type()); } -Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, - Type type, Value cmp) const { +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { Value threadMask = int_val(type.getIntOrFloatBitWidth(), -1); return rewriter.create(loc, type, threadMask, cmp); } -void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) const { + +static Value mapa(RewriterBase &rewriter, Location loc, Value ptr, Value ctaid, + Value pred) { + PTXBuilder builder; + (*builder.create<>("mapa.shared::cluster.u32"))( + builder.newOperand("=r"), // + builder.newAddrOperand(ptr, "r"), builder.newAddrOperand(ctaid, "r")) + .predicate(pred, "b"); + return builder.launch(rewriter, loc, i32_ty, /*hasSideEffects=*/false); +} + +static std::string getConstraintForBitwidth(unsigned bitwidth) { + switch (bitwidth) { + case 8: + case 16: + return "h"; + case 32: + return "r"; + case 64: + return "l"; + default: + llvm_unreachable("unsupported bitwidth"); + } +} + +static bool isConstantTruePred(Value pred) { + if (auto constOp = pred.getDefiningOp()) { + return cast(constOp.getValue()).getInt() != 0; + } + return false; +} + +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const { MLIRContext *ctx = rewriter.getContext(); - unsigned bits = std::max(8u, val.getType().getIntOrFloatBitWidth()); - const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); + auto ptrTy = cast(ptr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + + if (!isa(val.getType())) { + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, {val}, rewriter), + pred); + return; + } + + auto vecTy = cast(val.getType()); + Type elemTy = vecTy.getElementType(); + unsigned vec = vecTy.getNumElements(); + unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_32(vec)); + + if (elemBitwidth < 8) { + assert(vec == 1 && + "don't know how to load/store vectors of sub-byte elems"); + SmallVector vals = unpackLLVector(loc, val, rewriter); + for (Value &v : vals) { + v = zext(int_ty(8), bitcast(v, int_ty(elemBitwidth))); + } + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), + pred); + return; + } + + if (!elemTy.isInteger()) { + SmallVector vals = unpackLLVector(loc, val, rewriter); + for (Value &v : vals) { + v = bitcast(v, int_ty(elemBitwidth)); + } + storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), + pred); + return; + } + + // load/store ops only support v2 and v4. If the vector width is larger than + // 4, we have two strategies for dealing with it. + // 1. If the element type is smaller than b32, store b32's instead. + // 2. Otherwise, split the store into multiple stores. + if (vec > 4 && elemBitwidth < 32) { + assert(llvm::isPowerOf2_32(vec)); + int elemsPerPack = 32 / elemBitwidth; + SmallVector oldVals = unpackLLVector(loc, val, rewriter); + + SmallVector newVals; + for (int i = 0; i < vec / elemsPerPack; i++) { + Value v = packLLVector( + loc, ArrayRef(oldVals).slice(i * elemsPerPack, elemsPerPack), + rewriter); + newVals.push_back(bitcast(v, i32_ty)); + } + storeDShared(rewriter, loc, ptr, ctaId, + packLLVector(loc, newVals, rewriter), pred); + return; + } + + if (vec * elemBitwidth > 128) { + assert(llvm::isPowerOf2_32(vec)); + assert(elemBitwidth == 32 || elemBitwidth == 64); + int maxVec = 128 / elemBitwidth; + + auto newVecTy = vec_ty(elemTy, maxVec); + SmallVector vals = unpackLLVector(loc, val, rewriter); + for (int i = 0; i < vec / maxVec; i++) { + auto newPtr = gep(ptr.getType(), elemTy, ptr, i32_val(i * maxVec), + /*inbounds=*/true); + storeDShared( + rewriter, loc, newPtr, ctaId, + packLLVector(loc, ArrayRef(vals).slice(i * maxVec, maxVec), rewriter), + pred); + } + return; + } + + // At this point we're committed to doing the store! + assert(elemBitwidth >= 8); + assert(elemTy.isInteger()); + assert(1 <= vec && vec <= 4); + assert(vec * elemBitwidth <= 128); + + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } PTXBuilder builder; + auto st = builder.create<>("st") + ->o("shared::cta", ctaId.has_value()) + .o("shared", !ctaId.has_value()) + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); auto *ptrOpr = builder.newAddrOperand(ptr, "r"); - auto *valOpr = builder.newOperand(val, c); - auto &st = builder.create<>("st")->shared().b(bits); + + PTXBuilder::Operand *valOpr; + std::string constraint = getConstraintForBitwidth(elemBitwidth); + if (vec > 1) { + SmallVector> vecVals; + for (int i = 0; i < vec; i++) { + vecVals.push_back({extract_element(val, i32_val(i)), constraint}); + } + valOpr = builder.newListOperand(vecVals); + } else { + valOpr = builder.newOperand(val, constraint); + } st(ptrOpr, valOpr).predicate(pred, "b"); builder.launch(rewriter, loc, void_ty(ctx)); } -Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *converter, Value ptr, - Type elemTy, Value pred) const { +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type loadTy, + Value pred) const { MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared"); - unsigned bitwidth = std::max(8u, elemTy.getIntOrFloatBitWidth()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + + if (!isa(loadTy)) { + SmallVector values = unpackLLVector( + loc, loadDShared(rewriter, loc, ptr, ctaId, vec_ty(loadTy, 1), pred), + rewriter); + assert(values.size() == 1); + return values[0]; + } + + auto vecTy = cast(loadTy); + Type elemTy = vecTy.getElementType(); + unsigned vec = vecTy.getNumElements(); + unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_32(vec)); + + if (elemBitwidth < 8) { + assert(vec == 1 && + "don't know how to load/store vectors of sub-byte elems"); + SmallVector vals = unpackLLVector( + loc, loadDShared(rewriter, loc, ptr, ctaId, int_ty(8), pred), rewriter); + assert(vals.size() == 1); + return bitcast(trunc(int_ty(elemBitwidth), vals[0]), elemTy); + } + + // We only know how to load integers. + if (!elemTy.isInteger()) { + Type newLoadTy = vec_ty(int_ty(elemBitwidth), vec); + SmallVector vals = unpackLLVector( + loc, loadDShared(rewriter, loc, ptr, ctaId, newLoadTy, pred), rewriter); + for (Value &v : vals) { + v = bitcast(v, elemTy); + } + return packLLVector(loc, vals, rewriter); + } + + // load/store ops only support v2 and v4. If the vector width is larger than + // 4, we have two strategies for dealing with it. + // 1. If the element type is smaller than b32, load b32's instead. + // 2. Otherwise, split the load into multiple loads. + if (vec > 4 && elemBitwidth < 32) { + int newVec = vec / (32 / elemBitwidth); + auto newVecTy = vec_ty(i32_ty, newVec); + auto res = loadDShared(rewriter, loc, ptr, ctaId, newVecTy, pred); + + // Unpack the b32's into the original vector type. + SmallVector vals; + for (Value v : unpackLLVector(loc, res, rewriter)) { + Value vv = bitcast(v, vec_ty(elemTy, 32 / elemBitwidth)); + for (Value vvv : unpackLLVector(loc, vv, rewriter)) { + vals.push_back(vvv); + } + } + return packLLVector(loc, vals, rewriter); + } + + if (vec * elemBitwidth > 128) { + assert(elemBitwidth == 32 || elemBitwidth == 64); + assert(llvm::isPowerOf2_32(vec)); + int maxVec = 128 / elemBitwidth; + + SmallVector vals; + for (int i = 0; i < vec / maxVec; i++) { + auto newPtr = gep(ptr.getType(), elemTy, ptr, i32_val(i * maxVec), + /*inbounds=*/true); + auto newVal = loadDShared(rewriter, loc, newPtr, ctaId, + vec_ty(elemTy, maxVec), pred); + for (Value v : unpackLLVector(loc, newVal, rewriter)) { + vals.push_back(v); + } + } + return packLLVector(loc, vals, rewriter); + } - const char *c = bitwidth == 64 ? "=l" : (bitwidth == 16 ? "=h" : "=r"); + // At this point we're committed to actually do the load! + assert(elemBitwidth >= 8); + assert(elemTy.isInteger()); + assert(1 <= vec && vec <= 4); + assert(vec * elemBitwidth <= 128); + + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } PTXBuilder builder; - auto *dOpr = builder.newOperand(c); - auto *ptrOpr = builder.newAddrOperand(ptr, "r"); - auto &ld = builder.create<>("ld")->shared().b(bitwidth); - ld(dOpr, ptrOpr).predicate(pred, "b"); - return builder.launch(rewriter, loc, elemTy); + auto ld = builder.create<>("ld") + ->o("shared::cta", ctaId.has_value()) + .o("shared", !ctaId.has_value()) + .v(vec, /*predicate=*/vec > 1) + .b(elemBitwidth); + + Value load; + if (isConstantTruePred(pred)) { + Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) + : Type(vec_ty(int_ty(elemBitwidth), vec)); + load = load(resultTy, ptr); + if (vec > 1) { + Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); + Value structValue = undef(structTy); + for (int i = 0; i < vec; i++) { + structValue = insert_val(structTy, structValue, + extract_element(load, i32_val(i)), i); + } + load = structValue; + } + } else { + std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); + ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); + + Type resultTy = + vec == 1 + ? Type(int_ty(elemBitwidth)) + : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); + load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); + } + SmallVector resultVals = unpackLLElements(loc, load, rewriter); + return packLLVector(loc, resultVals, rewriter); } -Value TargetInfo::shuffleXor(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::NVIDIA::shuffleXor(loc, rewriter, val, i); } -Value TargetInfo::shuffleUp(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::NVIDIA::shuffleUp(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, Value i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, int axis) const { return LLVM::NVIDIA::llGetPid(loc, rewriter, moduleOp, axis); } -bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, +bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const { + unsigned numLaneToReduce, + unsigned interleave) const { if (auto kind = matchReduxKind(op, computeCapability)) { // Based on benchmarking on A100 redux op gives a speed up only when doing // a single reduction (not partitioned) and when the mask is static. @@ -360,20 +477,25 @@ bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, } return false; } -bool TargetInfo::processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzlingByteWidth) const { - if (isStMatrixCompatible(srcTy, swizzlingByteWidth) && - accumNumReplicates == 1 && outOrd[0] == 1 && paddedRepShape[1] % 8 == 0) { - storeDistributedToSharedWithStMatrix(srcTy, elemTy, vals, smemBase, - paddedRepShape, origRepShape, loc, - rewriter, swizzlingByteWidth); - return true; + +void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, + Value ptr, Value val) const { + auto vals = unpackLLVector(loc, val, rewriter); + // Ensure input consists of 4 vectors, each holding 2 elements of 16 bits + assert(vals[0].getType().getIntOrFloatBitWidth() == 16 && + "stmatrix requires elements to be 16-bit integers or floats"); + assert(vals.size() == 8 && + "stmatrix requires exactly 8 elements in the input vector"); + Type packedTy = vec_ty(vals[0].getType(), 2); + SmallVector inputs; + for (int i = 0; i < 4; i++) { + Value input = undef(packedTy); + for (int j = 0; j < 2; j++) { + input = insert_element(packedTy, input, vals[i * 2 + j], i32_val(j)); + } + inputs.push_back(bitcast(input, i32_ty)); } - return false; + rewriter.create(loc, ptr, inputs); } std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { @@ -382,9 +504,8 @@ std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { return funcName; } -void TargetInfo::printf(ConversionPatternRewriter &rewriter, - Value formatStrStart, int /*formatStrByteCount*/, - ValueRange args) const { +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int /*formatStrByteCount*/, ValueRange args) const { auto *ctx = rewriter.getContext(); Type ptr = ptr_ty(ctx); auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); @@ -425,22 +546,45 @@ void TargetInfo::printf(ConversionPatternRewriter &rewriter, call(funcOp, operands); } -void TargetInfo::assertFail(ConversionPatternRewriter &rewriter, Location loc, +void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, + "printfFormat_", msgNewline); + printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); +} + +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { auto funcOp = getAssertfailDeclaration(rewriter); auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); - Value messageString = - LLVM::addStringToModule(loc, rewriter, "assertMessage_", message); - Value fileString = - LLVM::addStringToModule(loc, rewriter, "assertFile_", file); - Value funcString = - LLVM::addStringToModule(loc, rewriter, "assertFunc_", func); + llvm::SmallString<64> messageString(message), fileString(file), + funcString(func); + messageString.push_back('\0'); + fileString.push_back('\0'); + funcString.push_back('\0'); + Value messageStringVal = + LLVM::addStringToModule(loc, rewriter, "assertMessage_", messageString); + Value fileStringVal = + LLVM::addStringToModule(loc, rewriter, "assertFile_", fileString); + Value funcStringVal = + LLVM::addStringToModule(loc, rewriter, "assertFunc_", funcString); Value lineNumber = i32_val(line); Value charSize = int_val(sizeof(size_t) * 8, sizeof(char)); - SmallVector operands = {messageString, fileString, lineNumber, - funcString, charSize}; + SmallVector operands = {messageStringVal, fileStringVal, lineNumber, + funcStringVal, charSize}; call(funcOp, operands); } +int TargetInfo::getSharedAddressSpace() const { return 3; } + +bool TargetInfo::supportVectorizedAtomics() const { + return computeCapability >= 90 && ptxVersion >= 81; +} + } // namespace mlir::triton::NVIDIA diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 5a5a45653..ed9bd91a8 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -7,54 +7,58 @@ namespace mlir::triton::NVIDIA { class TargetInfo : public mlir::triton::TargetInfoBase { public: - TargetInfo(int computeCapability) : computeCapability(computeCapability) {} + TargetInfo(int computeCapability, int ptxVersion) + : computeCapability(computeCapability), ptxVersion(ptxVersion) {} bool supportMaximumMinimum() const override; Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; - Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, + Value ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const override; - void storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) const override; - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, - const TypeConverter *converter, Value ptr, Type elemTy, - Value pred) const override; + void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const override; + Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const override; + void storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val) const override; - Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const override; - Value programId(ConversionPatternRewriter &rewriter, Location loc, - ModuleOp moduleOp, int axis) const override; + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + int axis) const override; - bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, - SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const override; - - bool processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const override; + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; std::string getMulhiFuncName(Type resultElementTy) const override; - void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const override; - void assertFail(ConversionPatternRewriter &rewriter, Location loc, - StringRef message, StringRef file, StringRef func, - int line) const override; + + void printf(RewriterBase &rewriter, StringRef msg, + ValueRange args) const override; + + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; + int getSharedAddressSpace() const override; + + bool supportVectorizedAtomics() const override; private: int computeCapability; + int ptxVersion; }; } // namespace mlir::triton::NVIDIA diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index f2742218e..30e76abcd 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,28 +1,25 @@ #include "Dialect/NVGPU/IR/Dialect.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" -#include "mlir/Analysis/DataFlowFramework.h" +#include "TritonNVIDIAGPUToLLVM/Utility.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "PatternTritonGPUOpToLLVM.h" -#include "Utility.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" @@ -35,7 +32,6 @@ namespace triton { using namespace mlir; using namespace mlir::triton::NVIDIA; -namespace ttng = mlir::triton::nvidia_gpu; namespace { @@ -83,31 +79,43 @@ struct ConvertTritonGPUToLLVM ConvertTritonGPUToLLVM(int32_t computeCapability) : ConvertTritonGPUToLLVMBase({computeCapability}) {} + ConvertTritonGPUToLLVM(int32_t computeCapability, int32_t ptxVersion) + : ConvertTritonGPUToLLVMBase({computeCapability, ptxVersion}) {} + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); mlir::LowerToLLVMOptions option(context); option.overrideIndexBitwidth(32); - TritonGPUToLLVMTypeConverter typeConverter(context, option); + TargetInfo targetInfo(computeCapability, ptxVersion); + TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); TritonLLVMConversionTarget convTarget(*context); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + // Hack: WSLowering may have changed the effective number of warps, + // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to + // respect that here. + if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { + numWarps *= cast(attr).getInt(); + } + // Allocate shared memory and set barrier ModuleAllocation allocation(mod); - ModuleMembarAnalysis membarPass(&allocation); + ModuleMembarAnalysis membarPass(&allocation, NVIDIA::canSkipBarSync); membarPass.run(); // Lower functions { mlir::LowerToLLVMOptions option(context); - TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); TritonLLVMFunctionConversionTarget funcTarget(*context); RewritePatternSet funcPatterns(context); - mlir::triton::populateFuncOpConversionPattern( - typeConverter, funcPatterns, numWarps, patternBenefitDefault); + mlir::triton::populateFuncOpConversionPattern(typeConverter, funcPatterns, + numWarps, targetInfo, + patternBenefitDefault); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, funcPatterns); if (failed( @@ -123,13 +131,14 @@ struct ConvertTritonGPUToLLVM OpBuilder::InsertPoint indexInsertPoint; RewritePatternSet patterns(context); - TargetInfo targetInfo(computeCapability); int benefit = patternBenefitPrioritizeOverLLVMConversions; mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMOptimizedPatterns( typeConverter, targetInfo, patterns, patternBenefitConvertLayoutOptimizedPattern); mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( typeConverter, targetInfo, patterns, benefit); + mlir::triton::NVIDIA::populateTMAToLLVMPatterns(typeConverter, targetInfo, + patterns, benefit); populateDotOpToLLVMPatterns(typeConverter, patterns, benefit); populateElementwiseOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis, computeCapability, @@ -151,7 +160,7 @@ struct ConvertTritonGPUToLLVM mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, - benefit); + targetInfo, benefit); mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns, benefit); mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, @@ -164,6 +173,7 @@ struct ConvertTritonGPUToLLVM mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); mlir::triton::populateViewOpToLLVMPatterns(typeConverter, patterns, benefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, @@ -172,6 +182,10 @@ struct ConvertTritonGPUToLLVM patterns, benefit); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, benefit); + mlir::triton::NVIDIA::populateUpcastMXFPToLLVMPatterns( + typeConverter, patterns, targetInfo, benefit); + mlir::triton::populateRegReallocOpToLLVMPatterns(typeConverter, patterns, + benefit); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); @@ -225,6 +239,54 @@ std::unique_ptr> createConvertTritonGPUToLLVMPass(int32_t computeCapability) { return std::make_unique(computeCapability); } +std::unique_ptr> +createConvertTritonGPUToLLVMPass(int32_t computeCapability, + int32_t ptxVersion) { + return std::make_unique(computeCapability, + ptxVersion); +} + +bool NVIDIA::canSkipBarSync(Operation *before, Operation *after) { + // Multiple init barriers on the same allocation would usually not happen but + // that allows us to avoid barriers between multiple subslice of an array of + // mbarriers. This is still correct even if the inits happen on the same + // allocation. + if (isa(before) && + isa(after)) + return true; + + if (isa(before) && + isa(after)) + return true; + + // We can't have a warp get ahead when we have a chain of mbarrier wait so we + // need a barrier in between two WaitBarrierOp. + if (isa(before) && + isa(after)) + return false; + + // Even though WaitBarrierOp, AsyncTMACopyGlobalToLocalOp and + // AsyncTMACopyGlobalToLocalOp read and write to the mbarrier allocation it is + // valid for them to happen in different order on different threads, therefore + // we don't need a barrier between those operations. + if (isa(before) && + isa(after)) + return true; + + // A mbarrier wait is released only when the whole operations is done, + // therefore any thread can access the memory after the barrier even if some + // threads haven't reached the mbarrier wait. + if (isa(before) && + !isa(after)) + return true; + + return false; +} } // namespace triton } // namespace mlir diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp new file mode 100644 index 000000000..722bf56cd --- /dev/null +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -0,0 +1,161 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" + +#include "PatternTritonGPUOpToLLVM.h" + +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { +class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { +private: + const TargetInfoBase &targetInfo; + +public: + UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + llvm::SmallVector + unpackFP4Elements(Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallVector &vals, Value laneId) const { + auto fp4x2ToBf16x2 = [&loc, &rewriter](Value v) -> Value { + auto em0 = and_(v, i8_val(0x70)); + auto em1 = and_(v, i8_val(0x7)); + Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), + shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); + Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), + shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); + + // Three cases: + // 1) x is normal and non-zero: Correct bias + v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), + add(v0, i16_val((127 - 1) << 7)), v0); + v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), + add(v1, i16_val((127 - 1) << 7)), v1); + + // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in + // bf16 + v0 = select(icmp_eq(em0, i8_val(0x10)), + or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0); + v1 = select(icmp_eq(em1, i8_val(0x1)), + or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1); + // 3) x is zero, nothing to do + + // Swap as they come packed in big endian + return or_(zext(i32_ty, v0), shl(zext(i32_ty, v1), i32_val(16))); + }; + + auto fp4x8ToBf16x2 = [&loc, &rewriter, &fp4x2ToBf16x2]( + Value v) -> llvm::SmallVector { + llvm::SmallVector results(4); + for (int i = 0; i < 4; ++i) { + auto v_i = trunc(i8_ty, lshr(v, i32_val(8 * i))); + results[i] = fp4x2ToBf16x2(v_i); + } + return results; + }; + + // Split fp4x8 into 4 bf16x2 + llvm::SmallVector ret; + ret.reserve(vals.size() * 4); + for (int i = 0; i < vals.size(); ++i) { + auto vs = fp4x8ToBf16x2(vals[i]); + assert(vs.size() == 4); + for (auto v : vs) { + ret.push_back(v); + } + } + + return ret; + } + + LogicalResult + matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto tyX = cast(op->getOperandTypes()[0]); + auto operands = adaptor.getOperands(); + + auto xVals = unpackLLElements(loc, operands[0], rewriter); + auto scaleVals = unpackLLElements(loc, operands[1], rewriter); + auto fpType = op.getFpType(); + + Value tid = tid_val(); + auto mod = op->getParentOfType(); + Value warpSize = + i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)); + Value warpId = udiv(tid, warpSize); + Value laneId = urem(tid, warpSize); + + if (fpType == F8F6F4Type::E2M1) { + xVals = unpackFP4Elements(loc, rewriter, xVals, laneId); + } + + auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value { + // Split bf16x2 into 2 bf16, scale each of them, and pack them back + // TODO Is it true that the bfloats are always packed as bf16x2? + auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); + auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); + auto scaleIsNan = icmp_eq(s, i8_val(0xff)); + auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty); + auto scaledBf16_0 = fmul(bf16_0, scaleBf16); + auto scaledBf16_1 = fmul(bf16_1, scaleBf16); + auto i16_0 = bitcast(scaledBf16_0, i16_ty); + auto i16_1 = bitcast(scaledBf16_1, i16_ty); + auto packed = + or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); + // Account for NaN in the scale as per the mxfp specification + auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); + return packed_nan; + }; + + // Each thread owns elements of 4 mxfp vectors so we need 4 scales + // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + + // 16, c + 17 + auto c = mul(udiv(laneId, i32_val(4)), i32_val(2)); + std::array ci = {c, add(c, i32_val(1)), add(c, i32_val(16)), + add(c, i32_val(17))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + // column major as per the DotOperandEncoding(opidx=0) layout + auto si = std::array{ + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[2]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[1]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]), + }; + + for (int j = 0; j < 16; ++j) { + xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]); + } + } + + Value result = + packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // anonymous namespace + +void mlir::triton::NVIDIA::populateUpcastMXFPToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 37c5b6ec7..4963a13b7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -8,9 +8,8 @@ namespace LLVM { namespace NVIDIA { using namespace mlir::triton; -static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, - Value val, Value i, NVVM::ShflKind mode, - Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, NVVM::ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); if (bits == 64) { @@ -42,31 +41,27 @@ static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, return result; } -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), NVVM::ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), NVVM::ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleIdx(loc, rewriter, val, i32_val(i)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { return shuffleCommon(loc, rewriter, val, i, NVVM::ShflKind::idx, i32_val(0x1f)); } -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis) { +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis) { assert(axis >= 0); assert(axis < 3); assert(moduleOp); @@ -92,8 +87,8 @@ Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) { return val; } -Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, - Value b, Value mask) { +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask) { PTXBuilder builder; auto &prmt = builder.create("prmt")->o("b32"); auto *destOpr = builder.newOperand("=r"); @@ -104,94 +99,8 @@ Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, return builder.launch(rewriter, loc, rewriter.getIntegerType(32), false); } -// A wrapper of LoadDSmemOp when vec = 1 -// (1) Get bitwidth from elemTy -// (2) Create LoadDSmemOp -// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy -Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Type elemTy) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value ret = - rewriter.create(loc, addr, ctaId, bitwidth); - return bitcast(ret, elemTy); -} - -// A wrapper of LoadDSmemOp when vec > 1 -// (1) Get bitwidth from elemTy -// (2) Create LoadDSmemOp and extract results from retStruct -// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy -SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, - Value addr, Value ctaId, unsigned vec, - Type elemTy) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value retStruct = rewriter.create( - loc, addr, ctaId, bitwidth, vec); - SmallVector retVals; - for (unsigned i = 0; i < vec; ++i) { - auto dataTy = rewriter.getIntegerType(bitwidth); - Value data = extract_val(dataTy, retStruct, i); - retVals.push_back(bitcast(data, elemTy)); - } - return retVals; -} - -// A wrapper of StoreDSmemOp when vec = 1 -// (1) Get bitwidth from elemTy -// (2) Bitcast value from elemTy to dataTy (u16/u32/u64) -// (3) Create StoreDSmemOp -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value, Value pred) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = value.getType().getIntOrFloatBitWidth(); - auto dataTy = rewriter.getIntegerType(bitwidth); - Value data = bitcast(value, dataTy); - rewriter.create(loc, addr, ctaId, data, pred); -} - -// A wrapper of StoreDSmemOp when vec = 1 and pred = 1 -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value) { - Value pred = int_val(/*width=*/1, 1); - createStoreDSmem(loc, rewriter, addr, ctaId, value, pred); -} - -// A wrapper of StoreDSmemOp when vec > 1 -// (1) Get bitwidth from elemTy -// (2) Bitcast values from elemTy to dataTy (u16/u32/u64) -// (3) Create StoreDSmemOp -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values, Value pred) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = 0; - if (!values.empty()) { - bitwidth = values.back().getType().getIntOrFloatBitWidth(); - } - auto dataTy = rewriter.getIntegerType(bitwidth); - SmallVector data; - for (unsigned i = 0; i < values.size(); ++i) - data.push_back(bitcast(values[i], dataTy)); - rewriter.create(loc, addr, ctaId, data, pred); -} - -// A wrapper of StoreDSmemOp when vec > 1 and pred = 1 -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values) { - Value pred = int_val(/*width=*/1, 1); - createStoreDSmem(loc, rewriter, addr, ctaId, values, pred); -} - /// Create a predicate with just single active thread. -Value createElectPredicate(Location loc, PatternRewriter &rewriter) { +Value createElectPredicate(Location loc, RewriterBase &rewriter) { PTXBuilder ptxBuilder; auto &elect = *ptxBuilder.create<>("elect.sync _|$0, 0xffffffff;"); elect({ptxBuilder.newOperand("=b")}, /*onlyAttachMLIRArgs=*/true); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h index bb4e9dd33..12344bfb5 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h @@ -29,56 +29,24 @@ using namespace mlir::triton; ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \ } while (0) -#define load_dsmem(...) \ - ::mlir::LLVM::NVIDIA::createLoadDSmem(loc, rewriter, __VA_ARGS__) -#define store_dsmem(...) \ - ::mlir::LLVM::NVIDIA::createStoreDSmem(loc, rewriter, __VA_ARGS__) - namespace mlir { namespace LLVM { namespace NVIDIA { Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr); -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i); -Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, - Value b, Value mask); - -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis); - -/// Usage of macro load_dsmem -/// (1) load_dsmem(addr, ctaId) -/// (2) load_dsmem(addr, ctaId, vec) -Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Type elemTy); -SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, - Value addr, Value ctaId, unsigned vec, - Type elemTy); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask); -/// Usage of macro store_dsmem -/// (1) store_dsmem(addr, ctaId, value, pred) -/// (2) store_dsmem(addr, ctaId, value) -/// (3) store_dsmem(addr, ctaId, values, pred) -/// (4) store_dsmem(addr, ctaId, values) -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value, Value pred); -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value); -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values, Value pred); -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values); +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis); /// Create a predicate with just single active thread. -Value createElectPredicate(Location loc, PatternRewriter &rewriter); +Value createElectPredicate(Location loc, RewriterBase &rewriter); } // namespace NVIDIA } // namespace LLVM diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 97e491cdc..a7a036401 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -1,6 +1,7 @@ -#include "Dialect/NVGPU/IR/Dialect.h" +#include "Dialect/NVGPU/IR/Dialect.h" #include "NVGPUToLLVM/NVGPUToLLVMPass.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" +#include "cublas_instance.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "passes.h" @@ -17,9 +18,11 @@ void init_triton_nvidia_passes_ttgpuir(py::module &&m) { using namespace mlir::triton; // TODO: it is weird to pass mlir::triton::NVVM here since the conversion is // nvidia-specificontext - m.def("add_to_llvmir", [](mlir::PassManager &pm, int32_t capability) { - pm.addPass(mlir::triton::createConvertTritonGPUToLLVMPass(capability)); - }); + m.def("add_to_llvmir", + [](mlir::PassManager &pm, int32_t capability, int32_t ptxVersion) { + pm.addPass(mlir::triton::createConvertTritonGPUToLLVMPass( + capability, ptxVersion)); + }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm) { pm.addPass(NVIDIA::createDecomposeUnsupportedConversionsPass()); }); @@ -82,4 +85,79 @@ void init_triton_nvidia(py::module &&m) { auto *reflect = MDNode::get(ctx, {mdFour, mdName, mdOne}); mod->addModuleFlag(reflect); }); + + // cublas + auto cublas = m.def_submodule("cublas"); + + py::class_(cublas, "CublasLt") + .def(py::init<>([&](py::object &workspace) { + auto wrk_ptr = workspace.attr("data_ptr")().cast(); + auto wrk_size = workspace.attr("numel")().cast() * + workspace.attr("element_size")().cast(); + return new CublasLtInstance(wrk_ptr, wrk_size); + })) + .def("matmul", [](CublasLtInstance &self, py::object &A, py::object &B, + py::object &C) { + auto A_ptr = A.attr("data_ptr")().cast(); + auto B_ptr = B.attr("data_ptr")().cast(); + auto C_ptr = C.attr("data_ptr")().cast(); + + auto A_shape = A.attr("shape").cast>(); + auto B_shape = B.attr("shape").cast>(); + auto C_shape = C.attr("shape").cast>(); + + auto A_dtype = A.attr("dtype").attr("__str__")().cast(); + auto B_dtype = B.attr("dtype").attr("__str__")().cast(); + auto C_dtype = C.attr("dtype").attr("__str__")().cast(); + + assert(A_dtype == B_dtype && A_dtype == C_dtype); + assert(A_dtype == "torch.float8_e4m3fn" || A_dtype == "torch.float16"); + + std::string dtype_str = A_dtype.substr(A_dtype.find_last_of('.') + 1); + cudaDataType_t dtype; + if (dtype_str == "float8_e4m3fn") { + dtype = CUDA_R_8F_E4M3; + } else if (dtype_str == "float16") { + dtype = CUDA_R_16F; + } + + if (A_shape.size() != 2 || B_shape.size() != 2 || C_shape.size() != 2) { + throw std::runtime_error("Only 2D matrices are supported."); + } + + int k = A_shape[1]; + if (k != B_shape[1]) { + throw std::runtime_error("Matrix dimensions do not match. A is [" + + std::to_string(A_shape[0]) + ", " + + std::to_string(A_shape[1]) + "], B is [" + + std::to_string(B_shape[0]) + ", " + + std::to_string(B_shape[1]) + + "]. Expected A.shape[1] == B.shape[1]. Note " + "that B needs to be transposed."); + } + + int m = A_shape[0]; + if (m != C_shape[0]) { + throw std::runtime_error("Matrix dimensions do not match. A is [" + + std::to_string(A_shape[0]) + ", " + + std::to_string(A_shape[1]) + "], C is [" + + std::to_string(C_shape[0]) + ", " + + std::to_string(C_shape[1]) + + "]. Expected A.shape[0] == C.shape[0]."); + } + + int n = B_shape[0]; + if (n != C_shape[1]) { + throw std::runtime_error("Matrix dimensions do not match. B is [" + + std::to_string(B_shape[0]) + ", " + + std::to_string(B_shape[1]) + "], C is [" + + std::to_string(C_shape[0]) + ", " + + std::to_string(C_shape[1]) + + "]. Expected B.shape[0] == C.shape[1]. Note " + "that B needs to be transposed."); + } + + self.matmul(A_shape[0], B_shape[0], A_shape[1], A_ptr, B_ptr, C_ptr, + dtype); + }); } diff --git a/third_party/nvidia/unittest/CMakeLists.txt b/third_party/nvidia/unittest/CMakeLists.txt new file mode 100644 index 000000000..bd3c0c6c0 --- /dev/null +++ b/third_party/nvidia/unittest/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Conversion) diff --git a/unittest/Conversion/CMakeLists.txt b/third_party/nvidia/unittest/Conversion/CMakeLists.txt similarity index 100% rename from unittest/Conversion/CMakeLists.txt rename to third_party/nvidia/unittest/Conversion/CMakeLists.txt diff --git a/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..5d2dbbb0b --- /dev/null +++ b/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,5 @@ +add_triton_ut( + NAME TestPtxAsmFormat + SRCS PTXAsmFormatTest.cpp + LIBS TritonGPUToLLVM TritonNVIDIAGPUToLLVM +) diff --git a/unittest/Conversion/TritonGPUToLLVM/PTXAsmFormatTest.cpp b/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/PTXAsmFormatTest.cpp similarity index 95% rename from unittest/Conversion/TritonGPUToLLVM/PTXAsmFormatTest.cpp rename to third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/PTXAsmFormatTest.cpp index 774335a06..4fd6cefbd 100644 --- a/unittest/Conversion/TritonGPUToLLVM/PTXAsmFormatTest.cpp +++ b/third_party/nvidia/unittest/Conversion/TritonGPUToLLVM/PTXAsmFormatTest.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/Support/Signals.h" #include @@ -145,3 +146,9 @@ TEST_F(PTXAsmFormatTest, onlyAttachMLIRArgs) { } // namespace triton } // namespace mlir + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/proton/CMakeLists.txt b/third_party/proton/CMakeLists.txt index 1518f2f20..e2d9152c9 100644 --- a/third_party/proton/CMakeLists.txt +++ b/third_party/proton/CMakeLists.txt @@ -35,7 +35,14 @@ endif() # Check if the platform is MacOS if(APPLE) - set(PROTON_PYTHON_LDFLAGS "-undefined dynamic_lookup -flto") + set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") + # Other platforms build with -flto, but we found that this adds significant overhead to our macos CI without providing a major benefit. + set(PROTON_PYTHON_LDFLAGS "-undefined dynamic_lookup") +endif() + +if(DEFINED CUPTI_LIB_DIR) + message(STATUS "CUPTI lib directory: ${CUPTI_LIB_DIR}") + add_compile_definitions(CUPTI_LIB_DIR=${CUPTI_LIB_DIR}) endif() include_directories(${CUPTI_INCLUDE_DIR}) diff --git a/third_party/proton/README.md b/third_party/proton/README.md index 054d44936..fede11ced 100644 --- a/third_party/proton/README.md +++ b/third_party/proton/README.md @@ -14,13 +14,12 @@ cd triton/python pip install . ``` -To not build Proton, you can set the `TRITON_BUILD_PROTON` environment variable to `OFF`: +To **not build** Proton, you can set the `TRITON_BUILD_PROTON` environment variable to `OFF`: ```bash TRITON_BUILD_PROTON=OFF pip install . ``` - ## Usage ### Basic usage @@ -120,7 +119,7 @@ flops64: float # The number of 64-bit floating-point operations bytes: int # The number of bytes expected to be transferred ``` -### Command Line +### Command line Proton can be used as a command-line tool to profile Python scripts and Pytest tests. The following examples demonstrate how to use Proton command-line. @@ -138,11 +137,79 @@ When profiling in the command line mode, the `proton.start` and `proton.finalize By default, proton profiles are in the *json* format and can be read by *Hatchet*. The following command visualizes the profile data on terminal. ```bash +pip install llnl-hatchet proton-viewer -m time/s ``` +NOTE: `pip install hatchet` does not work because the API is slightly different. + More options can be found by running the following command. ```bash proton-viewer -h ``` + +### Instruction sampling (experimental) + +Proton supports instruction sampling on NVIDIA GPUs. +Please note that this is an experimental feature and may not work on all GPUs. +You may experience ~20x end-to-end overhead when using instruction sampling, although the overhead for each individual GPU kernel is negligible. +The overhead is mostly caused by data transfer and processing on the CPU. +Additionally, the proton-viewer options `-i -d -t ` can be helpful for filtering out GPU kernels that are not of interest. +The following example demonstrates how to use instruction sampling: + +```python +import triton.profiler as proton + +proton.start(name="profile_name", context="shadow", backend="cupti_pcsampling") +``` + +## Proton *vs* nsys + +- Runtime overhead (up to 1.5x) + +Proton has a lower profiling overhead than nsys. Even for workload with a large number of small GPU kernels, proton triggers less than ~1.5x overhead. + +For GPU-bound workload, both proton and nsys has similar overhead, with little impact on the workload. + +The lower overhead of proton is due to its less profiling metrics and callbacks compared to nsys. + +- Profile size (significantly smaller than nsys) + +nsys traces and records every GPU kernel, while proton aggregates the metrics of GPU kernels under the same calling context. + +As a result, proton's profile size can be up to thousands of times smaller than nsys's profile size, depending on the running time. + +- Portability (support different GPUs) + +Proton is designed to be portable and can be used on AMD GPUs. nsys only supports NVIDIA GPUs. + +- Insights (more insightful than nsys on triton kernels) + +Proton can register hooks to analyze the metadata of triton kernels, while nsys cannot. **Note** that the hooks do add additional overhead to proton. + +## Proton *vs* ncu + +Similar to the comparison between Proton and Nsight Systems (Nsys), Proton has a lower profiling overhead than Nsight Compute (NCU). We also plan to support instruction sampling on AMD GPUs. +However, Nsight Compute supports the collection of more detailed metrics than Proton, such as memory access patterns, memory transactions, and other instruction-level metrics. +In contrast, Proton only supports instruction sampling and is designed to be lightweight and portable. + +## Known issues + +- CUDA graph + +`hooks` cannot be used to accurately accumulate the number of FLOPs in CUDA graph mode profiling because kernels are captured and launched separately; metrics are not accumulated when kernels are launched in graph mode. This issue can be circumvented by using `scope` to supply FLOPs. + +If profiling is initiated after CUDA graph capturing, there may be minor memory leak issues. +This is because the number of kernels in a graph instance (i.e., `cuGraphExec`) is unknown, preventing the deletion of mappings between the kernel ID and the graph ID. + +- Instruction sampling + +If you encounter permission related problems when using instruction sampling, you can lookup this [page](https://developer.nvidia.com/nvidia-development-tools-solutions-err_nvgpuctrperm-permission-issue-performance-counters) for help. + +The overhead of instruction sampling on NVIDIA GPUs is about 20x using Proton because we haven't enabled continuous sampling yet. +Continuous sampling can allow for more runtime optimizations, but it makes it more challenging to attribute performance data back to the GPU kernels because: (1) it enables profiling of concurrent kernels, (2) it doesn't allow profiling of time and instruction samples simultaneously, and (3) it works best if we have a separate thread dedicated to attributing instruction samples to the GPU kernels + +- Visible devices on AMD GPUs + +Environment variables such as `HIP_VISIBLE_DEVICES`, and `CUDA_VISIBLE_DEVICES` are not supported on AMD GPUs. Once it's set, we cannot find a valid mapping between the device ID returned by RocTracer and the physical device ID. Instead, `ROCR_VISIBLE_DEVICES` is recommended to be used. diff --git a/third_party/proton/csrc/include/Context/Context.h b/third_party/proton/csrc/include/Context/Context.h index 19defc7e1..9b1205f81 100644 --- a/third_party/proton/csrc/include/Context/Context.h +++ b/third_party/proton/csrc/include/Context/Context.h @@ -88,8 +88,8 @@ class OpInterface { if (isOpInProgress()) { return; } - setOpInProgress(true); startOp(scope); + setOpInProgress(true); } void exitOp(const Scope &scope) { if (!isOpInProgress()) { @@ -106,11 +106,7 @@ class OpInterface { virtual void setOpInProgress(bool value) = 0; }; -/// Internal op interface is used for objects that do not internally generate -/// new ops. For example, the TreeData object and the TraceData object do not -/// generate new ops. In contrast, the CuptiProfiler object may contribute to -/// new ops not trackable by the user. -class InternalOpInterface : public OpInterface { +class ThreadLocalOpInterface : public OpInterface { public: using OpInterface::OpInterface; @@ -124,7 +120,7 @@ class InternalOpInterface : public OpInterface { private: inline static const int MAX_CACHE_OBJECTS = 10; - static thread_local std::map opInProgress; + static thread_local std::map opInProgress; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Data/Data.h b/third_party/proton/csrc/include/Data/Data.h index 2d49da5bc..edb6f02b6 100644 --- a/third_party/proton/csrc/include/Data/Data.h +++ b/third_party/proton/csrc/include/Data/Data.h @@ -12,12 +12,17 @@ namespace proton { enum class OutputFormat { Hatchet, Count }; -class Data : public InternalOpInterface { +class Data : public ThreadLocalOpInterface { public: Data(const std::string &path, ContextSource *contextSource = nullptr) : path(path), contextSource(contextSource) {} virtual ~Data() = default; + /// Add a new scope to the data. + /// If the scope is already present, add a child scope under/inside it. + /// [MT] The implementation must be thread-safe. + virtual size_t addScope(size_t scopeId, const std::string &name = {}) = 0; + /// Add a single metric to the data. /// [MT] The implementation must be thread-safe. virtual void addMetric(size_t scopeId, std::shared_ptr metric) = 0; @@ -37,9 +42,7 @@ class Data : public InternalOpInterface { /// [MT] Thread-safe. virtual void doDump(std::ostream &os, OutputFormat outputFormat) const = 0; -protected: mutable std::shared_mutex mutex; - const std::string path{}; ContextSource *contextSource{}; }; diff --git a/third_party/proton/csrc/include/Data/Metric.h b/third_party/proton/csrc/include/Data/Metric.h index 7e9e9f97c..a75692877 100644 --- a/third_party/proton/csrc/include/Data/Metric.h +++ b/third_party/proton/csrc/include/Data/Metric.h @@ -2,13 +2,12 @@ #define PROTON_DATA_METRIC_H_ #include "Utility/Traits.h" -#include #include #include namespace proton { -enum class MetricKind { Flexible, Kernel, Count }; +enum class MetricKind { Flexible, Kernel, PCSampling, Count }; using MetricValueType = std::variant; @@ -141,11 +140,81 @@ class KernelMetric : public Metric { virtual bool isAggregable(int valueId) const { return AGGREGABLE[valueId]; } private: - const static inline bool AGGREGABLE[kernelMetricKind::Count] = {false, false, - true, true}; + const static inline bool AGGREGABLE[kernelMetricKind::Count] = { + false, false, true, true, false, false}; const static inline std::string VALUE_NAMES[kernelMetricKind::Count] = { - "StartTime (ns)", "EndTime (ns)", "Count", - "Time (ns)", "DeviceId", "DeviceType", + "start_time (ns)", "end_time (ns)", "count", + "time (ns)", "device_id", "device_type", + }; +}; + +class PCSamplingMetric : public Metric { +public: + enum PCSamplingMetricKind : int { + NumSamples, + NumStalledSamples, + StalledBranchResolving, + StalledNoInstruction, + StalledShortScoreboard, + StalledWait, + StalledLongScoreboard, + StalledTexThrottle, + StalledBarrier, + StalledMembar, + StalledIMCMiss, + StalledMIOThrottle, + StalledMathPipeThrottle, + StalledDrain, + StalledLGThrottle, + StalledNotSelected, + StalledMisc, + StalledDispatchStall, + StalledSleeping, + StalledSelected, + Count, + }; + + PCSamplingMetric() + : Metric(MetricKind::PCSampling, PCSamplingMetricKind::Count) {} + + PCSamplingMetric(PCSamplingMetricKind kind, uint64_t samples, + uint64_t stalledSamples) + : PCSamplingMetric() { + this->values[kind] = stalledSamples; + this->values[PCSamplingMetricKind::NumSamples] = samples; + this->values[PCSamplingMetricKind::NumStalledSamples] = stalledSamples; + } + + virtual const std::string getName() const { return "PCSamplingMetric"; } + + virtual const std::string getValueName(int valueId) const { + return VALUE_NAMES[valueId]; + } + + virtual bool isAggregable(int valueId) const { return true; } + +private: + const static inline std::string VALUE_NAMES[PCSamplingMetricKind::Count] = { + "num_samples", + "num_stalled_samples", + "stalled_branch_resolving", + "stalled_no_instruction", + "stalled_short_scoreboard", + "stalled_wait", + "stalled_long_scoreboard", + "stalled_tex_throttle", + "stalled_barrier", + "stalled_membar", + "stalled_imc_miss", + "stalled_mio_throttle", + "stalled_math_pipe_throttle", + "stalled_drain", + "stalled_lg_throttle", + "stalled_not_Selected", + "stalled_misc", + "stalled_dispatch_stall", + "stalled_sleeping", + "stalled_selected", }; }; diff --git a/third_party/proton/csrc/include/Data/TraceData.h b/third_party/proton/csrc/include/Data/TraceData.h index 349274280..c434f6213 100644 --- a/third_party/proton/csrc/include/Data/TraceData.h +++ b/third_party/proton/csrc/include/Data/TraceData.h @@ -8,6 +8,9 @@ namespace proton { class TraceData : public Data { public: using Data::Data; + virtual ~TraceData() = default; + + size_t addScope(size_t scopeId, const std::string &name) override; void addMetric(size_t scopeId, std::shared_ptr metric) override; diff --git a/third_party/proton/csrc/include/Data/TreeData.h b/third_party/proton/csrc/include/Data/TreeData.h index ca935ac3b..0250f2647 100644 --- a/third_party/proton/csrc/include/Data/TreeData.h +++ b/third_party/proton/csrc/include/Data/TreeData.h @@ -4,18 +4,19 @@ #include "Context/Context.h" #include "Data.h" #include +#include namespace proton { class TreeData : public Data { public: - TreeData(const std::string &path, ContextSource *contextSource) - : Data(path, contextSource) { - init(); - } + TreeData(const std::string &path, ContextSource *contextSource); + virtual ~TreeData(); TreeData(const std::string &path) : TreeData(path, nullptr) {} + size_t addScope(size_t scopeId, const std::string &name) override; + void addMetric(size_t scopeId, std::shared_ptr metric) override; void addMetrics(size_t scopeId, @@ -29,103 +30,14 @@ class TreeData : public Data { void stopOp(const Scope &scope) override; private: - class Tree { - public: - struct TreeNode : public Context { - inline static const size_t RootId = 0; - inline static const size_t DummyId = std::numeric_limits::max(); - - TreeNode() = default; - explicit TreeNode(size_t id, const std::string &name) - : id(id), Context(name) {} - TreeNode(size_t id, size_t parentId, const std::string &name) - : id(id), parentId(parentId), Context(name) {} - virtual ~TreeNode() = default; - - void addChild(const Context &context, size_t id) { - children[context] = id; - } - - bool hasChild(const Context &context) const { - return children.find(context) != children.end(); - } - - size_t getChild(const Context &context) const { - return children.at(context); - } - - size_t parentId = DummyId; - size_t id = DummyId; - std::map children = {}; - std::map> metrics = {}; - std::map flexibleMetrics = {}; - friend class Tree; - }; - - Tree() { - treeNodeMap.try_emplace(TreeNode::RootId, TreeNode::RootId, "ROOT"); - } - - size_t addNode(const Context &context, size_t parentId) { - if (treeNodeMap[parentId].hasChild(context)) { - return treeNodeMap[parentId].getChild(context); - } - auto id = nextContextId++; - treeNodeMap.try_emplace(id, id, parentId, context.name); - treeNodeMap[parentId].addChild(context, id); - return id; - } - - size_t addNode(const std::vector &indices) { - if (indices.empty()) { - throw std::runtime_error("Indices is empty"); - } - auto parentId = TreeNode::RootId; - for (auto index : indices) { - parentId = addNode(index, parentId); - } - return parentId; - } - - TreeNode &getNode(size_t id) { return treeNodeMap.at(id); } - - enum class WalkPolicy { PreOrder, PostOrder }; - - template void walk(FnT &&fn) { - if constexpr (walkPolicy == WalkPolicy::PreOrder) { - walkPreOrder(TreeNode::RootId, fn); - } else if constexpr (walkPolicy == WalkPolicy::PostOrder) { - walkPostOrder(TreeNode::RootId, fn); - } - } - - template void walkPreOrder(size_t contextId, FnT &&fn) { - fn(getNode(contextId)); - for (auto &child : getNode(contextId).children) { - walkPreOrder(child.second, fn); - } - } - - template void walkPostOrder(size_t contextId, FnT &&fn) { - for (auto &child : getNode(contextId).children) { - walkPostOrder(child.second, fn); - } - fn(getNode(contextId)); - } - - private: - size_t nextContextId = TreeNode::RootId + 1; - // tree node id->tree node - std::map treeNodeMap; - }; - void init(); void dumpHatchet(std::ostream &os) const; void doDump(std::ostream &os, OutputFormat outputFormat) const override; + class Tree; std::unique_ptr tree; // ScopeId -> ContextId - std::map scopeIdToContextId; + std::unordered_map scopeIdToContextId; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Driver/Dispatch.h b/third_party/proton/csrc/include/Driver/Dispatch.h index be92423c2..1d8ec017c 100644 --- a/third_party/proton/csrc/include/Driver/Dispatch.h +++ b/third_party/proton/csrc/include/Driver/Dispatch.h @@ -44,8 +44,9 @@ namespace proton { struct ExternLibBase { using RetType = int; // Generic type, can be overridden in derived structs - static constexpr const char *name = ""; // Placeholder - static constexpr RetType success = 0; // Placeholder + static constexpr const char *name = ""; // Placeholder + static constexpr const char *defaultDir = ""; // Placeholder + static constexpr RetType success = 0; // Placeholder ExternLibBase() = delete; ExternLibBase(const ExternLibBase &) = delete; ExternLibBase &operator=(const ExternLibBase &) = delete; @@ -62,9 +63,17 @@ template class Dispatch { *lib = dlopen(name, RTLD_NOLOAD); } if (*lib == nullptr) { - // If not found, try to load it + // If not found, try to load it from LD_LIBRARY_PATH *lib = dlopen(name, RTLD_LOCAL | RTLD_LAZY); } + if (*lib == nullptr) { + // If still not found, try to load it from the default path + auto dir = std::string(ExternLib::defaultDir); + if (dir.length() > 0) { + auto fullPath = dir + "/" + name; + *lib = dlopen(fullPath.c_str(), RTLD_LOCAL | RTLD_LAZY); + } + } if (*lib == nullptr) { throw std::runtime_error("Could not find `" + std::string(name) + "`. Make sure it is in your " diff --git a/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h b/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h index 44c7cf3dd..495964923 100644 --- a/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h +++ b/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h @@ -2,11 +2,17 @@ #define PROTON_DRIVER_GPU_CUPTI_H_ #include "cupti.h" +#include "cupti_pcsampling.h" namespace proton { namespace cupti { +template CUptiResult getVersion(uint32_t *version); + +template +CUptiResult getContextId(CUcontext context, uint32_t *pCtxId); + template CUptiResult activityRegisterCallbacks( CUpti_BuffersCallbackRequestFunc funcBufferRequested, @@ -20,6 +26,10 @@ template CUptiResult enableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber, CUpti_CallbackDomain domain); +template +CUptiResult enableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, CUpti_CallbackId cbid); + template CUptiResult activityEnableContext(CUcontext context, CUpti_ActivityKind kind); @@ -56,6 +66,46 @@ CUptiResult unsubscribe(CUpti_SubscriberHandle subscriber); template CUptiResult finalize(); +template +CUptiResult getGraphExecId(CUgraphExec graph, uint32_t *pId); + +template +CUptiResult getGraphId(CUgraph graph, uint32_t *pId); + +template +CUptiResult getCubinCrc(CUpti_GetCubinCrcParams *pParams); + +template +CUptiResult +getSassToSourceCorrelation(CUpti_GetSassToSourceCorrelationParams *pParams); + +template +CUptiResult +pcSamplingGetNumStallReasons(CUpti_PCSamplingGetNumStallReasonsParams *pParams); + +template +CUptiResult +pcSamplingGetStallReasons(CUpti_PCSamplingGetStallReasonsParams *pParams); + +template +CUptiResult pcSamplingSetConfigurationAttribute( + CUpti_PCSamplingConfigurationInfoParams *pParams); + +template +CUptiResult pcSamplingEnable(CUpti_PCSamplingEnableParams *pParams); + +template +CUptiResult pcSamplingDisable(CUpti_PCSamplingDisableParams *pParams); + +template +CUptiResult pcSamplingGetData(CUpti_PCSamplingGetDataParams *pParams); + +template +CUptiResult pcSamplingStart(CUpti_PCSamplingStartParams *pParams); + +template +CUptiResult pcSamplingStop(CUpti_PCSamplingStopParams *pParams); + } // namespace cupti } // namespace proton diff --git a/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h b/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h new file mode 100644 index 000000000..58b6e2be8 --- /dev/null +++ b/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h @@ -0,0 +1,141 @@ +#ifndef PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ +#define PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ + +#include "CuptiProfiler.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Utility/Map.h" +#include "Utility/Singleton.h" +#include +#include + +namespace proton { + +struct CubinData { + size_t cubinCrc; + const char *cubin; + size_t cubinSize; + + struct LineInfoKey { + uint32_t functionIndex; + uint64_t pcOffset; + + bool operator<(const LineInfoKey &other) const { + return functionIndex < other.functionIndex || + (functionIndex == other.functionIndex && + pcOffset < other.pcOffset); + } + }; + + struct LineInfoValue { + uint32_t lineNumber{}; + const std::string functionName{}; + const std::string dirName{}; + const std::string fileName{}; + + LineInfoValue() = default; + + LineInfoValue(uint32_t lineNumber, const std::string &functionName, + const std::string &dirName, const std::string &fileName) + : lineNumber(lineNumber), functionName(functionName), dirName(dirName), + fileName(fileName) {} + }; + + std::map lineInfo; +}; + +struct ConfigureData { + ConfigureData() = default; + + ~ConfigureData() { + if (stallReasonNames) { + for (size_t i = 0; i < numStallReasons; i++) { + if (stallReasonNames[i]) + std::free(stallReasonNames[i]); + } + std::free(stallReasonNames); + } + if (stallReasonIndices) + std::free(stallReasonIndices); + if (pcSamplingData.pPcData) { + for (size_t i = 0; i < numValidStallReasons; ++i) { + std::free(pcSamplingData.pPcData[i].stallReason); + } + std::free(pcSamplingData.pPcData); + } + } + + void initialize(CUcontext context); + + CUpti_PCSamplingConfigurationInfo configureStallReasons(); + CUpti_PCSamplingConfigurationInfo configureSamplingPeriod(); + CUpti_PCSamplingConfigurationInfo configureSamplingBuffer(); + CUpti_PCSamplingConfigurationInfo configureScratchBuffer(); + CUpti_PCSamplingConfigurationInfo configureHardwareBufferSize(); + CUpti_PCSamplingConfigurationInfo configureStartStopControl(); + CUpti_PCSamplingConfigurationInfo configureCollectionMode(); + + // The amount of data reserved on the GPU + static constexpr size_t HardwareBufferSize = 128 * 1024 * 1024; + // The amount of data copied from the hardware buffer each time + static constexpr size_t ScratchBufferSize = 16 * 1024 * 1024; + // The number of PCs copied from the scratch buffer each time + static constexpr size_t DataBufferPCCount = 1024; + // The sampling period in cycles = 2^frequency + static constexpr uint32_t DefaultFrequency = 10; + + CUcontext context{}; + uint32_t contextId; + uint32_t numStallReasons{}; + uint32_t numValidStallReasons{}; + char **stallReasonNames{}; + uint32_t *stallReasonIndices{}; + std::map stallReasonIndexToMetricIndex{}; + std::set notIssuedStallReasonIndices{}; + CUpti_PCSamplingData pcSamplingData{}; + // The memory storing configuration information has to be kept alive during + // the profiling session + std::vector configurationInfos; +}; + +class CuptiPCSampling : public Singleton { + +public: + CuptiPCSampling() = default; + virtual ~CuptiPCSampling() = default; + + void initialize(CUcontext context); + + void start(CUcontext context); + + void stop(CUcontext context, uint64_t externId, bool isAPI); + + void finalize(CUcontext context); + + void loadModule(const char *cubin, size_t cubinSize); + + void unloadModule(const char *cubin, size_t cubinSize); + +private: + ConfigureData *getConfigureData(uint32_t contextId); + + CubinData *getCubinData(uint64_t cubinCrc); + + void processPCSamplingData(ConfigureData *configureData, uint64_t externId, + bool isAPI); + + ThreadSafeMap contextIdToConfigureData; + // In case the same cubin is loaded multiple times, we need to keep track of + // all of them + ThreadSafeMap> + cubinCrcToCubinData; + ThreadSafeSet contextInitialized; + + std::atomic pcSamplingStarted{false}; + std::mutex pcSamplingMutex{}; + std::mutex contextMutex{}; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ diff --git a/third_party/proton/csrc/include/Profiler/CuptiProfiler.h b/third_party/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h similarity index 90% rename from third_party/proton/csrc/include/Profiler/CuptiProfiler.h rename to third_party/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h index 344d0fd4b..c443ec2e3 100644 --- a/third_party/proton/csrc/include/Profiler/CuptiProfiler.h +++ b/third_party/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h @@ -1,7 +1,7 @@ #ifndef PROTON_PROFILER_CUPTI_PROFILER_H_ #define PROTON_PROFILER_CUPTI_PROFILER_H_ -#include "GPUProfiler.h" +#include "Profiler/GPUProfiler.h" namespace proton { diff --git a/third_party/proton/csrc/include/Profiler/GPUProfiler.h b/third_party/proton/csrc/include/Profiler/GPUProfiler.h index c3c148658..d5033b06a 100644 --- a/third_party/proton/csrc/include/Profiler/GPUProfiler.h +++ b/third_party/proton/csrc/include/Profiler/GPUProfiler.h @@ -4,10 +4,14 @@ #include "Context/Context.h" #include "Profiler.h" #include "Utility/Atomic.h" -#include +#include "Utility/Map.h" +#include "Utility/Set.h" #include -#include +#include +#include +#include +#include namespace proton { @@ -15,58 +19,79 @@ namespace proton { // CuptiProfiler, should be a singleton. template class GPUProfiler : public Profiler, - public OpInterface, + public ThreadLocalOpInterface, public Singleton { public: GPUProfiler() = default; virtual ~GPUProfiler() = default; -protected: - // OpInterface - void startOp(const Scope &scope) override { pImpl->startOp(scope); } - void stopOp(const Scope &scope) override { pImpl->stopOp(scope); } + using CorrIdToExternIdMap = + ThreadSafeMap, /**/ + std::unordered_map>>; + using ApiExternIdSet = ThreadSafeSet>; - void setOpInProgress(bool value) override { - profilerState.isRecording = value; + ConcreteProfilerT &enablePCSampling() { + pcSamplingEnabled = true; + return dynamic_cast(*this); } + ConcreteProfilerT &disablePCSampling() { + pcSamplingEnabled = false; + return dynamic_cast(*this); + } + bool isPCSamplingEnabled() const { return pcSamplingEnabled; } - bool isOpInProgress() override { return profilerState.isRecording; } +protected: + // OpInterface + void startOp(const Scope &scope) override { + this->correlation.pushExternId(scope.scopeId); + } + void stopOp(const Scope &scope) override { this->correlation.popExternId(); } // Profiler virtual void doStart() override { pImpl->doStart(); } virtual void doFlush() override { pImpl->doFlush(); } virtual void doStop() override { pImpl->doStop(); } - struct ProfilerState { + struct ThreadState { ConcreteProfilerT &profiler; - std::set dataSet; - bool isRecording{false}; - Scope scope{}; - ProfilerState(ConcreteProfilerT &profiler) : profiler(profiler) {} + ThreadState(ConcreteProfilerT &profiler) : profiler(profiler) {} - void record(const Scope &scope) { - this->scope = scope; - // Take a snapshot of the current dataset - this->dataSet = profiler.getDataSet(); + void record(size_t scopeId) { + if (profiler.isOpInProgress()) + return; + std::set dataSet = profiler.getDataSet(); + for (auto data : dataSet) + data->addScope(scopeId); + profiler.correlation.apiExternIds.insert(scopeId); } - void enterOp() { - profiler.enterOp(scope); - for (auto data : dataSet) - data->enterOp(scope); + void enterOp(size_t scopeId) { + if (profiler.isOpInProgress()) + return; + profiler.correlation.pushExternId(scopeId); + profiler.setOpInProgress(true); } void exitOp() { - profiler.exitOp(scope); - for (auto data : dataSet) - data->exitOp(this->scope); + if (!profiler.isOpInProgress()) + return; + profiler.correlation.popExternId(); + profiler.setOpInProgress(false); } }; struct Correlation { std::atomic maxSubmittedCorrelationId{0}; std::atomic maxCompletedCorrelationId{0}; + // Mapping from a native profiler correlation id to an external id. + CorrIdToExternIdMap corrIdToExternId; + // A set of kernels triggered by GPU runtime APIs (e.g., torch + // kernels) other than Triton. + // It stores a subset of external ids in corrIdToExternId. + ApiExternIdSet apiExternIds; + static thread_local std::deque externIdQueue; Correlation() = default; @@ -78,6 +103,17 @@ class GPUProfiler : public Profiler, atomicMax(maxCompletedCorrelationId, correlationId); } + void pushExternId(size_t externId) { externIdQueue.push_back(externId); } + + void popExternId() { externIdQueue.pop_front(); } + + // Correlate the correlationId with the last externId + void correlate(uint64_t correlationId, size_t numInstances = 1) { + if (externIdQueue.empty()) + return; + corrIdToExternId[correlationId] = {externIdQueue.back(), numInstances}; + } + template void flush(uint64_t maxRetries, uint64_t sleepMs, FlushFnT &&flushFn) { flushFn(); @@ -93,7 +129,7 @@ class GPUProfiler : public Profiler, } }; - static thread_local ProfilerState profilerState; + static thread_local ThreadState threadState; Correlation correlation; // Use the pimpl idiom to hide the implementation details. This lets us avoid @@ -106,8 +142,6 @@ class GPUProfiler : public Profiler, : profiler(profiler) {} virtual ~GPUProfilerPimplInterface() = default; - virtual void startOp(const Scope &scope) = 0; - virtual void stopOp(const Scope &scope) = 0; virtual void doStart() = 0; virtual void doFlush() = 0; virtual void doStop() = 0; @@ -116,6 +150,8 @@ class GPUProfiler : public Profiler, ConcreteProfilerT &profiler; }; std::unique_ptr pImpl; + + bool pcSamplingEnabled{false}; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Profiler/RoctracerProfiler.h b/third_party/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h similarity index 91% rename from third_party/proton/csrc/include/Profiler/RoctracerProfiler.h rename to third_party/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h index 2f1791dcb..b9bc08de8 100644 --- a/third_party/proton/csrc/include/Profiler/RoctracerProfiler.h +++ b/third_party/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h @@ -1,7 +1,7 @@ #ifndef PROTON_PROFILER_ROCTRACER_PROFILER_H_ #define PROTON_PROFILER_ROCTRACER_PROFILER_H_ -#include "GPUProfiler.h" +#include "Profiler/GPUProfiler.h" namespace proton { diff --git a/third_party/proton/csrc/include/Utility/Atomic.h b/third_party/proton/csrc/include/Utility/Atomic.h index d7e40e73c..0f759e0d6 100644 --- a/third_party/proton/csrc/include/Utility/Atomic.h +++ b/third_party/proton/csrc/include/Utility/Atomic.h @@ -1,4 +1,8 @@ +#ifndef PROTON_UTILITY_ATOMIC_H_ +#define PROTON_UTILITY_ATOMIC_H_ + #include +#include namespace proton { @@ -16,4 +20,20 @@ template T atomicMin(std::atomic &target, T value) { return current; } +template +void doubleCheckedLock(Condition enterCondition, std::mutex &lock, + Function function) { + if (!enterCondition()) + return; + + std::unique_lock guard(lock); + + if (!enterCondition()) + return; + + function(); +} + } // namespace proton + +#endif // PROTON_UTILITY_ATOMIC_H_ diff --git a/third_party/proton/csrc/include/Utility/Errors.h b/third_party/proton/csrc/include/Utility/Errors.h index 62d4f3f66..094723d6f 100644 --- a/third_party/proton/csrc/include/Utility/Errors.h +++ b/third_party/proton/csrc/include/Utility/Errors.h @@ -1,3 +1,6 @@ +#ifndef PROTON_UTILITY_ERRORS_H_ +#define PROTON_UTILITY_ERRORS_H_ + #include namespace proton { @@ -8,3 +11,5 @@ class NotImplemented : public std::logic_error { }; } // namespace proton + +#endif // PROTON_UTILITY_ERRORS_H_ diff --git a/third_party/proton/csrc/include/Utility/Map.h b/third_party/proton/csrc/include/Utility/Map.h new file mode 100644 index 000000000..c173d163e --- /dev/null +++ b/third_party/proton/csrc/include/Utility/Map.h @@ -0,0 +1,61 @@ +#ifndef PROTON_UTILITY_MAP_H_ +#define PROTON_UTILITY_MAP_H_ + +#include +#include + +namespace proton { + +/// A simple thread safe map with read/write lock. +template > +class ThreadSafeMap { +public: + ThreadSafeMap() = default; + + Value &operator[](const Key &key) { + std::unique_lock lock(mutex); + return map[key]; + } + + Value &operator[](Key &&key) { + std::unique_lock lock(mutex); + return map[std::move(key)]; + } + + Value &at(const Key &key) { + std::shared_lock lock(mutex); + return map.at(key); + } + + void insert(const Key &key, const Value &value) { + std::unique_lock lock(mutex); + map[key] = value; + } + + bool contain(const Key &key) { + std::shared_lock lock(mutex); + auto it = map.find(key); + if (it == map.end()) + return false; + return true; + } + + bool erase(const Key &key) { + std::unique_lock lock(mutex); + return map.erase(key) > 0; + } + + void clear() { + std::unique_lock lock(mutex); + map.clear(); + } + +private: + Container map; + std::shared_mutex mutex; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_MAP_H_ diff --git a/third_party/proton/csrc/include/Utility/Set.h b/third_party/proton/csrc/include/Utility/Set.h new file mode 100644 index 000000000..50ce165db --- /dev/null +++ b/third_party/proton/csrc/include/Utility/Set.h @@ -0,0 +1,45 @@ +#ifndef PROTON_UTILITY_SET_H_ +#define PROTON_UTILITY_SET_H_ + +#include +#include + +namespace proton { + +/// A simple thread safe set with read/write lock. +template > +class ThreadSafeSet { +public: + ThreadSafeSet() = default; + + void insert(const Key &key) { + std::unique_lock lock(mutex); + set.insert(key); + } + + bool contain(const Key &key) { + std::shared_lock lock(mutex); + auto it = set.find(key); + if (it == set.end()) + return false; + return true; + } + + bool erase(const Key &key) { + std::unique_lock lock(mutex); + return set.erase(key) > 0; + } + + void clear() { + std::unique_lock lock(mutex); + set.clear(); + } + +private: + Container set; + std::shared_mutex mutex; +}; + +} // namespace proton + +#endif // PROTON_UTILITY_MAP_H_ diff --git a/third_party/proton/csrc/include/Utility/String.h b/third_party/proton/csrc/include/Utility/String.h index b7d45ae1f..b4a1d3ff9 100644 --- a/third_party/proton/csrc/include/Utility/String.h +++ b/third_party/proton/csrc/include/Utility/String.h @@ -13,6 +13,18 @@ inline std::string toLower(const std::string &str) { return lower; } +inline std::string replace(const std::string &str, const std::string &src, + const std::string &dst) { + std::string replaced = str; + size_t pos = replaced.find(src, pos); + while (pos != std::string::npos) { + replaced.replace(pos, src.length(), dst); + pos += dst.length(); + pos = replaced.find(src, pos); + } + return replaced; +} + } // namespace proton #endif // PROTON_UTILITY_STRING_H_ diff --git a/third_party/proton/csrc/lib/Context/Context.cpp b/third_party/proton/csrc/lib/Context/Context.cpp index f1b0177d1..676bdd8d6 100644 --- a/third_party/proton/csrc/lib/Context/Context.cpp +++ b/third_party/proton/csrc/lib/Context/Context.cpp @@ -4,7 +4,7 @@ namespace proton { std::atomic Scope::scopeIdCounter{1}; -/*static*/ thread_local std::map - InternalOpInterface::opInProgress; +/*static*/ thread_local std::map + ThreadLocalOpInterface::opInProgress; } // namespace proton diff --git a/third_party/proton/csrc/lib/Data/TraceData.cpp b/third_party/proton/csrc/lib/Data/TraceData.cpp index 1076c0863..03406368a 100644 --- a/third_party/proton/csrc/lib/Data/TraceData.cpp +++ b/third_party/proton/csrc/lib/Data/TraceData.cpp @@ -9,6 +9,10 @@ void TraceData::startOp(const Scope &scope) { throw NotImplemented(); } void TraceData::stopOp(const Scope &scope) { throw NotImplemented(); } +size_t TraceData::addScope(size_t scopeId, const std::string &name) { + throw NotImplemented(); +} + void TraceData::addMetric(size_t scopeId, std::shared_ptr metric) { throw NotImplemented(); } diff --git a/third_party/proton/csrc/lib/Data/TreeData.cpp b/third_party/proton/csrc/lib/Data/TreeData.cpp index b69e55b8c..ec6ea1c78 100644 --- a/third_party/proton/csrc/lib/Data/TreeData.cpp +++ b/third_party/proton/csrc/lib/Data/TreeData.cpp @@ -14,6 +14,91 @@ using json = nlohmann::json; namespace proton { +class TreeData::Tree { +public: + struct TreeNode : public Context { + inline static const size_t RootId = 0; + inline static const size_t DummyId = std::numeric_limits::max(); + + TreeNode() = default; + explicit TreeNode(size_t id, const std::string &name) + : id(id), Context(name) {} + TreeNode(size_t id, size_t parentId, const std::string &name) + : id(id), parentId(parentId), Context(name) {} + virtual ~TreeNode() = default; + + void addChild(const Context &context, size_t id) { children[context] = id; } + + bool hasChild(const Context &context) const { + return children.find(context) != children.end(); + } + + size_t getChild(const Context &context) const { + return children.at(context); + } + + size_t parentId = DummyId; + size_t id = DummyId; + std::map children = {}; + std::map> metrics = {}; + std::map flexibleMetrics = {}; + friend class Tree; + }; + + Tree() { + treeNodeMap.try_emplace(TreeNode::RootId, TreeNode::RootId, "ROOT"); + } + + size_t addNode(const Context &context, size_t parentId) { + if (treeNodeMap[parentId].hasChild(context)) { + return treeNodeMap[parentId].getChild(context); + } + auto id = nextContextId++; + treeNodeMap.try_emplace(id, id, parentId, context.name); + treeNodeMap[parentId].addChild(context, id); + return id; + } + + size_t addNode(const std::vector &indices) { + auto parentId = TreeNode::RootId; + for (auto index : indices) { + parentId = addNode(index, parentId); + } + return parentId; + } + + TreeNode &getNode(size_t id) { return treeNodeMap.at(id); } + + enum class WalkPolicy { PreOrder, PostOrder }; + + template void walk(FnT &&fn) { + if constexpr (walkPolicy == WalkPolicy::PreOrder) { + walkPreOrder(TreeNode::RootId, fn); + } else if constexpr (walkPolicy == WalkPolicy::PostOrder) { + walkPostOrder(TreeNode::RootId, fn); + } + } + + template void walkPreOrder(size_t contextId, FnT &&fn) { + fn(getNode(contextId)); + for (auto &child : getNode(contextId).children) { + walkPreOrder(child.second, fn); + } + } + + template void walkPostOrder(size_t contextId, FnT &&fn) { + for (auto &child : getNode(contextId).children) { + walkPostOrder(child.second, fn); + } + fn(getNode(contextId)); + } + +private: + size_t nextContextId = TreeNode::RootId + 1; + // tree node id->tree node + std::map treeNodeMap; +}; + void TreeData::init() { tree = std::make_unique(); } void TreeData::startOp(const Scope &scope) { @@ -29,6 +114,25 @@ void TreeData::startOp(const Scope &scope) { void TreeData::stopOp(const Scope &scope) {} +size_t TreeData::addScope(size_t parentScopeId, const std::string &name) { + std::unique_lock lock(mutex); + auto scopeIdIt = scopeIdToContextId.find(parentScopeId); + auto scopeId = parentScopeId; + if (scopeIdIt == scopeIdToContextId.end()) { + std::vector contexts; + if (contextSource != nullptr) + contexts = contextSource->getContexts(); + // Record the parent context + scopeIdToContextId[parentScopeId] = tree->addNode(contexts); + } else { + // Add a new context under it and update the context + scopeId = Scope::getNewScopeId(); + scopeIdToContextId[scopeId] = + tree->addNode(Context(name), scopeIdIt->second); + } + return scopeId; +} + void TreeData::addMetric(size_t scopeId, std::shared_ptr metric) { std::unique_lock lock(mutex); auto scopeIdIt = scopeIdToContextId.find(scopeId); @@ -76,66 +180,76 @@ void TreeData::dumpHatchet(std::ostream &os) const { jsonNodes[Tree::TreeNode::RootId] = &(output.back()); std::set valueNames; std::map> deviceIds; - this->tree->template walk( - [&](Tree::TreeNode &treeNode) { - const auto contextName = treeNode.name; - auto contextId = treeNode.id; - json *jsonNode = jsonNodes[contextId]; - (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; - (*jsonNode)["metrics"] = json::object(); - for (auto [metricKind, metric] : treeNode.metrics) { - if (metricKind == MetricKind::Kernel) { - auto kernelMetric = std::dynamic_pointer_cast(metric); - auto duration = std::get( - kernelMetric->getValue(KernelMetric::Duration)); - auto invocations = std::get( - kernelMetric->getValue(KernelMetric::Invocations)); - auto deviceId = std::get( - kernelMetric->getValue(KernelMetric::DeviceId)); - auto deviceType = std::get( - kernelMetric->getValue(KernelMetric::DeviceType)); - auto deviceTypeName = - getDeviceTypeString(static_cast(deviceType)); - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::Duration)] = - duration; - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::Invocations)] = - invocations; - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::DeviceId)] = - std::to_string(deviceId); - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::DeviceType)] = - deviceTypeName; - valueNames.insert( - kernelMetric->getValueName(KernelMetric::Duration)); - valueNames.insert( - kernelMetric->getValueName(KernelMetric::Invocations)); - deviceIds.insert({deviceType, {deviceId}}); - } else { - throw std::runtime_error("MetricKind not supported"); - } - } - for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) { - auto valueName = flexibleMetric.getValueName(0); + this->tree->template walk([&](Tree::TreeNode + &treeNode) { + const auto contextName = treeNode.name; + auto contextId = treeNode.id; + json *jsonNode = jsonNodes[contextId]; + (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; + (*jsonNode)["metrics"] = json::object(); + for (auto [metricKind, metric] : treeNode.metrics) { + if (metricKind == MetricKind::Kernel) { + std::shared_ptr kernelMetric = + std::dynamic_pointer_cast(metric); + uint64_t duration = + std::get(kernelMetric->getValue(KernelMetric::Duration)); + uint64_t invocations = std::get( + kernelMetric->getValue(KernelMetric::Invocations)); + uint64_t deviceId = + std::get(kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceType = std::get( + kernelMetric->getValue(KernelMetric::DeviceType)); + std::string deviceTypeName = + getDeviceTypeString(static_cast(deviceType)); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Duration)] = + duration; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Invocations)] = + invocations; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceId)] = + std::to_string(deviceId); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceType)] = + deviceTypeName; + valueNames.insert(kernelMetric->getValueName(KernelMetric::Duration)); + valueNames.insert( + kernelMetric->getValueName(KernelMetric::Invocations)); + deviceIds.insert({deviceType, {deviceId}}); + } else if (metricKind == MetricKind::PCSampling) { + auto pcSamplingMetric = + std::dynamic_pointer_cast(metric); + for (size_t i = 0; i < PCSamplingMetric::Count; i++) { + auto valueName = pcSamplingMetric->getValueName(i); valueNames.insert(valueName); std::visit( [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, - flexibleMetric.getValues()[0]); - } - (*jsonNode)["children"] = json::array(); - auto children = treeNode.children; - for (auto _ : children) { - (*jsonNode)["children"].push_back(json::object()); + pcSamplingMetric->getValues()[i]); } - auto idx = 0; - for (auto child : children) { - auto [index, childId] = child; - jsonNodes[childId] = &(*jsonNode)["children"][idx]; - idx++; - } - }); + } else { + throw std::runtime_error("MetricKind not supported"); + } + } + for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) { + auto valueName = flexibleMetric.getValueName(0); + valueNames.insert(valueName); + std::visit( + [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, + flexibleMetric.getValues()[0]); + } + (*jsonNode)["children"] = json::array(); + auto children = treeNode.children; + for (auto _ : children) { + (*jsonNode)["children"].push_back(json::object()); + } + auto idx = 0; + for (auto child : children) { + auto [index, childId] = child; + jsonNodes[childId] = &(*jsonNode)["children"][idx]; + idx++; + } + }); // Hints for all available metrics for (auto valueName : valueNames) { output[Tree::TreeNode::RootId]["metrics"][valueName] = 0; @@ -173,4 +287,11 @@ void TreeData::doDump(std::ostream &os, OutputFormat outputFormat) const { } } +TreeData::TreeData(const std::string &path, ContextSource *contextSource) + : Data(path, contextSource) { + init(); +} + +TreeData::~TreeData() {} + } // namespace proton diff --git a/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp index aae8b4ceb..d1617b48a 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/CudaApi.cpp @@ -11,6 +11,7 @@ struct ExternLibCuda : public ExternLibBase { // On WSL, "libcuda.so" and "libcuda.so.1" may not be linked, so we use // "libcuda.so.1" instead. static constexpr const char *name = "libcuda.so.1"; + static constexpr const char *defaultDir = ""; static constexpr RetType success = CUDA_SUCCESS; static void *lib; }; diff --git a/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp index eaf26dbd3..2c399d31c 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp @@ -6,15 +6,27 @@ namespace proton { namespace cupti { +#define STRINGIFY(x) #x +#define TOSTRING(x) STRINGIFY(x) struct ExternLibCupti : public ExternLibBase { using RetType = CUptiResult; static constexpr const char *name = "libcupti.so"; +#ifdef CUPTI_LIB_DIR + static constexpr const char *defaultDir = TOSTRING(CUPTI_LIB_DIR); +#else + static constexpr const char *defaultDir = ""; +#endif static constexpr RetType success = CUPTI_SUCCESS; static void *lib; }; void *ExternLibCupti::lib = nullptr; +DEFINE_DISPATCH(ExternLibCupti, getVersion, cuptiGetVersion, uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getContextId, cuptiGetContextId, CUcontext, + uint32_t *); + DEFINE_DISPATCH(ExternLibCupti, activityRegisterCallbacks, cuptiActivityRegisterCallbacks, CUpti_BuffersCallbackRequestFunc, @@ -26,6 +38,9 @@ DEFINE_DISPATCH(ExternLibCupti, subscribe, cuptiSubscribe, DEFINE_DISPATCH(ExternLibCupti, enableDomain, cuptiEnableDomain, uint32_t, CUpti_SubscriberHandle, CUpti_CallbackDomain) +DEFINE_DISPATCH(ExternLibCupti, enableCallback, cuptiEnableCallback, uint32_t, + CUpti_SubscriberHandle, CUpti_CallbackDomain, CUpti_CallbackId); + DEFINE_DISPATCH(ExternLibCupti, activityEnable, cuptiActivityEnable, CUpti_ActivityKind) @@ -61,6 +76,46 @@ DEFINE_DISPATCH(ExternLibCupti, unsubscribe, cuptiUnsubscribe, DEFINE_DISPATCH(ExternLibCupti, finalize, cuptiFinalize) +DEFINE_DISPATCH(ExternLibCupti, getGraphExecId, cuptiGetGraphExecId, + CUgraphExec, uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getGraphId, cuptiGetGraphId, CUgraph, + uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getCubinCrc, cuptiGetCubinCrc, + CUpti_GetCubinCrcParams *); + +DEFINE_DISPATCH(ExternLibCupti, getSassToSourceCorrelation, + cuptiGetSassToSourceCorrelation, + CUpti_GetSassToSourceCorrelationParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetNumStallReasons, + cuptiPCSamplingGetNumStallReasons, + CUpti_PCSamplingGetNumStallReasonsParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetStallReasons, + cuptiPCSamplingGetStallReasons, + CUpti_PCSamplingGetStallReasonsParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingSetConfigurationAttribute, + cuptiPCSamplingSetConfigurationAttribute, + CUpti_PCSamplingConfigurationInfoParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingEnable, cuptiPCSamplingEnable, + CUpti_PCSamplingEnableParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingDisable, cuptiPCSamplingDisable, + CUpti_PCSamplingDisableParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetData, cuptiPCSamplingGetData, + CUpti_PCSamplingGetDataParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingStart, cuptiPCSamplingStart, + CUpti_PCSamplingStartParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingStop, cuptiPCSamplingStop, + CUpti_PCSamplingStopParams *); + } // namespace cupti } // namespace proton diff --git a/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp index 18de4a4f6..9e8ef8d22 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/HipApi.cpp @@ -10,6 +10,7 @@ namespace hip { struct ExternLibHip : public ExternLibBase { using RetType = hipError_t; static constexpr const char *name = "libamdhip64.so"; + static constexpr const char *defaultDir = ""; static constexpr RetType success = hipSuccess; static void *lib; }; diff --git a/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp index e07f5eb1b..7c607b4b9 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/HsaApi.cpp @@ -8,6 +8,7 @@ namespace hsa { struct ExternLibHsa : public ExternLibBase { using RetType = hsa_status_t; static constexpr const char *name = "libhsa-runtime64.so"; + static constexpr const char *defaultDir = ""; static constexpr RetType success = HSA_STATUS_SUCCESS; static void *lib; }; diff --git a/third_party/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp index 21b6a03a4..a6dcdcf34 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/RoctracerApi.cpp @@ -8,6 +8,7 @@ namespace roctracer { struct ExternLibRoctracer : public ExternLibBase { using RetType = roctracer_status_t; static constexpr const char *name = "libroctracer64.so"; + static constexpr const char *defaultDir = ""; static constexpr RetType success = ROCTRACER_STATUS_SUCCESS; static void *lib; }; diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp new file mode 100644 index 000000000..19b50214b --- /dev/null +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp @@ -0,0 +1,445 @@ +#include "Profiler/Cupti/CuptiPCSampling.h" +#include "Data/Metric.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Utility/Atomic.h" +#include "Utility/Map.h" +#include "Utility/String.h" +#include +#include +#include + +namespace proton { + +namespace { + +uint64_t getCubinCrc(const char *cubin, size_t size) { + CUpti_GetCubinCrcParams cubinCrcParams = { + .size = CUpti_GetCubinCrcParamsSize, + .cubinSize = size, + .cubin = cubin, + .cubinCrc = 0, + }; + cupti::getCubinCrc(&cubinCrcParams); + return cubinCrcParams.cubinCrc; +} + +size_t getNumStallReasons(CUcontext context) { + size_t numStallReasons = 0; + CUpti_PCSamplingGetNumStallReasonsParams numStallReasonsParams = { + .size = CUpti_PCSamplingGetNumStallReasonsParamsSize, + .pPriv = NULL, + .ctx = context, + .numStallReasons = &numStallReasons}; + cupti::pcSamplingGetNumStallReasons(&numStallReasonsParams); + return numStallReasons; +} + +std::tuple +getSassToSourceCorrelation(const char *functionName, uint64_t pcOffset, + const char *cubin, size_t cubinSize) { + CUpti_GetSassToSourceCorrelationParams sassToSourceParams = { + .size = CUpti_GetSassToSourceCorrelationParamsSize, + .cubin = cubin, + .functionName = functionName, + .cubinSize = cubinSize, + .lineNumber = 0, + .pcOffset = pcOffset, + .fileName = NULL, + .dirName = NULL, + }; + // Get source can fail if the line mapping is not available in the cubin so we + // don't check the return value + cupti::getSassToSourceCorrelation(&sassToSourceParams); + auto fileNameStr = sassToSourceParams.fileName + ? std::string(sassToSourceParams.fileName) + : ""; + auto dirNameStr = + sassToSourceParams.dirName ? std::string(sassToSourceParams.dirName) : ""; + // It's user's responsibility to free the memory + if (sassToSourceParams.fileName) + std::free(sassToSourceParams.fileName); + if (sassToSourceParams.dirName) + std::free(sassToSourceParams.dirName); + return std::make_tuple(sassToSourceParams.lineNumber, fileNameStr, + dirNameStr); +} + +std::pair +getStallReasonNamesAndIndices(CUcontext context, size_t numStallReasons) { + char **stallReasonNames = + static_cast(std::calloc(numStallReasons, sizeof(char *))); + for (size_t i = 0; i < numStallReasons; i++) { + stallReasonNames[i] = static_cast( + std::calloc(CUPTI_STALL_REASON_STRING_SIZE, sizeof(char))); + } + uint32_t *stallReasonIndices = + static_cast(std::calloc(numStallReasons, sizeof(uint32_t))); + // Initialize the names with 128 characters to avoid buffer overflow + CUpti_PCSamplingGetStallReasonsParams stallReasonsParams = { + .size = CUpti_PCSamplingGetStallReasonsParamsSize, + .pPriv = NULL, + .ctx = context, + .numStallReasons = numStallReasons, + .stallReasonIndex = stallReasonIndices, + .stallReasons = stallReasonNames, + }; + cupti::pcSamplingGetStallReasons(&stallReasonsParams); + return std::make_pair(stallReasonNames, stallReasonIndices); +} + +size_t matchStallReasonsToIndices( + size_t numStallReasons, char **stallReasonNames, + uint32_t *stallReasonIndices, + std::map &stallReasonIndexToMetricIndex, + std::set ¬IssuedStallReasonIndices) { + // In case there's any invalid stall reasons, we only collect valid ones. + // Invalid ones are swapped to the end of the list + std::vector validIndex(numStallReasons, false); + size_t numValidStalls = 0; + for (size_t i = 0; i < numStallReasons; i++) { + bool notIssued = std::string(stallReasonNames[i]).find("not_issued") != + std::string::npos; + std::string cuptiStallName = std::string(stallReasonNames[i]); + for (size_t j = 0; j < PCSamplingMetric::PCSamplingMetricKind::Count; j++) { + auto metricName = PCSamplingMetric().getValueName(j); + if (cuptiStallName.find(metricName) != std::string::npos) { + if (notIssued) + notIssuedStallReasonIndices.insert(stallReasonIndices[i]); + stallReasonIndexToMetricIndex[stallReasonIndices[i]] = j; + validIndex[i] = true; + numValidStalls++; + break; + } + } + } + int invalidIndex = -1; + for (size_t i = 0; i < numStallReasons; i++) { + if (invalidIndex == -1 && !validIndex[i]) { + invalidIndex = i; + } else if (invalidIndex != -1 && validIndex[i]) { + std::swap(stallReasonIndices[invalidIndex], stallReasonIndices[i]); + std::swap(stallReasonNames[invalidIndex], stallReasonNames[i]); + validIndex[invalidIndex] = true; + invalidIndex++; + } + } + return numValidStalls; +} + +#define CUPTI_CUDA12_4_VERSION 22 +#define CUPTI_CUDA12_4_PC_DATA_PADDING_SIZE sizeof(uint32_t) + +CUpti_PCSamplingData allocPCSamplingData(size_t collectNumPCs, + size_t numValidStallReasons) { + uint32_t libVersion = 0; + cupti::getVersion(&libVersion); + size_t pcDataSize = sizeof(CUpti_PCSamplingPCData); + // Check cupti api version < 12.4 but cupti header version >= 12.4 + // If so, we subtract 4 bytes from the size of CUpti_PCSamplingPCData + // because it introduces a new field (i.e., correlationId) at the end of the + // struct, which is not compatible with the previous versions. + if (libVersion < CUPTI_CUDA12_4_VERSION && + CUPTI_API_VERSION >= CUPTI_CUDA12_4_VERSION) + pcDataSize -= CUPTI_CUDA12_4_PC_DATA_PADDING_SIZE; + CUpti_PCSamplingData pcSamplingData{ + .size = pcDataSize, + .collectNumPcs = collectNumPCs, + .pPcData = static_cast( + std::calloc(collectNumPCs, sizeof(CUpti_PCSamplingPCData)))}; + for (size_t i = 0; i < collectNumPCs; ++i) { + pcSamplingData.pPcData[i].stallReason = + static_cast(std::calloc( + numValidStallReasons, sizeof(CUpti_PCSamplingStallReason))); + } + return pcSamplingData; +} + +void enablePCSampling(CUcontext context) { + CUpti_PCSamplingEnableParams params = { + .size = CUpti_PCSamplingEnableParamsSize, + .pPriv = NULL, + .ctx = context, + }; + cupti::pcSamplingEnable(¶ms); +} + +void disablePCSampling(CUcontext context) { + CUpti_PCSamplingDisableParams params = { + .size = CUpti_PCSamplingDisableParamsSize, + .pPriv = NULL, + .ctx = context, + }; + cupti::pcSamplingDisable(¶ms); +} + +void startPCSampling(CUcontext context) { + CUpti_PCSamplingStartParams params = { + .size = CUpti_PCSamplingStartParamsSize, + .pPriv = NULL, + .ctx = context, + }; + cupti::pcSamplingStart(¶ms); +} + +void stopPCSampling(CUcontext context) { + CUpti_PCSamplingStopParams params = { + .size = CUpti_PCSamplingStopParamsSize, + .pPriv = NULL, + .ctx = context, + }; + cupti::pcSamplingStop(¶ms); +} + +void getPCSamplingData(CUcontext context, + CUpti_PCSamplingData *pcSamplingData) { + CUpti_PCSamplingGetDataParams params = { + .size = CUpti_PCSamplingGetDataParamsSize, + .pPriv = NULL, + .ctx = context, + .pcSamplingData = pcSamplingData, + }; + cupti::pcSamplingGetData(¶ms); +} + +void setConfigurationAttribute( + CUcontext context, + std::vector &configurationInfos) { + CUpti_PCSamplingConfigurationInfoParams infoParams = { + .size = CUpti_PCSamplingConfigurationInfoParamsSize, + .pPriv = NULL, + .ctx = context, + .numAttributes = configurationInfos.size(), + .pPCSamplingConfigurationInfo = configurationInfos.data(), + }; + cupti::pcSamplingSetConfigurationAttribute(&infoParams); +} + +} // namespace + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureStallReasons() { + numStallReasons = getNumStallReasons(context); + std::tie(this->stallReasonNames, this->stallReasonIndices) = + getStallReasonNamesAndIndices(context, numStallReasons); + numValidStallReasons = matchStallReasonsToIndices( + numStallReasons, stallReasonNames, stallReasonIndices, + stallReasonIndexToMetricIndex, notIssuedStallReasonIndices); + CUpti_PCSamplingConfigurationInfo stallReasonInfo{}; + stallReasonInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON; + stallReasonInfo.attributeData.stallReasonData.stallReasonCount = + numValidStallReasons; + stallReasonInfo.attributeData.stallReasonData.pStallReasonIndex = + stallReasonIndices; + return stallReasonInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureSamplingPeriod() { + CUpti_PCSamplingConfigurationInfo samplingPeriodInfo{}; + samplingPeriodInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_PERIOD; + samplingPeriodInfo.attributeData.samplingPeriodData.samplingPeriod = + DefaultFrequency; + return samplingPeriodInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureSamplingBuffer() { + CUpti_PCSamplingConfigurationInfo samplingBufferInfo{}; + samplingBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_DATA_BUFFER; + this->pcSamplingData = + allocPCSamplingData(DataBufferPCCount, numValidStallReasons); + samplingBufferInfo.attributeData.samplingDataBufferData.samplingDataBuffer = + &this->pcSamplingData; + return samplingBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureScratchBuffer() { + CUpti_PCSamplingConfigurationInfo scratchBufferInfo{}; + scratchBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SCRATCH_BUFFER_SIZE; + scratchBufferInfo.attributeData.scratchBufferSizeData.scratchBufferSize = + ScratchBufferSize; + return scratchBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureHardwareBufferSize() { + CUpti_PCSamplingConfigurationInfo hardwareBufferInfo{}; + hardwareBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_HARDWARE_BUFFER_SIZE; + hardwareBufferInfo.attributeData.hardwareBufferSizeData.hardwareBufferSize = + HardwareBufferSize; + return hardwareBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureStartStopControl() { + CUpti_PCSamplingConfigurationInfo startStopControlInfo{}; + startStopControlInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL; + startStopControlInfo.attributeData.enableStartStopControlData + .enableStartStopControl = true; + return startStopControlInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureCollectionMode() { + CUpti_PCSamplingConfigurationInfo collectionModeInfo{}; + collectionModeInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_COLLECTION_MODE; + collectionModeInfo.attributeData.collectionModeData.collectionMode = + CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS; + return collectionModeInfo; +} + +void ConfigureData::initialize(CUcontext context) { + this->context = context; + cupti::getContextId(context, &contextId); + configurationInfos.emplace_back(configureStallReasons()); + configurationInfos.emplace_back(configureSamplingPeriod()); + configurationInfos.emplace_back(configureHardwareBufferSize()); + configurationInfos.emplace_back(configureScratchBuffer()); + configurationInfos.emplace_back(configureSamplingBuffer()); + configurationInfos.emplace_back(configureStartStopControl()); + configurationInfos.emplace_back(configureCollectionMode()); + setConfigurationAttribute(context, configurationInfos); +} + +ConfigureData *CuptiPCSampling::getConfigureData(uint32_t contextId) { + return &contextIdToConfigureData[contextId]; +} + +CubinData *CuptiPCSampling::getCubinData(uint64_t cubinCrc) { + return &(cubinCrcToCubinData[cubinCrc].first); +} + +void CuptiPCSampling::initialize(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() { return !contextInitialized.contain(contextId); }, + contextMutex, + [&]() { + enablePCSampling(context); + getConfigureData(contextId)->initialize(context); + contextInitialized.insert(contextId); + }); +} + +void CuptiPCSampling::start(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() -> bool { return !pcSamplingStarted; }, + pcSamplingMutex, + [&]() { + initialize(context); + // Ensure all previous operations are completed + cuda::ctxSynchronize(); + startPCSampling(context); + pcSamplingStarted = true; + }); +} + +void CuptiPCSampling::processPCSamplingData(ConfigureData *configureData, + uint64_t externId, bool isAPI) { + auto *pcSamplingData = &configureData->pcSamplingData; + auto &profiler = CuptiProfiler::instance(); + auto dataSet = profiler.getDataSet(); + // In the first round, we need to call getPCSamplingData to get the unsynced + // data from the hardware buffer + bool firstRound = true; + while (pcSamplingData->totalNumPcs > 0 || + pcSamplingData->remainingNumPcs > 0 || firstRound) { + // Handle data + for (size_t i = 0; i < pcSamplingData->totalNumPcs; ++i) { + auto *pcData = pcSamplingData->pPcData + i; + auto *cubinData = getCubinData(pcData->cubinCrc); + auto key = + CubinData::LineInfoKey{pcData->functionIndex, pcData->pcOffset}; + if (cubinData->lineInfo.find(key) == cubinData->lineInfo.end()) { + auto [lineNumber, fileName, dirName] = + getSassToSourceCorrelation(pcData->functionName, pcData->pcOffset, + cubinData->cubin, cubinData->cubinSize); + cubinData->lineInfo.try_emplace(key, lineNumber, + std::string(pcData->functionName), + dirName, fileName); + } + auto &lineInfo = cubinData->lineInfo[key]; + for (size_t j = 0; j < pcData->stallReasonCount; ++j) { + auto *stallReason = &pcData->stallReason[j]; + if (!configureData->stallReasonIndexToMetricIndex.count( + stallReason->pcSamplingStallReasonIndex)) + throw std::runtime_error("Invalid stall reason index"); + for (auto *data : dataSet) { + auto scopeId = externId; + if (isAPI) + scopeId = data->addScope(externId, lineInfo.functionName); + if (lineInfo.fileName.size()) + scopeId = data->addScope( + scopeId, lineInfo.dirName + "/" + lineInfo.fileName + ":" + + lineInfo.functionName + "@" + + std::to_string(lineInfo.lineNumber)); + auto metricKind = static_cast( + configureData->stallReasonIndexToMetricIndex + [stallReason->pcSamplingStallReasonIndex]); + auto samples = stallReason->samples; + auto stalledSamples = + configureData->notIssuedStallReasonIndices.count( + stallReason->pcSamplingStallReasonIndex) + ? 0 + : samples; + auto metric = std::make_shared(metricKind, samples, + stalledSamples); + data->addMetric(scopeId, metric); + } + } + } + if (pcSamplingData->remainingNumPcs > 0 || firstRound) { + getPCSamplingData(configureData->context, pcSamplingData); + firstRound = false; + } else + break; + } +} + +void CuptiPCSampling::stop(CUcontext context, uint64_t externId, bool isAPI) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() -> bool { return pcSamplingStarted; }, + pcSamplingMutex, + [&]() { + auto *configureData = getConfigureData(contextId); + stopPCSampling(context); + pcSamplingStarted = false; + processPCSamplingData(configureData, externId, isAPI); + }); +} + +void CuptiPCSampling::finalize(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + if (!contextInitialized.contain(contextId)) + return; + auto *configureData = getConfigureData(contextId); + contextIdToConfigureData.erase(contextId); + contextInitialized.erase(contextId); + disablePCSampling(context); +} + +void CuptiPCSampling::loadModule(const char *cubin, size_t cubinSize) { + auto cubinCrc = getCubinCrc(cubin, cubinSize); + auto *cubinData = getCubinData(cubinCrc); + cubinData->cubinCrc = cubinCrc; + cubinData->cubinSize = cubinSize; + cubinData->cubin = cubin; +} + +void CuptiPCSampling::unloadModule(const char *cubin, size_t cubinSize) { + // XXX: Unload module is supposed to be called in a thread safe manner + // i.e., no two threads will be calling unload module the same time + auto cubinCrc = getCubinCrc(cubin, cubinSize); + auto count = cubinCrcToCubinData[cubinCrc].second; + if (count > 1) + cubinCrcToCubinData[cubinCrc].second = count - 1; + else + cubinCrcToCubinData.erase(cubinCrc); +} + +} // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp new file mode 100644 index 000000000..9ddbd7a71 --- /dev/null +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp @@ -0,0 +1,435 @@ +#include "Profiler/Cupti/CuptiProfiler.h" +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Driver/Device.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Profiler/Cupti/CuptiPCSampling.h" +#include "Utility/Map.h" + +#include +#include +#include +#include + +namespace proton { + +template <> +thread_local GPUProfiler::ThreadState + GPUProfiler::threadState(CuptiProfiler::instance()); + +template <> +thread_local std::deque + GPUProfiler::Correlation::externIdQueue{}; + +namespace { + +std::shared_ptr convertActivityToMetric(CUpti_Activity *activity) { + std::shared_ptr metric; + switch (activity->kind) { + case CUPTI_ACTIVITY_KIND_KERNEL: + case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { + auto *kernel = reinterpret_cast(activity); + if (kernel->start < kernel->end) { + metric = std::make_shared( + static_cast(kernel->start), + static_cast(kernel->end), 1, + static_cast(kernel->deviceId), + static_cast(DeviceType::CUDA)); + } // else: not a valid kernel activity + break; + } + default: + break; + } + return metric; +} + +uint32_t +processActivityKernel(CuptiProfiler::CorrIdToExternIdMap &corrIdToExternId, + CuptiProfiler::ApiExternIdSet &apiExternIds, + std::set &dataSet, CUpti_Activity *activity) { + // Support CUDA >= 11.0 + auto *kernel = reinterpret_cast(activity); + auto correlationId = kernel->correlationId; + if (/*Not a valid context*/ !corrIdToExternId.contain(correlationId)) + return correlationId; + auto [parentId, numInstances] = corrIdToExternId.at(correlationId); + if (kernel->graphId == 0) { + // Non-graph kernels + for (auto *data : dataSet) { + auto scopeId = parentId; + if (apiExternIds.contain(scopeId)) { + // It's triggered by a CUDA op but not triton op + scopeId = data->addScope(parentId, kernel->name); + } + data->addMetric(scopeId, convertActivityToMetric(activity)); + } + } else { + // Graph kernels + // A single graph launch can trigger multiple kernels. + // Our solution is to construct the following maps: + // --- Application threads --- + // 1. graphId -> numKernels + // 2. graphExecId -> graphId + // --- CUPTI thread --- + // 3. corrId -> numKernels + for (auto *data : dataSet) { + auto externId = data->addScope(parentId, kernel->name); + data->addMetric(externId, convertActivityToMetric(activity)); + } + } + apiExternIds.erase(parentId); + --numInstances; + if (numInstances == 0) { + corrIdToExternId.erase(correlationId); + } else { + corrIdToExternId[correlationId].second = numInstances; + } + return correlationId; +} + +uint32_t processActivity(CuptiProfiler::CorrIdToExternIdMap &corrIdToExternId, + CuptiProfiler::ApiExternIdSet &apiExternIds, + std::set &dataSet, CUpti_Activity *activity) { + auto correlationId = 0; + switch (activity->kind) { + case CUPTI_ACTIVITY_KIND_KERNEL: + case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { + correlationId = processActivityKernel(corrIdToExternId, apiExternIds, + dataSet, activity); + break; + } + default: + break; + } + return correlationId; +} + +void setRuntimeCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { +#define CALLBACK_ENABLE(id) \ + cupti::enableCallback(static_cast(enable), subscriber, \ + CUPTI_CB_DOMAIN_RUNTIME_API, id) + + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_ptsz_v7000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_ptsz_v7000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernelExC_v11060); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernelExC_ptsz_v11060); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000); + CALLBACK_ENABLE( + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_ptsz_v9000); + CALLBACK_ENABLE( + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaGraphLaunch_v10000); + CALLBACK_ENABLE(CUPTI_RUNTIME_TRACE_CBID_cudaGraphLaunch_ptsz_v10000); + +#undef CALLBACK_ENABLE +} + +void setDriverCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { +#define CALLBACK_ENABLE(id) \ + cupti::enableCallback(static_cast(enable), subscriber, \ + CUPTI_CB_DOMAIN_DRIVER_API, id) + + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunch); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchGridAsync); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel_ptsz); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx_ptsz); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel_ptsz); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch); + CALLBACK_ENABLE(CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz); +#undef CALLBACK_ENABLE +} + +void setGraphCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { + +#define CALLBACK_ENABLE(id) \ + cupti::enableCallback(static_cast(enable), subscriber, \ + CUPTI_CB_DOMAIN_RESOURCE, id) + + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHNODE_CREATED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHNODE_DESTROY_STARTING); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHEXEC_CREATED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPHEXEC_DESTROY_STARTING); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_GRAPH_DESTROY_STARTING); +#undef CALLBACK_ENABLE +} + +void setResourceCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { +#define CALLBACK_ENABLE(id) \ + cupti::enableCallback(static_cast(enable), subscriber, \ + CUPTI_CB_DOMAIN_RESOURCE, id) + + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_MODULE_LOADED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_MODULE_UNLOAD_STARTING); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_CONTEXT_CREATED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING); +#undef CALLBACK_ENABLE +} + +bool isDriverAPILaunch(CUpti_CallbackId cbId) { + return cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunch || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchGridAsync || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice || + cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch || + cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz; +} + +} // namespace + +struct CuptiProfiler::CuptiProfilerPimpl + : public GPUProfiler::GPUProfilerPimplInterface { + CuptiProfilerPimpl(CuptiProfiler &profiler) + : GPUProfiler::GPUProfilerPimplInterface(profiler) {} + virtual ~CuptiProfilerPimpl() = default; + + void doStart() override; + void doFlush() override; + void doStop() override; + + static void allocBuffer(uint8_t **buffer, size_t *bufferSize, + size_t *maxNumRecords); + static void completeBuffer(CUcontext context, uint32_t streamId, + uint8_t *buffer, size_t size, size_t validSize); + static void callbackFn(void *userData, CUpti_CallbackDomain domain, + CUpti_CallbackId cbId, const void *cbData); + + static constexpr size_t AlignSize = 8; + static constexpr size_t BufferSize = 64 * 1024 * 1024; + static constexpr size_t AttributeSize = sizeof(size_t); + + CUpti_SubscriberHandle subscriber{}; + CuptiPCSampling pcSampling; + + ThreadSafeMap> + graphIdToNumInstances; + ThreadSafeMap> + graphExecIdToGraphId; +}; + +void CuptiProfiler::CuptiProfilerPimpl::allocBuffer(uint8_t **buffer, + size_t *bufferSize, + size_t *maxNumRecords) { + *buffer = static_cast(aligned_alloc(AlignSize, BufferSize)); + if (*buffer == nullptr) { + throw std::runtime_error("aligned_alloc failed"); + } + *bufferSize = BufferSize; + *maxNumRecords = 0; +} + +void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx, + uint32_t streamId, + uint8_t *buffer, + size_t size, + size_t validSize) { + CuptiProfiler &profiler = threadState.profiler; + auto &dataSet = profiler.dataSet; + uint32_t maxCorrelationId = 0; + CUptiResult status; + CUpti_Activity *activity = nullptr; + do { + status = cupti::activityGetNextRecord(buffer, validSize, &activity); + if (status == CUPTI_SUCCESS) { + auto correlationId = + processActivity(profiler.correlation.corrIdToExternId, + profiler.correlation.apiExternIds, dataSet, activity); + maxCorrelationId = std::max(maxCorrelationId, correlationId); + } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { + break; + } else { + throw std::runtime_error("cupti::activityGetNextRecord failed"); + } + } while (true); + + std::free(buffer); + + profiler.correlation.complete(maxCorrelationId); +} + +void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbId, + const void *cbData) { + CuptiProfiler &profiler = threadState.profiler; + if (domain == CUPTI_CB_DOMAIN_RESOURCE) { + auto *resourceData = + static_cast(const_cast(cbData)); + auto *pImpl = dynamic_cast(profiler.pImpl.get()); + if (cbId == CUPTI_CBID_RESOURCE_MODULE_LOADED) { + auto *moduleResource = static_cast( + resourceData->resourceDescriptor); + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.loadModule(moduleResource->pCubin, + moduleResource->cubinSize); + } + } else if (cbId == CUPTI_CBID_RESOURCE_MODULE_UNLOAD_STARTING) { + auto *moduleResource = static_cast( + resourceData->resourceDescriptor); + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.unloadModule(moduleResource->pCubin, + moduleResource->cubinSize); + } + } else if (cbId == CUPTI_CBID_RESOURCE_CONTEXT_CREATED) { + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.initialize(resourceData->context); + } + } else if (cbId == CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING) { + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.finalize(resourceData->context); + } + } else { + auto *graphData = + static_cast(resourceData->resourceDescriptor); + uint32_t graphId = 0; + uint32_t graphExecId = 0; + if (graphData->graph) + cupti::getGraphId(graphData->graph, &graphId); + if (graphData->graphExec) + cupti::getGraphExecId(graphData->graphExec, &graphExecId); + if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CREATED || + cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED) { + if (!pImpl->graphIdToNumInstances.contain(graphId)) + pImpl->graphIdToNumInstances[graphId] = 1; + else + pImpl->graphIdToNumInstances[graphId]++; + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_DESTROY_STARTING) { + pImpl->graphIdToNumInstances[graphId]--; + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_CREATED) { + pImpl->graphExecIdToGraphId[graphExecId] = graphId; + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_DESTROY_STARTING) { + pImpl->graphExecIdToGraphId.erase(graphExecId); + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPH_DESTROY_STARTING) { + pImpl->graphIdToNumInstances.erase(graphId); + } + } + } else { + const CUpti_CallbackData *callbackData = + static_cast(cbData); + auto *pImpl = dynamic_cast(profiler.pImpl.get()); + if (callbackData->callbackSite == CUPTI_API_ENTER) { + auto scopeId = Scope::getNewScopeId(); + threadState.record(scopeId); + threadState.enterOp(scopeId); + size_t numInstances = 1; + if (cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch || + cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz) { + auto graphExec = static_cast( + callbackData->functionParams) + ->hGraph; + uint32_t graphExecId = 0; + cupti::getGraphExecId(graphExec, &graphExecId); + numInstances = std::numeric_limits::max(); + auto findGraph = false; + if (pImpl->graphExecIdToGraphId.contain(graphExecId)) { + auto graphId = pImpl->graphExecIdToGraphId[graphExecId]; + if (pImpl->graphIdToNumInstances.contain(graphId)) { + numInstances = pImpl->graphIdToNumInstances[graphId]; + findGraph = true; + } + } + if (!findGraph) + std::cerr << "[PROTON] Cannot find graph for graphExecId: " + << graphExecId + << ", and t may cause memory leak. To avoid this problem, " + "please start profiling before the graph is created." + << std::endl; + } + profiler.correlation.correlate(callbackData->correlationId, numInstances); + if (profiler.isPCSamplingEnabled() && isDriverAPILaunch(cbId)) { + pImpl->pcSampling.start(callbackData->context); + } + } else if (callbackData->callbackSite == CUPTI_API_EXIT) { + if (profiler.isPCSamplingEnabled() && isDriverAPILaunch(cbId)) { + // XXX: Conservatively stop every GPU kernel for now + auto scopeId = profiler.correlation.externIdQueue.back(); + pImpl->pcSampling.stop( + callbackData->context, scopeId, + profiler.correlation.apiExternIds.contain(scopeId)); + } + threadState.exitOp(); + profiler.correlation.submit(callbackData->correlationId); + } + } +} + +void CuptiProfiler::CuptiProfilerPimpl::doStart() { + cupti::subscribe(&subscriber, callbackFn, nullptr); + if (profiler.isPCSamplingEnabled()) { + setResourceCallbacks(subscriber, /*enable=*/true); + // Continuous PC sampling is not compatible with concurrent kernel profiling + cupti::activityEnable(CUPTI_ACTIVITY_KIND_KERNEL); + } else { + cupti::activityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); + } + cupti::activityRegisterCallbacks(allocBuffer, completeBuffer); + setGraphCallbacks(subscriber, /*enable=*/true); + setRuntimeCallbacks(subscriber, /*enable=*/true); + setDriverCallbacks(subscriber, /*enable=*/true); +} + +void CuptiProfiler::CuptiProfilerPimpl::doFlush() { + // cuptiActivityFlushAll returns the activity records associated with all + // contexts/streams. + // This is a blocking call but it doesn’t issue any CUDA synchronization calls + // implicitly thus it’s not guaranteed that all activities are completed on + // the underlying devices. + // We do an "opportunistic" synchronization here to try to ensure that all + // activities are completed on the current context. + // If the current context is not set, we don't do any synchronization. + CUcontext cuContext = nullptr; + cuda::ctxGetCurrent(&cuContext); + if (cuContext) { + cuda::ctxSynchronize(); + } + if (profiler.isPCSamplingEnabled()) { + pcSampling.finalize(cuContext); + } + profiler.correlation.flush( + /*maxRetries=*/100, /*sleepMs=*/10, + /*flush=*/[]() { + cupti::activityFlushAll( + /*flag=*/0); + }); + // CUPTI_ACTIVITY_FLAG_FLUSH_FORCED is used to ensure that even incomplete + // activities are flushed so that the next profiling session can start with + // new activities. + cupti::activityFlushAll(/*flag=*/CUPTI_ACTIVITY_FLAG_FLUSH_FORCED); +} + +void CuptiProfiler::CuptiProfilerPimpl::doStop() { + if (profiler.isPCSamplingEnabled()) { + setResourceCallbacks(subscriber, /*enable=*/false); + cupti::activityDisable(CUPTI_ACTIVITY_KIND_KERNEL); + } else { + cupti::activityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); + } + setGraphCallbacks(subscriber, /*enable=*/false); + setRuntimeCallbacks(subscriber, /*enable=*/false); + setDriverCallbacks(subscriber, /*enable=*/false); + cupti::unsubscribe(subscriber); + cupti::finalize(); +} + +CuptiProfiler::CuptiProfiler() { + pImpl = std::make_unique(*this); +} + +CuptiProfiler::~CuptiProfiler() = default; + +} // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp deleted file mode 100644 index 81cef5fa0..000000000 --- a/third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp +++ /dev/null @@ -1,292 +0,0 @@ -#include "Profiler/CuptiProfiler.h" -#include "Context/Context.h" -#include "Data/Metric.h" -#include "Driver/Device.h" -#include "Driver/GPU/CudaApi.h" -#include "Driver/GPU/CuptiApi.h" - -#include -#include -#include - -namespace proton { - -template <> -thread_local GPUProfiler::ProfilerState - GPUProfiler::profilerState(CuptiProfiler::instance()); - -namespace { - -std::shared_ptr convertActivityToMetric(CUpti_Activity *activity) { - std::shared_ptr metric; - switch (activity->kind) { - case CUPTI_ACTIVITY_KIND_KERNEL: - case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { - auto *kernel = reinterpret_cast(activity); - metric = - std::make_shared(static_cast(kernel->start), - static_cast(kernel->end), 1, - static_cast(kernel->deviceId), - static_cast(DeviceType::CUDA)); - break; - } - default: - break; - } - return metric; -} - -void addMetric(size_t scopeId, std::set &dataSet, - CUpti_Activity *activity) { - for (auto *data : dataSet) { - data->addMetric(scopeId, convertActivityToMetric(activity)); - } -} - -uint32_t -processActivityExternalCorrelation(std::map &corrIdToExternId, - CUpti_Activity *activity) { - auto *externalActivity = - reinterpret_cast(activity); - corrIdToExternId[externalActivity->correlationId] = - externalActivity->externalId; - return externalActivity->correlationId; -} - -uint32_t processActivityKernel(std::map &corrIdToExternId, - std::set &dataSet, - CUpti_Activity *activity) { - // Support CUDA >= 11.0 - auto *kernel = reinterpret_cast(activity); - auto correlationId = kernel->correlationId; - if (corrIdToExternId.find(correlationId) == corrIdToExternId.end()) - return correlationId; - auto externalId = corrIdToExternId[correlationId]; - addMetric(externalId, dataSet, activity); - // Track correlation ids from the same stream and erase those < correlationId - corrIdToExternId.erase(correlationId); - return correlationId; -} - -uint32_t processActivity(std::map &corrIdToExternId, - std::set &dataSet, CUpti_Activity *activity) { - auto correlationId = 0; - switch (activity->kind) { - case CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION: { - correlationId = - processActivityExternalCorrelation(corrIdToExternId, activity); - break; - } - case CUPTI_ACTIVITY_KIND_KERNEL: - case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: { - correlationId = processActivityKernel(corrIdToExternId, dataSet, activity); - break; - } - default: - break; - } - return correlationId; -} - -std::pair matchKernelCbId(CUpti_CallbackId cbId) { - bool isRuntimeApi = false; - bool isDriverApi = false; - switch (cbId) { - // TODO: switch to directly subscribe the APIs - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_v3020: - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000: - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunch_ptsz_v7000: - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_ptsz_v7000: - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernelExC_v11060: - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernelExC_ptsz_v11060: - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000: - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_ptsz_v9000: - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000: { - isRuntimeApi = true; - break; - } - case CUPTI_DRIVER_TRACE_CBID_cuLaunch: - case CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid: - case CUPTI_DRIVER_TRACE_CBID_cuLaunchGridAsync: - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel_ptsz: - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx: - case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx_ptsz: - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel: - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel_ptsz: - case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice: { - isDriverApi = true; - break; - } - default: - break; - } - return std::make_pair(isRuntimeApi, isDriverApi); -} - -} // namespace - -struct CuptiProfiler::CuptiProfilerPimpl - : public GPUProfiler::GPUProfilerPimplInterface { - CuptiProfilerPimpl(CuptiProfiler &profiler) - : GPUProfiler::GPUProfilerPimplInterface(profiler) {} - virtual ~CuptiProfilerPimpl() = default; - - void startOp(const Scope &scope); - void stopOp(const Scope &scope); - - void doStart(); - void doFlush(); - void doStop(); - - static void allocBuffer(uint8_t **buffer, size_t *bufferSize, - size_t *maxNumRecords); - static void completeBuffer(CUcontext context, uint32_t streamId, - uint8_t *buffer, size_t size, size_t validSize); - static void callbackFn(void *userData, CUpti_CallbackDomain domain, - CUpti_CallbackId cbId, const void *cbData); - - static constexpr size_t AlignSize = 8; - static constexpr size_t BufferSize = 64 * 1024 * 1024; - - std::map corrIdToExternId; - CUpti_SubscriberHandle subscriber{}; -}; - -void CuptiProfiler::CuptiProfilerPimpl::allocBuffer(uint8_t **buffer, - size_t *bufferSize, - size_t *maxNumRecords) { - *buffer = reinterpret_cast(aligned_alloc(AlignSize, BufferSize)); - if (*buffer == nullptr) { - throw std::runtime_error("aligned_alloc failed"); - } - *bufferSize = BufferSize; - *maxNumRecords = 0; -} - -void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx, - uint32_t streamId, - uint8_t *buffer, - size_t size, - size_t validSize) { - CuptiProfiler &profiler = - dynamic_cast(CuptiProfiler::instance()); - auto &pImpl = dynamic_cast(*profiler.pImpl.get()); - auto &dataSet = profiler.dataSet; - uint32_t maxCorrelationId = 0; - CUptiResult status; - CUpti_Activity *activity = nullptr; - do { - status = cupti::activityGetNextRecord(buffer, validSize, &activity); - if (status == CUPTI_SUCCESS) { - auto correlationId = - processActivity(pImpl.corrIdToExternId, dataSet, activity); - maxCorrelationId = std::max(maxCorrelationId, correlationId); - } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { - break; - } else { - throw std::runtime_error("cupti::activityGetNextRecord failed"); - } - } while (true); - - std::free(buffer); - - profiler.correlation.complete(maxCorrelationId); -} - -void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, - CUpti_CallbackDomain domain, - CUpti_CallbackId cbId, - const void *cbData) { - auto [isRuntimeAPI, isDriverAPI] = matchKernelCbId(cbId); - if (!(isRuntimeAPI || isDriverAPI)) { - return; - } - CuptiProfiler &profiler = - dynamic_cast(CuptiProfiler::instance()); - const CUpti_CallbackData *callbackData = - reinterpret_cast(cbData); - if (callbackData->callbackSite == CUPTI_API_ENTER) { - if (callbackData->context) { - // Valid context and outermost level of the kernel launch - auto scopeId = Scope::getNewScopeId(); - auto scope = Scope(scopeId, callbackData->symbolName); - profilerState.record(scope); - } - profilerState.enterOp(); - } else if (callbackData->callbackSite == CUPTI_API_EXIT) { - profilerState.exitOp(); - profiler.correlation.submit(callbackData->correlationId); - } -} - -void CuptiProfiler::CuptiProfilerPimpl::startOp(const Scope &scope) { - cupti::activityPushExternalCorrelationId( - CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0, scope.scopeId); -} - -void CuptiProfiler::CuptiProfilerPimpl::stopOp(const Scope &scope) { - uint64_t correlationId; - cupti::activityPopExternalCorrelationId( - CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0, &correlationId); -} - -void CuptiProfiler::CuptiProfilerPimpl::doStart() { - cupti::activityRegisterCallbacks(allocBuffer, completeBuffer); - cupti::activityEnable(CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION); - // Enable driver and runtime activities after external correlation so that - // external correlation id returned is not 0 - cupti::activityEnable(CUPTI_ACTIVITY_KIND_DRIVER); - cupti::activityEnable(CUPTI_ACTIVITY_KIND_RUNTIME); - cupti::activityEnable(CUPTI_ACTIVITY_KIND_FUNCTION); - cupti::activityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); - // TODO: switch to directly subscribe the APIs and measure overhead - cupti::subscribe(&subscriber, callbackFn, nullptr); - cupti::enableDomain(1, subscriber, CUPTI_CB_DOMAIN_DRIVER_API); - cupti::enableDomain(1, subscriber, CUPTI_CB_DOMAIN_RUNTIME_API); -} - -void CuptiProfiler::CuptiProfilerPimpl::doFlush() { - // cuptiActivityFlushAll returns the activity records associated with all - // contexts/streams. - // This is a blocking call but it doesn’t issue any CUDA synchronization calls - // implicitly thus it’s not guaranteed that all activities are completed on - // the underlying devices. - // We do an "oppurtunistic" synchronization here to try to ensure that all - // activities are completed on the current context. - // If the current context is not set, we don't do any synchronization. - CUcontext cuContext = nullptr; - cuda::ctxGetCurrent(&cuContext); - if (cuContext) - cuda::ctxSynchronize(); - profiler.correlation.flush( - /*maxRetries=*/100, /*sleepMs=*/10, - /*flush=*/[]() { - cupti::activityFlushAll( - /*flag=*/0); - }); - // CUPTI_ACTIVITY_FLAG_FLUSH_FORCED is used to ensure that even incomplete - // activities are flushed so that the next profiling session can start with - // new activities. - cupti::activityFlushAll(/*flag=*/CUPTI_ACTIVITY_FLAG_FLUSH_FORCED); -} - -void CuptiProfiler::CuptiProfilerPimpl::doStop() { - cupti::activityDisable(CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION); - cupti::activityDisable(CUPTI_ACTIVITY_KIND_DRIVER); - cupti::activityDisable(CUPTI_ACTIVITY_KIND_RUNTIME); - cupti::activityDisable(CUPTI_ACTIVITY_KIND_FUNCTION); - cupti::activityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); - cupti::enableDomain(0, subscriber, CUPTI_CB_DOMAIN_DRIVER_API); - cupti::enableDomain(0, subscriber, CUPTI_CB_DOMAIN_RUNTIME_API); - cupti::unsubscribe(subscriber); - cupti::finalize(); -} - -CuptiProfiler::CuptiProfiler() { - pImpl = std::make_unique(*this); -} - -CuptiProfiler::~CuptiProfiler() = default; - -} // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp new file mode 100644 index 000000000..68f3f0bea --- /dev/null +++ b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -0,0 +1,393 @@ +#include "Profiler/Roctracer/RoctracerProfiler.h" +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Driver/GPU/HipApi.h" +#include "Driver/GPU/HsaApi.h" +#include "Driver/GPU/RoctracerApi.h" + +#include "hip/amd_detail/hip_runtime_prof.h" +#include "roctracer/roctracer_ext.h" +#include "roctracer/roctracer_hip.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace proton { + +template <> +thread_local GPUProfiler::ThreadState + GPUProfiler::threadState(RoctracerProfiler::instance()); + +template <> +thread_local std::deque + GPUProfiler::Correlation::externIdQueue{}; + +namespace { + +class DeviceInfo : public Singleton { +public: + DeviceInfo() = default; + int mapDeviceId(int id) { + // Lazy initialization of device offset by calling hip API. + // Otherwise on nvidia platforms, the HSA call will fail because of no + // available libraries. + std::call_once(deviceOffsetFlag, [this]() { initDeviceOffset(); }); + return id - deviceOffset; + } + +private: + void initDeviceOffset() { + int dc = 0; + auto ret = hip::getDeviceCount(&dc); + hsa::iterateAgents( + [](hsa_agent_t agent, void *data) { + auto &offset = *static_cast(data); + int nodeId; + hsa::agentGetInfo( + agent, + static_cast(HSA_AMD_AGENT_INFO_DRIVER_NODE_ID), + &nodeId); + int deviceType; + hsa::agentGetInfo( + agent, static_cast(HSA_AGENT_INFO_DEVICE), + &deviceType); + if ((nodeId < offset) && (deviceType == HSA_DEVICE_TYPE_GPU)) + offset = nodeId; + + return HSA_STATUS_SUCCESS; + }, + &deviceOffset); + } + + std::once_flag deviceOffsetFlag; + int deviceOffset = 0x7fffffff; +}; + +std::shared_ptr +convertActivityToMetric(const roctracer_record_t *activity) { + std::shared_ptr metric; + switch (activity->kind) { + case kHipVdiCommandKernel: { + if (activity->begin_ns < activity->end_ns) { + metric = std::make_shared( + static_cast(activity->begin_ns), + static_cast(activity->end_ns), 1, + static_cast( + DeviceInfo::instance().mapDeviceId(activity->device_id)), + static_cast(DeviceType::HIP)); + } + break; + } + default: + break; + } + return metric; +} + +void processActivityKernel( + RoctracerProfiler::CorrIdToExternIdMap &corrIdToExternId, size_t externId, + std::set &dataSet, const roctracer_record_t *activity, bool isAPI, + bool isGraph) { + if (externId == Scope::DummyScopeId) + return; + auto correlationId = activity->correlation_id; + auto [parentId, numInstances] = corrIdToExternId.at(correlationId); + if (!isGraph) { + for (auto *data : dataSet) { + auto scopeId = parentId; + if (isAPI) + scopeId = data->addScope(parentId, activity->kernel_name); + data->addMetric(scopeId, convertActivityToMetric(activity)); + } + } else { + // Graph kernels + // A single grpah launch can trigger multiple kernels. + // Our solution is to construct the following maps: + // --- Application threads --- + // 1. Graph -> numKernels + // 2. GraphExec -> Graph + // --- Roctracer thread --- + // 3. corrId -> numKernels + for (auto *data : dataSet) { + auto externId = data->addScope(parentId, activity->kernel_name); + data->addMetric(externId, convertActivityToMetric(activity)); + } + } + --numInstances; + if (numInstances == 0) { + corrIdToExternId.erase(correlationId); + } else { + corrIdToExternId[correlationId].second = numInstances; + } + return; +} + +void processActivity(RoctracerProfiler::CorrIdToExternIdMap &corrIdToExternId, + RoctracerProfiler::ApiExternIdSet &apiExternIds, + size_t externId, std::set &dataSet, + const roctracer_record_t *record, bool isAPI, + bool isGraph) { + switch (record->kind) { + case 0x11F1: // Task - kernel enqueued by graph launch + case kHipVdiCommandKernel: { + processActivityKernel(corrIdToExternId, externId, dataSet, record, isAPI, + isGraph); + break; + } + default: + break; + } +} + +} // namespace + +namespace { + +std::pair matchKernelCbId(uint32_t cbId) { + bool isRuntimeApi = false; + bool isDriverApi = false; + switch (cbId) { + // TODO: switch to directly subscribe the APIs + case HIP_API_ID_hipStreamBeginCapture: + case HIP_API_ID_hipStreamEndCapture: + case HIP_API_ID_hipExtLaunchKernel: + case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: + case HIP_API_ID_hipExtModuleLaunchKernel: + case HIP_API_ID_hipHccModuleLaunchKernel: + case HIP_API_ID_hipLaunchCooperativeKernel: + case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: + case HIP_API_ID_hipLaunchKernel: + case HIP_API_ID_hipModuleLaunchKernel: + case HIP_API_ID_hipGraphLaunch: + case HIP_API_ID_hipModuleLaunchCooperativeKernel: + case HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: + case HIP_API_ID_hipGraphExecDestroy: + case HIP_API_ID_hipGraphInstantiate: { + isRuntimeApi = true; + break; + } + default: + break; + } + return std::make_pair(isRuntimeApi, isDriverApi); +} + +} // namespace + +struct RoctracerProfiler::RoctracerProfilerPimpl + : public GPUProfiler::GPUProfilerPimplInterface { + RoctracerProfilerPimpl(RoctracerProfiler &profiler) + : GPUProfiler::GPUProfilerPimplInterface(profiler) {} + virtual ~RoctracerProfilerPimpl() = default; + + void doStart() override; + void doFlush() override; + void doStop() override; + + static void apiCallback(uint32_t domain, uint32_t cid, + const void *callbackData, void *arg); + static void activityCallback(const char *begin, const char *end, void *arg); + + static constexpr size_t BufferSize = 64 * 1024 * 1024; + + ThreadSafeMap> + CorrIdToIsHipGraph; + + ThreadSafeMap> + GraphExecToGraph; + + ThreadSafeMap> + GraphToNumInstances; + + ThreadSafeMap> + StreamToCaptureCount; + + ThreadSafeMap> + StreamToCapture; +}; + +void RoctracerProfiler::RoctracerProfilerPimpl::apiCallback( + uint32_t domain, uint32_t cid, const void *callbackData, void *arg) { + auto [isRuntimeAPI, isDriverAPI] = matchKernelCbId(cid); + + if (!(isRuntimeAPI || isDriverAPI)) { + return; + } + + auto &profiler = + dynamic_cast(RoctracerProfiler::instance()); + auto *pImpl = dynamic_cast( + profiler.pImpl.get()); + if (domain == ACTIVITY_DOMAIN_HIP_API) { + const hip_api_data_t *data = (const hip_api_data_t *)(callbackData); + if (data->phase == ACTIVITY_API_PHASE_ENTER) { + // Valid context and outermost level of the kernel launch + auto scopeId = Scope::getNewScopeId(); + threadState.record(scopeId); + threadState.enterOp(scopeId); + size_t numInstances = 1; + if (cid == HIP_API_ID_hipGraphLaunch) { + pImpl->CorrIdToIsHipGraph[data->correlation_id] = true; + hipGraphExec_t GraphExec = data->args.hipGraphLaunch.graphExec; + numInstances = std::numeric_limits::max(); + bool findGraph = false; + if (pImpl->GraphExecToGraph.contain(GraphExec)) { + hipGraph_t Graph = pImpl->GraphExecToGraph[GraphExec]; + if (pImpl->GraphToNumInstances.contain(Graph)) { + numInstances = pImpl->GraphToNumInstances[Graph]; + findGraph = true; + } + } + if (!findGraph) + std::cerr + << "[PROTON] Cannot find graph and it may cause a memory leak." + "To avoid this problem, please start profiling before the " + "graph is created." + << std::endl; + } + profiler.correlation.correlate(data->correlation_id, numInstances); + } else if (data->phase == ACTIVITY_API_PHASE_EXIT) { + switch (cid) { + case HIP_API_ID_hipStreamBeginCapture: { + hipStream_t Stream = data->args.hipStreamBeginCapture.stream; + pImpl->StreamToCaptureCount[Stream] = 0; + pImpl->StreamToCapture[Stream] = true; + break; + } + case HIP_API_ID_hipStreamEndCapture: { + hipGraph_t Graph = *(data->args.hipStreamEndCapture.pGraph); + hipStream_t Stream = data->args.hipStreamEndCapture.stream; + // How many times did we capture a kernel launch for this stream + uint32_t StreamCaptureCount = pImpl->StreamToCaptureCount[Stream]; + pImpl->GraphToNumInstances[Graph] = StreamCaptureCount; + pImpl->StreamToCapture.erase(Stream); + } + case HIP_API_ID_hipLaunchKernel: { + hipStream_t Stream = data->args.hipLaunchKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipExtLaunchKernel: { + hipStream_t Stream = data->args.hipExtLaunchKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipLaunchCooperativeKernel: { + hipStream_t Stream = data->args.hipLaunchCooperativeKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipModuleLaunchKernel: { + hipStream_t Stream = data->args.hipModuleLaunchKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipModuleLaunchCooperativeKernel: { + hipStream_t Stream = data->args.hipModuleLaunchCooperativeKernel.stream; + if (pImpl->StreamToCapture.contain(Stream)) + pImpl->StreamToCaptureCount[Stream]++; + break; + } + case HIP_API_ID_hipGraphInstantiate: { + hipGraph_t Graph = data->args.hipGraphInstantiate.graph; + hipGraphExec_t GraphExec = *(data->args.hipGraphInstantiate.pGraphExec); + pImpl->GraphExecToGraph[GraphExec] = Graph; + break; + } + } + threadState.exitOp(); + // Track outstanding op for flush + profiler.correlation.submit(data->correlation_id); + } + } +} + +void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback( + const char *begin, const char *end, void *arg) { + auto &profiler = + dynamic_cast(RoctracerProfiler::instance()); + auto *pImpl = dynamic_cast( + profiler.pImpl.get()); + auto &dataSet = profiler.dataSet; + auto &correlation = profiler.correlation; + + const roctracer_record_t *record = + reinterpret_cast(begin); + const roctracer_record_t *endRecord = + reinterpret_cast(end); + uint64_t maxCorrelationId = 0; + + while (record != endRecord) { + // Log latest completed correlation id. Used to ensure we have flushed all + // data on stop + maxCorrelationId = + std::max(maxCorrelationId, record->correlation_id); + // TODO(Keren): Roctracer doesn't support cuda graph yet. + auto externId = + correlation.corrIdToExternId.contain(record->correlation_id) + ? correlation.corrIdToExternId.at(record->correlation_id).first + : Scope::DummyScopeId; + auto isAPI = correlation.apiExternIds.contain(externId); + bool isGraph = pImpl->CorrIdToIsHipGraph.contain(record->correlation_id); + processActivity(correlation.corrIdToExternId, correlation.apiExternIds, + externId, dataSet, record, isAPI, isGraph); + // Track correlation ids from the same stream and erase those < + // correlationId + correlation.corrIdToExternId.erase(record->correlation_id); + correlation.apiExternIds.erase(externId); + roctracer::getNextRecord(record, &record); + } + correlation.complete(maxCorrelationId); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doStart() { + roctracer::enableDomainCallback(ACTIVITY_DOMAIN_HIP_API, apiCallback, + nullptr); + // Activity Records + roctracer_properties_t properties{0}; + properties.buffer_size = BufferSize; + properties.buffer_callback_fun = activityCallback; + roctracer::openPool(&properties); + roctracer::enableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); + roctracer::start(); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doFlush() { + // Implement reliable flushing. + // Wait for all dispatched ops to be reported. + std::ignore = hip::deviceSynchronize(); + // If flushing encounters an activity record still being written, flushing + // stops. Use a subsequent flush when the record has completed being written + // to resume the flush. + profiler.correlation.flush( + /*maxRetries=*/100, /*sleepMs=*/10, /*flush=*/ + []() { roctracer::flushActivity(); }); +} + +void RoctracerProfiler::RoctracerProfilerPimpl::doStop() { + roctracer::stop(); + roctracer::disableDomainCallback(ACTIVITY_DOMAIN_HIP_API); + roctracer::disableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); + roctracer::closePool(); +} + +RoctracerProfiler::RoctracerProfiler() { + pImpl = std::make_unique(*this); +} + +RoctracerProfiler::~RoctracerProfiler() = default; + +} // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp deleted file mode 100644 index a56d23e3e..000000000 --- a/third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp +++ /dev/null @@ -1,343 +0,0 @@ -#include "Profiler/RoctracerProfiler.h" -#include "Context/Context.h" -#include "Data/Metric.h" -#include "Driver/GPU/HipApi.h" -#include "Driver/GPU/HsaApi.h" -#include "Driver/GPU/RoctracerApi.h" - -#include "hip/amd_detail/hip_runtime_prof.h" -#include "roctracer/roctracer_ext.h" -#include "roctracer/roctracer_hip.h" - -#include -#include -#include -#include -#include - -#include -#include - -namespace proton { - -template <> -thread_local GPUProfiler::ProfilerState - GPUProfiler::profilerState( - RoctracerProfiler::instance()); - -namespace { - -// Node to device id mapping -int deviceOffset = 0x7fffffff; - -void createDeviceMap() { - int dc = 0; - auto ret = hip::getDeviceCount(&dc); - hsa::iterateAgents( - [](hsa_agent_t agent, void *data) { - auto &deviceOffset = *static_cast(data); - int nodeId; - hsa::agentGetInfo( - agent, - static_cast(HSA_AMD_AGENT_INFO_DRIVER_NODE_ID), - &nodeId); - int deviceType; - hsa::agentGetInfo( - agent, static_cast(HSA_AGENT_INFO_DEVICE), - &deviceType); - if ((nodeId < deviceOffset) && (deviceType == HSA_DEVICE_TYPE_GPU)) - deviceOffset = nodeId; - - return HSA_STATUS_SUCCESS; - }, - &deviceOffset); -}; - -int mapDeviceId(int id) { return id - deviceOffset; }; - -std::shared_ptr -convertActivityToMetric(const roctracer_record_t *activity) { - std::shared_ptr metric; - switch (activity->kind) { - case kHipVdiCommandKernel: { - metric = std::make_shared( - static_cast(activity->begin_ns), - static_cast(activity->end_ns), 1, - static_cast(mapDeviceId(activity->device_id)), - static_cast(DeviceType::HIP)); - break; - } - default: - break; - } - return metric; -} - -void addMetric(size_t scopeId, std::set &dataSet, - const roctracer_record_t *activity) { - for (auto *data : dataSet) { - data->addMetric(scopeId, convertActivityToMetric(activity)); - } -} - -void processActivityKernel(std::mutex &corrIdToExternIdMutex, - std::map &corrIdToExternId, - std::set &dataSet, - const roctracer_record_t *activity) { - auto correlationId = activity->correlation_id; - std::unique_lock lock(corrIdToExternIdMutex); - if (corrIdToExternId.find(correlationId) == corrIdToExternId.end()) - return; - auto externalId = corrIdToExternId[correlationId]; - addMetric(externalId, dataSet, activity); - // Track correlation ids from the same stream and erase those < correlationId - corrIdToExternId.erase(correlationId); -} - -void processActivity(std::mutex &corrIdToExternIdMutex, - std::map &corrIdToExternId, - std::set &dataSet, - const roctracer_record_t *record) { - switch (record->kind) { - case 0x11F1: // Task - kernel enqueued by graph launch - case kHipVdiCommandKernel: { - processActivityKernel(corrIdToExternIdMutex, corrIdToExternId, dataSet, - record); - break; - } - default: - break; - } -} - -} // namespace - -namespace { - -std::pair matchKernelCbId(uint32_t cbId) { - bool isRuntimeApi = false; - bool isDriverApi = false; - switch (cbId) { - // TODO: switch to directly subscribe the APIs - case HIP_API_ID_hipExtLaunchKernel: - case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: - case HIP_API_ID_hipExtModuleLaunchKernel: - case HIP_API_ID_hipHccModuleLaunchKernel: - case HIP_API_ID_hipLaunchCooperativeKernel: - case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: - case HIP_API_ID_hipLaunchKernel: - case HIP_API_ID_hipModuleLaunchKernel: - case HIP_API_ID_hipGraphLaunch: - case HIP_API_ID_hipModuleLaunchCooperativeKernel: - case HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: { - isRuntimeApi = true; - break; - } - default: - break; - } - return std::make_pair(isRuntimeApi, isDriverApi); -} -// C++ symbol demangle -static inline const std::string cxxDemangle(const char *symbol) { - size_t funcNameSize; - int status; - if (const char *name = - abi::__cxa_demangle(symbol, NULL, &funcNameSize, &status)) { - std::string ret(name); - std::free(reinterpret_cast(const_cast(name))); - return ret; - } - return std::string(symbol); -} - -const std::string kernelName(uint32_t domain, uint32_t cid, - const void *callback_data) { - std::string name; - if (domain == ACTIVITY_DOMAIN_HIP_API) { - const hip_api_data_t *data = (const hip_api_data_t *)(callback_data); - switch (cid) { - case HIP_API_ID_hipExtLaunchKernel: { - auto ¶ms = data->args.hipExtLaunchKernel; - name = cxxDemangle( - hip::getKernelNameRefByPtr(params.function_address, params.stream)); - } break; - case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: { - auto ¶ms = - data->args.hipExtLaunchMultiKernelMultiDevice.launchParamsList__val; - name = - cxxDemangle(hip::getKernelNameRefByPtr(params.func, params.stream)); - } break; - case HIP_API_ID_hipExtModuleLaunchKernel: { - auto ¶ms = data->args.hipExtModuleLaunchKernel; - name = cxxDemangle(hip::getKernelNameRef(params.f)); - } break; - case HIP_API_ID_hipHccModuleLaunchKernel: { - auto ¶ms = data->args.hipHccModuleLaunchKernel; - name = cxxDemangle(hip::getKernelNameRef(params.f)); - } break; - case HIP_API_ID_hipLaunchCooperativeKernel: { - auto ¶ms = data->args.hipLaunchCooperativeKernel; - name = cxxDemangle(hip::getKernelNameRefByPtr(params.f, params.stream)); - } break; - case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: { - auto ¶ms = data->args.hipLaunchCooperativeKernelMultiDevice - .launchParamsList__val; - name = - cxxDemangle(hip::getKernelNameRefByPtr(params.func, params.stream)); - } break; - case HIP_API_ID_hipLaunchKernel: { - auto ¶ms = data->args.hipLaunchKernel; - name = cxxDemangle( - hip::getKernelNameRefByPtr(params.function_address, params.stream)); - } break; - case HIP_API_ID_hipModuleLaunchKernel: { - auto ¶ms = data->args.hipModuleLaunchKernel; - name = cxxDemangle(hip::getKernelNameRef(params.f)); - } break; - case HIP_API_ID_hipGraphLaunch: { - name = "graphLaunch"; - } break; - default: - break; - } - } - return name; -} - -} // namespace - -enum CorrelationDomain { Default, Domain0, Domain1, Count }; - -struct RoctracerProfiler::RoctracerProfilerPimpl - : public GPUProfiler::GPUProfilerPimplInterface { - RoctracerProfilerPimpl(RoctracerProfiler &profiler) - : GPUProfiler::GPUProfilerPimplInterface(profiler) {} - virtual ~RoctracerProfilerPimpl() = default; - - void startOp(const Scope &scope); - void stopOp(const Scope &scope); - - void doStart(); - void doFlush(); - void doStop(); - - static void apiCallback(uint32_t domain, uint32_t cid, - const void *callbackData, void *arg); - static void activityCallback(const char *begin, const char *end, void *arg); - - static constexpr size_t BufferSize = 64 * 1024 * 1024; - - std::mutex corrIdToExternIdMutex; - std::map corrIdToExternId; - inline static thread_local std::deque - externIdQueue[CorrelationDomain::Count]; -}; - -void RoctracerProfiler::RoctracerProfilerPimpl::apiCallback( - uint32_t domain, uint32_t cid, const void *callback_data, void *arg) { - auto [isRuntimeAPI, isDriverAPI] = matchKernelCbId(cid); - if (!(isRuntimeAPI || isDriverAPI)) { - return; - } - auto &profiler = - dynamic_cast(RoctracerProfiler::instance()); - auto &pImpl = dynamic_cast( - *profiler.pImpl); - if (domain == ACTIVITY_DOMAIN_HIP_API) { - const hip_api_data_t *data = (const hip_api_data_t *)(callback_data); - if (data->phase == ACTIVITY_API_PHASE_ENTER) { - // Valid context and outermost level of the kernel launch - const std::string name = kernelName(domain, cid, callback_data); - auto scopeId = Scope::getNewScopeId(); - auto scope = Scope(scopeId, name); - profilerState.record(scope); - profilerState.enterOp(); - if (externIdQueue[CorrelationDomain::Domain0].empty()) - return; - std::unique_lock lock(pImpl.corrIdToExternIdMutex); - pImpl.corrIdToExternId[data->correlation_id] = - externIdQueue[CorrelationDomain::Domain0].back(); - } else if (data->phase == ACTIVITY_API_PHASE_EXIT) { - profilerState.exitOp(); - // Track outstanding op for flush - profiler.correlation.submit(data->correlation_id); - } - } -} - -void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback( - const char *begin, const char *end, void *arg) { - auto &profiler = - dynamic_cast(RoctracerProfiler::instance()); - auto &pImpl = dynamic_cast( - *profiler.pImpl); - auto &dataSet = profiler.dataSet; - auto &correlation = profiler.correlation; - - const roctracer_record_t *record = - reinterpret_cast(begin); - const roctracer_record_t *endRecord = - reinterpret_cast(end); - uint64_t maxCorrelationId = 0; - - while (record != endRecord) { - // Log latest completed correlation id. Used to ensure we have flushed all - // data on stop - maxCorrelationId = - std::max(maxCorrelationId, record->correlation_id); - processActivity(pImpl.corrIdToExternIdMutex, pImpl.corrIdToExternId, - dataSet, record); - roctracer::getNextRecord(record, &record); - } - correlation.complete(maxCorrelationId); -} - -void RoctracerProfiler::RoctracerProfilerPimpl::startOp(const Scope &scope) { - // Track correlation id for the scope - externIdQueue[CorrelationDomain::Domain0].push_back(scope.scopeId); -} - -void RoctracerProfiler::RoctracerProfilerPimpl::stopOp(const Scope &scope) { - externIdQueue[CorrelationDomain::Domain0].pop_back(); -} - -void RoctracerProfiler::RoctracerProfilerPimpl::doStart() { - roctracer::enableDomainCallback(ACTIVITY_DOMAIN_HIP_API, apiCallback, - nullptr); - // Activity Records - roctracer_properties_t properties{0}; - properties.buffer_size = BufferSize; - properties.buffer_callback_fun = activityCallback; - roctracer::openPool(&properties); - roctracer::enableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); - roctracer::start(); -} - -void RoctracerProfiler::RoctracerProfilerPimpl::doFlush() { - // Implement reliable flushing. - // Wait for all dispatched ops to be reported. - std::ignore = hip::deviceSynchronize(); - // If flushing encounters an activity record still being written, flushing - // stops. Use a subsequent flush when the record has completed being written - // to resume the flush. - profiler.correlation.flush( - /*maxRetries=*/100, /*sleepMs=*/10, /*flush=*/ - []() { roctracer::flushActivity(); }); -} - -void RoctracerProfiler::RoctracerProfilerPimpl::doStop() { - roctracer::stop(); - roctracer::disableDomainCallback(ACTIVITY_DOMAIN_HIP_API); - roctracer::disableDomainActivity(ACTIVITY_DOMAIN_HIP_OPS); - roctracer::closePool(); -} - -RoctracerProfiler::RoctracerProfiler() { - pImpl = std::make_unique(*this); - createDeviceMap(); -} - -RoctracerProfiler::~RoctracerProfiler() = default; - -} // namespace proton diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 0265981e5..5ff74f0fc 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -2,8 +2,8 @@ #include "Context/Python.h" #include "Context/Shadow.h" #include "Data/TreeData.h" -#include "Profiler/CuptiProfiler.h" -#include "Profiler/RoctracerProfiler.h" +#include "Profiler/Cupti/CuptiProfiler.h" +#include "Profiler/Roctracer/RoctracerProfiler.h" #include "Utility/String.h" namespace proton { @@ -13,6 +13,9 @@ Profiler *getProfiler(const std::string &profilerName) { if (proton::toLower(profilerName) == "cupti") { return &CuptiProfiler::instance(); } + if (proton::toLower(profilerName) == "cupti_pcsampling") { + return &CuptiProfiler::instance().enablePCSampling(); + } if (proton::toLower(profilerName) == "roctracer") { return &RoctracerProfiler::instance(); } @@ -37,10 +40,21 @@ makeContextSource(const std::string &contextSourceName) { } throw std::runtime_error("Unknown context source: " + contextSourceName); } + +void throwIfSessionNotInitialized( + const std::map> &sessions, + size_t sessionId) { + if (!sessions.count(sessionId)) { + throw std::runtime_error("Session has not been initialized: " + + std::to_string(sessionId)); + } +} + } // namespace void Session::activate() { profiler->start(); + profiler->flush(); profiler->registerData(data.get()); } @@ -76,6 +90,7 @@ void SessionManager::deactivateSession(size_t sessionId) { } void SessionManager::activateSessionImpl(size_t sessionId) { + throwIfSessionNotInitialized(sessions, sessionId); if (activeSessions[sessionId]) return; activeSessions[sessionId] = true; @@ -85,6 +100,7 @@ void SessionManager::activateSessionImpl(size_t sessionId) { } void SessionManager::deActivateSessionImpl(size_t sessionId) { + throwIfSessionNotInitialized(sessions, sessionId); if (!activeSessions[sessionId]) { return; } diff --git a/third_party/proton/proton/hook.py b/third_party/proton/proton/hook.py index a9ec5f36b..94f94bee0 100644 --- a/third_party/proton/proton/hook.py +++ b/third_party/proton/proton/hook.py @@ -6,7 +6,7 @@ class TritonHook: flops_width = [8, 16, 32, 64] - metrics = [f"flops{width}" for width in flops_width] + ["bytes"] + metrics = [f"flops{width}" for width in flops_width] + ["bytes"] + ["flops"] @staticmethod def enter(lazy_dict: LazyDict) -> None: diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 2bf7938a5..808a1742a 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -1,5 +1,6 @@ import functools import triton +import os from triton._C.libproton import proton as libproton from .hook import register_triton_hook, unregister_triton_hook @@ -19,6 +20,16 @@ def _select_backend() -> str: raise ValueError("No backend is available for the current target.") +def _check_env(backend: str) -> None: + if backend == "roctracer": + hip_device_envs = ["HIP_VISIBLE_DEVICES", "CUDA_VISIBLE_DEVICES"] + for env in hip_device_envs: + if os.getenv(env, None) is not None: + raise ValueError( + f"Proton does not work when the environment variable {env} is set on AMD GPUs. Please unset it and use `ROCR_VISIBLE_DEVICES` instead" + ) + + def start( name: Optional[str] = None, *, @@ -42,7 +53,7 @@ def start( name (str, optional): The name (with path) of the profiling session. If not provided, the default name is "~/proton.hatchet". backend (str, optional): The backend to use for profiling. - Available options are ["cupti"]. + Available options are [None, "cupti", "cupti_pcsampling", "roctracer"]. Defaults to None, which automatically selects the backend matching the current active runtime. context (str, optional): The context to use for profiling. Available options are ["shadow", "python"]. @@ -66,6 +77,8 @@ def start( if backend is None: backend = _select_backend() + _check_env(backend) + set_profiling_on() if hook and hook == "triton": register_triton_hook() diff --git a/third_party/proton/proton/proton.py b/third_party/proton/proton/proton.py index 21267f97d..cbb7a0b6f 100644 --- a/third_party/proton/proton/proton.py +++ b/third_party/proton/proton/proton.py @@ -13,13 +13,15 @@ def parse_arguments(): python -m triton.profiler.proton [options] script.py [script_args] [script_options] """, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("-n", "--name", type=str, help="Name of the profiling session") - parser.add_argument("-b", "--backend", type=str, help="Profiling backend", default=None, choices=["cupti"]) + parser.add_argument("-b", "--backend", type=str, help="Profiling backend", default=None, + choices=["cupti", "cupti_pcsampling", "roctracer"]) parser.add_argument("-c", "--context", type=str, help="Profiling context", default="shadow", choices=["shadow", "python"]) parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree"]) parser.add_argument("-k", "--hook", type=str, help="Profiling hook", default=None, choices=[None, "triton"]) - args, target_args = parser.parse_known_args() - return args, target_args + parser.add_argument('target_args', nargs=argparse.REMAINDER, help='Subcommand and its arguments') + args = parser.parse_args() + return args, args.target_args def is_pytest(script): @@ -38,6 +40,8 @@ def execute_as_main(script, args): original_argv = sys.argv sys.argv = [script] + args + # Append the script's directory in case the script uses relative imports + sys.path.append(os.path.dirname(script_path)) # Execute in the isolated environment try: diff --git a/third_party/proton/proton/scope.py b/third_party/proton/proton/scope.py index 5695b8807..26d946a8c 100644 --- a/third_party/proton/proton/scope.py +++ b/third_party/proton/proton/scope.py @@ -5,7 +5,7 @@ from .flags import get_profiling_on from triton._C.libproton import proton as libproton -_local = threading.local() +thread_local_scopes = threading.local() MetricValueType = Union[float, int] PropertyValueType = Union[float, int, str] @@ -22,7 +22,7 @@ class scope: foo[1,](x, y) ``` - decoarator: + decorator: ```python @proton.scope("test0", {metric_name: metric_value}) def foo(x, y): @@ -36,25 +36,25 @@ def foo(x, y): def __init__(self, name: str, metrics: Optional[dict[str, MetricValueType]] = None, properties: Optional[dict[str, PropertyValueType]] = None) -> None: - self._name = name - self._metrics = metrics - self._properties = properties + self.name = name + self.metrics = metrics + self.properties = properties def __enter__(self): if not get_profiling_on(): return self - self._id = libproton.record_scope() - libproton.enter_scope(self._id, self._name) - if self._metrics: - libproton.add_metrics(self._id, self._metrics) - if self._properties: - libproton.set_properties(self._id, self._properties) + self.id = libproton.record_scope() + libproton.enter_scope(self.id, self.name) + if self.metrics: + libproton.add_metrics(self.id, self.metrics) + if self.properties: + libproton.set_properties(self.id, self.properties) return self def __exit__(self, exc_type, exc_value, traceback) -> None: if not get_profiling_on(): return - libproton.exit_scope(self._id, self._name) + libproton.exit_scope(self.id, self.name) def __call__(self, func): @@ -62,14 +62,14 @@ def __call__(self, func): def wrapper(*args, **kwargs): if get_profiling_on(): id = libproton.record_scope() - libproton.enter_scope(id, self._name) - if self._metrics: - libproton.add_metrics(id, self._metrics) - if self._properties: - libproton.set_properties(id, self._properties) + libproton.enter_scope(id, self.name) + if self.metrics: + libproton.add_metrics(id, self.metrics) + if self.properties: + libproton.set_properties(id, self.properties) ret = func(*args, **kwargs) if get_profiling_on(): - libproton.exit_scope(id, self._name) + libproton.exit_scope(id, self.name) return ret return wrapper @@ -80,9 +80,9 @@ def enter_scope(name: str, *, triton_op: bool = False, metrics: Optional[dict[st if not get_profiling_on(): return -1 id = libproton.record_scope() - if not hasattr(_local, "scopes"): - _local.scopes = [] - _local.scopes.append((id, name)) + if not hasattr(thread_local_scopes, "scopes"): + thread_local_scopes.scopes = [] + thread_local_scopes.scopes.append((id, name)) if triton_op: libproton.enter_op(id, name) else: @@ -97,7 +97,7 @@ def enter_scope(name: str, *, triton_op: bool = False, metrics: Optional[dict[st def exit_scope(triton_op: bool = False) -> int: if not get_profiling_on(): return -1 - id, name = _local.scopes.pop() + id, name = thread_local_scopes.scopes.pop() if triton_op: libproton.exit_op(id, name) else: diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index 3ef3a4c93..fe7c98807 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -2,8 +2,12 @@ from collections import namedtuple import json import pandas as pd - -import hatchet as ht +try: + import hatchet as ht + from hatchet.query import NegationQuery +except ImportError: + raise ImportError("Failed to import hatchet. `pip install llnl-hatchet` to get the correct version.") +import numpy as np from triton.profiler.hook import COMPUTE_METADATA_SCOPE_NAME, TritonHook @@ -18,7 +22,9 @@ def match_available_metrics(metrics, raw_metrics): ret.append(raw_metric + " (inc)") break else: - ret = [raw_metrics[0]] + " (inc)" + ret = [raw_metrics[0] + " (inc)"] + if len(ret) == 0: + raise RuntimeError(f"Metric {metric} is not found. Use the --list flag to list available metrics") return ret @@ -37,7 +43,7 @@ def get_min_time_flops(df, device_info): num_sms = device_info[device_type][device_index]["num_sms"] clock_rate = device_info[device_type][device_index]["clock_rate"] for width in TritonHook.flops_width: - idx = df["DeviceId"] == device_index + idx = df["device_id"] == device_index device_frames = df[idx] if f"flops{width}" not in device_frames.columns: continue @@ -66,7 +72,7 @@ def get_min_time_bytes(df, device_info): min_time_bytes = pd.DataFrame(0.0, index=df.index, columns=["min_time"]) for device_type in device_info: for device_index in device_info[device_type]: - idx = df["DeviceId"] == device_index + idx = df["device_id"] == device_index device_frames = df[idx] memory_clock_rate = device_info[device_type][device_index]["memory_clock_rate"] # in khz bus_width = device_info[device_type][device_index]["bus_width"] # in bits @@ -77,71 +83,131 @@ def get_min_time_bytes(df, device_info): FactorDict = namedtuple("FactorDict", ["name", "factor"]) time_factor_dict = FactorDict("time", {"time/s": 1, "time/ms": 1e-3, "time/us": 1e-6, "time/ns": 1e-9}) -flops_factor_dict = FactorDict("flops", {"flop/s": 1, "gflop/s": 1e9, "tflop/s": 1e12}) +avg_time_factor_dict = FactorDict("avg_time", {f"avg_{key}": value for key, value in time_factor_dict.factor.items()}) bytes_factor_dict = FactorDict("bytes", {"byte/s": 1, "gbyte/s": 1e9, "tbyte/s": 1e12}) derivable_metrics = { - **{key: flops_factor_dict - for key in flops_factor_dict.factor.keys()}, **{key: bytes_factor_dict for key in bytes_factor_dict.factor.keys()}, } +# FLOPS have a specific width to their metric +default_flop_factor_dict = {f"flop/s": 1, f"gflop/s": 1e9, f"tflop/s": 1e12} +derivable_metrics.update( + {key: FactorDict("flops", default_flop_factor_dict) + for key in default_flop_factor_dict.keys()}) +for width in TritonHook.flops_width: + factor_name = f"flops{width}" + factor_dict = {f"flop{width}/s": 1, f"gflop{width}/s": 1e9, f"tflop{width}/s": 1e12} + derivable_metrics.update({key: FactorDict(factor_name, factor_dict) for key in factor_dict.keys()}) + def derive_metrics(gf, metrics, raw_metrics, device_info): derived_metrics = [] - original_metrics = [] - time_metric_name = match_available_metrics([time_factor_dict.name], raw_metrics)[0] - time_unit = (time_factor_dict.name + "/" + time_metric_name.split("(")[1].split(")")[0]) + internal_frame_indices = gf.dataframe["device_id"].isna() + + def get_time_seconds(df): + time_metric_name = match_available_metrics([time_factor_dict.name], raw_metrics)[0] + time_unit = (time_factor_dict.name + "/" + time_metric_name.split("(")[1].split(")")[0]) + return df[time_metric_name] * time_factor_dict.factor[time_unit] + for metric in metrics: if metric == "util": # Tensor core only min_time_bytes = get_min_time_bytes(gf.dataframe, device_info) min_time_flops = get_min_time_flops(gf.dataframe, device_info) - time_sec = gf.dataframe[time_metric_name] * (time_factor_dict.factor[time_unit] / - time_factor_dict.factor["time/s"]) + time_sec = get_time_seconds(gf.dataframe) gf.dataframe["util (inc)"] = min_time_flops["min_time"].combine(min_time_bytes["min_time"], max) / time_sec + gf.dataframe.loc[internal_frame_indices, "util (inc)"] = np.nan derived_metrics.append("util (inc)") - elif metric in derivable_metrics: - deriveable_metric = derivable_metrics[metric] - metric_name = deriveable_metric.name - metric_factor_dict = deriveable_metric.factor + elif metric in derivable_metrics: # flop/s, byte/s + derivable_metric = derivable_metrics[metric] + metric_name = derivable_metric.name + metric_factor_dict = derivable_metric.factor matched_metric_name = match_available_metrics([metric_name], raw_metrics)[0] - gf.dataframe[f"{metric} (inc)"] = (gf.dataframe[matched_metric_name] / - (gf.dataframe[time_metric_name] * time_factor_dict.factor[time_unit]) / + gf.dataframe[f"{metric} (inc)"] = (gf.dataframe[matched_metric_name] / (get_time_seconds(gf.dataframe)) / metric_factor_dict[metric]) derived_metrics.append(f"{metric} (inc)") elif metric in time_factor_dict.factor: metric_time_unit = time_factor_dict.name + "/" + metric.split("/")[1] - gf.dataframe[f"{metric} (inc)"] = gf.dataframe[time_metric_name] * ( - time_factor_dict.factor[time_unit] / time_factor_dict.factor[metric_time_unit]) + gf.dataframe[f"{metric} (inc)"] = (get_time_seconds(gf.dataframe) / + time_factor_dict.factor[metric_time_unit]) + derived_metrics.append(f"{metric} (inc)") + elif metric in avg_time_factor_dict.factor: + metric_time_unit = avg_time_factor_dict.name + "/" + metric.split("/")[1] + gf.dataframe[f"{metric} (inc)"] = (get_time_seconds(gf.dataframe) / gf.dataframe['count'] / + avg_time_factor_dict.factor[metric_time_unit]) + gf.dataframe.loc[internal_frame_indices, f"{metric} (inc)"] = np.nan derived_metrics.append(f"{metric} (inc)") else: - original_metrics.append(metric) + metric_name_and_unit = metric.split("/") + metric_name = metric_name_and_unit[0] + if len(metric_name_and_unit) > 1: + metric_unit = metric_name_and_unit[1] + if metric_unit != "%": + raise ValueError(f"Unsupported unit {metric_unit}") + matched_metric_name = match_available_metrics([metric_name], raw_metrics)[0] + single_frame = gf.dataframe[matched_metric_name] + total = gf.dataframe[matched_metric_name].iloc[0] + gf.dataframe[f"{metric_name}/% (inc)"] = (single_frame / total) * 100.0 + derived_metrics.append(f"{metric_name}/% (inc)") + else: + matched_metric_name = match_available_metrics([metric_name], raw_metrics)[0] + derived_metrics.append(matched_metric_name) + return derived_metrics - if original_metrics: - original_metrics = match_available_metrics(original_metrics, raw_metrics) - return derived_metrics + original_metrics + +def format_frames(gf, format): + if format == "file_function_line": + gf.dataframe["name"] = gf.dataframe["name"].apply(lambda x: x.split("/")[-1]) + elif format == "function_line": + gf.dataframe["name"] = gf.dataframe["name"].apply(lambda x: x.split(":")[-1]) + elif format == "file_function": + gf.dataframe["name"] = gf.dataframe["name"].apply(lambda x: x.split("/")[-1].split("@")[0]) + return gf + + +def filter_frames(gf, include=None, exclude=None, threshold=None, metric=None): + if include: + query = f""" +MATCH ("*")->(".", p)->("*") +WHERE p."name" =~ "{include}" +""" + gf = gf.filter(query, squash=True) + if exclude: + inclusion_query = f""" +MATCH (".", p)->("*") +WHERE p."name" =~ "{exclude}" +""" + query = NegationQuery(inclusion_query) + gf = gf.filter(query, squash=True) + # filter out metadata computation + query = [{"name": f"^(?!{COMPUTE_METADATA_SCOPE_NAME}).*"}] + gf = gf.filter(query, squash=True) + if threshold: + query = ["*", {metric: f">= {threshold}"}] + gf = gf.filter(query, squash=True) + return gf -def parse(metrics, filename, include, exclude, threshold, depth): +def parse(metrics, filename, include=None, exclude=None, threshold=None, depth=100, format=None): with open(filename, "r") as f: gf, raw_metrics, device_info = get_raw_metrics(f) + gf = format_frames(gf, format) assert len(raw_metrics) > 0, "No metrics found in the input file" gf.update_inclusive_columns() metrics = derive_metrics(gf, metrics, raw_metrics, device_info) - if include or exclude: - # make regex do negative match - name_filter = f"^(?!{exclude}).*" if exclude else include - query = ["*", {"name": name_filter}] - gf = gf.filter(query, squash=True) - # filter out metadata computation - query = [{"name": f"^(?!{COMPUTE_METADATA_SCOPE_NAME}).*"}] - gf = gf.filter(query, squash=True) - if threshold: - # TODO: generalize to support multiple metrics - query = ["*", {metrics[0]: f">= {threshold}"}] - gf = gf.filter(query, squash=True) + # TODO: generalize to support multiple metrics, not just the first one + gf = filter_frames(gf, include, exclude, threshold, metrics[0]) print(gf.tree(metric_column=metrics, expand_name=True, depth=depth, render_header=False)) + emit_warnings(gf, metrics) + + +def emit_warnings(gf, metrics): + if "bytes (inc)" in metrics: + byte_values = gf.dataframe["bytes (inc)"].values + min_byte_value = np.nanmin(byte_values) + if min_byte_value < 0: + print("Warning: Negative byte values detected, this is usually the result of a datatype overflow\n") def show_metrics(file_name): @@ -152,7 +218,6 @@ def show_metrics(file_name): for raw_metric in raw_metrics: raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() print(f"- {raw_metric_no_unit}") - return def main(): @@ -167,9 +232,11 @@ def main(): help="""List available metrics. Metric names are case insensitive and ignore units. Derived metrics can be created when source metrics are available. - time/s, time/ms, time/us, time/ns: time -- flop/s, gflop/s, tflop/s: flops / time +- avg_time/s, avg_time/ms, avg_time/us, avg_time/ns: time / count +- flop[<8/16/32/64>]/s, gflop[<8/16/32/64>]/s, tflop[<8/16/32/64>]/s: flops / time - byte/s, gbyte/s, tbyte/s: bytes / time -- util: max(sum(flops) / peak_flops_time, bytes / peak_bandwidth_time)) +- util: max(sum(flops) / peak_flops_time, sum(bytes) / peak_bandwidth_time) +- /%%: frame(metric) / sum(metric). Only availble for inclusive metrics (e.g. time) """, ) argparser.add_argument( @@ -188,16 +255,26 @@ def main(): "--include", type=str, default=None, - help="Include frames(kernels) that match the given regular expression", + help= + """Find frames that match the given regular expression and return all nodes in the paths that pass through the matching frames. +For example, the following command will display all paths that contain frames that contains "test": +``` +proton-viewer -i ".*test.*" path/to/file.json +``` +""", ) argparser.add_argument( "-e", "--exclude", type=str, default=None, - help="Exclude frames(kernels) that match the given regular expression", + help="""Exclude frames that match the given regular expression and their children. +For example, the following command will exclude all paths that contain frames that contains "test": +``` +proton-viewer -e ".*test.*" path/to/file.json +``` +""", ) - argparser.add_argument( "-t", "--threshold", @@ -206,7 +283,6 @@ def main(): help= "Exclude frames(kernels) whose metrics are below the given threshold. This filter only applies on the first metric.", ) - argparser.add_argument( "-d", "--depth", @@ -214,6 +290,14 @@ def main(): default=100, help="The depth of the tree to display", ) + argparser.add_argument( + "-f", "--format", type=str, choices=["full", "file_function_line", "function_line", "file_function"], + default="full", help="""Formatting the frame name. +- full: include the path, file name, function name and line number. +- file_function_line: include the file name, function name and line number. +- function_line: include the function name and line number. +- file_function: include the file name and function name. +""") args, target_args = argparser.parse_known_args() assert len(target_args) == 1, "Must specify a file to read" @@ -224,12 +308,13 @@ def main(): exclude = args.exclude threshold = args.threshold depth = args.depth + format = args.format if include and exclude: raise ValueError("Cannot specify both include and exclude") if args.list: show_metrics(file_name) elif metrics: - parse(metrics, file_name, include, exclude, threshold, depth) + parse(metrics, file_name, include, exclude, threshold, depth, format) if __name__ == "__main__": diff --git a/third_party/proton/test/example_cuda.json b/third_party/proton/test/example_cuda.json index 9e148ff79..445f0e224 100644 --- a/third_party/proton/test/example_cuda.json +++ b/third_party/proton/test/example_cuda.json @@ -8,10 +8,10 @@ "type": "function" }, "metrics": { - "Count": 1, - "DeviceId": "1", - "DeviceType": "CUDA", - "Time (ns)": 204800, + "count": 10, + "device_id": "1", + "device_type": "CUDA", + "time (ns)": 204800, "flops8": 1e11, "bytes": 1e8 } @@ -23,10 +23,10 @@ "type": "function" }, "metrics": { - "Count": 1, - "DeviceId": "0", - "DeviceType": "CUDA", - "Time (ns)": 204800, + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 204800, "flops8": 1e10, "bytes": 1e7 } @@ -37,8 +37,8 @@ "type": "function" }, "metrics": { - "Count": 0, - "Time (ns)": 0, + "count": 0, + "time (ns)": 0, "flops8": 0, "bytes": 0 } diff --git a/third_party/proton/test/example_frame.json b/third_party/proton/test/example_frame.json new file mode 100644 index 000000000..0069476fb --- /dev/null +++ b/third_party/proton/test/example_frame.json @@ -0,0 +1,58 @@ +[ + { + "children": [ + { + "children": [ + { + "children": [], + "frame": { + "name": "/home/user/projects/example.py/test.py:foo@1", + "type": "function" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800 + } + } + ], + "frame": { + "name": "test0" + }, + "metrics": {} + }, + { + "children": [], + "frame": { + "name": "test1" + }, + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800 + } + } + ], + "frame": { + "name": "ROOT", + "type": "function" + }, + "metrics": { + "count": 0, + "time (ns)": 0 + } + }, + { + "HIP": { + "0": { + "arch": "gfx90a", + "bus_width": 4096, + "clock_rate": 1700000, + "memory_clock_rate": 1600000, + "num_sms": 104 + } + } + } +] diff --git a/third_party/proton/test/example_hip.json b/third_party/proton/test/example_hip.json index 2fcfad3c5..68538706c 100644 --- a/third_party/proton/test/example_hip.json +++ b/third_party/proton/test/example_hip.json @@ -8,10 +8,10 @@ "type": "function" }, "metrics": { - "Count": 1, - "DeviceId": "1", - "DeviceType": "HIP", - "Time (ns)": 204800, + "count": 1, + "device_id": "1", + "device_type": "HIP", + "time (ns)": 204800, "flops8": 1e11, "bytes": 1e8 } @@ -23,10 +23,10 @@ "type": "function" }, "metrics": { - "Count": 1, - "DeviceId": "0", - "DeviceType": "HIP", - "Time (ns)": 204800, + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800, "flops8": 1e10, "bytes": 1e7 } @@ -37,8 +37,8 @@ "type": "function" }, "metrics": { - "Count": 0, - "Time (ns)": 0, + "count": 0, + "time (ns)": 0, "flops8": 0, "bytes": 0 } diff --git a/third_party/proton/test/helper.py b/third_party/proton/test/helper.py index 7cfc7a452..4591aeb54 100644 --- a/third_party/proton/test/helper.py +++ b/third_party/proton/test/helper.py @@ -1,14 +1,9 @@ -import triton import triton.profiler as proton -import triton.language as tl import torch import sys - -@triton.jit -def custom_add(a_ptr): - tl.store(a_ptr, 1.0) +from helper_kernels import custom_add def main(): diff --git a/third_party/proton/test/helper_kernels.py b/third_party/proton/test/helper_kernels.py new file mode 100644 index 000000000..7a128dbac --- /dev/null +++ b/third_party/proton/test/helper_kernels.py @@ -0,0 +1,7 @@ +import triton.language as tl +import triton + + +@triton.jit +def custom_add(a_ptr): + tl.store(a_ptr, 1.0) diff --git a/third_party/proton/test/test_api.py b/third_party/proton/test/test_api.py index 7f1d20dc8..713572c4f 100644 --- a/third_party/proton/test/test_api.py +++ b/third_party/proton/test/test_api.py @@ -146,3 +146,28 @@ def foo(): assert child["metrics"]["a"] == 1.0 elif child["frame"]["name"] == "test0": assert child["metrics"]["a"] == "1" + + +def test_throw(): + # Catch an exception thrown by c++ + session_id = 100 + with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: + activate_error = "" + try: + session_id = proton.start(f.name.split(".")[0]) + proton.activate(session_id + 1) + except Exception as e: + activate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in activate_error + + deactivate_error = "" + try: + session_id = proton.start(f.name.split(".")[0]) + proton.deactivate(session_id + 1) + except Exception as e: + deactivate_error = str(e) + finally: + proton.finalize() + assert "Session has not been initialized: " + str(session_id + 1) in deactivate_error diff --git a/third_party/proton/test/test_cmd.py b/third_party/proton/test/test_cmd.py index 333557013..fa3331c02 100644 --- a/third_party/proton/test/test_cmd.py +++ b/third_party/proton/test/test_cmd.py @@ -22,9 +22,10 @@ def test_exec(mode): ret = subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) elif mode == "pytest": - ret = subprocess.check_call(["proton", "-n", name, "pytest", helper_file], stdout=subprocess.DEVNULL) + ret = subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], + stdout=subprocess.DEVNULL) assert ret == 0 data = json.load(f, ) kernels = data[0]["children"] assert len(kernels) == 2 - assert kernels[1]["frame"]["name"] == "test" + assert kernels[0]["frame"]["name"] == "test" or kernels[1]["frame"]["name"] == "test" diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index a64c6ab8b..13cb9bd99 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -9,6 +9,10 @@ import triton.language as tl +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + @pytest.mark.parametrize("context", ["shadow", "python"]) def test_torch(context): with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: @@ -21,6 +25,7 @@ def test_torch(context): if context == "shadow": assert len(data[0]["children"]) == 1 assert data[0]["children"][0]["frame"]["name"] == "test" + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 elif context == "python": assert len(data[0]["children"]) == 1 # The last frame is the torch kernel @@ -56,6 +61,59 @@ def foo(x, y): assert data[0]["children"][1]["frame"]["name"] == "test2" +def test_cudagraph(): + stream = torch.cuda.Stream() + torch.cuda.set_stream(stream) + + @triton.jit + def foo(x, y, z): + tl.store(z, tl.load(y) + tl.load(x)) + + def fn(): + a = torch.ones((2, 2), device="cuda") + b = torch.ones((2, 2), device="cuda") + c = a + b + foo[(1, )](a, b, c) + + with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: + proton.start(f.name.split(".")[0], context="shadow") + + # warmup + # four kernels + fn() + + # no kernels + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(10): + fn() + + proton.enter_scope("test") + g.replay() + g.reset() + torch.cuda.synchronize() + proton.exit_scope() + proton.finalize() + + data = json.load(f) + # CUDA/HIP graph may also invoke additional kernels to reset outputs + # {torch.ones, add, foo, test} + assert len(data[0]["children"]) >= 4 + # find the test frame + test_frame = None + for child in data[0]["children"]: + if child["frame"]["name"] == "test": + test_frame = child + break + assert test_frame is not None + # {torch.ones, add, foo} + if is_hip(): + assert len(test_frame["children"]) >= 2 + else: + assert len(test_frame["children"]) >= 3 + assert test_frame["children"][0]["metrics"]["time (ns)"] > 0 + + def test_metrics(): @triton.jit @@ -139,3 +197,54 @@ def foo(x, size: tl.constexpr, y): assert data[0]["children"][0]["frame"]["name"] == "test0" assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + + +def test_pcsampling(): + if is_hip(): + pytest.skip("HIP backend does not support pc sampling") + + import os + if os.environ.get("PROTON_SKIP_PC_SAMPLING_TEST", "0") == "1": + pytest.skip("PC sampling test is disabled") + + @triton.jit + def foo(x, y, size: tl.constexpr): + offs = tl.arange(0, size) + for _ in range(1000): + tl.store(y + offs, tl.load(x + offs)) + + with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: + proton.start(f.name.split(".")[0], hook="triton", backend="cupti_pcsampling") + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + with proton.scope("test"): + foo[(1, )](x, y, x.size()[0], num_warps=4) + proton.finalize() + data = json.load(f) + init_frame = data[0]["children"][0] + test_frame = data[0]["children"][1] + # With line mapping + assert "foo" in test_frame["children"][0]["frame"]["name"] + assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 + assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] + # Without line mapping + assert "elementwise" in init_frame["children"][0]["frame"]["name"] + assert init_frame["children"][0]["metrics"]["num_samples"] > 0 + + +def test_deactivate(): + with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: + session_id = proton.start(f.name.split(".")[0], hook="triton") + proton.deactivate(session_id) + torch.randn((10, 10), device="cuda") + proton.activate(session_id) + torch.zeros((10, 10), device="cuda") + proton.deactivate(session_id) + proton.finalize() + data = json.load(f) + # Root shouldn't have device id + assert "device_id" not in data[0]["metrics"] + assert len(data[0]["children"]) == 1 + assert "device_id" in data[0]["children"][0]["metrics"] diff --git a/third_party/proton/test/test_viewer.py b/third_party/proton/test/test_viewer.py index 63a74b06c..b2d4d39f9 100644 --- a/third_party/proton/test/test_viewer.py +++ b/third_party/proton/test/test_viewer.py @@ -1,10 +1,12 @@ +import pytest import subprocess -from triton.profiler.viewer import get_min_time_flops, get_min_time_bytes, get_raw_metrics +from triton.profiler.viewer import get_min_time_flops, get_min_time_bytes, get_raw_metrics, format_frames, derive_metrics, filter_frames import numpy as np file_path = __file__ cuda_example_file = file_path.replace("test_viewer.py", "example_cuda.json") hip_example_file = file_path.replace("test_viewer.py", "example_hip.json") +frame_example_file = file_path.replace("test_viewer.py", "example_frame.json") def test_help(): @@ -13,12 +15,45 @@ def test_help(): assert ret == 0 +@pytest.mark.parametrize("option", ["full", "file_function_line", "function_line", "file_function"]) +def test_format_frames(option): + with open(frame_example_file, "r") as f: + gf, _, _ = get_raw_metrics(f) + gf = format_frames(gf, option) + if option == "full": + idx = gf.dataframe["name"] == "/home/user/projects/example.py/test.py:foo@1" + elif option == "file_function_line": + idx = gf.dataframe["name"] == "test.py:foo@1" + elif option == "function_line": + idx = gf.dataframe["name"] == "foo@1" + elif option == "file_function": + idx = gf.dataframe["name"] == "test.py:foo" + assert idx.sum() == 1 + + +@pytest.mark.parametrize("option", ["include", "exclude"]) +def test_filter_frames(option): + include = "" + exclude = "" + with open(frame_example_file, "r") as f: + gf, _, _ = get_raw_metrics(f) + if option == "include": + include = ".*test0.*" + elif option == "exclude": + exclude = ".*test1.*" + gf = filter_frames(gf, include=include, exclude=exclude) + idx = gf.dataframe["name"] == "test1" + assert idx.sum() == 0 + idx = gf.dataframe["name"] == "test0" + assert idx.sum() == 1 + + def test_min_time_flops(): with open(cuda_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_flops(gf.dataframe, device_info) - device0_idx = gf.dataframe["DeviceId"] == "0" - device1_idx = gf.dataframe["DeviceId"] == "1" + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" # sm89 np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000025]], atol=1e-5) # sm90 @@ -26,8 +61,8 @@ def test_min_time_flops(): with open(hip_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_flops(gf.dataframe, device_info) - device0_idx = gf.dataframe["DeviceId"] == "0" - device1_idx = gf.dataframe["DeviceId"] == "1" + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" # MI200 np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000026]], atol=1e-5) # MI300 @@ -38,8 +73,8 @@ def test_min_time_bytes(): with open(cuda_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_bytes(gf.dataframe, device_info) - device0_idx = gf.dataframe["DeviceId"] == "0" - device1_idx = gf.dataframe["DeviceId"] == "1" + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" # sm89 np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[9.91969e-06]], atol=1e-6) # sm90 @@ -47,9 +82,64 @@ def test_min_time_bytes(): with open(hip_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_bytes(gf.dataframe, device_info) - device0_idx = gf.dataframe["DeviceId"] == "0" - device1_idx = gf.dataframe["DeviceId"] == "1" + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" # MI200 np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[6.10351e-06]], atol=1e-6) # MI300 np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[1.93378e-05]], atol=1e-6) + + +def derivation_metrics_test(metrics, expected_data, sample_file, rtol=1e-7, atol=1e-6): + with open(sample_file, "r") as f: + gf, raw_metrics, device_info = get_raw_metrics(f) + assert len(raw_metrics) > 0, "No metrics found in the input file" + gf.update_inclusive_columns() + derived_metrics = derive_metrics(gf, metrics, raw_metrics, device_info) + for derived_metric in derived_metrics: + np.testing.assert_allclose(gf.dataframe[derived_metric].to_numpy(), expected_data[derived_metric], + rtol=rtol, atol=atol) + + +def test_avg_time_derivation(): + derivation_metrics_test( + metrics=["avg_time/s", "avg_time/ms", "avg_time/us", "avg_time/ns"], expected_data={ + 'avg_time/s (inc)': [np.nan, 0.0000205, 0.000205], 'avg_time/ms (inc)': [np.nan, 0.02048, 0.2048], + 'avg_time/us (inc)': [np.nan, 20.48, 204.8], 'avg_time/ns (inc)': [np.nan, 20480.0, 204800.0] + }, sample_file=cuda_example_file) + + +def test_util(): + derivation_metrics_test(metrics=["util"], expected_data={ + 'util (inc)': [np.nan, 0.247044, 0.147830], + }, sample_file=cuda_example_file) + + +def test_time_derivation(): + derivation_metrics_test( + metrics=["time/s", "time/ms", "time/us", "time/ns"], expected_data={ + 'time/s (inc)': [0.0004096, 0.0002048, 0.0002048], + 'time/ms (inc)': [0.4096, 0.2048, 0.2048], + 'time/us (inc)': [409.6, 204.8, 204.8], + 'time/ns (inc)': [409600.0, 204800.0, 204800.0], + 'time/% (inc)': [100.0, 50.0, 50.0], + }, sample_file=cuda_example_file) + + +def test_bytes_derivation(): + derivation_metrics_test( + metrics=["byte/s", "gbyte/s", "tbyte/s"], expected_data={ + 'byte/s (inc)': [2.68554687e+11, 4.88281250e+11, 4.88281250e+10], 'gbyte/s (inc)': + [268.5546875, 488.28125, 48.828125], 'tbyte/s (inc)': [0.26855469, 0.48828125, 0.04882812] + }, sample_file=cuda_example_file) + + +def test_flops_derivation(): + derivation_metrics_test( + metrics=["flop8/s", "gflop8/s", "tflop8/s"], + expected_data={ + 'flop8/s (inc)': [2.68554687e+14, 4.88281250e+14, 4.88281250e+13], 'gflop8/s (inc)': + [268554.6875, 488281.25, 48828.125], 'tflop8/s (inc)': [268.554687, 488.28125, 48.828125] + }, + sample_file=cuda_example_file, + ) diff --git a/third_party/proton/tutorials/dynamic_net.py b/third_party/proton/tutorials/dynamic_net.py index a1a82b53e..5793bebd0 100644 --- a/third_party/proton/tutorials/dynamic_net.py +++ b/third_party/proton/tutorials/dynamic_net.py @@ -85,13 +85,14 @@ def run(): argparser.add_argument("--profile", action="store_true") argparser.add_argument("--mode", default="torch", choices=["torch", "torchinductor"]) argparser.add_argument("--context", default="shadow", choices=["shadow", "python"]) +argparser.add_argument("--backend", default=None, choices=["cupti", "roctracer", "cupti_pcsampling"]) args = argparser.parse_args() mode = args.mode if args.profile: - func = proton.profile(run, name="dynamic_net", context=args.context) + func = proton.profile(run, name="dynamic_net", context=args.context, backend=args.backend) else: func = run diff --git a/third_party/proton/tutorials/matmul.py b/third_party/proton/tutorials/matmul.py index 1b5424af4..5ee8d9f3e 100644 --- a/third_party/proton/tutorials/matmul.py +++ b/third_party/proton/tutorials/matmul.py @@ -26,32 +26,13 @@ def metadata_fn( num_stages = metadata.num_stages cluster_x, cluster_y, cluster_z = metadata.cluster_dims shared_memory = metadata.shared + M, K = args["a_ptr"].shape + K, N = args["b_ptr"].shape return { "name": - f"matmul_____" - } - - -def matmul_metrics_fn( - grid_x: int, - grid_y: int, - grid_z: int, - num_warps: int, - num_ctas: int, - cluster_x: int, - cluster_y: int, - cluster_z: int, - shared_memory: int, - stream: int, - function: int, - metadata, - *args, -): - M, K = args[0].shape - K, N = args[1].shape - return { + f"matmul_____", "flops": 2 * M * N * K, - "bytes": (M * N + N * K + K * M) * args[0].element_size(), + "bytes": (M * N + N * K + K * M) * args["a_ptr"].element_size(), } @@ -255,6 +236,12 @@ def grid(META): return c +argparser = argparse.ArgumentParser() +argparser.add_argument("--profile", action="store_true") +argparser.add_argument("--cudagraph", action="store_true", default=False) +args = argparser.parse_args() + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot @@ -287,7 +274,11 @@ def benchmark(M, N, K, provider): def cublas_matmul(a, b): torch.matmul(a, b) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: cublas_matmul(a, b), quantiles=quantiles) + if args.cudagraph: + ms = triton.testing.do_bench_cudagraph(lambda: cublas_matmul(a, b)) + min_ms = max_ms = ms + else: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: cublas_matmul(a, b), quantiles=quantiles) if provider == "triton": def enter_autotune(args, reset_only=False): @@ -301,7 +292,11 @@ def exit_autotune(args, exception): matmul_kernel.pre_hook = enter_autotune matmul_kernel.post_hook = exit_autotune with proton.scope("triton"): - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + if args.cudagraph: + ms = triton.testing.do_bench_cudagraph(lambda: matmul(a, b)) + min_ms = max_ms = ms + else: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) def perf(ms): return 2 * M * N * K * 1e-12 / (ms * 1e-3) @@ -309,10 +304,6 @@ def perf(ms): return perf(ms), perf(max_ms), perf(min_ms) -argparser = argparse.ArgumentParser() -argparser.add_argument("--profile", action="store_true") -args = argparser.parse_args() - if args.profile: proton.start("matmul", hook="triton") benchmark.run(show_plots=True, print_data=True) diff --git a/third_party/xpu/CMakeLists.txt b/third_party/xpu/CMakeLists.txt new file mode 100644 index 000000000..fdb9700b0 --- /dev/null +++ b/third_party/xpu/CMakeLists.txt @@ -0,0 +1,10 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) # For #include "Dialect/TritonXPU/IR/Dialect.h" +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) + +add_subdirectory(device) +add_subdirectory(include) +add_subdirectory(lib) + +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonXPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_xpu.cc) #LINK_LIBS TritonToTritonXPU TritonXPUToLLVM) +endif() diff --git a/third_party/xpu/__init__.py b/third_party/xpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/xpu/backend/__init__.py b/third_party/xpu/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/xpu/backend/_inductor_201/__init__.py b/third_party/xpu/backend/_inductor_201/__init__.py new file mode 100644 index 000000000..6826bf310 --- /dev/null +++ b/third_party/xpu/backend/_inductor_201/__init__.py @@ -0,0 +1,2 @@ +from . import codegen +from . import triton_ops diff --git a/third_party/xpu/backend/_inductor_201/codegen/__init__.py b/third_party/xpu/backend/_inductor_201/codegen/__init__.py new file mode 100644 index 000000000..ff2e2c1c3 --- /dev/null +++ b/third_party/xpu/backend/_inductor_201/codegen/__init__.py @@ -0,0 +1 @@ +from . import triton diff --git a/third_party/xpu/backend/_inductor_201/codegen/triton.py b/third_party/xpu/backend/_inductor_201/codegen/triton.py new file mode 100644 index 000000000..3751467fa --- /dev/null +++ b/third_party/xpu/backend/_inductor_201/codegen/triton.py @@ -0,0 +1,407 @@ +# Reuse Across Files +import os +from torch._inductor.virtualized import V +from torch._inductor.codegen.common import ( + IndentedBuffer, + SizeArg, + PythonPrinter, +) + +from torch._dynamo import config as dynamo_config + +# Reuse Within The Same File +from torch._inductor.codegen.triton import ( + signature_of, + config_of, +) +from torch._inductor.codegen import triton + + +# ===-------------------- For XPytorch Inductor -----------------------=== +# Modifed Base Pytorch(v2.0.1) torch/_inductor/codegen/triton.py::TritonPrinter +class XPUTritonPrinter(triton.TritonPrinter): + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return f"libdevice.floor({self.paren(self._print(expr.args[0]))})" + + +triton.TritonPrinter = XPUTritonPrinter + +# ===------------------------------------------------------------------=== + + +# ===-------------------- For XPytorch Inductor -----------------------=== +# Modifed Base Pytorch(v2.0.1) torch/_inductor/codegen/triton.py::TritonOverrides +class XPUTritonOverrides(triton.TritonOverrides): + """Map element-wise ops to Triton""" + + @staticmethod + def libdevice_abs(x): + return f"libdevice.abs({x})" + + @staticmethod + def libdevice_exp(x): + return f"libdevice.exp({x})" + + @staticmethod + def exp2(x): + return f"libdevice.exp2({x})" + + @staticmethod + def expm1(x): + return f"libdevice.expm1({x})" + + @staticmethod + def libdevice_sqrt(x): + return f"libdevice.sqrt({x})" + + @staticmethod + def libdevice_cos(x): + return f"libdevice.cos({x})" + + @staticmethod + def libdevice_sin(x): + return f"libdevice.sin({x})" + + @staticmethod + def lgamma(x): + return f"libdevice.lgamma({x})" + + @staticmethod + def erf(x): + return f"libdevice.erf({x})" + + @staticmethod + def cosh(x): + return f"libdevice.cosh({x})" + + @staticmethod + def sinh(x): + return f"libdevice.sinh({x})" + + @staticmethod + def acos(x): + return f"libdevice.acos({x})" + + @staticmethod + def acosh(x): + return f"libdevice.acosh({x})" + + @staticmethod + def asin(x): + return f"libdevice.asin({x})" + + @staticmethod + def asinh(x): + return f"libdevice.asinh({x})" + + @staticmethod + def atan2(x, y): + return f"libdevice.atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"libdevice.atan({x})" + + @staticmethod + def atanh(x): + return f"libdevice.atanh({x})" + + @staticmethod + def copysign(x, y): + return f"libdevice.copysign({x}, {y})" + + @staticmethod + def erfc(x): + return f"libdevice.erfc({x})" + + @staticmethod + def hypot(x, y): + return f"libdevice.hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"libdevice.log10({x})" + + @staticmethod + def nextafter(x, y): + return f"libdevice.nextafter({x}, {y})" + + @staticmethod + def rsqrt(x): + return f"libdevice.rsqrt({x})" + + @staticmethod + def log1p(x): + return f"libdevice.log1p({x})" + + @staticmethod + def tan(x): + return f"libdevice.tan({x})" + + @staticmethod + def tanh(x): + return f"libdevice.tanh({x})" + + @staticmethod + def libdevice_sigmoid(x): + return f"1/(1 + libdevice.exp(-({x})))" + + @staticmethod + def signbit(x): + # XX: This is wrong for the value -0.0 in floating point + return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0" + + @staticmethod + def fmod(a, b): + return f"libdevice.fmod({a}, {b})" + + @staticmethod + def pow(a, b): + return f"libdevice.pow({a}, {b})" + + @staticmethod + def libdevice_log(x): + return f"libdevice.log({x})" + + @staticmethod + def isinf(x): + return f"libdevice.isinf({x})" + + @staticmethod + def isnan(x): + return f"libdevice.isnan({x})" + + @staticmethod + def round(x): + return f"libdevice.nearbyint({x})" + + @staticmethod + def floor(x): + return f"libdevice.floor({x})" + + @staticmethod + def trunc(x): + return f"libdevice.trunc({x})" + + @staticmethod + def ceil(x): + return f"libdevice.ceil({x})" + + +triton.TritonOverrides = XPUTritonOverrides +# ===------------------------------------------------------------------=== + +# ===-------------------- For XPytorch Inductor -----------------------=== +# Modifed Base Pytorch(v2.0.1) torch/_inductor/codegen/triton.py::TritonKernel +from xpu.backend.driver import get_xpu_spec +import math + +pexpr = PythonPrinter().doprint +xpu_hasAtomic = False + + +class XPUTritonKernel(triton.TritonKernel): + overrides = triton.TritonOverrides + sexpr = pexpr + + def store(self, name, index, value, mode=None): + var = self.args.output(name) + index, mask_vars, mask = self.indexing(index, dense_indexing=True) + if mode is None: + line = f"tl.store({var} + ({index}), {value}, {mask})" + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({index}), {value}, {mask})" + global xpu_hasAtomic + xpu_hasAtomic = True + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(name, line) + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + def codegen_kernel(self, name=None): + from triton import next_power_of_2 + + # ===-------------------- For Triton XPU -----------------------=== + from xpu.backend.driver import get_xpu_spec + arch = int(os.environ.get('TRITON_XPU_ARCH', '3')) + cluster_num = get_xpu_spec(arch)[0] + core_num_per_cluster = get_xpu_spec(arch)[1] + core_num = cluster_num * core_num_per_cluster + + def next_multiply_of_num(n: int, m: int) -> int: + res = math.ceil(n / m) * m + return res + + def get_xpu_1d_hint(numel_0: int, cluster_num: int) -> int: + size_hint_0 = math.ceil(numel_0 / cluster_num) + size_hint_0 = next_power_of_2(size_hint_0) + size_hint_0 = size_hint_0 * cluster_num + return size_hint_0 + + def get_xpu_2d_hint(numel_0: int, numel_1: int, cluster_num: int, core_num: int) -> int: + if numel_0 < core_num: + size_hint_0 = math.ceil(numel_0 / cluster_num) + size_hint_0 = next_power_of_2(size_hint_0) + size_hint_0 = size_hint_0 * cluster_num + else: + size_hint_0 = next_multiply_of_num(numel_0, core_num) + size_hint_1 = next_power_of_2(numel_1) + return [size_hint_0, size_hint_1] + + code = IndentedBuffer() + # size_hints = [ + # next_power_of_2(V.graph.sizevars.size_hint(numel)) for numel in self.numels + # ] + # vvv + numel_0 = V.graph.sizevars.size_hint(self.numels[0]) + if len(self.numels) == 1: + size_hints = [ + get_xpu_1d_hint(numel_0, cluster_num), + ] + elif len(self.numels) == 2: + numel_1 = V.graph.sizevars.size_hint(self.numels[1]) + if numel_1 != 1: + size_hints = get_xpu_2d_hint(numel_0, numel_1, cluster_num, core_num) + else: + size_hints = [ + get_xpu_1d_hint(numel_0, cluster_num), + 1, + ] + else: + raise AssertionError(f"invalid size for numels {len(self.numels)}") + # ===-----------------------------------------------------------=== + + if self.persistent_reduction: + assert self.inside_reduction + heuristics = "persistent_reduction" + elif self.inside_reduction: + heuristics = "reduction" + else: + size_hints.pop() + heuristics = "pointwise" + + if name is None: + code.splice(f""" + import triton + import triton.language as tl + from torch._inductor.ir import ReductionHint + from torch._inductor.ir import TileHint + from torch._inductor.triton_ops.autotune import {heuristics} + from torch._inductor.utils import instance_descriptor + # ===-------------------- For Triton XPU -----------------------=== + # Borrowed From Pytorch(v2.5.0-rc9) torch/_inductor/runtime/triton_helpers.py + # In the latest triton, math functions were shuffled around into different modules: + # https://github.com/openai/triton/pull/3172 + try: + from triton.language.extra import libdevice + + libdevice = tl.extra.libdevice # noqa: F811 + math = tl.math + except ImportError: + if hasattr(tl.extra, "xpu") and hasattr(tl.extra.xpu, "libdevice"): + libdevice = tl.extra.xpu.libdevice + math = tl.math + elif hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): + libdevice = tl.extra.cuda.libdevice + math = tl.math + elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): + libdevice = tl.extra.intel.libdevice + math = tl.math + else: + libdevice = tl.math + math = tl + # ===-----------------------------------------------------------=== + """) + + argdefs, _, signature = self.args.python_argdefs() + # maps actual expression to SizeArg if its in sizevars replacements + for i, arg in enumerate(signature): + if (isinstance(arg, SizeArg) and arg.expr in V.graph.sizevars.inv_precomputed_replacements): + signature[i] = SizeArg(arg.name, V.graph.sizevars.inv_precomputed_replacements[arg.expr]) + + mutated_args = set() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if mutation in self.args.inplace_buffers: + mutated_args.add(self.args.inplace_buffers[mutation].inner_name) + if mutation in self.args.output_buffers: + mutated_args.add(self.args.output_buffers[mutation]) + mutated_args = sorted(mutated_args) + + global xpu_hasAtomic + triton_meta = { + "signature": dict(enumerate(map(signature_of, signature))), + "device": V.graph.scheduler.current_device.index, + "constants": {}, + "mutated_arg_names": mutated_args, + "hasAtomic": xpu_hasAtomic, + } + xpu_hasAtomic = False # reset global var + + for tree in self.range_trees: + if tree.prefix != "r" or self.inside_reduction: + sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) + signature.append(sizearg) + triton_meta["signature"][len(argdefs)] = signature_of(sizearg) + argdefs.append(f"{tree.prefix}numel") + # constexpr version causes issues, see + # https://github.com/pytorch/torchdynamo/pull/1362 + # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( + # tree.numel + # ) + # argdefs.append(f"{tree.prefix}numel: tl.constexpr") + triton_meta["configs"] = [config_of(signature)] + + for tree in self.range_trees: + if tree.prefix != "r" or self.inside_reduction: + argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr") + + if self.inside_reduction: + reduction_hint = self.reduction_hint + heuristics_line = f""" + @{heuristics}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + meta={triton_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if len(signature) == 4: # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @{heuristics}(size_hints={size_hints!r}, {tile_hint}filename=__file__, meta={triton_meta!r}) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline(f"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):") + self.codegen_body() + with code.indent(): + if not dynamo_config.dynamic_shapes: + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if name is not None: + return code.getvalue() + + wrapper = IndentedBuffer() + wrapper.writeline("async_compile.triton('''") + wrapper.splice(code.getvalue(), strip=True) + wrapper.writeline("''')") + return wrapper.getvalue() + + +triton.TritonKernel = XPUTritonKernel + +# ===------------------------------------------------------------------=== diff --git a/third_party/xpu/backend/_inductor_201/triton_ops/__init__.py b/third_party/xpu/backend/_inductor_201/triton_ops/__init__.py new file mode 100644 index 000000000..bfea22481 --- /dev/null +++ b/third_party/xpu/backend/_inductor_201/triton_ops/__init__.py @@ -0,0 +1 @@ +from . import autotune diff --git a/third_party/xpu/backend/_inductor_201/triton_ops/autotune.py b/third_party/xpu/backend/_inductor_201/triton_ops/autotune.py new file mode 100644 index 000000000..658ec02bf --- /dev/null +++ b/third_party/xpu/backend/_inductor_201/triton_ops/autotune.py @@ -0,0 +1,387 @@ +import os +import copy +import torch +from torch._inductor.triton_ops import has_triton +from torch._inductor.utils import ceildiv + +from .runtime_utils import ( + get_first_attr, ) + +if has_triton(): + import triton + from triton import cdiv, Config, next_power_of_2 + from triton.runtime.jit import get_cuda_stream, KernelInterface + try: + from triton.compiler.compiler import ASTSource + except ImportError: + ASTSource = None + + try: + from triton.backends.compiler import GPUTarget + except ImportError: + GPUTarget = None +else: + cdiv = None + Config = object + get_cuda_stream = None + KernelInterface = object + next_power_of_2 = None + triton = None + ASTSource = None + GPUTarget = None + +from torch._inductor.triton_ops import autotune + +from xpu.backend.driver import get_xpu_spec + +arch = int(os.environ.get('TRITON_XPU_ARCH', '3')) +CLUSTER_NUM = get_xpu_spec(arch)[0] +CORE_NUM = get_xpu_spec(arch)[1] + + +# ===-------------------- For XPytorch Inductor -----------------------=== +# Base Pytorch(v2.0.1) torch/_inductor/triton_ops/autotune.py +# vvv +# Target Pytorch(v2.5.0-rc9) torch/_inductor/runtime/triton_heuristics.py +class XPUCachingAutotuner(autotune.CachingAutotuner): + + def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.meta) + for k, v in cfg.kwargs.items(): + compile_meta["constants"][self.fn.arg_names.index(k)] = v + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + + compile_meta["device_type"] = "xpu" + compile_meta["cc"] = 3 + compile_meta["debug"] = True + + if ASTSource: + compile_args = (ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + compile_meta["configs"][0], + ), ) + + cc_str = str(compile_meta["cc"]) + if "gfx10" in cc_str or "gfx11" in cc_str: + rocm_warp_size = 32 + else: + rocm_warp_size = 64 + + if GPUTarget: + target = GPUTarget( + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size if torch.version.hip else 32, + ) + else: + target = ((compile_meta["device_type"], compile_meta["cc"]) if not torch.version.hip else [ + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size, + ]) + + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + } + # if self.device_props.type != "hip": + # if "waves_per_eu" in compile_meta: + # options["waves_per_eu"] = compile_meta["waves_per_eu"] + # if "matrix_instr_nonkdim" in compile_meta: + # options["matrix_instr_nonkdim"] = compile_meta[ + # "matrix_instr_nonkdim" + # ] + compile_kwargs = { + "target": target, + "options": options, + } + else: + compile_args = (self.fn, ) + compile_kwargs = compile_meta + + if warm_cache_only_with_cc: + triton.compile(*compile_args, **compile_kwargs) + return + + # load binary to the correct device + with torch.cuda.device(compile_meta["device"]): + # need to initialize context + torch.cuda.synchronize(torch.cuda.current_device()) + binary = triton.compile(*compile_args, **compile_kwargs) + binary._init_handles() + + call_args = [arg for i, arg in enumerate(self.fn.arg_names) if i not in self.fn.constexprs] + def_args = list(self.fn.arg_names) + while def_args and def_args[-1] in cfg.kwargs: + def_args.pop() + + binary_shared = (binary.shared if hasattr(binary, "shared") else binary.metadata.shared) + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "torch": torch, + "set_device": torch.cuda.set_device, + "current_device": torch.cuda.current_device, + "metadata": binary.packed_metadata, + "launch_enter_hook": binary.launch_enter_hook, + "launch_exit_hook": binary.launch_exit_hook, + "shared": binary_shared, + } + + scope["num_warps"] = (binary.num_warps if hasattr(binary, "num_warps") else binary.metadata.num_warps) + + scope["cta_args"] = ((binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) if hasattr( + binary, "num_ctas") else ((binary.metadata.num_ctas, + *binary.metadata.cluster_dims) if hasattr(binary, "metadata") else ())) + + scope["function"] = get_first_attr(binary, "function", "cu_function") + + def get_launch_args_without_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args before CompiledKernel.launch_metadata is added. + """ + return ( + grid_0, + grid_1, + grid_2, + num_warps, + *cta_args, + shared, + stream, + function, + launch_enter_hook, + launch_exit_hook, + metadata, + ) + + # Getting the kernel launch args is extremely perf-sensitive. Evaluating + # `bin.launch_metadata` is relatively expensive, and returns None unless a + # `launch_enter_hook` is installed. So if we don't have that hook installed, + # we want to burn None in to the launch args with zero overhead. + # See https://github.com/pytorch/pytorch/issues/123597 + if binary.launch_enter_hook: + + def get_launch_args_with_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + by https://github.com/openai/triton/pull/3492 . + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin.launch_metadata(grid, stream, *args), + launch_enter_hook, + launch_exit_hook, + ) + + else: + + def get_launch_args_with_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + by https://github.com/openai/triton/pull/3492 . + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + None, + launch_enter_hook, + launch_exit_hook, + ) + + scope["get_launch_args"] = (get_launch_args_with_kernel_launch_metadata if hasattr(binary, "launch_metadata") + else get_launch_args_without_kernel_launch_metadata) + + scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + exec( + f""" + def launcher({', '.join(def_args)}, grid, stream): + if callable(grid): + grid_0, grid_1, grid_2 = grid(grid_meta) + else: + grid_0, grid_1, grid_2 = grid + + args = {', '.join(call_args)}, + launch_args = get_launch_args( + grid, grid_0, grid_1, grid_2, stream, function, + metadata, bin, launch_enter_hook, launch_exit_hook, + num_warps, shared, cta_args, args + ) + runner(*launch_args, *args) + return bin + """.lstrip(), + scope, + ) + + launcher = scope["launcher"] + launcher.config = cfg + return launcher + + +autotune.CachingAutotuner = XPUCachingAutotuner +# ===------------------------------------------------------------------=== + +# ===-------------------- For XPytorch Inductor -----------------------=== +from torch._inductor.triton_ops.autotune import cached_autotune, triton_config +from torch._inductor import config + +# Modified Pytorch(v2.0.1) torch/_inductor/triton_ops/autotune.py::reduction() && pointwise() + + +def triton_xpu_config_reduction(size_hints, x, r, num_stages=2) -> Config: + """ + Construct a reduction triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + """ + + cfg = {"XBLOCK": x, "RBLOCK": r} + num_warps = -1 # invalid value, just a placeholder + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def tritonxpu_reduction(size_hints, reduction_hint=False, meta=None, filename=None): + from torch._inductor.ir import ReductionHint + import math + """args to @triton.heuristics()""" + assert meta is not None + + # ===-------------------- For Triton XPU -----------------------=== + if bool(meta.get("hasAtomic", False)): + xnumel = math.ceil(size_hints[0] / 1) + else: + xnumel = math.ceil(size_hints[0] / CLUSTER_NUM) + # ===-----------------------------------------------------------=== + + rnumel = size_hints[-1] + if len(size_hints) == 2: + contiguous_config = triton_xpu_config_reduction(size_hints, xnumel, (rnumel if 0 < rnumel < 8192 else 8192), + num_stages=1) + buffersize_config = triton_xpu_config_reduction(size_hints, xnumel, (rnumel if 0 < rnumel < 128 else 128), + num_stages=1) + + if config.max_autotune: + pass # skip all these cases + elif reduction_hint == ReductionHint.INNER or ReductionHint.DEFAULT: + return cached_autotune(configs=[contiguous_config], meta=meta) + else: + raise NotImplementedError(f"reduction_hint: {reduction_hint}") + + raise NotImplementedError(f"size_hints: {size_hints}") + + +def tritonxpu_pointwise(size_hints, meta, tile_hint=None, filename=None): + import functools + import operator + """ + Construct @triton.heuristics() based on size_hints. + """ + # ===-------------------- For Triton XPU -----------------------=== + numel = functools.reduce(operator.mul, size_hints) + if bool(meta.get("hasAtomic", False)): + # We need to tile all data in only one cluster for atomic simulation + bs = max(CORE_NUM, numel // 1) + else: + bs = max(CORE_NUM, numel // CLUSTER_NUM) + # ===-----------------------------------------------------------=== + + if len(size_hints) == 1: + return cached_autotune([triton_config(size_hints, bs)], meta=meta) + if len(size_hints) == 2: + raise NotImplementedError(f"[Triton XPU] len(size_hints) == 2 Not Supported") + if len(size_hints) == 3: + raise NotImplementedError(f"[Triton XPU] len(size_hints) == 3 Not Supported") + raise NotImplementedError(f"size_hints: {size_hints}") + + +from torch._inductor.triton_ops import autotune + +autotune.reduction = tritonxpu_reduction +autotune.pointwise = tritonxpu_pointwise + +# ===------------------------------------------------------------------=== + + +# ===-------------------- For XPytorch Inductor -----------------------=== +def grid(xnumel, ynumel=None, znumel=None): + """Helper function to compute triton grids""" + + def get_grid_dim(numel, block_name, block): + if numel is None: + return 1 + core_nums = CLUSTER_NUM * CORE_NUM + grid_num = ceildiv(numel, block) if numel < core_nums else CLUSTER_NUM + return grid_num + + def grid_fn(meta): + return ( + get_grid_dim(xnumel, "XBLOCK", meta.get("XBLOCK", 1)), + get_grid_dim(ynumel, "YBLOCK", meta.get("YBLOCK", None)), + get_grid_dim(znumel, "ZBLOCK", meta.get("ZBLOCK", None)), + ) + + return grid_fn + + +autotune.grid = grid + +# ===------------------------------------------------------------------=== diff --git a/third_party/xpu/backend/_inductor_201/triton_ops/runtime_utils.py b/third_party/xpu/backend/_inductor_201/triton_ops/runtime_utils.py new file mode 100644 index 000000000..1b81c7d03 --- /dev/null +++ b/third_party/xpu/backend/_inductor_201/triton_ops/runtime_utils.py @@ -0,0 +1,10 @@ +# Borrowed From Pytorch(v2.5.0-rc9) torch/_inductor/runtime/runtime_utils.py +def get_first_attr(obj, *attrs): + """ + Return the first available attribute or throw an exception if none is present. + """ + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + + raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") diff --git a/third_party/xpu/backend/_inductor_210/__init__.py b/third_party/xpu/backend/_inductor_210/__init__.py new file mode 100644 index 000000000..b7c682499 --- /dev/null +++ b/third_party/xpu/backend/_inductor_210/__init__.py @@ -0,0 +1,2 @@ +from . import codegen +from . import triton_heuristics diff --git a/third_party/xpu/backend/_inductor_210/codegen/__init__.py b/third_party/xpu/backend/_inductor_210/codegen/__init__.py new file mode 100644 index 000000000..ff2e2c1c3 --- /dev/null +++ b/third_party/xpu/backend/_inductor_210/codegen/__init__.py @@ -0,0 +1 @@ +from . import triton diff --git a/third_party/xpu/backend/_inductor_210/codegen/triton.py b/third_party/xpu/backend/_inductor_210/codegen/triton.py new file mode 100644 index 000000000..29ebc4499 --- /dev/null +++ b/third_party/xpu/backend/_inductor_210/codegen/triton.py @@ -0,0 +1,868 @@ +import os +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +from typing import Dict, Iterable, List, Set + +import sympy + +import torch + +import torch._logging +from torch._prims_common import is_integer_dtype +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.value_ranges import ValueRanges + +from torch._inductor import ir +from torch._inductor.codegen import triton +from torch._inductor.utils import ( + DeferredLineBase, + get_fused_kernel_name, + get_kernel_metadata, + green_text, + is_welford_reduction, + next_power_of_2, + sympy_product, + sympy_subs, + sympy_symbol, + unique, + yellow_text, +) +from torch._inductor import config +from torch._inductor.virtualized import ops, V +from torch._inductor.codegen.common import ( + CSEVariable, + DeferredLine, + free_symbol_startswith, + IndentedBuffer, + index_prevent_reordering, + Kernel, + OpOverrides, + PythonPrinter, + SizeArg, +) + +from torch._inductor.codegen.triton import triton_compute_type, triton_constant, texpr +from torch._dynamo.utils import counters + +from torch._inductor.codegen.triton_utils import ( + config_of, + signature_of, + signature_to_meta, +) + +CLUSTER_NUM = 8 +CORE_NUM = 64 + + +class TritonXPUPrinter(triton.TritonPrinter): + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return f"tl.libdevice.floor({self.paren(self._print(expr.args[0]))})" + + def _helper_sqrt(self, expr): + return f"tl.libdevice.sqrt({self.paren(self._print(expr))}.to(tl.float32))" + + def _print_Min(self, expr): + nargs = len(expr.args) + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Min(*expr.args[:mid])) + b = self._print(sympy.Min(*expr.args[mid:])) + return f"tl.libdevice.min({a}, {b})" + + def _print_Max(self, expr): + nargs = len(expr.args) + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Max(*expr.args[:mid])) + b = self._print(sympy.Max(*expr.args[mid:])) + return f"tl.libdevice.max({a}, {b})" + + +triton.TritonPrinter = TritonXPUPrinter +texpr = TritonXPUPrinter().doprint + + +class XPUIterationRangesEntry(triton.IterationRangesEntry): + + def _codegen(self): + self.writeline(f"{self.name} = " + texpr(V.kernel.rename_indexing(self.expr))) + return self.name + + +triton.IterationRangesEntry = XPUIterationRangesEntry + + +# ===-------------------- For XPytorch Inductor -----------------------=== +# Modifed Base Pytorch(v2.0.1) torch/_inductor/codegen/triton.py::TritonOverrides +class XPUTritonOverrides(triton.TritonOverrides): + """Map element-wise ops to Triton""" + + @staticmethod + def libdevice_abs(x): + return f"libdevice.abs({x})" + + @staticmethod + def libdevice_exp(x): + return f"libdevice.exp({x})" + + @staticmethod + def exp2(x): + return f"libdevice.exp2({x})" + + @staticmethod + def expm1(x): + return f"libdevice.expm1({x})" + + @staticmethod + def libdevice_sqrt(x): + return f"libdevice.sqrt({x})" + + @staticmethod + def libdevice_cos(x): + return f"libdevice.cos({x})" + + @staticmethod + def libdevice_sin(x): + return f"libdevice.sin({x})" + + @staticmethod + def lgamma(x): + return f"libdevice.lgamma({x})" + + @staticmethod + def erf(x): + return f"libdevice.erf({x})" + + @staticmethod + def cosh(x): + return f"libdevice.cosh({x})" + + @staticmethod + def sinh(x): + return f"libdevice.sinh({x})" + + @staticmethod + def acos(x): + return f"libdevice.acos({x})" + + @staticmethod + def acosh(x): + return f"libdevice.acosh({x})" + + @staticmethod + def asin(x): + return f"libdevice.asin({x})" + + @staticmethod + def asinh(x): + return f"libdevice.asinh({x})" + + @staticmethod + def atan2(x, y): + return f"libdevice.atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"libdevice.atan({x})" + + @staticmethod + def atanh(x): + return f"libdevice.atanh({x})" + + @staticmethod + def copysign(x, y): + return f"libdevice.copysign({x}, {y})" + + @staticmethod + def erfc(x): + return f"libdevice.erfc({x})" + + @staticmethod + def hypot(x, y): + return f"libdevice.hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"libdevice.log10({x})" + + @staticmethod + def nextafter(x, y): + return f"libdevice.nextafter({x}, {y})" + + @staticmethod + def rsqrt(x): + return f"libdevice.rsqrt({x})" + + @staticmethod + def log1p(x): + return f"libdevice.log1p({x})" + + @staticmethod + def tan(x): + return f"libdevice.tan({x})" + + @staticmethod + def tanh(x): + return f"libdevice.tanh({x})" + + @staticmethod + def libdevice_sigmoid(x): + return f"1/(1 + libdevice.exp(-({x})))" + + @staticmethod + def signbit(x): + # XX: This is wrong for the value -0.0 in floating point + return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0" + + @staticmethod + def fmod(a, b): + return f"libdevice.fmod({a}, {b})" + + @staticmethod + def pow(a, b): + return f"libdevice.pow({a}, {b})" + + @staticmethod + def libdevice_log(x): + return f"libdevice.log({x})" + + @staticmethod + def isinf(x): + return f"libdevice.isinf({x})" + + @staticmethod + def isnan(x): + return f"libdevice.isnan({x})" + + @staticmethod + def round(x): + return f"libdevice.nearbyint({x})" + + @staticmethod + def floor(x): + return f"libdevice.floor({x})" + + @staticmethod + def trunc(x): + return f"libdevice.trunc({x})" + + @staticmethod + def ceil(x): + return f"libdevice.ceil({x})" + + +triton.TritonOverrides = XPUTritonOverrides +# ===------------------------------------------------------------------=== + + +class IterationXPURangesRoot(triton.IterationRangesRoot): + # Remove no_x_dim mode(test_bilibili_layernorm.py) + def codegen_header(self, code, no_x_dim=False): + x = self.prefix + if self.is_loop(): + code.writeline(f"{self.name} = {x}offset + {x}base") + elif x == "r" and self.kernel.persistent_reduction: + # no need to "roffset = " + code.writeline(f"{self.name} = {self.ranges_code()}", ) + else: + line = f"{x}offset + {self.ranges_code()}" + code.writelines([ + f"{x}offset = {self.get_pid()} * {x.upper()}BLOCK", + f"{self.name} = {line}", + ]) + code.writeline(f"{x}mask = {self.name} < {x}numel") + + +triton.IterationRangesRoot = IterationXPURangesRoot + +xpu_hasAtomic = False + + +class TritonXPUKernel(triton.TritonKernel): + + overrides = XPUTritonOverrides + + # Remove evict_last flag, it is only used in cuda cache(test_softmax.py perf) + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + indirect_indexing = self.is_indirect_indexing(index) + original_index = index + index, mask_vars, mask, expand_str = self.indexing(index) + + ep = "" + # "other" below is a workaround for https://github.com/openai/triton/issues/737 + # for bool, even though it's likely subject to the same bug, setting `other` leads + # to LLVM errors so we are skipping it for now + if ("tmp" in mask or "rmask" in mask) and V.graph.get_dtype(name) != torch.bool: + other = ", other=0" + else: + other = "" + + append_broadcast = None + if V.graph.is_unspec_arg(name): + line = var + else: + if isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + ({original_index}))" + append_broadcast = expand_str + else: + line = f"tl.load({var} + ({index}), {mask}{ep}{other})" + if V.graph.get_dtype(name) in (torch.float16, torch.bfloat16): + line += ".to(tl.float32)" + + if "tmp" in mask: + # Masked loads must come after the mask is computed + load_buffer = self.compute + elif (self.inside_reduction and not self.persistent_reduction and "rmask" not in mask + and not indirect_indexing): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + load_buffer = self.body + else: + load_buffer = self.loads + + result_var = self.cse.generate(load_buffer, line) + result_var.mask_vars = mask_vars + + if append_broadcast: + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate(load_buffer, line) + + if not self.inside_reduction or "rmask" not in mask: + self.outside_loop_vars.add(result_var) + + return result_var + + # Remove tl.debug_barrier()(test_ks_batchnorm_online.py) + def store(self, name, index, value, mode=None): + var = self.args.output(name) + indirect_indexing = self.is_indirect_indexing(index) + original_index = index + index, mask_vars, mask, expand_str = self.indexing(index, dense_indexing=True) + + if mode is None: + line = f"tl.store({var} + ({index}), {value}, {mask})" + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({index}), {value}, {mask})" + global xpu_hasAtomic + xpu_hasAtomic = True + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + assert self.inside_reduction + masks = {f"{tree.prefix}mask" for tree in self.range_trees} + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix + reduction_sizes = ["None" for _ in self.range_trees] + reduction_sizes[-1] = ":" + + # Say we have + # tmp0 = ops.constant(1, torch.int64) + # tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0) + # tmp0 in the triton code is either a scalar, or single-element tensor + # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1 + # To avoid this, we broadcast to the expected shape first. + dense_size_str = self.dense_size_str() + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate(self.compute, f"tl.broadcast_to({v}, {dense_size_str})"), + value, + ) + + def final_reduction(value): + module = "tl" + return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})") + + def final_argreduce(buffer, result_var, value, index): + buffer.splice(f"""\ + _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f'{result_var}_tmp')} + """) + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + dim = len(self.range_trees) - 1 - int(bool(self.no_x_dim)) + acc_type = triton.triton_acc_type(src_dtype) + result_var = self.cse.newvar() + result_var.mask_vars = {var for var in masks if var[0] != "r"} + cond = " & ".join(masks) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(triton.triton_constant, default) + + def _mask_value(value, default): + return self.cse.generate(self.compute, f"tl.where({cond}, {value}, {default})") + + if isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] + else: + masked_value = _mask_value(value, default) + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = self.cse.generate( + self.compute, + f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce(self.compute, result_var, masked_value, accumulator_index) + elif reduction_type == "welford_reduce": + # For persistent reductions, don't bother with + # welford's algorithm since it uses more registers, and + # taking two reductions doesn't increase memory usage. + sum_ = ops.reduction(dtype, dtype, "sum", value) + self.inside_reduction = False + rnumel = ops.index_expr(self.numels[-1], dtype) + mean = ops.div(sum_, rnumel) + + self.inside_reduction = True + dx = ops.sub(value, mean) + dx2 = ops.mul(dx, dx) + m2 = ops.reduction(dtype, dtype, "sum", dx2) + result_var = (mean, m2, rnumel) + elif reduction_type == "welford_combine": + mean, m2, weight = masked_value + welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" + mean, m2, weight = (self.cse.newvar() for _ in range(3)) + self.compute.writeline(f"{mean}, {m2}, {weight} = {welford}") + + result_var = tuple( + self.cse.generate(self.compute, self.reduction_resize(var_name)) for var_name in (mean, m2, weight)) + else: + result_var = self.cse.generate(self.compute, final_reduction(masked_value)) + else: + accumulator = f"_{result_var}" + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(triton.triton_constant, default) + if not isinstance(default, tuple): + self.body.writeline(f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})") + + if reduction_type in {"argmax", "argmin"}: + accumulator_index = f"_{result_var}_index" + long_max = torch.iinfo(torch.int64).max + self.body.writeline(f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)") + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + + self.compute.splice(f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index + ) + {accumulator} = tl.where({cond}, {accumulator}_next, {accumulator}) + {accumulator_index} = tl.where({cond}, {accumulator_index}_next, {accumulator_index}) + """) + final_argreduce(self.suffix, result_var, accumulator, accumulator_index) + elif is_welford_reduction(reduction_type): + accumulator = f"{result_var}_mean" + accumulator_m2 = f"{result_var}_m2" + accumulator_weight = f"{result_var}_weight" + self.body.writeline(f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})") + self.body.writeline(f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})") + self.body.writeline(f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})") + + if reduction_type == "welford_combine": + mean, m2, weight = value + self.compute.splice(f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine( + {accumulator}, {accumulator_m2}, {accumulator_weight}, + {mean}, {m2}, {weight} + ) + """) + else: + assert reduction_type == "welford_reduce" + self.compute.splice(f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce( + {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, + ) + """) + + self.compute.splice(f"""\ + {accumulator} = tl.where({cond}, {accumulator}_next, {accumulator}) + {accumulator_m2} = tl.where({cond}, {accumulator_m2}_next, {accumulator_m2}) + {accumulator_weight} = tl.where({cond}, {accumulator_weight}_next, {accumulator_weight}) + """) + + result_mean = result_var + result_m2 = self.cse.newvar() + result_weight = self.cse.newvar() + self.suffix.splice(f"""\ + {result_mean}_tmp, {result_m2}_tmp, {result_weight}_tmp = triton_helpers.welford( + {accumulator}, {accumulator_m2}, {accumulator_weight}, {dim} + ) + {result_mean} = {self.reduction_resize(f'{result_mean}_tmp')} + {result_m2} = {self.reduction_resize(f'{result_m2}_tmp')} + {result_weight} = {self.reduction_resize(f'{result_weight}_tmp')} + """) + result_var = result_mean, result_m2, result_weight + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + self.compute.writeline(f"{accumulator} = tl.where({cond}, {updated}, {accumulator})") + + if src_dtype == torch.bool: + # This is only really used for aten.any. It changes the + # final reduction of a non-persistent reduction from + # tmp5 = triton_helpers.max(_tmp5, 1)[:, None] + # to + # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) + # which is needed because tl.reduce doesn't support tl.int1 + accumulator = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + self.suffix.writeline(f"{result_var} = {final_reduction(accumulator)}.to({result_type})") + else: + self.suffix.writeline(f"{result_var} = {final_reduction(accumulator)}") + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + self.outside_loop_vars |= set(result_var) + else: + self.outside_loop_vars.add(result_var) + + return result_var + + # Remove tl.where(test_gather.py) + def indirect_indexing(self, var, size, check=True): + # TODO(lezcano) This code should be lifted to codegen/common.py. + # This should be easy, as now CSE variables carry bounds info + class IndirectAssertLine(DeferredLineBase): + + def __init__(self, line, var, mask, size_map): + self.var = var + self.mask = mask + self.line = line + self.size_map = size_map + + def __call__(self): + size, size_str = self.size_map[(self.var, self.mask)] + + # We assert if we've not been able to prove the bound + assert_min = (self.var.bounds.lower >= 0) != sympy.true + assert_max = (self.var.bounds.upper < size) != sympy.true + + # FooBar interview question + if not (assert_min or assert_max): + return None + elif assert_min and assert_max: + # The conditions need to be in parens because of Python's operator precedence. + # It'd be less error-prone to use and/or/not, which is suported by triton + cond = f"(0 <= {self.var}) & ({self.var} < {size_str})" + cond_print = f"0 <= {self.var} < {size_str}" + elif assert_min: + cond = f"0 <= {self.var}" + cond_print = cond + else: + assert assert_max + cond = f"{self.var} < {size_str}" + cond_print = cond + + if self.mask: + cond = f"({cond}) | ~{self.mask}" + return self.line.format(cond=cond, cond_print=cond_print) + + def _new_line(self, line): + return IndirectAssertLine(line, self.var, self.mask, self.size_map) + + generate_assert = ((check or config.debug_index_asserts) and config.triton.assert_indirect_indexing + and torch.version.hip is None) + if generate_assert: + mask_vars = set(var.mask_vars) + if self._load_mask: + mask_vars.add(self._load_mask) + + mask = "" + if mask_vars: + mask = (f"{list(mask_vars)[0]}" + if len(mask_vars) == 1 else f"({' & '.join(str(v) for v in mask_vars)})") + + # An assertion line may have been written already, if so just + # update the max size. + map_key = (var, mask) + existing_size, _ = self.indirect_max_sizes.get(map_key, (None, None)) + if existing_size is not None: + size = sympy.Min(size, existing_size) + + self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) + + return sympy_symbol(str(var)) + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in triton code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the triton kernel. + + Index expressions often need to be passed in as arguments to the triton kernel. + Rename_indexing and codegen_indexing keep track of the needed indices and add + new parameters to the function signature. + """ + return texpr(self.rename_indexing(self.codegen_indexing(index))) + + def codegen_kernel(self, name=None): + + from triton import next_power_of_2 + + def next_multiply_of_512(n: int) -> int: + m = CLUSTER_NUM * CORE_NUM # 8 cluster 64 core + if n < m: + res = next_power_of_2(n) + else: + res = math.ceil(n / m) * m + return res + + code = IndentedBuffer() + + if len(self.numels) == 1: + size_hints = [next_power_of_2(V.graph.sizevars.size_hint(numel)) for numel in self.numels] + elif len(self.numels) == 2: + if self.numels[1] != 1: + size_hints = [ + next_multiply_of_512(V.graph.sizevars.size_hint(self.numels[0])), + next_power_of_2(V.graph.sizevars.size_hint(self.numels[1])), + ] + else: + size_hints = [next_power_of_2(V.graph.sizevars.size_hint(numel)) for numel in self.numels] + else: + raise AssertionError(f"invalid size for numels {len(self.numels)}") + + if self.persistent_reduction: + assert self.inside_reduction + heuristics = "persistent_reduction" + elif self.inside_reduction: + heuristics = "reduction" + else: + size_hints.pop() + heuristics = "pointwise" + + if name is None: + code.splice(f""" + import triton + import triton.language as tl + from torch._inductor.ir import ReductionHint + from torch._inductor.ir import TileHint + from torch._inductor.triton_heuristics import AutotuneHint, {heuristics} + from torch._inductor.utils import instance_descriptor + from torch._inductor import triton_helpers + # ===-------------------- For Triton XPU -----------------------=== + # Borrowed From Pytorch(v2.5.0-rc9) torch/_inductor/runtime/triton_helpers.py + # In the latest triton, math functions were shuffled around into different modules: + # https://github.com/openai/triton/pull/3172 + try: + from triton.language.extra import libdevice + + libdevice = tl.extra.libdevice # noqa: F811 + math = tl.math + except ImportError: + if hasattr(tl.extra, "xpu") and hasattr(tl.extra.xpu, "libdevice"): + libdevice = tl.extra.xpu.libdevice + math = tl.math + elif hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): + libdevice = tl.extra.cuda.libdevice + math = tl.math + elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): + libdevice = tl.extra.intel.libdevice + math = tl.math + else: + libdevice = tl.math + math = tl + # ===-----------------------------------------------------------=== + """) + if config.benchmark_kernel: + code.splice(""" + from torch._dynamo.testing import rand_strided + from torch._C import _cuda_getCurrentRawStream as get_cuda_stream + import torch + from torch._inductor.triton_heuristics import grid + """) + + argdefs, _, signature = self.args.python_argdefs() + # maps actual expression to SizeArg if its in sizevars replacements + for i, arg in enumerate(signature): + if (isinstance(arg, SizeArg) and arg.expr in V.graph.sizevars.inv_precomputed_replacements): + signature[i] = SizeArg(arg.name, V.graph.sizevars.inv_precomputed_replacements[arg.expr]) + + mutated_args = set() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if (mutation in self.args.inplace_buffers and mutation not in V.graph.removed_buffers): + mutated_args.add(self.args.inplace_buffers[mutation].inner_name) + if mutation in self.args.output_buffers: + mutated_args.add(self.args.output_buffers[mutation]) + mutated_args = sorted(mutated_args) + + global xpu_hasAtomic + triton_meta = { + "signature": signature_to_meta(signature, size_dtype=self.index_dtype), + "device": V.graph.scheduler.current_device.index, + "device_type": V.graph.scheduler.current_device.type, + "constants": {}, + "mutated_arg_names": mutated_args, + "autotune_hints": set(self.autotune_hints), + "kernel_name": "DESCRIPTIVE_KRNL_NAME", + "hasAtomic": xpu_hasAtomic, + } + xpu_hasAtomic = False # reset global var + + for tree in self.range_trees: + if tree.prefix != "r" or self.inside_reduction: + sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) + signature.append(sizearg) + triton_meta["signature"][len(argdefs)] = signature_of(sizearg, size_dtype=self.index_dtype) + argdefs.append(f"{tree.prefix}numel") + # constexpr version causes issues, see + # https://github.com/pytorch/torchdynamo/pull/1362 + # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( + # tree.numel + # ) + # argdefs.append(f"{tree.prefix}numel: tl.constexpr") + triton_meta["configs"] = [config_of(signature)] + + for tree in self.range_trees: + if tree.prefix == "r" and (not self.inside_reduction or self.persistent_reduction): + continue + if tree.prefix == "x" and self.no_x_dim: + continue + argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr") + + if self.inside_reduction: + reduction_hint = self.reduction_hint + heuristics_line = f""" + @{heuristics}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + meta={triton_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if len(signature) == 4: # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @{heuristics}(size_hints={size_hints!r}, {tile_hint}filename=__file__, meta={triton_meta!r}) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline(f"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):") + self.codegen_body() + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark()) + + if name is not None: + return code.getvalue() + + return code.getvalue() + + +triton.TritonKernel = TritonXPUKernel + +import torch._inductor.scheduler as scheduler +from torch._inductor.codegen.triton import EnableReduction + +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") + + +class TritonXPUScheduling(triton.TritonScheduling): + + @classmethod + def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): + """ + Heuristics to decide how to tile kernels. + Currently, we tile based on stride-1 dimensions. + + Returns: + `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` + + """ + if reduction_numel != 1 or config.triton.max_tiles <= 1: + # TODO(jansel): should we tile reductions? + # do perf hint here if stride-1 dim is not being reduced + if perf_hint_log.level <= logging.WARNING: + for node in EnableReduction.filter(node_schedule): + if len(cls.candidate_tilings(node)) > 0: + perf_hint_log.info("reduction over non-contiguous dims") + break + return (numel, reduction_numel) + + seen_names = set() + candidate_tiles = collections.Counter() + for node in EnableReduction.filter(node_schedule): + for tiling in cls.candidate_tilings(node): + if tiling.name in seen_names: + continue + seen_names.add(tiling.name) + candidate_tiles[tiling.tiling] += tiling.score + + ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()] + + if config.triton.max_tiles >= 3: + # Consider adding a third dimension of tiling, but only + # when a1 is a multiple of b1; otherwise, you have a lot + # of stragglers which is annoying to generate code for. + # + # NB: More than three max tiles is not enabled by default. + + # Add one 3D tiling choice + for i in range(1, len(ranked_tilings)): + a0, a1 = ranked_tilings[0] + b0, b1 = ranked_tilings[i] + if V.graph.sizevars.size_hint(a1 - b1) == 0: + continue + if V.graph.sizevars.size_hint(a1 - b1) < 0: + # swap so a0 is bigger + a0, a1 = ranked_tilings[i] + b0, b1 = ranked_tilings[0] + assert V.graph.sizevars.size_hint(a1 - b1) > 0 + if V.graph.sizevars.statically_known_multiple_of(a1, b1): + tiling = (a0, FloorDiv(a1, b1), b1) + ranked_tilings = [tiling] + ranked_tilings + break # only 1 choice for now + + if len(ranked_tilings) > 1: + perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) + + for tiled_groups in ranked_tilings: + # [TODO]: Remove this tiling limit + if (len(tiled_groups) != 1) and all(isinstance(node.node.data, ir.Pointwise) for node in node_schedule): + print(f"[TORCH_XPU Warning] Pointwise Kernel Limit To 1D-Tiling") + return (numel, reduction_numel) + new_groups = (*tiled_groups, reduction_numel) + if all( + triton.TritonKernel.is_compatible(new_groups, node.get_ranges()) + for node in node_schedule + if isinstance(node, scheduler.SchedulerNode)): + return new_groups + + return (numel, reduction_numel) + + +triton.TritonScheduling = TritonXPUScheduling diff --git a/third_party/xpu/backend/_inductor_210/runtime_utils.py b/third_party/xpu/backend/_inductor_210/runtime_utils.py new file mode 100644 index 000000000..1b81c7d03 --- /dev/null +++ b/third_party/xpu/backend/_inductor_210/runtime_utils.py @@ -0,0 +1,10 @@ +# Borrowed From Pytorch(v2.5.0-rc9) torch/_inductor/runtime/runtime_utils.py +def get_first_attr(obj, *attrs): + """ + Return the first available attribute or throw an exception if none is present. + """ + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + + raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") diff --git a/third_party/xpu/backend/_inductor_210/triton_heuristics.py b/third_party/xpu/backend/_inductor_210/triton_heuristics.py new file mode 100644 index 000000000..e3d92e588 --- /dev/null +++ b/third_party/xpu/backend/_inductor_210/triton_heuristics.py @@ -0,0 +1,445 @@ +import os +import copy +import math +import torch +from torch._inductor.utils import (has_triton, ceildiv) +from torch._inductor import config +import torch.autograd.profiler as autograd_profiler + +from .runtime_utils import ( + get_first_attr, ) + +if has_triton(): + import triton + from triton import cdiv, Config, next_power_of_2 + from triton.runtime.jit import get_cuda_stream, KernelInterface + try: + from triton.compiler.compiler import ASTSource + except ImportError: + ASTSource = None + + try: + from triton.backends.compiler import GPUTarget + except ImportError: + GPUTarget = None +else: + cdiv = None + Config = object + get_cuda_stream = None + KernelInterface = object + next_power_of_2 = None + triton = None + ASTSource = None + GPUTarget = None + +from torch._inductor import ( + triton_heuristics, ) +from torch._inductor.triton_heuristics import (cached_autotune, HeuristicType) +from torch._inductor.ir import TileHint, ReductionHint + +from xpu.backend.driver import get_xpu_spec + +arch = int(os.environ.get('TRITON_XPU_ARCH', '3')) +CLUSTER_NUM = get_xpu_spec(arch)[0] +CORE_NUM = get_xpu_spec(arch)[1] + + +# ===-------------------- For XPytorch Inductor -----------------------=== +# Base Pytorch(v2.1.0) torch/_inductor/triton_heuristics.py +# vvv +# Target Pytorch(v2.5.0-rc9) torch/_inductor/runtime/triton_heuristics.py +class XPUCachingAutotuner(triton_heuristics.CachingAutotuner): + + def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.meta) + for k, v in cfg.kwargs.items(): + compile_meta["constants"][self.fn.arg_names.index(k)] = v + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + + compile_meta["device_type"] = "xpu" + compile_meta["cc"] = 3 + compile_meta["debug"] = False + + if ASTSource: + compile_args = (ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + compile_meta["configs"][0], + ), ) + + cc_str = str(compile_meta["cc"]) + if "gfx10" in cc_str or "gfx11" in cc_str: + rocm_warp_size = 32 + else: + rocm_warp_size = 64 + + if GPUTarget: + target = GPUTarget( + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size if torch.version.hip else 32, + ) + else: + target = ((compile_meta["device_type"], compile_meta["cc"]) if not torch.version.hip else [ + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size, + ]) + + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + } + # if self.device_props.type != "hip": + # if "waves_per_eu" in compile_meta: + # options["waves_per_eu"] = compile_meta["waves_per_eu"] + # if "matrix_instr_nonkdim" in compile_meta: + # options["matrix_instr_nonkdim"] = compile_meta[ + # "matrix_instr_nonkdim" + # ] + compile_kwargs = { + "target": target, + "options": options, + } + else: + compile_args = (self.fn, ) + compile_kwargs = compile_meta + + if warm_cache_only_with_cc: + triton.compile(*compile_args, **compile_kwargs) + return + + # load binary to the correct device + with torch.cuda.device(compile_meta["device"]): + # need to initialize context + torch.cuda.synchronize(torch.cuda.current_device()) + binary = triton.compile(*compile_args, **compile_kwargs) + binary._init_handles() + + call_args = [arg for i, arg in enumerate(self.fn.arg_names) if i not in self.fn.constexprs] + def_args = list(self.fn.arg_names) + while def_args and def_args[-1] in cfg.kwargs: + def_args.pop() + + binary_shared = (binary.shared if hasattr(binary, "shared") else binary.metadata.shared) + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "torch": torch, + "set_device": torch.cuda.set_device, + "current_device": torch.cuda.current_device, + "metadata": binary.packed_metadata, + "launch_enter_hook": binary.launch_enter_hook, + "launch_exit_hook": binary.launch_exit_hook, + "shared": binary_shared, + } + + scope["num_warps"] = (binary.num_warps if hasattr(binary, "num_warps") else binary.metadata.num_warps) + + scope["cta_args"] = ((binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) if hasattr( + binary, "num_ctas") else ((binary.metadata.num_ctas, + *binary.metadata.cluster_dims) if hasattr(binary, "metadata") else ())) + + scope["function"] = get_first_attr(binary, "function", "cu_function") + + def get_launch_args_without_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args before CompiledKernel.launch_metadata is added. + """ + return ( + grid_0, + grid_1, + grid_2, + num_warps, + *cta_args, + shared, + stream, + function, + launch_enter_hook, + launch_exit_hook, + metadata, + ) + + # Getting the kernel launch args is extremely perf-sensitive. Evaluating + # `bin.launch_metadata` is relatively expensive, and returns None unless a + # `launch_enter_hook` is installed. So if we don't have that hook installed, + # we want to burn None in to the launch args with zero overhead. + # See https://github.com/pytorch/pytorch/issues/123597 + if binary.launch_enter_hook: + + def get_launch_args_with_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + by https://github.com/openai/triton/pull/3492 . + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin.launch_metadata(grid, stream, *args), + launch_enter_hook, + launch_exit_hook, + ) + + else: + + def get_launch_args_with_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + by https://github.com/openai/triton/pull/3492 . + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + None, + launch_enter_hook, + launch_exit_hook, + ) + + scope["get_launch_args"] = (get_launch_args_with_kernel_launch_metadata if hasattr(binary, "launch_metadata") + else get_launch_args_without_kernel_launch_metadata) + + scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + exec( + f""" + def launcher({', '.join(def_args)}, grid, stream): + if callable(grid): + grid_0, grid_1, grid_2 = grid(grid_meta) + else: + grid_0, grid_1, grid_2 = grid + + args = {', '.join(call_args)}, + launch_args = get_launch_args( + grid, grid_0, grid_1, grid_2, stream, function, + metadata, bin, launch_enter_hook, launch_exit_hook, + num_warps, shared, cta_args, args + ) + runner(*launch_args, *args) + return bin + """.lstrip(), + scope, + ) + + launcher = scope["launcher"] + launcher.config = cfg + return launcher + + def run(self, *args, grid, stream): + if len(self.launchers) != 1: + if len(self.launchers) == 0: + self.precompile() + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, grid=grid) + + if (not getattr(self.launchers[0].config, "found_by_coordesc", False) and config.coordinate_descent_tuning): + self.launchers = [self.coordinate_descent_tuning(self.launchers[0], *args, grid=grid)] + + (launcher, ) = self.launchers + + if launcher.config.pre_hook is not None: + launcher.config.pre_hook({**dict(zip(self.arg_names, args)), **launcher.config.kwargs}) + + # guard the record_function_ctx and only call it if profiling is currently + # in progress, to reduce latency when profiler is not turned on. Note that + # the "if" statement (instead of, say, a contextlib.nullcontext) is intentional; + # it is faster than entering and exiting a context manager, even if the context + # manager is a nullcontext. + if autograd_profiler._is_profiler_enabled: + with self.record_function_ctx: + return launcher( + *args, + grid=grid, + stream=stream, + ) + else: + return launcher( + *args, + grid=grid, + stream=stream, + ) + + +triton_heuristics.CachingAutotuner = XPUCachingAutotuner +# ===------------------------------------------------------------------=== + +# ===-------------------- For XPytorch Inductor -----------------------=== + + +def triton_config(size_hints, x, y=None, z=None, num_stages=1, num_elements_per_warp=256) -> Config: + + cfg = {"XBLOCK": x} + if y: + cfg["YBLOCK"] = y + if z: + cfg["ZBLOCK"] = z + + num_warps = 16 # num_warps represents groups in XPU2/xpu3(16) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def tritonxpu_pointwise(size_hints, meta, tile_hint=None, filename=None): + import functools + import operator + """ + Construct @triton.heuristics() based on size_hints. + """ + numel = functools.reduce(operator.mul, size_hints) + bs = max(256, min(numel // 128, 1024)) + + if len(size_hints) == 1: + # TODO: make it more tunable. + if bool(meta.get("hasAtomic", False)): + # We need to tile all data in only one cluster for atomic simulation + bs = max(CORE_NUM, math.ceil(numel / 1)) + else: + bs = max(CORE_NUM, math.ceil(numel / CLUSTER_NUM)) + return cached_autotune( + size_hints=size_hints, + configs=[triton_config(size_hints, bs)], + meta=meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + raise NotImplementedError(f"size_hints: {size_hints}") + + +def triton_xpu_config_reduction(size_hints, x, r, num_stages=2) -> Config: + + cfg = {"XBLOCK": x, "RBLOCK": r} + num_warps = 16 # num_warps represents groups in XPU2/XPU3(16) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def tritonxpu_reduction(size_hints, reduction_hint=False, meta=None, filename=None): + """args to @triton.heuristics()""" + assert meta is not None + if bool(meta.get("hasAtomic", False)): + xnumel = math.ceil(size_hints[0] / 1) + else: + xnumel = math.ceil(size_hints[0] / CLUSTER_NUM) + rnumel = size_hints[-1] + + if len(size_hints) == 2: + contiguous_config = triton_xpu_config_reduction(size_hints, xnumel, (rnumel if 0 < rnumel < 8192 else 8192), + num_stages=1) + buffersize_config = triton_xpu_config_reduction(size_hints, xnumel, (rnumel if 0 < rnumel < 128 else 128), + num_stages=1) + if config.max_autotune: + pass # skip all these cases + elif reduction_hint == ReductionHint.INNER or ReductionHint.DEFAULT: + return cached_autotune( + size_hints=size_hints, + configs=[ + contiguous_config, + # buffersize_config, # TODO: Open autotune + ], + meta=meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + raise NotImplementedError(f"size_hints: {size_hints}") + + +triton_heuristics.pointwise = tritonxpu_pointwise +triton_heuristics.reduction = tritonxpu_reduction + +# ===------------------------------------------------------------------=== + + +# ===-------------------- For XPytorch Inductor -----------------------=== +# Base Pytorch(v2.1.0) torch/_inductor/triton_heuristics.py +def grid(*numels): + """Helper function to compute triton grids""" + + if len(numels) == 1: + xnumel, ynumel, znumel = numels[0], None, None + # ===-------------------- For Triton XPU -----------------------=== + # elif len(numels) == 2: + # xnumel, ynumel, znumel = numels[1], numels[0], None + # elif len(numels) == 3: + # xnumel, ynumel, znumel = numels[2], numels[1], numels[0] + # ===-----------------------------------------------------------=== + else: + raise AssertionError(f"invalid size for numels {len(numels)}") + + def get_grid_dim(numel, block): + if numel is None: + return 1 + # return ceildiv(numel, block) + # ===-------------------- For Triton XPU -----------------------=== + if block is None: + return numel + core_num = CLUSTER_NUM * CORE_NUM + grid_num = ceildiv(numel, block) if numel < core_num else CLUSTER_NUM + return grid_num + # ===-----------------------------------------------------------=== + + def grid_fn(meta): + return ( + get_grid_dim(xnumel, meta.get("XBLOCK", 1)), + get_grid_dim(ynumel, meta.get("YBLOCK", None)), + get_grid_dim(znumel, meta.get("ZBLOCK", None)), + ) + + return grid_fn + + +triton_heuristics.grid = grid +# ===------------------------------------------------------------------=== diff --git a/third_party/xpu/backend/compiler.py b/third_party/xpu/backend/compiler.py new file mode 100644 index 000000000..2de76b23d --- /dev/null +++ b/third_party/xpu/backend/compiler.py @@ -0,0 +1,294 @@ +from triton.backends.compiler import BaseBackend, GPUTarget +from triton._C.libtriton import ir, passes, xpu, llvm +from triton.runtime.cache import get_cache_manager +import subprocess +import tempfile +import re +import warnings + +import os + +from dataclasses import dataclass +import functools +from typing import Any, Tuple, Optional +import hashlib +from pathlib import Path + + +@functools.lru_cache(None) +def file_hash(path): + with open(path, "rb") as f: + return hashlib.sha256(f.read()).hexdigest() + + +# @dataclass create __dataclass_fields__ specical attribute +# frozen=True can't modift entry's attribute once it have been created +# [raise FrozenInstanceError] +@dataclass(frozen=True) +class XPUOptions: + arch: int = int(os.environ.get("TRITON_XPU_ARCH", "3")) + assert arch in [2, 3, 4], "Invalid XPU ARCH" + cluster_num: int = 12 if arch == 3 else 8 + core_num: int = 64 + buffer_size_limit: int = 512 + extern_libs: dict = None + debug: bool = False + backend_name: str = "xpu" + cluster_dims: tuple = (1, 1, 1) # TODO: find mapping relationship + + isOpenCmpNan: bool = False + isCloseOffsetAnalysis: bool = False + isCloseCoreTiling: bool = False + isCloseUnrollControl: bool = False + isCLOSE_TTXPU_O_ATOMIC_SIM: bool = False + + enable_fp_fusion: bool = False + allow_fp8e4nv: bool = False + allow_fp8e4b15: bool = False + default_dot_input_precision: str = "ieee" + allowed_dot_input_precisions: Tuple[str] = ("ieee", ) + + num_warps: int = (-1) # TODO: invalid value, just to keep num_warps function signature + num_ctas: int = -1 # TODO: invalid value, just to keep num_ctas function signature + num_stages: int = 1 + + def __post_init__(self): + default_libdir = Path(__file__).parent / f"xpu{self.arch}" + extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) + if not extern_libs.get("libdevice", None): + extern_libs["libdevice"] = os.getenv( + "TRITON_LIBDEVICE_PATH", + str(default_libdir / "lib" / f"libdevice-xpu{self.arch}.bc"), + ) + if not os.path.exists(extern_libs["libdevice"]): + warnings.warn(f'libdevice not found: {extern_libs["libdevice"]}', UserWarning) + del extern_libs["libdevice"] + + object.__setattr__(self, "extern_libs", tuple(extern_libs.items())) + + invalid_params = [] + if self.num_warps != -1: + invalid_params.append(f"num_warps={self.num_warps}") + if self.num_ctas != -1: + invalid_params.append(f"num_ctas={self.num_ctas}") + if len(invalid_params) > 0: + warnings.warn(f"Invalid {', '.join(invalid_params)} in xpu arch", UserWarning) + + def hash(self): + hash_dict = dict(self.__dict__) + hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"])) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class XPUBackend(BaseBackend): + + def __init__(self, target: GPUTarget) -> None: + super().__init__(target) + assert isinstance(target.arch, int) + self.binary_ext = "xpubin" + self.buffer_len = 128 + + @staticmethod + def supports_target(target: GPUTarget): + return target.backend == "xpu" + + @staticmethod + def path_to_xpu_compile_tool(opt): + # Check env path for clang + if "TRITON_XPU_CLANG_PATH" in os.environ: + clang_path = os.getenv("TRITON_XPU_CLANG_PATH") + return clang_path + return os.path.join(Path(__file__).parent, f"xpu{opt.arch}", "bin") + + @staticmethod + def make_ttir(mod, metadata, opt): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttxir(mod, metadata, opt): + metadata["xpu_arch"] = opt.arch + metadata["shared"] = (-1) # TODO: invalid value, just to keep CompiledKernel _init_handles() success + + max_buffer_size = int(os.environ.get("TRITONXPU_BUFFER_SIZE", metadata["buffer_size_limit"])) + max_buffer_size = metadata["buffer_size_limit"] + XPUBackend.buffer_len = xpu.get_buffer_len(mod, max_buffer_size) + # print(f"XPUBackend.buffer_len = {XPUBackend.buffer_len}") + core_num = metadata["core_num"] + + # F/O Prefix For Function/Optimization Macro + TTXPU_F_OHTER_VALUE_SIM = int(os.environ.get("TRITONXPU_OTHER_SIM", 0)) + TTXPU_F_STORE_MASK_SIM = int(os.environ.get("TRITONXPU_STORE_MASK_SIM", 0)) + TTXPU_F_DTYPE_CONVERT = int(os.environ.get("TRITONXPU_DTYPE_CONVERT", 1)) + TTXPU_O_ATOMIC_SIM = 0 if metadata["isCLOSE_TTXPU_O_ATOMIC_SIM"] else int( + os.environ.get("TRITONXPU_ATOMIC_SIM", 1)) + TTXPU_O_CLOSE_OPT = int(os.environ.get("TRITONXPU_CLOSE_OPTIMIZE", 0)) + + pm = ir.pass_manager(mod.context) + pm.enable_debug() + + xpu.passes.ttxpuir.add_convert_triton_to_tritonxpu_pass(pm, opt.arch, XPUBackend.buffer_len, core_num) + xpu.passes.ttxpuir.add_tritonxpu_gm2lm_pass(pm, opt.arch, TTXPU_O_ATOMIC_SIM) + passes.common.add_canonicalizer(pm) + if TTXPU_F_DTYPE_CONVERT: + xpu.passes.ttxpuir.add_tritonxpu_dtype_convert_pass(pm, opt.arch) + if not metadata["isCloseCoreTiling"]: + xpu.passes.ttxpuir.add_tritonxpu_core_tiling_pass( + pm, 0, XPUBackend.buffer_len) if not TTXPU_O_CLOSE_OPT else None # dumpFlag=0 + # xpu.passes.ttxpuir.add_tritonxpu_lm_to_sm_pass(pm) + if not metadata["isCloseOffsetAnalysis"]: + xpu.passes.ttxpuir.add_tritonxpu_offset_state_pass(pm, 0) if not TTXPU_O_CLOSE_OPT else None # dumpFlag=0 + passes.common.add_canonicalizer(pm) + xpu.passes.ttxpuir.add_tritonxpu_legalize_pass(pm, XPUBackend.buffer_len, core_num) + if not TTXPU_F_OHTER_VALUE_SIM: + xpu.passes.ttxpuir.add_tritonxpu_mask_pass(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + xpu.passes.ttxpuir.add_tritonxpu_interleave_pass(pm) if not TTXPU_O_CLOSE_OPT else None + # xpu.passes.ttxpuir.add_tritonxpu_interleave_mask_pass(pm) + passes.common.add_canonicalizer(pm) + xpu.passes.ttxpuir.add_tritonxpu_vectorize_pass(pm, 0) if not TTXPU_O_CLOSE_OPT else None # dumpFlag=0 + xpu.passes.ttxpuir.add_tritonxpu_alloca_pass(pm) + if not TTXPU_F_OHTER_VALUE_SIM: + xpu.passes.ttxpuir.add_tritonxpu_other_sim_pass(pm) + xpu.passes.ttxpuir.add_tritonxpu_memory_async_pass(pm, 0) if not TTXPU_O_CLOSE_OPT else None # dumpFlag=0 + if not metadata["isCloseUnrollControl"]: + xpu.passes.ttxpuir.add_tritonxpu_unroll_control_pass(pm) if not TTXPU_O_CLOSE_OPT else None + xpu.passes.ttxpuir.add_tritonxpu_store_control_pass(pm) if not TTXPU_O_CLOSE_OPT else None + xpu.passes.ttxpuir.add_tritonxpu_loop_grid_pass(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + + pm.run(mod) + return mod + + @staticmethod + def make_llir(mod, metadata, opt): + # TritonXPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + # xpu.passes.ttxpuir.add_decompose_unsupported_conversions(pm, opt.arch) + passes.convert.add_scf_to_cf(pm) # cf->llvm exist choose scf->cf->llvm + # passes.convert.add_index_to_llvmir(pm) // TODO[dyq]: necessary? + + passes.ttgpuir.add_allocate_shared_memory(pm) + xpu.passes.ttxpuir.add_convert_tritonxpu_to_llvm_pass(pm, opt.arch, XPUBackend.buffer_len) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + + # passes.convert.add_cf_to_llvmir(pm) + # passes.convert.add_arith_to_llvmir(pm) + # passes.common.add_canonicalizer(pm) + # passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + + pm.run(mod) + + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + + if opt.extern_libs: + paths = [path for (name, path) in opt.extern_libs if xpu.llvm.need_extern_lib(mod)] + assert (len(paths) <= 1), f"Expected 0/1 extern_lib path, but found {len(paths)}" + llvm.link_extern_libs(llvm_mod, paths) + + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, f"xpu{opt.arch}") + xpu.llvm.amend_func(llvm_mod, mod, context, opt.arch) + + del context + return llvm_mod + + @staticmethod + def make_elf(mod, metadata, opt): + # Find kernel names (there should only be one) + # We get the name at the last possible step to accomodate `triton.compile` + # on user-provided LLVM + metadata["name"] = xpu.llvm.get_kernel_name(mod) + + # llvm -> elf/asm + triple = f"xpu{opt.arch}-baidu-none-gnu" + proc = f"xpu{opt.arch}" + flags = ["xpu-cmp-nan"] if metadata["isOpenCmpNan"] else [] + ret_asm = xpu.llvm.translate_to_asm(mod, triple, proc, "", flags, False, False) + fn_cache_manager = get_cache_manager(metadata["hash"]) + fn_cache_manager.put(ret_asm, f"{metadata['name']}.asm") + ret_elf = xpu.llvm.translate_to_asm(mod, triple, proc, "", [], False, True) + + del mod + return ret_elf + + @staticmethod + def make_xpubin(mod, metadata, opt): + with tempfile.TemporaryDirectory() as tmpdir: + clang_path = XPUBackend.path_to_xpu_compile_tool(opt) + elfconv = os.path.join(Path(__file__).parent, f"xpu{opt.arch}-elfconv") + objfile = os.path.join(tmpdir, "kernel.o") + binfile = os.path.join(tmpdir, "kernel.bin") + with open(objfile, "wb") as f: + f.write(mod) + cmd = ["bash", elfconv, objfile, binfile, clang_path] + out = subprocess.run(cmd, check=True, capture_output=True) + printf_buf_offset_res = re.search(rb"0x[0-9a-fA-F]+", out.stdout) + if printf_buf_offset_res: + printf_buf_offset_hex = printf_buf_offset_res.group(0) + printf_buf_offset_hex_str = printf_buf_offset_hex.decode("utf-8") + printf_buf_offset = int(printf_buf_offset_hex_str, 16) + else: + printf_buf_offset = 0 + metadata["printf_buf_offset"] = printf_buf_offset + with open(binfile, "rb") as f: + return f.read() + + @staticmethod + def is_elf_stack_size_oob(mod) -> bool: + stack_size_oob = llvm.is_elf_stack_size_oob(mod) + return stack_size_oob + + def hash(self) -> str: + """Returns a unique identifier for this backend""" + # TODO: + return f"1" + + def parse_options(self, options: dict) -> object: + args = {"arch": self.target.arch} + args.update({k: options[k] for k in XPUOptions.__dataclass_fields__.keys() if k in options}) + return XPUOptions(**args) + + def add_stages(self, stages: dict, options: object) -> None: + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttxir"] = lambda src, metadata: self.make_ttxir(src, metadata, options) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) + stages["elf"] = lambda src, metadata: self.make_elf(src, metadata, options) + stages["xpubin"] = lambda src, metadata: self.make_xpubin(src, metadata, options) + + def pack_metadata(self, metadata): + return ( + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + def get_codegen_implementation(self): + codegen_fns = dict() + return codegen_fns + + def load_dialects(self, context): + xpu.load_dialects(context) diff --git a/third_party/xpu/backend/driver.c b/third_party/xpu/backend/driver.c new file mode 100644 index 000000000..965403381 --- /dev/null +++ b/third_party/xpu/backend/driver.c @@ -0,0 +1,192 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include +#define PY_SSIZE_T_CLEAN // control type size +#include + +static inline void xpuAssert(int code, const char *file, int line, + const char *call) { + if (code != XPU_SUCCESS) { + const char *err_msg = xpu_strerror(code); + char buf[1024] = {0}; + sprintf(buf, "%s:%d: %s -> %s(err_code: %d)", file, line, call, err_msg, + code); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, buf); + PyGILState_Release(gil_state); + } +} + +#define XPU_CHECK(ans) \ + { \ + xpuAssert((ans), __FILE__, __LINE__, #ans); \ + if (PyErr_Occurred()) \ + return NULL; \ + } + +uint32_t checksum(const char *data, size_t length) { + uint32_t crc32 = 0; + for (size_t i = 0; i < length; ++i) { + crc32 += static_cast(data[i]); + } + return crc32; +} + +// XPU3 runtime export this function no more, copy it from runtime repo. +// XPU Kernel type +enum kernel_type { + /// XPU Cluster kernel + KT_CLUSTER = 0, + /// XPU SDNN kernel + KT_SDCDNN = 1, +}; + +// Place of XPU kernel binary +enum kernel_place { + /// XPU kernel binary locates on CPU memory + KP_CPU = 0, + /// XPU kernel binary locates on XPU memory + KP_XPU = 1, +}; + +// XPU Kernel +struct xpu_kernel { + /// Combination of kernel place and type: + /// [31:16] kernel place, KP_CPU or KP_XPU + /// [15:0] kernel type, KT_CLUSTER or KT_SDCDNN + uint32_t type : 16; + uint32_t place : 16; + /// kernel code address on CPU Memory + uint64_t code_addr; + /// kernel code size in bytes + uint32_t code_byte_size; + /// initial program counter + uint32_t code_pc; + /// dword size kernel needed to transfer params + /// essentially, this is the count of param registers needed + uint32_t param_dword_size; + /// kernel code hash, for cache indexing + uint64_t hash; + /// (maybe mangled) function name + const char *name; + /// private data structure used by xpu runtime + void *rt_private; + uint64_t printf_buffer_offset; +}; + +static int __xpu_create_func(XPUFunc *pfunc, int type, uint64_t code_addr, + uint32_t code_bsz, uint32_t code_pc, + uint32_t param_dsz, uint64_t hash, + const char *name, bool on_xpu, + uint64_t printf_buf_offset) { + if (pfunc == NULL) { + return -XPUERR_INVALID_PARAM; + } + + struct xpu_kernel *kern = new struct xpu_kernel(); + // printf("create func @0x%" PRIx64 " hash=0x%" PRIx64 " name='%s'(%p)\n", + // code_addr, hash, (name == NULL) ? "NULL" : name, name); + + kern->type = type; + kern->place = (on_xpu) ? KP_XPU : KP_CPU; + kern->code_addr = code_addr; + kern->code_byte_size = code_bsz; + kern->code_pc = code_pc; + kern->param_dword_size = param_dsz; + kern->hash = hash; + kern->name = name; + // printf("printf_buf_offset = 0x%08lx\n", printf_buf_offset); + kern->printf_buffer_offset = printf_buf_offset; + + *pfunc = kern; + + return 0; +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + // Parse Input Data + const char *name; + const char *data; + Py_ssize_t data_size; + uint64_t printf_buf_offset; + if (!PyArg_ParseTuple(args, "ss#K", &name, &data, &data_size, + &printf_buf_offset)) { + return NULL; + } + + // Create XPUFunc + XPUFunc pfunc; + int type = KT_CLUSTER; + uint64_t code_addr = reinterpret_cast(data); + uint32_t code_byte_size = static_cast(data_size); + uint32_t code_pc = 0; + uint32_t param_dword_size = 0; + uint32_t hash = checksum(data, data_size); + bool on_xpu = false; + + XPU_CHECK(__xpu_create_func(&pfunc, type, code_addr, code_byte_size, code_pc, + param_dword_size, hash, name, on_xpu, + printf_buf_offset)); + + // Build Output Value + const void *mod = static_cast(data); + int32_t n_regs = 0; + int32_t n_spills = 0; + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)pfunc, n_regs, + n_spills); +} + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + + // create a struct to hold device properties + int max_shared_mem = 256 * 1024; // 256K for XPU2 + int max_num_regs = 0; + int warp_size = 1; + int sm_clock_rate = 0; + int mem_clock_rate = 0; + int mem_bus_width = 0; + + int multiprocessor_count = 0; + uint64_t num_cluster = 0; + XPU_CHECK(xpu_device_get_attr(&num_cluster, XPUATTR_NUM_CLUSTER, device_id)); + multiprocessor_count = num_cluster; + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided xpubin into XPU driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "xpu_utils", + NULL, // documentation + -1, // size global + ModuleMethods}; + +// Python C API Binding +PyMODINIT_FUNC PyInit_xpu_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/xpu/backend/driver.py b/third_party/xpu/backend/driver.py new file mode 100644 index 000000000..7c382c0d7 --- /dev/null +++ b/third_party/xpu/backend/driver.py @@ -0,0 +1,398 @@ +import os +import hashlib +import tempfile +import functools +import subprocess +from pathlib import Path + +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.compiler import GPUTarget +from triton.backends.driver import GPUDriver + +dirname = os.path.dirname(os.path.realpath(__file__)) +arch = int(os.environ.get('TRITON_XPU_ARCH', '3')) +include_dir = [os.path.join(dirname, f"xpu{arch}", "include")] +libdevice_dir = os.path.join(dirname, f"xpu{arch}", "lib") +library_dir = os.path.join(dirname, f"xpu{arch}", "so") +libraries = ['xpurt'] + + +def get_xpu_spec(xpu_arch, is_sdnn=False): + """ + `is_sdnn=False`: return a tuple represents (num_clusters, num_cores) + + `is_sdnn=True`: return a tuple represents (num_sdnns, num_cores) + """ + if xpu_arch == 2: + return (8, 8) if is_sdnn else (8, 64) + elif xpu_arch == 3: + return (12, 8) if is_sdnn else (12, 64) + elif xpu_arch == 4: + return (6, 8) if is_sdnn else (12, 64) + else: + raise RuntimeError(f"Unknown XPU architecture: {xpu_arch}") + + +@functools.lru_cache() +def library_dirs(): + return [libdevice_dir, library_dir] + + +def compile_module_from_src(src, name): + key = hashlib.sha256(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + # print(f"src_path = {src_path}") + so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# ------------------------ +# Launcher +# ------------------------ + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "void *" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def make_launcher(constants, signature, ids, xpu_arch): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", # TODO[dyq]: L? + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + def generate_argument_set_code(signature, constants, xpu_arch): + newline = "\n " + eightBytesTypes = ['void *', 'int64_t', 'uint64_t', 'double'] + lines = [] + for i, ty in signature.items(): + if i in constants: + continue + is_align_to_8 = (ty_to_cpp(ty) in eightBytesTypes) and (xpu_arch == 3 or xpu_arch == 4) + if is_align_to_8: + offset_align_to_8_line = "offset = alignSizeTo8Bytes(offset);" + lines.append(offset_align_to_8_line) + align_fn = "alignSizeTo8Bytes" if is_align_to_8 else "alignSizeTo4Bytes" + xpu_check_line = f"XPU_CHECK(xpu_launch_argument_set(&arg{i}, sizeof(arg{i}), offset));" + offset_increment_line = f"offset += {align_fn}(sizeof(arg{i}));" + lines.append(f"{xpu_check_line} {offset_increment_line}") + + return newline.join(lines) + + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiKKOOOO" + args_format + + args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + def read_data_to_hexstr(file_name): + if not file_name: + return "" + with open(file_name, 'rb') as f: + data = f.read() + hex_lines = [] + for i in range(0, len(data), 128): + chunk = data[i:i + 128] + hex_string = ','.join(f'0x{byte:02x}' for byte in chunk) + hex_lines.append(hex_string) + return ',\n '.join(hex_lines) + + # generate glue code + src = f""" +#include +#include +#include +#include + +// XPU_SPEC_START +static inline void xpuAssert(int code, const char *file, int line, + const char *call) +{{ + if (code != XPU_SUCCESS) + {{ + const char* err_msg = xpu_strerror(code); + char buf[1024] = {{0}}; + sprintf(buf, "%s:%d: %s -> %s(err_code: %d)", + file, line, call, err_msg, code); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); + PyErr_SetString(PyExc_RuntimeError, buf); + PyGILState_Release(gil_state); + }} +}} + +#define XPU_CHECK(ans) {{ xpuAssert((ans), __FILE__, __LINE__, #ans); }} + +static inline size_t alignSizeTo4Bytes(size_t size) {{ + return (size + 3) & ~3; +}} + +static inline size_t alignSizeTo8Bytes(size_t size) {{ + return (size + 7) & ~7; +}} + +enum {{ + kINVALID = 0, + kL3, + kGM +}}; + +static inline int xpu2PointerCheck(void *ptr) {{ + unsigned int ptr_high = (((unsigned long long) ptr) >> 32); + unsigned int ptr_low = (((unsigned long long) ptr)); + if (ptr_high == 0 && ptr_low >= 0xC0000000 && ptr_low <= 0xC3FFFFFF) {{ + return kL3; + }} + if (ptr_high >= 8 && ptr_high <= 15) {{ + return kGM; + }} + printf("ptr_high = %u\\n", ptr_high); + printf("ptr_low = %u\\n", ptr_low); + return kINVALID; +}} + +static inline int xpu3PointerCheck(void *ptr) {{ + // TODO: do it for XPU3. + return kGM; +}} + +static inline int xpu4PointerCheck(void *ptr) {{ + // TODO: do it for XPU4. + return kGM; +}} + +inline int min(int a, int b) {{ + return a < b ? a : b; +}} + +static void _launch(int gridX, int gridY, int gridZ, int clusterDimX, int clusterDimY, int clusterDimZ, XPUStream stream, XPUFunc function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + if (gridX*gridY*gridZ > 0) {{ + size_t offset = 0; + {generate_argument_set_code(signature, constants, xpu_arch)} + // printf("gridXYZ=[%d, %d, %d]\\n", gridX, gridY, gridZ); + int nclusters = {get_xpu_spec(xpu_arch)[0]}; + int ncores = {get_xpu_spec(xpu_arch)[1]}; + xpu_launch_argument_set(&gridX, sizeof(gridX), offset+0); + xpu_launch_argument_set(&gridY, sizeof(gridY), offset+4); + xpu_launch_argument_set(&gridZ, sizeof(gridZ), offset+8); + XPU_CHECK(xpu_launch_config(min(gridX*gridY*gridZ, nclusters), ncores)); // TODO[dyq]: should we set stream config + // xpu_kernel_debug_reset(); + XPU_CHECK(xpu_launch_async(function)); + }} +}} +// XPU_SPEC_END + +typedef struct _DevicePtrInfo {{ + void *dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = PyLong_AsVoidPtr(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = PyLong_AsVoidPtr(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + void *dev_ptr = PyLong_AsVoidPtr(ret); + if (xpu{xpu_arch}PointerCheck(dev_ptr) == kINVALID) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} + ptr_info.dev_ptr = dev_ptr; + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {args_list})) {{ + return NULL; + }} + + int clusterDimX, clusterDimY, clusterDimZ; + if (!PyArg_ParseTuple(kernel_metadata, \"iii\", &clusterDimX, &clusterDimY, &clusterDimZ)) {{ + PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); + return NULL; + }} + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + Py_BEGIN_ALLOW_THREADS; + _launch(gridX, gridY, gridZ, clusterDimX, clusterDimY, clusterDimZ, (XPUStream)_stream, (XPUFunc)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; + }} + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src + + +class XPUUtils(object): + + def __init__(self): + mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "xpu_utils") + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +class XPULauncher(object): + + def __init__(self, src, metadata): + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + + src = make_launcher(constants, signature, ids, metadata.xpu_arch) + mod = compile_module_from_src(src, "__triton_launcher") + self.launch = mod.launch + + def __call__(self, *args, **kwargs): + self.launch(*args, **kwargs) + + +class XPUDriver(GPUDriver): + + def __init__(self): + self.utils = XPUUtils() + self.launcher_cls = XPULauncher + super().__init__() + + @staticmethod + def is_active(): + return True + + def get_current_target(self): + arch = int(os.environ.get('TRITON_XPU_ARCH', '3')) + warp_size = 1 # we don't have warp + return GPUTarget("xpu", arch, warp_size) diff --git a/third_party/xpu/backend/xpu3-elfconv b/third_party/xpu/backend/xpu3-elfconv new file mode 100644 index 000000000..12ca50cdb --- /dev/null +++ b/third_party/xpu/backend/xpu3-elfconv @@ -0,0 +1,55 @@ +set -e +INPUT=$1 +OUTPUT=$2 +CLANG_PATH=$3 +ALIGNE_BITS=6 +cdnn_attribute=`$CLANG_PATH/llvm-readelf --arch-specific "$INPUT"` +cdnn_key="CDNN_Inst_Is_Used" +if [ "${cdnn_attribute/$cdnn_key}" = "$cdnn_attribute" ]; then + CDNN_INST_IS_USED=false + CDNN_Inst_Is_Used_FLAG="" + LM_SIZE=8 #32KB + LITTLE_CORE_SUFFIX='' +else + CDNN_INST_IS_USED=true + CDNN_Inst_Is_Used_FLAG="-DCDNN_Inst_Is_Used" + LM_SIZE=8 + LITTLE_CORE_SUFFIX='s' +fi +index=0 +crt_object_file="$OUTPUT.crt.o" +kernel_bin_file="$OUTPUT.bin" +for function_name in $($CLANG_PATH/llvm-readelf -s "$INPUT" | grep FUNC | grep GLOBAL | grep DEFAULT | awk '{print $8}'); do + # Read PARAM_SIZE from section .XPU.KERNEL_PARAM_SIZE. + sec_file=$(dirname $OUTPUT).${function_name}_PARAM_SIZE + $CLANG_PATH/llvm-objcopy "$INPUT" --dump-section=.XPU.KERNEL_PARAM_SIZE.${function_name}=${sec_file} + OLD_IFS="$IFS" + IFS=" " + sec_file_array=($($CLANG_PATH/xpu-xxd -e ${sec_file})) + IFS="$OLD_IFS" + PARAM_SIZE=$(printf %d $(echo ${sec_file_array[1]} | sed -r 's/0*([0-9])/0x\1/')) + # //XPU3 has 64bit arguments, need to align 8, and stack aligned to 64 constrain by datalayout + PARAM_SPACE=$(( ($PARAM_SIZE + 63)/64*64)) + rm -f ${sec_file} + # Prepare crt.o for each global function + $CLANG_PATH/clang --xpu-arch=xpu3 -c $CLANG_PATH/xpu3-crt.xpu -o "$crt_object_file" -DKERNEL_ENTRY="$function_name" -DALIGNE_BITS=${ALIGNE_BITS} -DPARAM_SIZE=${PARAM_SIZE} -DPARAM_SPACE=${PARAM_SPACE} -DLOCAL_MEM_SIZE=${LM_SIZE} ${CDNN_Inst_Is_Used_FLAG} -O2 --xpu-device-only + # Link crt.o $INPUT into a bin file + if [ -n "$DEVICE_LIB_PATH" ]; then + device_libs="" + for device_lib_name in $DEVICE_LIB_PATH/*.a + do + device_libs=$device_libs"$device_lib_name " + done + device_libs="${device_libs#"${device_libs%%[![:space:]]*}"}" + device_libs="${device_libs%"${device_libs##*[![:space:]]}"}" + $CLANG_PATH/ld.lld -gc-sections "$crt_object_file" "$INPUT" "${CLANG_PATH}/../lib/linux/libclang_rt.builtins-xpu3$LITTLE_CORE_SUFFIX.a" "${device_libs}" -T "$CLANG_PATH"/xpu-kernel.t -o "$kernel_bin_file" + elif [ -f "${CLANG_PATH}/../lib/linux/libclang_rt.builtins-xpu3.a" ]; then + $CLANG_PATH/ld.lld -gc-sections "$crt_object_file" "$INPUT" "${CLANG_PATH}/../lib/linux/libclang_rt.builtins-xpu3$LITTLE_CORE_SUFFIX.a" -T "$CLANG_PATH"/xpu-kernel.t -o "$kernel_bin_file" + else + $CLANG_PATH/ld.lld -gc-sections "$crt_object_file" "$INPUT" -T "$CLANG_PATH"/xpu-kernel.t -o "$kernel_bin_file" + fi + let index=index+1 + $CLANG_PATH/llvm-objcopy "$kernel_bin_file" --dump-section="KERNEL"="$OUTPUT" + printf_buffer_offset=`$CLANG_PATH/llvm-objdump -t "$kernel_bin_file" | grep .xpu_kernel_printf_buffer | cut -d' ' -f1` + echo "[TritonXPU] printf_buffer_offset = 0x$printf_buffer_offset" +done diff --git a/third_party/xpu/device/CMakeLists.txt b/third_party/xpu/device/CMakeLists.txt new file mode 100644 index 000000000..fc2e85934 --- /dev/null +++ b/third_party/xpu/device/CMakeLists.txt @@ -0,0 +1,18 @@ +function(add_xpu_libdevice OUTPUT SRC ARCH) + set(CLANG ${LLVM_TOOLS_BINARY_DIR}/clang) + + get_filename_component(OUTPUT_NAME ${OUTPUT} NAME_WE) + + add_custom_target( + libdevice-${ARCH} ALL + COMMAND ${CLANG} --xpu-arch=${ARCH} ${SRC} -c -emit-llvm --xpu-device-only -O3 -o ${OUTPUT} -std=c++11 -Wno-literal-range + DEPENDS ${SRC} + COMMENT "Building libdevice-${ARCH} ..." + VERBATIM + ) +endfunction() + +set(XPU_LIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/trigonometric.xpu) +set(XPU3_LIB_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../backend/xpu3/lib/) +file(MAKE_DIRECTORY ${XPU3_LIB_DIR}) +add_xpu_libdevice(${XPU3_LIB_DIR}/libdevice-xpu3.bc ${XPU_LIB_SRC} xpu3) diff --git a/third_party/xpu/device/trigonometric.xpu b/third_party/xpu/device/trigonometric.xpu new file mode 100644 index 000000000..8fa85d803 --- /dev/null +++ b/third_party/xpu/device/trigonometric.xpu @@ -0,0 +1,1446 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "xpu/kernel/xtdk.h" +#include "xpu/kernel/xtdk_math.h" +#include "xpu/kernel/xtdk_simd.h" +#include "xpu/kernel/xtdk_atomic_sm_xpu3.h" +#include "xpu/kernel/xtdk_io.h" + +#define inline __attribute__((always_inline)) +#define PI 3.14159265358979323846f +// Value of ln2 and ln10 +#define LN2 0.69314718055994530942f +#define LN10 2.30258509299404568402f + +#define HUGE_VALF 1e10000f +#define INFINITY HUGE_VALF + +#ifdef __GNUC__ +#define predict_true(x) __builtin_expect(!!(x), 1) +#define predict_false(x) __builtin_expect(x, 0) +#else +#define predict_true(x) (x) +#define predict_false(x) (x) +#endif + +// TODO[dyq]: replace asm by builtin for scheduling +namespace xpu { + +#ifdef __arch_xpu2__ +typedef struct { + unsigned short val; +} float16; +#endif + +#ifdef __arch_xpu3__ +typedef _Float16 float16; +#endif + + +__device__ inline void printTest(int customIdx) { + printf("[printTest_%d] cluster_id[%d] cid[%d/%d]\n", customIdx, cluster_id(), core_num(), core_id()); +} + +__device__ inline void printBool(bool b, int customIdx) { + int clusterId = cluster_id(); + int coreId = core_id(); + if (clusterId == 0 && coreId == 0) { + if (b) { + printf("[printBool_%d] cluster_id = %d, core_id = %d, local_x[0] = true \n", customIdx, clusterId, coreId); + } else { + printf("[printBool_%d] cluster_id = %d, core_id = %d, local_x[0] = false \n", customIdx, clusterId, coreId); + } + } +} + +__device__ inline void printFloat(float a, int customIdx) { + int clusterId = cluster_id(); + int coreId = core_id(); + if (clusterId == 0 && coreId == 0) { + printf("[printFloat_%d] cluster_id = %d, core_id = %d, local_x[0] = %f \n", customIdx, clusterId, coreId, a); + } +} + +__device__ inline void printInt(int a, int customIdx) { + int clusterId = cluster_id(); + int coreId = core_id(); + if (clusterId == 0 && coreId == 0) { + printf("[printInt_%d] cluster_id = %d, core_id = %d, local_x[0] = %d \n", customIdx, clusterId, coreId, a); + } +} + +__device__ inline void printMMaOp(int a, int b, int c) { + int clusterId = cluster_id(); + int coreId = core_id(); + if (clusterId == 0) { + printf("[printMMAOp] cluster_id = %d, core_id = %d, mma %d %d %d\n", clusterId, coreId, a, b, c); + } +} + +__device__ inline void printDsOp(int a, int b) { + int clusterId = cluster_id(); + int coreId = core_id(); + if (clusterId == 0) { + printf("[printDsOp] cluster_id = %d, core_id = %d, ds %d %d\n", clusterId, coreId, a, b); + } +} + +__device__ inline void printDmaOp(int a) { + int clusterId = cluster_id(); + int coreId = core_id(); + if (clusterId == 0) { + printf("[printDmaOp] cluster_id = %d, core_id = %d, dma %d\n", clusterId, coreId, a); + } +} + +__device__ inline void printDmaoOp(int a) { + int clusterId = cluster_id(); + int coreId = core_id(); + if (clusterId == 0) { + printf("[printDMAOop] cluster_id = %d, core_id = %d, dmao %d\n", clusterId, coreId, a); + } +} + +__device__ void inline fp16tofp32(const float16* x, float* y, int len) { + int start = (len - 1) / 32 * 32; + x = x + start; + y = y + start; + for (int i = start; i >= 0; i -= 32) { + float16x32_t X; + float32x16_t X_l; + float32x16_t X_h; + __asm__ __volatile__("vload.mz %0{mr1}, 0(%1)":"=&v"(X):"r"(x)); + __asm__ __volatile__("vfp162float_l.rn %0, %1":"=&v"(X_l):"v"(X)); + __asm__ __volatile__("vfp162float_h.rn %0, %1":"=&v"(X_h):"v"(X)); + __asm__ __volatile__("vstore_mask16.mz %0{mr1}, 0(%1)"::"v"(X_h), "r"(y + 16)); + __asm__ __volatile__("vstore_mask16.mz %0{mr1}, 0(%1)"::"v"(X_l), "r"(y)); + x -= 32; + y -= 32; + } + mfence_lm(); +} + + +__device__ void inline fp32tofp16(const float* x, float16* y, int len) { + for (int i = 0; i < len; i += 32) { + float32x16_t Y_h = __builtin_xpu2_vload_mask16_mr1(x + 16, 0); + float32x16_t Y_l = __builtin_xpu2_vload_mask16_mr1(x, 0); + __asm__ __volatile__("vfloat2fp16_l.rn vr0, %0\t\n" + "vfloat2fp16_h.rn vr0, %1\t\n" + "vstore.mz vr0{mr1}, 0(%2)" + ::"v"(Y_l), "v"(Y_h), "r"(y):"vr0"); + x += 32; + y += 32; + } + mfence_lm(); +} + +static __device__ inline void taylor_sin(float *C1, float *C3, float *C5, float *C7, float *C9) { + *C1 = 1; + *C3 = -0.16666666666666666; + *C5 = 0.008333333333333333; + *C7 = -0.0001984126984126984; + *C9 = 2.7557319223985893e-06; +} + +static __device__ inline void taylor_arcsin(float *C1, float *C3, float *C5, float *C7, float *C9) { + *C1 = 1; + *C3 = 1.0f / 6.0f; + *C5 = 3.0f / 40.0f; + *C7 = 5.0f / 112.0f; + *C9 = 35.0f / 1152.0f; +} + +// standardized input into [-PI, PI] interval +static __device__ inline void translation_sin(float *input) { + float tmp; + tmp = *input; + int factor = int(*input / PI); + tmp = tmp - factor * PI; + if (factor % 2 != 0){ + tmp = -tmp; + } + *input = tmp; +} + +__device__ inline float nearbyint(float input) { + return rint(input); +} + +__device__ inline float rint(float __x) { return __builtin_rintf(__x); } + +__device__ inline float sinf(float input) { + float C1, C3, C5, C7, C9; + taylor_sin(&C1, &C3, &C5, &C7, &C9); + + translation_sin(&input); + + // make the input exist in the [- PI/2, PI/2] interval, + // taylor expansion is more effective for this part + if (input > (PI / 2)) { + input = PI - input; + } else if (input < (-PI / 2)) { + input = -PI - input; + } + + // taylor expansion process + float tmp0 = input * input; + float tmp1 = tmp0 * input; + float tmp2 = tmp1 * tmp0; + float tmp3 = tmp2 * tmp0; + float tmp4 = tmp3 * tmp0; + float result = C1 * input + C3 * tmp1 + C5 * tmp2 + C7 * tmp3 + C9 * tmp4; + + return result; +} + + +// 把不在 [-pi, pi]的数map到[-pi, pi]区间内,len通常=32 +static __device__ inline void translation_sin(float* input, int len) { + for (int k = 0; k < len; k++) { + float tmp; + tmp = *(input + k); + int factor = int(*(input + k) / PI); + tmp = tmp - factor * PI; //取余数 + if (factor % 2 != 0) { + tmp = -tmp; + } + *(input + k) = tmp; + } +} + +__device__ inline float32x16_t vsinf(float32x16_t input) { + float C1, C3, C5, C7, C9; + taylor_sin(&C9, &C7, &C5, &C3, &C1); + + // 把转成float32的数在LM进行处理,使得处于区间[-pi, pi] + __simd__ float tmp_lm[16]; + vstore_lm_float32x16(tmp_lm, input); + __asm__("mfence{lm}"); + translation_sin(tmp_lm, 16); + __asm__("mfence{lm}"); + + float32x16_t tmp1, tmp2, v0l; + v0l = vload_lm_float32x16(tmp_lm); + + // 如果x>pi/2, x=pi-x + // 如果x< -pi/2, x= -pi -x + // 使得所有数都在[-pi/2, pi/2],泰勒展开对这部分比较有效 + int mask = svle_float32x16(PI / 2, v0l); + tmp1 = svsub_float32x16_mz(PI, v0l, mask); + tmp2 = svadd_float32x16_mz(0, v0l, ~mask); + v0l = vvadd_float32x16(tmp1, tmp2); + tmp1 = svmul_float32x16(-1, v0l); + mask = svlt_float32x16(PI / 2, tmp1); + tmp1 = svsub_float32x16_mz(-PI, v0l, mask); + tmp2 = svadd_float32x16_mz(0, v0l, ~mask); + v0l = vvadd_float32x16(tmp1, tmp2); + + //泰勒展开过程 + tmp1 = vvmul_float32x16(v0l, v0l); // tmp0 = a*a + tmp2 = svmul_float32x16(C1, tmp1); // C1*tmp0 + tmp2 = svadd_float32x16(C3, tmp2); // C3 + C1*tmp0 + tmp2 = vvmul_float32x16(tmp2, tmp1); // C3*tmp0 + C1*tmp0*tmp0 + tmp2 = svadd_float32x16(C5, tmp2); // C5 + C3*tmp0 + C1*tmp0*tmp0 + tmp2 = vvmul_float32x16(tmp2, tmp1); // C5*tmp0 + C3*tmp0*tmp0 + C1*tmp0*tmp0*tmp0 + tmp2 = svadd_float32x16(C7, tmp2); // C7 + C5*tmp0 + C3*tmp0*tmp0 + C1*tmp0*tmp0*tmp0 + tmp2 = vvmul_float32x16(tmp2, tmp1); // C7*tmp0 + C5*tmp0 + C3*tmp0*tmp0 + C1*tmp0*tmp0*tmp0 + tmp2 = svadd_float32x16(C9, tmp2); + v0l = vvmul_float32x16(tmp2, v0l); + + return v0l; +} + +__device__ inline float16x32_t vsinf(float16x32_t input) { + float32x16_t vl_input = vfp162float_l(input); + float32x16_t vh_input = vfp162float_h(input); + float32x16_t vl_sin = vsinf(vl_input); + float32x16_t vh_sin = vsinf(vh_input); + float16x32_t v_sin = vfloat2fp16_lh(vl_sin, vh_sin); + return v_sin; +} + +// 先把cos函数平移成sin函数,再把不在 [-pi, pi]的数map到[-pi, pi]区间内,len通常=32 +static __device__ inline void translation_cos(float* input, int len) { + for (int k = 0; k < len; k++) { + float tmp = *(input + k) + PI / 2; //平移 + int factor = int(tmp / PI); + tmp = tmp - factor * PI; + if (factor % 2 != 0) { + tmp = -tmp; + } + *(input + k) = tmp; + } +} + +__device__ inline float32x16_t vcosf(float32x16_t input) { + float C1, C3, C5, C7, C9; + taylor_sin(&C9, &C7, &C5, &C3, &C1); + + // 把转成float32的数在LM进行处理,使得处于区间[-pi, pi] + __simd__ float tmp_lm[16]; + vstore_lm_float32x16(tmp_lm, input); + __asm__("mfence{lm}"); + translation_cos(tmp_lm, 16); + __asm__("mfence{lm}"); + + float32x16_t tmp1, tmp2, v0l; + v0l = vload_lm_float32x16(tmp_lm); + + // 如果x>pi/2, x=pi-x + // 如果x< -pi/2, x= -pi -x + // 使得所有数都在[-pi/2, pi/2],泰勒展开对这部分比较有效 + int mask = svle_float32x16(PI / 2, v0l); + tmp1 = svsub_float32x16_mz(PI, v0l, mask); + tmp2 = svadd_float32x16_mz(0, v0l, ~mask); + v0l = vvadd_float32x16(tmp1, tmp2); + tmp1 = svmul_float32x16(-1, v0l); + mask = svlt_float32x16(PI / 2, tmp1); + tmp1 = svsub_float32x16_mz(-PI, v0l, mask); + tmp2 = svadd_float32x16_mz(0, v0l, ~mask); + v0l = vvadd_float32x16(tmp1, tmp2); + + //泰勒展开过程 + tmp1 = vvmul_float32x16(v0l, v0l); + tmp2 = svmul_float32x16(C1, tmp1); + tmp2 = svadd_float32x16(C3, tmp2); + tmp2 = vvmul_float32x16(tmp2, tmp1); + tmp2 = svadd_float32x16(C5, tmp2); + tmp2 = vvmul_float32x16(tmp2, tmp1); + tmp2 = svadd_float32x16(C7, tmp2); + tmp2 = vvmul_float32x16(tmp2, tmp1); + tmp2 = svadd_float32x16(C9, tmp2); + v0l = vvmul_float32x16(tmp2, v0l); + + return v0l; +} + +__device__ inline float16x32_t vcosf(float16x32_t input) { + float32x16_t vl_input = vfp162float_l(input); + float32x16_t vh_input = vfp162float_h(input); + float32x16_t vl_cos = vcosf(vl_input); + float32x16_t vh_cos = vcosf(vh_input); + float16x32_t v_cos = vfloat2fp16_lh(vl_cos, vh_cos); + return v_cos; +} + +// Translate cos into sin, convert to [-PI, PI] +static __device__ inline void translation_cos(float* input) { + float tmp; + tmp = *input + PI / 2; + int factor = int(tmp / PI); + tmp = tmp - factor * PI; + if (factor % 2 != 0){ + tmp = -tmp; + } + *input = tmp; +} + +extern __device__ inline float cosf(float input) { + float C1, C3, C5, C7, C9; + taylor_sin(&C1, &C3, &C5, &C7, &C9); + + translation_cos(&input); + + // make the input exist in the [-PI/2, PI/2] + if (input > (PI / 2)) { + input = PI - input; + } else if (input < (-PI / 2)) { + input = -PI - input; + } + + // taylor expansion process + float tmp0 = input * input; + float tmp1 = tmp0 * input; + float tmp2 = tmp1 * tmp0; + float tmp3 = tmp2 * tmp0; + float tmp4 = tmp3 * tmp0; + float result = C1 * input + C3 * tmp1 + C5 * tmp2 + C7 * tmp3 + C9 * tmp4; + + return result; +} + + + +static __device__ inline void translation_tan(float* input) { + float tmp = *input; + int factor = int(tmp / (PI/2)); + tmp = tmp - factor * (PI/2); + if (factor % 2 != 0) { + if (factor > 0) { + tmp = tmp - (PI/2); + } else if (factor < 0) { + tmp = tmp + (PI/2); + } + } + *input = tmp; +} + +__device__ inline float tanf(float input) { + float C1, C3, C5, C7, C9; + taylor_sin(&C1, &C3, &C5, &C7, &C9); + + translation_tan(&input); + + // sin + float tmp0 = input * input; + float tmp1 = tmp0 * input; + float tmp2 = tmp1 * tmp0; + float tmp3 = tmp2 * tmp0; + float tmp4 = tmp3 * tmp0; + float result_sin = C1 * input + C3 * tmp1 + C5 * tmp2 + C7 * tmp3 + C9 * tmp4; + + // cos + input = input + PI/2; + if (input > (PI/2)) { + input = PI - input; // sin (pi-x) = sin x + } + + tmp0 = input * input; + tmp1 = tmp0 * input; + tmp2 = tmp1 * tmp0; + tmp3 = tmp2 * tmp0; + tmp4 = tmp3 * tmp0; + float result_cos = C1 * input + C3 * tmp1 + C5 * tmp2 + C7 * tmp3 + C9 * tmp4; + + return result_sin/result_cos; +} + +static __device__ inline void translation_arcsin(float* input) { + float tmp; + tmp = *input; + if (tmp < 0) {tmp = -tmp;} + if (tmp > 0.62f){ + tmp = sqrt(1.0f - tmp * tmp); + } + *input = tmp; +} + +// There is no feedback when the input is not [- 1, 1] +__device__ inline float asinf(float input) { + float C1, C3, C5, C7, C9; + taylor_arcsin(&C1, &C3, &C5, &C7, &C9); + + float input_ori = input; + translation_arcsin(&input); // The interval becomes [0, 0.62] + + float tmp0 = input * input; + float tmp1 = tmp0 * input; + float tmp2 = tmp1 * tmp0; + float tmp3 = tmp2 * tmp0; + float tmp4 = tmp3 * tmp0; + float result = C1 * input + C3 * tmp1 + C5 * tmp2 + C7 * tmp3 + C9 * tmp4; + + if (input_ori > 0.62f || input_ori < -0.62f) { + result = PI/2 - result; + } + if (input_ori < 0) { + result = - result; + } + + return result; +} + +__device__ inline float acosf(float input) { + float result = PI/2 - asinf(input); + return result; +} + +__device__ inline float atanf(float input) { + input = input / sqrt(1.0f + input * input); + float result = asinf(input); + return result; +} + +__device__ inline float atan2f(float input, float other) { + if (other > 0) { + return atanf(input / other); + } else if (other < 0 && input >= 0) { + return atanf(input / other) + PI; + } else if (other < 0 && input < 0) { + return atanf(input / other) - PI; + } else if (other == 0 && input > 0) { + return PI / 2; + } else if (other == 0 && input < 0) { + return -PI / 2; + } else { + return 0; + } +} + +__device__ inline float sinhf(float input) { + float tmp = exp(input); + float result = (tmp-1.0f/tmp)/2; + return result; +} + +__device__ inline float coshf(float input) { + float tmp = exp(input); + float result = (tmp+1.0f/tmp)/2; + return result; +} + +__device__ inline float tanhf(float input) { + float tmp = exp(input); + float result = (tmp-1.0f/tmp)/(tmp+1.0f/tmp); + return result; +} + +__device__ inline float32x16_t vtanhf(float32x16_t input) { + int len = 16; + __simd__ float input_lm[len]; + __simd__ float exp_lm[len]; + __simd__ float inv_exp_lm[len]; + vstore_lm_float32x16(input_lm, input); + mfence_lm(); + for (int k = 0; k < len; k++) { + float input_tmp = *(input_lm + k); + float exp_tmp = exp(input_tmp); + float inv_exp_tmp = 1 / exp_tmp; + *(exp_lm + k) = exp_tmp; + *(inv_exp_lm + k) = inv_exp_tmp; + } + mfence_lm(); + float32x16_t _exp = vload_lm_float32x16(exp_lm); + float32x16_t _inv_exp = vload_lm_float32x16(inv_exp_lm); + float32x16_t _sub = vvsub_float32x16(_exp, _inv_exp); + float32x16_t _add = vvadd_float32x16(_exp, _inv_exp); + vstore_lm_float32x16(exp_lm, _sub); + vstore_lm_float32x16(inv_exp_lm, _add); + mfence_lm(); + for (int k = 0; k < len; k++) { + float sub_tmp = *(exp_lm + k); + float add_tmp = *(inv_exp_lm + k); + float res_tmp = sub_tmp / add_tmp; + *(input_lm + k) = res_tmp; + } + mfence_lm(); + float32x16_t res = vload_lm_float32x16(input_lm); + return res; +} + +__device__ inline float16x32_t hvtanh(float16x32_t input) { + float32x16_t vl_input = vfp162float_l(input); + float32x16_t vh_input = vfp162float_h(input); + float32x16_t vl_tanh = vcosf(vl_input); + float32x16_t vh_tanh = vcosf(vh_input); + float16x32_t v_tanh = vfloat2fp16_lh(vl_tanh, vh_tanh); + return v_tanh; +} + +__device__ inline float asinhf(float input) { + float result = log(input + sqrt(1.0f + input * input)); + return result; +} + +__device__ inline float acoshf(float input) { + float result = log(input + sqrt(-1.0f + input * input)); + return result; +} + +__device__ inline float atanhf(float input) { + float result = log((1.0f + input)/(1.0f - input)) / 2; + return result; +} + +// other +__device__ inline float rsqrtf(float input) { + return 1.0f / sqrt(input); +} + +__device__ inline double rsqrtf(double input) { + return 1.0f / sqrt(input); +} + +__device__ inline float __fsqrt_rn(float __a) { + float ret; + asm volatile("sqrt.f.rn %0, %1" : "=r"(ret) : "r"(__a)); + return ret; +} + +__device__ inline double __dsqrt_rn(double __a) { + double dsqrt = sqrt(__a); + double rounded_dsqrt = static_cast(dsqrt); + return rounded_dsqrt; +} + +__device__ inline float pow(float input1, float input2) { + float ret; + float mul_res; + float log_res; + if (input1 > 0.0f) { + log_res = __builtin_xpu_log2f(input1); + mul_res = __builtin_xpu_mulf(log_res, input2); + ret = __builtin_xpu_exp2f(mul_res); + } else if ((input1 < 0.0f) && (rint(input2) != input2)) { // intput1 is negative and input2 is not integer + ret = __builtin_nanf (""); + } else if (input2 == 0.0f) { + ret = 1.0f; + } else if (input1 == 0.0f && input2 > 0.0f) { + ret = 0.0f; + } else if (input1 == 0.0f && input2 < 0.0f) { + ret = __builtin_inff(); + } else if (rint(0.5f * input2) == (0.5f * input2)) { // input2 is even + log_res = __builtin_xpu_log2f(-input1); + mul_res = __builtin_xpu_mulf(log_res, input2); + ret = __builtin_xpu_exp2f(mul_res); + } else { + log_res = __builtin_xpu_log2f(-input1); + mul_res = __builtin_xpu_mulf(log_res, input2); + ret = -__builtin_xpu_exp2f(mul_res); + } + return ret; +} + +__device__ inline int32_t ffs(int32_t x) { + if (x == 0) return 0; + + int32_t position = 1; + + // Check lower half (16 bits) + if ((x & 0x0000FFFF) == 0) { + x >>= 16; + position += 16; + } + // Check lower 8 bits of the current half + if ((x & 0x000000FF) == 0) { + x >>= 8; + position += 8; + } + // Check lower 4 bits of the current quarter + if ((x & 0x0000000F) == 0) { + x >>= 4; + position += 4; + } + // Check lower 2 bits of the current nibble + if ((x & 0x00000003) == 0) { + x >>= 2; + position += 2; + } + // Check the lowest bit of the current pair + if ((x & 0x00000001) == 0) { + position += 1; + } + + return position; +} + +__device__ inline float log1pf(float input) { + return log(input + 1.0f); +} + +__device__ inline float expm1f(float input) { + return exp(input) - 1.0f; +} + +__device__ inline uint64_t asuint64(double f) { + union { + double _f; + uint64_t _i; + } u; + u._f = f; + return u._i; +} + +/* Top 12 bits of a double (sign and exponent bits). */ +__device__ inline uint32_t top12(double x) { + return asuint64(x) >> 52; +} + +__device__ inline uint32_t asuint(float f) { + union { + float _f; + uint32_t _i; + } u; + u._f = f; + return u._i; +} + +__device__ inline double eval_as_double(double x) { + double y = x; + return y; +} + +__device__ inline double asdouble(uint64_t i) { + union { + uint64_t _i; + double _f; + } u; + u._i = i; + return u._f; +} + +__device__ inline float eval_as_float(float x) { + float y = x; + return y; +} + +__device__ inline float exp2f(float x) { + int N = 32; + uint32_t abstop; + uint64_t ki, t; + double kd, xd, z, r, r2, y, s; + uint64_t T[32] = { + 0x3ff0000000000000, 0x3fefd9b0d3158574, 0x3fefb5586cf9890f, 0x3fef9301d0125b51, 0x3fef72b83c7d517b, + 0x3fef54873168b9aa, 0x3fef387a6e756238, 0x3fef1e9df51fdee1, 0x3fef06fe0a31b715, 0x3feef1a7373aa9cb, + 0x3feedea64c123422, 0x3feece086061892d, 0x3feebfdad5362a27, 0x3feeb42b569d4f82, 0x3feeab07dd485429, + 0x3feea47eb03a5585, 0x3feea09e667f3bcd, 0x3fee9f75e8ec5f74, 0x3feea11473eb0187, 0x3feea589994cce13, + 0x3feeace5422aa0db, 0x3feeb737b0cdc5e5, 0x3feec49182a3f090, 0x3feed503b23e255d, 0x3feee89f995ad3ad, + 0x3feeff76f2fb5e47, 0x3fef199bdd85529c, 0x3fef3720dcef9069, 0x3fef5818dcfba487, 0x3fef7c97337b9b5f, + 0x3fefa4afa2a490da, 0x3fefd0765b6e4540, + }; + double SHIFT = 0x1.8p+52 / N; + xd = (double)x; + abstop = top12(x) & 0x7ff; + if (predict_false(abstop >= top12(128.0f))) { + /* |x| >= 128 or x is nan. */ + if (asuint(x) == asuint(-INFINITY)) + return 0.0f; + if (abstop >= top12(INFINITY)) + return x + x; + if (x > 0.0f) + return 0; + if (x <= -150.0f) + return 0; + } + + /* x = k/N + r with r in [-1/(2N), 1/(2N)] and int k. */ + kd = eval_as_double(xd + SHIFT); + ki = asuint64(kd); + kd -= SHIFT; /* k/N for int k. */ + r = xd - kd; + + /* exp2(x) = 2^(k/N) * 2^r ~= s * (C0*r^3 + C1*r^2 + C2*r + 1) */ + t = T[ki % N]; + t += ki << (52 - 5); + s = asdouble(t); + z = 0x1.c6af84b912394p-5 * r + 0x1.ebfce50fac4f3p-3; + r2 = r * r; + y = 0x1.62e42ff0c52d6p-1 * r + 1; + y = z * r2 + y; + y = y * s; + return eval_as_float(y); +} + +__device__ inline float roundf(float input) { + return (input > 0.0) ? floor(input + 0.5) : ceil(input - 0.5); +} + +__device__ inline float log2f(float input) { + return log(input) / LN2; +} + +__device__ inline float log10f(float input) { + return log(input) / LN10; +} + +__device__ inline float xpu_sqrt(float input) { + return sqrt(input); +} + +__device__ inline double xpu_sqrt(double input) { + return sqrt(input); +} + +__device__ inline float xpu_floor(float input) { + return floor(input); +} + +__device__ inline double xpu_floor(double input) { + return floor(input); +} + +__device__ inline float xpu_ceil(float input) { + return ceil(input); +} + +__device__ inline double xpu_ceil(double input) { + return ceil(input); +} + +__device__ inline int32_t xpu_min(int32_t a, int32_t b) { + return min(a, b); +} + +__device__ inline uint32_t xpu_min(uint32_t a, uint32_t b) { + return min(a, b); +} + +__device__ inline int64_t xpu_min(int64_t a, int64_t b) { + return min(a, b); +} + +__device__ inline uint64_t xpu_min(uint64_t a, uint64_t b) { + return min(a, b); +} + +__device__ inline float xpu_min(float a, float b) { + return min(a, b); +} + +__device__ inline double xpu_min(double a, double b) { + return min(a, b); +} + +__device__ inline int32_t xpu_max(int32_t a, int32_t b) { + return max(a, b); +} + +__device__ inline uint32_t xpu_max(uint32_t a, uint32_t b) { + return max(a, b); +} + +__device__ inline int64_t xpu_max(int64_t a, int64_t b) { + return max(a, b); +} + +__device__ inline uint64_t xpu_max(uint64_t a, uint64_t b) { + return max(a, b); +} + +__device__ inline float xpu_max(float a, float b) { + return max(a, b); +} + +__device__ inline double xpu_max(double a, double b) { + return max(a, b); +} + +__device__ inline float fma(float x, float y, float z) { + return x * y + z; +} + +__device__ inline float erf(float x) { + float a1 = 0.254829592; + float a2 = -0.284496736; + float a3 = 1.421413741; + float a4 = -1.453152027; + float a5 = 1.061405429; + float p = 0.3275911; + + int sign = 1; + if (x < 0) { + sign = -1; + } + x = fabs(x); + + float t = 1.0 / (1.0 + p * x); + float y = (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t; + float result = 1 - (y * exp(-x * x)); + return sign * result; +} + +__device__ inline float erfc(float x) { + return 1.0f - erf(x); +} + + __device__ inline float erfinv(float x) { + if (x < -1.0f || x > 1.0f) { + return __builtin_nanf(""); + } else if (x == 1.0f) { + return INFINITY; + } else if (x == -1.0f) { + return -INFINITY; + } + + float a0 = 1.1975323115670912564578e0; + float a1 = 4.7072688112383978012285e1; + float a2 = 6.9706266534389598238465e2; + float a3 = 4.8548868893843886794648e3; + float a4 = 1.6235862515167575384252e4; + float a5 = 2.3782041382114385731252e4; + float a6 = 1.1819493347062294404278e4; + float a7 = 8.8709406962545514830200e2; + + float b0 = 1.0000000000000000000e0; + float b1 = 4.2313330701600911252e1; + float b2 = 6.8718700749205790830e2; + float b3 = 5.3941960214247511077e3; + float b4 = 2.1213794301586595867e4; + float b5 = 3.9307895800092710610e4; + float b6 = 2.8729085735721942674e4; + float b7 = 5.2264952788528545610e3; + + float c0 = 1.42343711074968357734e0; + float c1 = 4.63033784615654529590e0; + float c2 = 5.76949722146069140550e0; + float c3 = 3.64784832476320460504e0; + float c4 = 1.27045825245236838258e0; + float c5 = 2.41780725177450611770e-1; + float c6 = 2.27238449892691845833e-2; + float c7 = 7.74545014278341407640e-4; + + float d0 = 1.4142135623730950488016887e0; + float d1 = 2.9036514445419946173133295e0; + float d2 = 2.3707661626024532365971225e0; + float d3 = 9.7547832001787427186894837e-1; + float d4 = 2.0945065210512749128288442e-1; + float d5 = 2.1494160384252876777097297e-2; + float d6 = 7.7441459065157709165577218e-4; + float d7 = 1.4859850019840355905497876e-9; + + float e0 = 6.65790464350110377720e0; + float e1 = 5.46378491116411436990e0; + float e2 = 1.78482653991729133580e0; + float e3 = 2.96560571828504891230e-1; + float e4 = 2.65321895265761230930e-2; + float e5 = 1.24266094738807843860e-3; + float e6 = 2.71155556874348757815e-5; + float e7 = 2.01033439929228813265e-7; + + float f0 = 1.414213562373095048801689e0; + float f1 = 8.482908416595164588112026e-1; + float f2 = 1.936480946950659106176712e-1; + float f3 = 2.103693768272068968719679e-2; + float f4 = 1.112800997078859844711555e-3; + float f5 = 2.611088405080593625138020e-5; + float f6 = 2.010321207683943062279931e-7; + float f7 = 2.891024605872965461538222e-15; + + float abs_x = fabs(x); + + if (abs_x <= 0.85f) { + float r = 0.180625f - 0.25f * x * x; + float num = (((((((a7 * r + a6) * r + a5) * r + a4) * r + a3) * r + a2) * r + a1) * r + a0); + float den = (((((((b7 * r + b6) * r + b5) * r + b4) * r + b3) * r + b2) * r + b1) * r + b0); + return x * num / den; + } + + float r = sqrt(LN2 - log(1.0f - abs_x)); + + float num, den; + if (r <= 5.0f) { + r = r - 1.6f; + num = (((((((c7 * r + c6) * r + c5) * r + c4) * r + c3) * r + c2) * r + c1) * r + c0); + den = (((((((d7 * r + d6) * r + d5) * r + d4) * r + d3) * r + d2) * r + d1) * r + d0); + } else { + r = r - 5.0f; + num = (((((((e7 * r + e6) * r + e5) * r + e4) * r + e3) * r + e2) * r + e1) * r + e0); + den = (((((((f7 * r + f6) * r + f5) * r + f4) * r + f3) * r + f2) * r + f1) * r + f0); + } + + float sign = x >= 0.0f ? 1.0f : -1.0f; + return num / den * sign; +} + +__device__ inline float32x16_t verf(float32x16_t vl) { + // 定义 erf 中使用的常数 + float a1 = 0.254829592; + float a2 = -0.284496736; + float a3 = 1.421413741; + float a4 = -1.453152027; + float a5 = 1.061405429; + float p = 0.3275911; + + __simd__ float expx2[16]; + __simd__ float t[16]; + float32x16_t tmp1, tmp2, tmp; + int signl; + + // 计算 x 的平方 + tmp1 = vvmul_float32x16(vl, vl); + vstore_lm_float32x16(expx2, tmp1); + + // 计算 |x| + signl = svlt_float32x16(0, vl); // x>0 + tmp1 = svmul_float32x16_mz(1, vl, signl); + tmp2 = svmul_float32x16_mz(-1, vl, ~signl); + tmp = vvadd_float32x16(tmp1, tmp2); + vstore_lm_float32x16(t, tmp); + + // 计算 t = 1.0 / (1.0 + p * |x|) + mfence_lm(); + for (int k = 0; k < 16; k++) { + expx2[k] = exp(-expx2[k]); + t[k] = 1.0f / (1.0f + p * t[k]); + } + mfence_lm(); + + // 计算 erf 多项式 (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t + vl = vload_lm_float32x16(t); + tmp1 = svmul_float32x16(a5, vl); + tmp1 = svadd_float32x16(a4, tmp1); + tmp1 = vvmul_float32x16(tmp1, vl); + tmp1 = svadd_float32x16(a3, tmp1); + tmp1 = vvmul_float32x16(tmp1, vl); + tmp1 = svadd_float32x16(a2, tmp1); + tmp1 = vvmul_float32x16(tmp1, vl); + tmp1 = svadd_float32x16(a1, tmp1); + tmp1 = vvmul_float32x16(tmp1, vl); + + // 计算 *exp(-x^2) + vl = vload_lm_float32x16(expx2); + vl = vvmul_float32x16(vl, tmp1); + vl = svsub_float32x16(1, vl); + + // sign ? y : -y + tmp1 = svmul_float32x16_mz(1, vl, signl); + tmp2 = svmul_float32x16_mz(-1, vl, ~signl); + vl = vvadd_float32x16(tmp1, tmp2); + + return vl; +} + +#define CHAR_BIT 8 +typedef int32_t si_int; +typedef uint32_t su_int; +typedef int64_t di_int; +typedef uint64_t du_int; + +typedef union { + di_int all; + struct { + su_int low; + si_int high; // _YUGA_LITTLE_ENDIAN + } s; +} dwords; + +static __device__ inline di_int __muldsi3(su_int a, su_int b) { + dwords r; + const int bits_in_word_2 = (int)(sizeof(si_int) * CHAR_BIT) / 2; + const su_int lower_mask = (su_int)~0 >> bits_in_word_2; + r.s.low = (a & lower_mask) * (b & lower_mask); + su_int t = r.s.low >> bits_in_word_2; + r.s.low &= lower_mask; + t += (a >> bits_in_word_2) * (b & lower_mask); + r.s.low += (t & lower_mask) << bits_in_word_2; + r.s.high = t >> bits_in_word_2; + t = r.s.low >> bits_in_word_2; + r.s.low &= lower_mask; + t += (b >> bits_in_word_2) * (a & lower_mask); + r.s.low += (t & lower_mask) << bits_in_word_2; + r.s.high += t >> bits_in_word_2; + r.s.high += (a >> bits_in_word_2) * (b >> bits_in_word_2); + return r.all; +} + +static __device__ inline di_int __muldi3(di_int a, di_int b) { + dwords x; + x.all = a; + dwords y; + y.all = b; + dwords r; + r.all = __muldsi3(x.s.low, y.s.low); + r.s.high += x.s.high * y.s.low + x.s.low * y.s.high; + return r.all; +} + +__device__ inline int32_t mulhi (int32_t a, int32_t b) { + int64_t product = __muldi3((int64_t)a, (int64_t)b); + int32_t high_part = (int32_t)(product >> 32); + return high_part; +} + +__device__ inline uint32_t umulhi (uint32_t a, uint32_t b) { + int64_t product = __muldi3((int64_t)a, (int64_t)b); + uint32_t high_part = (uint32_t)(product >> 32); + return high_part; +} + + +__device__ inline uint32_t __FLOAT_BITS(float __f) { + union { + float __f; + uint32_t __i; + } __u; + __u.__f = __f; + return __u.__i; +} + +__device__ inline int32_t isnan(float x) { + return (__FLOAT_BITS(x) & 0x7fffffffU) > 0x7f800000U; +} + +__device__ inline int32_t isinf(float __a) { + uint32_t bits = *reinterpret_cast(&__a); + uint32_t exponentMask = 0x7F800000U; + uint32_t fractionMask = 0x007FFFFFU; + uint32_t exponent = bits & exponentMask; + uint32_t fraction = bits & fractionMask; + return (exponent == exponentMask) && (fraction == 0); +} + +__device__ inline int32x16_t visinf(float32x16_t vx0) { + vx0 = reinterpret_cast(svand_int32x16(0x7FFFFFFF, reinterpret_cast(vx0))); + int32x16_t mask0 = sveq_int32x16(0x7F800000, reinterpret_cast(vx0)); + return mask0; +} + +__device__ inline int32_t finitef(float __a) { + unsigned int *valueAsBits = reinterpret_cast(&__a); + unsigned int exponentMask = 0x7F800000U; + unsigned int fractionMask = 0x007FFFFFU; + unsigned int exponent = (*valueAsBits) & exponentMask; + unsigned int fraction = (*valueAsBits) & fractionMask; + return (exponent != exponentMask) || (fraction != 0) && (exponent == 0); +} + + +__device__ inline float __fdiv_rn(float __a, float __b) { + float ret; + asm volatile("div.f.rn %0, %1, %2" : "=r"(ret) : "r"(__a), "r"(__b)); + return ret; +} + +__device__ inline double __ddiv_rn(double __a, double __b) { + double div = __a / __b; + double rounded_div = static_cast(div); + return rounded_div; +} + +__device__ inline float __fdiv_rz(float __a, float __b) { + float ret; + asm volatile("div.f.rz %0, %1, %2" : "=r"(ret) : "r"(__a), "r"(__b)); + return ret; +} + + +__device__ inline double __ddiv_rz(double __a, double __b) { + double div = __a / __b; + double rounded_div = static_cast(div); + return rounded_div; +} + +__device__ inline float __fdiv_rd(float __a, float __b) { + float ret; + asm volatile("div.f.rd %0, %1, %2" : "=r"(ret) : "r"(__a), "r"(__b)); + return ret; +} + +__device__ inline double __ddiv_rd(double __a, double __b) { + double div = __a / __b; + double rounded_div = static_cast(div); + return rounded_div; +} + +__device__ inline float truncf(float input) { return __builtin_truncf(input); } + +__device__ inline double trunc(double x) { + union { + double f; + uint64_t i; + } u = {x}; + int e = (int)(u.i >> 52 & 0x7ff) - 0x3ff + 12; + uint64_t m; + if (e >= 52 + 12) + return x; + if (e < 12) + e = 1; + m = -1ULL >> e; + if ((u.i & m) == 0) + return x; + u.i &= ~m; + return u.f; +} + + +__device__ inline float xpu_trunc(float x, float y) { + float output_temp1 = x / y; + float z = floor(output_temp1); + float output_temp2 = fabs(z); + z = z + output_temp2; + z = z / 2; + + output_temp1 = -output_temp1; + output_temp1 = floor(output_temp1); + output_temp2 = fabs(output_temp1); + output_temp1 = output_temp1 + output_temp2; + output_temp1 = output_temp1 / 2; + output_temp1 = -output_temp1; + z = z + output_temp1; + return z; +} + +__device__ inline bool isnanf(float x) { return __builtin_isnan(x); } + +__device__ inline float fmodf(float x, float y) { + union { + float f; + unsigned i; + } ux = {x}, uy = {y}; + int ex = ux.i >> 23 & 0xff; + int ey = uy.i >> 23 & 0xff; + unsigned sx = ux.i & 0x80000000; + unsigned i; + unsigned uxi = ux.i; + + if (uy.i << 1 == 0 || isnanf(y) || ex == 0xff) + return (x * y) / (x * y); + if (uxi << 1 <= uy.i << 1) { + if (uxi << 1 == uy.i << 1) + return 0 * x; + return x; + } + + /* normalize x and y */ + if (!ex) { + for (i = uxi << 9; i >> 31 == 0; ex--, i <<= 1) + ; + uxi <<= -ex + 1; + } else { + uxi &= -1U >> 9; + uxi |= 1U << 23; + } + if (!ey) { + for (i = uy.i << 9; i >> 31 == 0; ey--, i <<= 1) + ; + uy.i <<= -ey + 1; + } else { + uy.i &= -1U >> 9; + uy.i |= 1U << 23; + } + + /* x mod y */ + for (; ex > ey; ex--) { + i = uxi - uy.i; + if (i >> 31 == 0) { + if (i == 0) + return 0 * x; + uxi = i; + } + uxi <<= 1; + } + i = uxi - uy.i; + if (i >> 31 == 0) { + if (i == 0) + return 0 * x; + uxi = i; + } + for (; uxi >> 23 == 0; uxi <<= 1, ex--) + ; + + /* scale result up */ + if (ex > 0) { + uxi -= 1U << 23; + uxi |= (unsigned)ex << 23; + } else { + uxi >>= -ex + 1; + } + uxi |= sx; + ux.i = uxi; + return ux.f; +} + + +__device__ inline int __signbitf(float a) { + return a < 0.f; +} + +// float16 +#ifdef __arch_xpu3__ +__device__ inline float fp16tofp32(float16 input) { + __simd__ float16 _input[32] = {input}; + __simd__ float _input_fp32[32]; + fp16tofp32(_input, _input_fp32, 1); + return _input_fp32[0]; +} + +__device__ inline float16 fp32tofp16(float input) { + __simd__ float _input[32] = {input}; + __simd__ float16 _input_fp16[32]; + fp32tofp16(_input, _input_fp16, 1); + return _input_fp16[0]; +} + +__device__ inline float16 hsin(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = sinf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hcos(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = cosf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 htan(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = tanf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hasin(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = asinf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hacos(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = acosf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 htanh(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = tanhf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hasinh(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = asinhf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hacosh(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = acoshf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hatanh(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = atanhf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hrsqrt(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = rsqrtf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hpow(float16 input1, float16 input2) { + float input1_fp32 = fp16tofp32(input1); + float input2_fp32 = fp16tofp32(input2); + float result_fp32 = pow(input1_fp32, input2_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hexpm1(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = expm1f(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hexp2(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = exp2f(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hlog2(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = log2f(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 hlog10(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = log10f(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 xpu_hfloor(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = xpu_floor(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 xpu_hceil(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = xpu_ceil(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 xpu_hmin(float16 input1, float16 input2) { + float input1_fp32 = fp16tofp32(input1); + float input2_fp32 = fp16tofp32(input2); + float result_fp32 = xpu_min(input1_fp32, input2_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 xpu_hmax(float16 input1, float16 input2) { + float input1_fp32 = fp16tofp32(input1); + float input2_fp32 = fp16tofp32(input2); + float result_fp32 = xpu_max(input1_fp32, input2_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline float16 herf(float16 input) { + float input_fp32 = fp16tofp32(input); + float result_fp32 = erf(input_fp32); + return fp32tofp16(result_fp32); +} + +__device__ inline int32_t hisnan(float16 input) { + float input_fp32 = fp16tofp32(input); + int32_t result_fp32 = isnan(input_fp32); + return result_fp32; +} + +__device__ inline int32_t hisin(float16 input) { + float input_fp32 = fp16tofp32(input); + int32_t result_fp32 = isinf(input_fp32); + return result_fp32; +} + +__device__ inline int32_t hfinite(float16 input) { + float input_fp32 = fp16tofp32(input); + int32_t result_fp32 = finitef(input_fp32); + return result_fp32; +} + + +// borrowed from xtrans include/xpu/kernel/atomic.h +#define GM_BASE_ADDR 0x4000000000 +#define NOP_TIME 5 + +__attribute__((weak)) __device__ int XTransSpinLock = 0; + +__attribute__((used)) __device__ void xtransLock() { + int nop_time = NOP_TIME; + // 使用amoadd获取返回值,如果不为0的话,则一直循环,直到获取到锁为止,在里面设置一些nop来减少竞争 + while (__builtin_xpu_amoadd(1, &XTransSpinLock) != 0) { + for (int i = 0; i < nop_time; i++) { + __asm __volatile__("nop"); + } + nop_time += NOP_TIME; + } +} + +__attribute__((used)) __device__ inline void xtransUnlock() { + // 释放锁就是将锁置为0 + __builtin_xpu_amoswap(0, &XTransSpinLock); +} + +__device__ float atomicAdd(__global_ptr__ float *address, float val) { + if ((long)address >= GM_BASE_ADDR) { + xtransLock(); + float x = *address; + *address = x + val; + xtransUnlock(); + return x; + } else { + ticket_lock_mix(); + float x = *address; + *address = x + val; + ticket_unlock_mix(); + return x; + } +} + +__device__ float16 atomicAdd(__global_ptr__ float16 *address, float16 val) { + if ((long)address >= GM_BASE_ADDR) { + xtransLock(); + float16 x = *address; + *address = __hadd(x, val); + xtransUnlock(); + return x; + } else { + ticket_lock_mix(); + float16 x = *address; + *address = __hadd(x, val); + ticket_unlock_mix(); + return x; + } +} + +#endif + +} // namespace xpu diff --git a/third_party/xpu/include/CMakeLists.txt b/third_party/xpu/include/CMakeLists.txt new file mode 100644 index 000000000..109c292fe --- /dev/null +++ b/third_party/xpu/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(triton) diff --git a/third_party/xpu/include/triton/Analysis/Alias.h b/third_party/xpu/include/triton/Analysis/Alias.h new file mode 100644 index 000000000..a06df5ae2 --- /dev/null +++ b/third_party/xpu/include/triton/Analysis/Alias.h @@ -0,0 +1,96 @@ +#ifndef TRITON_ANALYSIS_ALIAS_H +#define TRITON_ANALYSIS_ALIAS_H + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { + +class AliasInfo { +public: + AliasInfo() = default; + AliasInfo(Value value) { insert(value); } + + void insert(Value value) { allocs.insert(value); } + + const DenseSet &getAllocs() const { return allocs; } + + bool operator==(const AliasInfo &other) const { + return allocs == other.allocs; + } + + /// The pessimistic value state of a value without alias + static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) { + return AliasInfo(); + } + static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); } + + /// The union of both arguments + static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs); + + void print(raw_ostream &os) const { + llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); }); + } + +private: + /// The set of allocated values that are aliased by this lattice. + /// For now, we only consider aliased value produced by the following + /// situations: + /// 1. values returned by scf.yield + /// 2. block arguments in scf.for + /// Example: + /// alloc v1 alloc v2 + /// | | + /// |--------------| |------------| + /// scf.for v3 scf.for v4 scf.for v5 + /// | + /// scf.yield v6 + /// + /// v1's alloc [v1] + /// v2's alloc [v2] + /// v3's alloc [v1] + /// v4's alloc [v1, v2] + /// v5's alloc [v2] + /// v6's alloc [v1] + /// + /// Therefore, v1's liveness range is the union of v3, v4, and v6 + /// v2's liveness range is the union of v4 and v5. + DenseSet allocs; +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Alias Analysis +//===----------------------------------------------------------------------===// +class SharedMemoryAliasAnalysis + : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +public: + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::SparseForwardDataFlowAnalysis; + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + + /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. + /// Given two values, returns their aliasing behavior. + AliasResult alias(Value lhs, Value rhs); + + /// Returns the modify-reference behavior of `op` on `location`. + ModRefResult getModRef(Operation *op, Value location); + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, lattice->join( + AliasInfo::getPessimisticValueState(lattice->getPoint()))); + } + + /// Computes if the alloc set of the results are changed. + void + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALIAS_H diff --git a/third_party/xpu/include/triton/Analysis/Allocation.h b/third_party/xpu/include/triton/Analysis/Allocation.h new file mode 100644 index 000000000..a9e02b420 --- /dev/null +++ b/third_party/xpu/include/triton/Analysis/Allocation.h @@ -0,0 +1,258 @@ +#ifndef TRITON_ANALYSIS_ALLOCATION_H +#define TRITON_ANALYSIS_ALLOCATION_H + +#include "triton/Analysis/Utility.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include +#include + +namespace mlir { + +namespace triton { +class AllocationAnalysis; + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec); +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op); + +} // namespace triton + +/// Modified from llvm-15.0: llvm/ADT/AddressRanges.h +/// A class that represents an interval, specified using a start and an end +/// values: [Start, End). +template class Interval { +public: + Interval() {} + Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); } + T start() const { return Start; } + T end() const { return End; } + T size() const { return End - Start; } + bool contains(T Addr) const { return Start <= Addr && Addr < End; } + bool intersects(const Interval &R) const { + return Start < R.End && R.Start < End; + } + bool operator==(const Interval &R) const { + return Start == R.Start && End == R.End; + } + bool operator!=(const Interval &R) const { return !(*this == R); } + bool operator<(const Interval &R) const { + return std::make_pair(Start, End) < std::make_pair(R.Start, R.End); + } + +private: + T Start = std::numeric_limits::min(); + T End = std::numeric_limits::max(); +}; + +template Interval(T, T) -> Interval; + +class Allocation { +public: + /// A unique identifier for shared memory buffers + using BufferId = size_t; + using BufferIdSetT = DenseSet; + using FuncAllocMapT = CallGraph::FuncDataMapT; + + static constexpr BufferId InvalidBufferId = + std::numeric_limits::max(); + + Allocation() = default; + /// Creates a new Allocation analysis that computes the shared memory + /// information for all associated shared memory values. + explicit Allocation(Operation *operation) : operation(operation) {} + + /// Runs allocation analysis on the given top-level operation. + void run(FuncAllocMapT &funcAllocMap); + + /// Returns the operation this analysis was constructed from. + Operation *getOperation() const { return operation; } + + /// Returns the offset of the given buffer in the shared memory. + size_t getOffset(BufferId bufferId) const { + return bufferSet.at(bufferId).offset; + } + + /// Returns the size of the given buffer in the shared memory. + size_t getAllocatedSize(BufferId bufferId) const { + return bufferSet.at(bufferId).size; + } + + /// Returns the allocated interval of the given buffer. + Interval getAllocatedInterval(BufferId bufferId) const { + auto &buffer = bufferSet.at(bufferId); + return Interval(buffer.offset, buffer.offset + buffer.size); + } + + /// Returns the buffer id of the given value. + /// This interface only returns the allocated buffer id. + /// If you want to get all the buffer ids that are associated with the given + /// value, including alias buffers, use getBufferIds. + BufferId getBufferId(Value value) const { + if (valueBuffer.count(value)) { + return valueBuffer.lookup(value)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns all the buffer ids of the given value, including alias buffers. + BufferIdSetT getBufferIds(Value value) const { + BufferIdSetT bufferIds; + auto allocBufferId = getBufferId(value); + if (allocBufferId != InvalidBufferId) + bufferIds.insert(allocBufferId); + for (auto *buffer : aliasBuffer.lookup(value)) { + if (buffer->id != InvalidBufferId) + bufferIds.insert(buffer->id); + } + return bufferIds; + } + + /// Returns the scratch buffer id of the given value. + BufferId getBufferId(Operation *operation) const { + if (opScratch.count(operation)) { + return opScratch.lookup(operation)->id; + } else if (opVirtual.count(operation)) { + return opVirtual.lookup(operation)->id; + } else { + return InvalidBufferId; + } + } + + /// Returns if the given buffer is a virtual buffer. + bool isVirtualBuffer(BufferId bufferId) const { + return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual; + } + + /// Returns the size of total shared memory allocated + size_t getSharedMemorySize() const { return sharedMemorySize; } + +private: + /// A class that represents a shared memory buffer + struct BufferT { + /// Explicit: triton_gpu.local_alloc + /// Scratch: triton_gpu.convert_layout + /// Virtual: triton.call + enum class BufferKind { Explicit, Scratch, Virtual }; + + /// MT: thread-safe + inline static std::atomic nextId = 0; + + BufferKind kind; + BufferId id; + size_t size; + size_t alignment; + size_t offset; + + bool operator==(const BufferT &other) const { return id == other.id; } + bool operator<(const BufferT &other) const { return id < other.id; } + + BufferT() : BufferT(BufferKind::Explicit, 0) {} + BufferT(BufferKind kind, size_t size, size_t alignment = 4, + size_t offset = 0) + : kind(kind), id(nextId++), size(size), alignment(alignment), + offset(offset) {} + + size_t setOffsetAligned(size_t newOffset) { + return offset = llvm::alignTo(newOffset, alignment); + } + }; + + /// Op -> Scratch Buffer + using OpScratchMapT = DenseMap; + /// Value -> Explicit Buffer + using ValueBufferMapT = llvm::MapVector; + /// Value -> Alias Buffer + using AliasBufferMapT = llvm::MapVector>; + /// BufferId -> Buffer + using BufferSetT = std::map; + +private: + template + void addBuffer(KeyType &key, Args &&...args) { + auto buffer = BufferT(Kind, std::forward(args)...); + bufferSet[buffer.id] = std::move(buffer); + if constexpr (Kind == BufferT::BufferKind::Explicit) { + valueBuffer[key] = &bufferSet[buffer.id]; + } else if constexpr (Kind == BufferT::BufferKind::Virtual) { + opVirtual[key] = &bufferSet[buffer.id]; + } else { + opScratch[key] = &bufferSet[buffer.id]; + } + } + + void addAlias(Value value, Value alloc) { + aliasBuffer[value].insert(valueBuffer[alloc]); + } + +private: + Operation *operation = nullptr; + OpScratchMapT opScratch; + OpScratchMapT opVirtual; + ValueBufferMapT valueBuffer; + AliasBufferMapT aliasBuffer; + BufferSetT bufferSet; + size_t sharedMemorySize = 0; + + friend class triton::AllocationAnalysis; +}; + +/// Static analysis that computes the allocation of shared memory buffers +/// of the entire call graph. +/// The allocation is performed in a post-order walk of the call graph. +/// Each call op is treated like convert_layout that allocates a scratch buffer. +/// At each call, we compute the start offset of the scratch buffer and pass it +/// as an argument to the callee. +class ModuleAllocation : public CallGraph { +public: + using FuncOffsetMapT = DenseMap; + + explicit ModuleAllocation(ModuleOp moduleOp) + : CallGraph(moduleOp) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); + if (inserted) + iter->second.run(funcMap); + }); + } + + size_t getSharedMemorySize() { + size_t size = 0; + for (auto funcOp : getRoots()) { + auto *alloc = getFuncData(funcOp); + size = std::max(size, alloc->getSharedMemorySize()); + } + return size; + } + + size_t getSharedMemorySize(FunctionOpInterface funcOp) { + return getFuncData(funcOp)->getSharedMemorySize(); + } + + void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) { + sharedMemoryValue[funcOp] = value; + } + + Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) { + return sharedMemoryValue[funcOp]; + } + +private: + FuncOffsetMapT sharedMemoryValue; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_ALLOCATION_H diff --git a/third_party/xpu/include/triton/Analysis/AxisInfo.h b/third_party/xpu/include/triton/Analysis/AxisInfo.h new file mode 100644 index 000000000..22a7ed554 --- /dev/null +++ b/third_party/xpu/include/triton/Analysis/AxisInfo.h @@ -0,0 +1,215 @@ +#ifndef TRITON_ANALYSIS_AXISINFO_H +#define TRITON_ANALYSIS_AXISINFO_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include +#include + +namespace mlir::triton { + +//===----------------------------------------------------------------------===// +// AxisInfo +//===----------------------------------------------------------------------===// + +/// This lattice value represents known information on the axes of a lattice. +class AxisInfo { +public: + typedef SmallVector DimVectorT; + +public: + AxisInfo() : AxisInfo({}, {}, {}) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy) + : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} + + AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy, + std::optional constantValue) + : contiguity(contiguity), divisibility(divisibility), + constancy(constancy), constantValue(constantValue) { + assert(divisibility.size() == contiguity.size()); + assert(constancy.size() == contiguity.size()); + } + + // contiguity[d] is the length of the shortest sequence of contiguous integers + // along dimension d. + // + // If we have an array of N elements with a contiguity value C, then the array + // can be divided into a list of N/C sequences of C contiguous elements. + // Since we have N = 2^k, C must be a power of two. + // + // For example, the 2D array + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has contiguity [1, 4], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27], + // [18, 22, 26, 30], + // [19, 23, 27, 31]] + // + // has contiguity [2, 1]. + int64_t getContiguity(size_t dim) const { return contiguity[dim]; } + const DimVectorT &getContiguity() const { return contiguity; } + + // divisibility[d] is the largest power of two that divides the first element + // of all groups of length contiguity[d] along dimension d. + // + // For example, + // + // [[10, 11, 12, 13, 18, 19, 20, 21], + // [20, 21, 22, 23, 28, 29, 30, 31]] + // + // has divisibility [1, 2], and + // + // [[12, 16, 20, 24], + // [13, 17, 21, 25], + // [14, 18, 22, 26], + // [15, 19, 23, 27]] + // + // has divisibility [4, 1]. + // + // On the other hand, + // + // [0, 1, 2, 0, 4, 5, 6, 7] + // + // has divisibility 1 because its contiguity is 1. + int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } + const DimVectorT &getDivisibility() const { return divisibility; } + + // constancy[d] is the length of the shortest sequence of repeating integers + // along dimension d. + // + // This is particularly useful to infer the contiguity of operations (e.g. + // add) involving a constant. + // + // If we have an array of N elements, with a constancy value C, then the array + // can be divided into a list of N/C sequences of C elements with the same + // value. Since we have N = 2^k, C must be a power of two. + // + // For example + // + // [[8, 8, 8, 8, 12, 12, 12, 12], + // [16, 16, 16, 16, 20, 20, 20, 20]] + // + // has constancy [1, 4]. + int64_t getConstancy(size_t dim) const { return constancy[dim]; } + const DimVectorT &getConstancy() const { return constancy; } + + int getRank() const { return contiguity.size(); } + + std::optional getConstantValue() const { return constantValue; } + + template + static void + initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, + DimVectorT *divisibility, DimVectorT *constancy); + + bool operator==(const AxisInfo &other) const { + return contiguity == other.contiguity && + divisibility == other.divisibility && constancy == other.constancy && + constantValue == other.constantValue; + } + + static AxisInfo getPessimisticValueState(Value value); + + // The gcd of both arguments for each dimension + static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); + + void print(raw_ostream &os) const { + auto print = [&](StringRef name, DimVectorT vec) { + os << name << " = ["; + llvm::interleaveComma(vec, os); + os << "]"; + }; + print("contiguity", contiguity); + print(", divisibility", divisibility); + print(", constancy", constancy); + os << ", constant_value = "; + if (constantValue) + os << *constantValue; + else + os << ""; + } + +private: + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + + // The constant value of the lattice if we can infer it. + std::optional constantValue; +}; + +// Module level axis info analysis based on the call graph, assuming that we do +// not have recursive functions. +// +// Since each function will be called multiple times, we need to calculate the +// axis info based on the axis info of all the callers. In the future, we can +// perform optimization using function cloning so that each call site will have +// unique axis info. +using AxisInfoMapT = DenseMap; +class ModuleAxisInfoAnalysis : public CallGraph { +public: + explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, AxisInfoMapT{}); + }); + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + SymbolTableCollection symbolTable; + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = + dyn_cast(callOp.resolveCallable(&symbolTable)); + update(callOp, callee); + }); + } + } + + AxisInfo *getAxisInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } + + unsigned getPtrContiguity(Value ptr); + unsigned getPtrAlignment(Value ptr); + unsigned getMaskAlignment(Value mask); + +private: + void initialize(FunctionOpInterface funcOp); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; + +} // namespace mlir::triton + +#endif diff --git a/third_party/xpu/include/triton/Analysis/Membar.h b/third_party/xpu/include/triton/Analysis/Membar.h new file mode 100644 index 000000000..43bd5d15b --- /dev/null +++ b/third_party/xpu/include/triton/Analysis/Membar.h @@ -0,0 +1,154 @@ +#ifndef TRITON_ANALYSIS_MEMBAR_H +#define TRITON_ANALYSIS_MEMBAR_H + +#include "Allocation.h" +#include "llvm/ADT/SmallPtrSet.h" + +#include + +namespace mlir { + +class OpBuilder; + +struct BlockInfo { + using BufferIdSetT = Allocation::BufferIdSetT; + using IntervalSetT = std::set>; + + IntervalSetT syncReadIntervals; + IntervalSetT syncWriteIntervals; + + BlockInfo() = default; + + /// Unions two BlockInfo objects. + BlockInfo &join(const BlockInfo &other) { + syncReadIntervals.insert(other.syncReadIntervals.begin(), + other.syncReadIntervals.end()); + syncWriteIntervals.insert(other.syncWriteIntervals.begin(), + other.syncWriteIntervals.end()); + return *this; + } + + /// Returns true if intervals in two BlockInfo objects are intersected. + bool isIntersected(const BlockInfo &other) const { + return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) || + /*WAR*/ + isIntersected(syncReadIntervals, other.syncWriteIntervals) || + /*WAW*/ + isIntersected(syncWriteIntervals, other.syncWriteIntervals); + } + + /// Clears the intervals because a barrier is inserted. + void sync() { + syncReadIntervals.clear(); + syncWriteIntervals.clear(); + } + + /// Compares two BlockInfo objects. + bool operator==(const BlockInfo &other) const { + return syncReadIntervals == other.syncReadIntervals && + syncWriteIntervals == other.syncWriteIntervals; + } + + bool operator!=(const BlockInfo &other) const { return !(*this == other); } + +private: + bool isIntersected(const IntervalSetT &lhsIntervalSet, + const IntervalSetT &rhsIntervalSet) const { + for (auto &lhs : lhsIntervalSet) + for (auto &rhs : rhsIntervalSet) + if (lhs.intersects(rhs)) + return true; + return false; + } +}; + +//===----------------------------------------------------------------------===// +// Shared Memory Barrier Analysis +//===----------------------------------------------------------------------===// +class MembarAnalysis { +public: + using FuncBlockInfoMapT = CallGraph::FuncDataMapT; + /// Creates a new Membar analysis that generates the shared memory barrier + /// in the following circumstances: + /// - RAW: If a shared memory write is followed by a shared memory read, and + /// their addresses are intersected, a barrier is inserted. + /// - WAR: If a shared memory read is followed by a shared memory write, and + /// their addresses are intersected, a barrier is inserted. + /// The following circumstances do not require a barrier: + /// - WAW: not possible because overlapped memory allocation is not allowed. + /// - RAR: no write is performed. + /// Temporary storage of operations such as Reduce are considered as both + /// a shared memory read. If the temporary storage is written but not read, + /// it is considered as the problem of the operation itself but not the membar + /// analysis. + MembarAnalysis() = default; + explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {} + + /// Runs the membar analysis to the given operation, inserts a barrier if + /// necessary. + void run(FuncBlockInfoMapT &funcBlockInfoMap); + +private: + /// Applies the barrier analysis based on the SCF dialect, in which each + /// region has a single basic block only. + /// Example: + /// region1 + /// op1 + /// op2 (scf.if) + /// region2 + /// op3 + /// op4 + /// region3 + /// op5 + /// op6 + /// op7 + /// TODO: Explain why we don't use ForwardAnalysis: + void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder); + + /// Updates the BlockInfo operation based on the operation. + void update(Operation *operation, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder); + + /// Collects the successors of the terminator + void visitTerminator(Operation *operation, SmallVector &successors); + + void insertBarrier(Operation *operation, OpBuilder *builder); + +private: + Allocation *allocation = nullptr; +}; + +/// Postorder traversal on the callgraph to insert membar instructions +/// of each function. +/// Each function maintains a BlockInfo map that includes all potential buffers +/// after returning. This way users do not have to explicitly insert membars +/// before and after function calls, but might be a bit conservative. +class ModuleMembarAnalysis : public CallGraph { +public: + ModuleMembarAnalysis(ModuleAllocation *moduleAllocation) + : CallGraph(moduleAllocation->getModuleOp()), + moduleAllocation(moduleAllocation) {} + + void run() { + walk( + // Pre-order walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order walk callback + [&](FunctionOpInterface funcOp) { + auto *allocation = moduleAllocation->getFuncData(funcOp); + auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo()); + if (inserted) { + MembarAnalysis analysis(allocation); + analysis.run(funcMap); + } + }); + } + +private: + ModuleAllocation *moduleAllocation; +}; + +} // namespace mlir + +#endif // TRITON_ANALYSIS_MEMBAR_H diff --git a/third_party/xpu/include/triton/Analysis/Utility.h b/third_party/xpu/include/triton/Analysis/Utility.h new file mode 100644 index 000000000..554477646 --- /dev/null +++ b/third_party/xpu/include/triton/Analysis/Utility.h @@ -0,0 +1,490 @@ +#ifndef TRITON_ANALYSIS_UTILITY_H +#define TRITON_ANALYSIS_UTILITY_H + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "triton-analysis" + +namespace mlir { + +inline bool isZeroConst(Value v) { + auto constantOp = v.getDefiningOp(); + if (!constantOp) + return false; + if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + return false; +} + +struct redSMOffsetInfo { + int64_t startOffset; // bytes + int64_t endOffset; // bytes + llvm::SmallVector offsets; + + redSMOffsetInfo() : startOffset(0), endOffset(0) {} + + redSMOffsetInfo(int64_t _startOffset, llvm::SmallVector &_offsets) + : startOffset(_startOffset), endOffset(_startOffset), offsets(_offsets) { + for (auto offset : offsets) + endOffset += offset; + } +}; + +class ReduceOpHelper { +public: + explicit ReduceOpHelper(triton::ReduceOp op) + : op(op.getOperation()), axis(op.getAxis()) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = op.getElementTypes(); + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + + ArrayRef getSrcShape() { return srcShape; } + + Attribute getSrcLayout() { return srcEncoding; } + + triton::ReduceOp getOperation() { return op; } + + bool isReductionOnLayoutFastAxis(); + + unsigned getThreadOffsetOnReductionAxis(); + + bool isWarpSynchronous(); + + unsigned getInterWarpSize(); + + unsigned getIntraWarpSize(); + + unsigned getInterWarpSizeWithUniqueData(); + + unsigned getIntraWarpSizeWithUniqueData(); + + unsigned getThreadsReductionAxis(); + + SmallVector getScratchConfig(); + + SmallVector getOrderWithAxisAtBeginning(); + + unsigned getScratchSizeInBytes(); + + bool isSupportedLayout(); + + bool isReduceWithinCTA(); + + unsigned getAxis() { return axis; } + + //===-------------------- For Triton XPU -----------------------===// + explicit ReduceOpHelper(triton::xpu::ReduceOp op) + : xpu_op(op.getOperation()), axis(op.getAxis()) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = op.getElementTypes(); + + for (const auto &[i, t] : llvm::enumerate(op.getInputTypes())) { + if (i == (op.getInputTypes().size() - 1)) + continue; // skip loopIndex + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + + triton::xpu::ReduceOp getXPUOperation() { return xpu_op; } + + bool isCoreSynchronous(); + + unsigned getIntraGroupSizeWithUniqueData(); + + SmallVector getXPUScratchConfig(); + + unsigned getXPUScratchSizeInBytes(); + + void setReduceId(unsigned _reduceId) { reduceIdMap[xpu_op] = _reduceId; } + + unsigned getReduceId() { return reduceIdMap[xpu_op]; } + + void setReduceNum(unsigned _reduceNum) { reduceNum = _reduceNum; } + + unsigned getReduceNum() { return reduceNum; } + + void setSMOffsets(unsigned _reduceId, SmallVector &_offsets) { + int64_t _startOffset; + if (_reduceId == 0) { + _startOffset = 0; + } else { + _startOffset = getSMOffsets(getReduceId() - 1)->endOffset; + } + reduceSMOffsetMap[_reduceId] = + std::make_unique(_startOffset, _offsets); + } + + redSMOffsetInfo *getSMOffsets(unsigned _reduceId) { + auto it = reduceSMOffsetMap.find(_reduceId); + if (it != reduceSMOffsetMap.end()) { + return it->second.get(); + } + return nullptr; + } + + void dumpSMOffsets() { + LLVM_DEBUG({ + for (auto i = 0; i < reduceNum; ++i) { + auto info = getSMOffsets(i); + if (info == nullptr) + continue; + llvm::dbgs() << "\nreduceOp" << i << " [start, end] = [" + << info->startOffset << ", " << info->endOffset << "]\n"; + llvm::dbgs() << "detail offsets: ["; + for (auto offset : info->offsets) { + llvm::dbgs() << offset << ","; + } + llvm::dbgs() << "]\n"; + } + }); + } + + SmallVector getReturnDefOps() { + SmallVector returnDefOps; + for (Block &block : xpu_op.getCombineOp().getBlocks()) { + triton::xpu::ReduceReturnOp returnOp = + cast(block.getTerminator()); + for (auto operand : returnOp.getOperands()) { + returnDefOps.emplace_back(operand.getDefiningOp()); + } + } + return returnDefOps; + } + + bool isVectorized() { + for (auto type : xpu_op.getInputTypes()) { + if (!isa(getElementTypeOrSelf(type))) { + return false; + } + } + return true; + } + //===-----------------------------------------------------------===// + +private: + triton::ReduceOp op; + ArrayRef srcShape; + Attribute srcEncoding; + SmallVector srcElementTypes; + int axis; + + //===-------------------- For Triton XPU -----------------------===// + triton::xpu::ReduceOp xpu_op; + static std::map reduceIdMap; + static unsigned reduceNum; + static std::map> reduceSMOffsetMap; + SmallVector returnDefOps; + //===-----------------------------------------------------------===// +}; + +class ScanLoweringHelper { +public: + explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) { + auto firstTy = cast(op.getOperands()[0].getType()); + srcShape = firstTy.getShape(); + srcEncoding = firstTy.getEncoding(); + srcElementTypes = op.getElementTypes(); + + for (const auto &t : op.getInputTypes()) { + if (t.getShape() != srcShape) { + op.emitError() << "shape mismatch"; + } + if (t.getEncoding() != srcEncoding) { + op.emitError() << "encoding mismatch"; + } + } + } + // Return true if the lowering of the scan op is supported. + bool isSupported(); + // Return the number of elements per thread along axis dim. + unsigned getAxisNumElementsPerThread(); + // Return the number of elements per thread along non-axis dims. + unsigned getNonAxisNumElementsPerThread(); + // Return the number of threads per warp along non-axis dims. + unsigned getNonAxisNumThreadsPerWarp(); + // Return the flat numbers of threads computing independent scan results. + unsigned getNonAxisNumThreadsPerCTA(); + // Return the number of warps per CTA along axis dim. + unsigned getAxisNumWarps(); + // Return the number of warps per CTA along axis dim with unique data. + unsigned getAxisNumWarpsWithUniqueData(); + // Return the number of threads per warp along axis dim. + unsigned getAxisNumThreadsPerWarp(); + // Return the number of threads per warp along axis dim with unique data. + unsigned getAxisNumThreadsPerWarpWithUniqueData(); + // Return the number of blocks along axis dim. + unsigned getAxisNumBlocks(); + // Return the number of blocks along non axis dim. + unsigned getNonAxisNumBlocks(); + // Return the size of the scratch space needed for scan lowering. + unsigned getScratchSizeInBytes(); + // Return the number of elements of the scratch space needed for scan + // lowering. + unsigned getScratchSizeInElems(); + + // Stride between contiguous element along axis dim. + unsigned getAxisElementStride(); + // Stride between contiguous threads along axis dim. + unsigned getAxisThreadStride(); + // Stride between contiguous blocks along axis dim. + unsigned getAxisBlockStride(); + + Location getLoc() { return scanOp.getLoc(); } + unsigned getAxis() { return scanOp.getAxis(); } + bool getReverse() { return scanOp.getReverse(); } + triton::gpu::BlockedEncodingAttr getEncoding(); + llvm::ArrayRef getShape() { return srcShape; } + unsigned getNumOperands() { return scanOp.getNumOperands(); } + SmallVector getElementTypes() { return srcElementTypes; } + Attribute getSrcLayout() { return srcEncoding; } + Region &getCombineOp(); + +private: + triton::ScanOp scanOp; + Attribute srcEncoding; + llvm::ArrayRef srcShape; + SmallVector srcElementTypes; +}; + +// Decomposes a reshape into simpler pieces. +// +// As an example, suppose we have a reshape from [4,4,4] to [2,2,8,2]. +// You might explain what this does as follows. +// +// - Split the first input dimension into [2,2]. +// - Take the remaining two input dimensions, merge them into a single [16] +// dim, and then split that into [8,2]. +// +// In general, a reshape can be described a sequence of smushing one or more +// input dimensions together and then breaking them apart into one or more +// output dimensions. So we could represent the example above as follows. +// +// [ +// ([0], [0, 1]), # input dim [0] -> output dims [0, 1] +// ([1, 2], [2, 3]), # input dims [1, 2] -> output dims [2, 3] +// ] +// +// Notice that the input dims (first tuple elems) appear in sequential order if +// you read left-to-right-top-to-bottom, and so do the output dims. +// +// This function returns the above decomposition. +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, ArrayRef dstShape); + +bool maybeSharedAllocationOp(Operation *op); + +bool supportMFMA(triton::DotOp op); + +bool supportWMMA(triton::DotOp op); + +bool supportMMA(triton::DotOp op, int version); + +bool supportMMA(Value value, int version); + +bool isSingleValue(Value value); + +bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); + +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); + +bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy); + +// Return true if the src and dst layout match. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy); + +// TODO: Move utility functions that belong to ConvertLayoutOp to class +// ConvertLayoutOpHelper in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); + +/// Multi-root DAG topological sort. +/// Performs a topological sort of the Operation in the `toSort` SetVector. +/// Returns a topologically sorted SetVector. +/// It is faster than mlir::topologicalSort because it prunes nodes that have +/// been visited before. +SetVector +multiRootTopologicalSort(const SetVector &toSort); + +/// This uses the toplogicalSort above +SetVector +multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, + TransitiveFilter forwardFilter = nullptr); + +/// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +/// This class represents a call graph for a given ModuleOp and holds +/// data of type T associated with each FunctionOpInterface. +template class CallGraph { +public: + using FuncDataMapT = DenseMap; + + /// Constructor that builds the call graph for the given moduleOp. + explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); } + + /// Walks the call graph and applies the provided update functions + /// to the edges and nodes. + template + void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) { + DenseSet visited; + for (auto root : roots) { + doWalk(root, visited, updateEdgeFn, + updateNodeFn); + } + } + + /// Retrieves the data associated with a function + T *getFuncData(FunctionOpInterface funcOp) { + if (funcMap.count(funcOp)) { + return &funcMap[funcOp]; + } + return nullptr; + } + + /// Getters + ModuleOp getModuleOp() const { return moduleOp; } + SmallVector getRoots() const { return roots; } + size_t getNumFunctions() const { return funcMap.size(); } + + /// Returns true if the given function is a root. + bool isRoot(FunctionOpInterface funcOp) const { + return llvm::is_contained(roots, funcOp); + } + + /// Maps the data and the graph nodes associated with a funcOp to a + /// targetFuncOp. + template + void mapFuncOp(FROM funcOp, TO targetFuncOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.second == funcOp) { + edge.second = targetFuncOp; + } + } + } + graph[targetFuncOp] = graph[funcOp]; + // Replace in roots + for (auto it = roots.begin(); it != roots.end(); ++it) { + if (*it == funcOp) { + *it = targetFuncOp; + break; + } + } + // Replace in funcMap + funcMap[targetFuncOp] = funcMap[funcOp]; + } + + /// Maps the graph edges associated with a callOp to a targetCallOp. + template + void mapCallOp(FROM callOp, TO targetCallOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.first == callOp) { + edge.first = targetCallOp; + } + } + } + } + +private: + void build() { + SymbolTableCollection symbolTable; + DenseSet visited; + // Build graph + moduleOp.walk([&](Operation *op) { + auto caller = op->getParentOfType(); + if (auto callOp = dyn_cast(op)) { + auto *callee = callOp.resolveCallable(&symbolTable); + auto funcOp = dyn_cast_or_null(callee); + if (funcOp) { + graph[caller].emplace_back( + std::pair(callOp, funcOp)); + visited.insert(funcOp); + } + } + }); + // Find roots + moduleOp.walk([&](FunctionOpInterface funcOp) { + if (!visited.count(funcOp)) { + roots.push_back(funcOp); + } + }); + } + + template + void doWalk(FunctionOpInterface funcOp, + DenseSet &visited, UpdateEdgeFn updateEdgeFn, + UpdateNodeFn updateNodeFn) { + if (visited.count(funcOp)) { + llvm::report_fatal_error("Cycle detected in call graph"); + } + if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) { + updateNodeFn(funcOp); + } + for (auto [callOp, callee] : graph[funcOp]) { + if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) { + updateEdgeFn(callOp, callee); + } + doWalk(callee, visited, updateEdgeFn, + updateNodeFn); + if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) { + updateEdgeFn(callOp, callee); + } + } + if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) { + updateNodeFn(funcOp); + } + visited.erase(funcOp); + } + +protected: + ModuleOp moduleOp; + DenseMap>> + graph; + FuncDataMapT funcMap; + SmallVector roots; +}; +// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + +triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v); + +} // namespace mlir + +#undef DEBUG_TYPE + +#endif // TRITON_ANALYSIS_UTILITY_H diff --git a/third_party/xpu/include/triton/Analysis/UtilityXPU.h b/third_party/xpu/include/triton/Analysis/UtilityXPU.h new file mode 100644 index 000000000..cf00d933e --- /dev/null +++ b/third_party/xpu/include/triton/Analysis/UtilityXPU.h @@ -0,0 +1,158 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITONXPU_ANALYSIS_UTILITY_H +#define TRITONXPU_ANALYSIS_UTILITY_H + +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" + +namespace mlir { + +#define XPU_MEMORY_OP \ + triton::xpu::GM2LMOp, triton::xpu::LM2GMOp, triton::xpu::SM2GMOp + +template struct is_xpu_memory_op { + static const bool value = false; +}; +template <> struct is_xpu_memory_op { + static const bool value = true; +}; +template <> struct is_xpu_memory_op { + static const bool value = true; +}; +template <> struct is_xpu_memory_op { + static const bool value = true; +}; + +#define ARITH_PTR_UNARY_OP arith::ExtSIOp + +#define ARITH_PTR_BINARY_OP \ + arith::DivSIOp, arith::RemSIOp, arith::MulIOp, arith::AddIOp, arith::SubIOp + +#define XPU_VVECTORIZED_BINARY_OP \ + triton::xpu::VvaddFOp, triton::xpu::VvmulFOp, triton::xpu::VvsubFOp, \ + triton::xpu::VvmaxFOp + +#define XPU_SVECTORIZED_BINARY_OP \ + triton::xpu::SvaddFOp, triton::xpu::SvmulFOp, triton::xpu::SvsubFOp, \ + triton::xpu::SvmaxFOp + +enum class OffsetState { + Unknown = -1, + DiscreteSame = 0, + Continuous = 1, + Discrete = 2, + LocallyContinuous = 3 +}; + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const OffsetState &state); + +enum class AtomicMaskCond { + PostiveCond = 1, + NegativeCond = -1, + NonActivate = 0, +}; + +enum class AtomicMaskType { + NaiveMask = 1, + OptimizationMask = 2, +}; + +enum class XPUArch { XPU2 = 2, XPU3 = 3 }; + +enum class MemCpyType { GM2LM = 0, LM2GM = 1, GM2SM = 2, SM2GM = 3 }; + +class SMHelper { +public: + explicit SMHelper(Operation *op) : op(op) {} + + void setOffset(int64_t offset) { smOffsetMap[op] = offset; } + + int64_t getOffset() { + int64_t offset = 0; + if (hasOffset()) { + offset = smOffsetMap[op]; + } + return offset; + } + + bool hasOffset() { return smOffsetMap.find(op) != smOffsetMap.end(); } + +private: + Operation *op; + static std::map smOffsetMap; +}; + +Type addrspaceCast(Type type, int addressSpace); + +bool inOpChain(llvm::SetVector &opChain, Operation *op); + +void getOpChainBwd(llvm::SetVector &opChain, Operation *op); +void getOpChainFwd(llvm::SetVector &opChain, Operation *op); +void getOpTreeBwd(llvm::SetVector &opTree, + llvm::SetVector &visitedOps, Operation *op); +void getOpTreeBwd(llvm::SetVector &opTree, + llvm::SetVector &visitedOps, Operation *op, + Block *block); + +llvm::SmallVector +sortOpTreeBwd(llvm::SmallVector &opTree); +llvm::SetVector +sortOpTreeBwd(llvm::SetVector &opTree); +llvm::SetVector sortOpTree(llvm::SetVector &opTree); + +bool inSameSCFIfBlock(llvm::SetVector &storeOps, + Operation *storeOp); + +template +Operation *findUserOpImpl(Operation *op, + llvm::SetVector &visitedOps) { + if (!op || op->use_empty() || visitedOps.contains(op)) + return nullptr; + + visitedOps.insert(op); + + if (isa(op)) { + return op; + } + + for (Operation *user : op->getUsers()) { + Operation *userOp = findUserOpImpl(user, visitedOps); + if (userOp) { + return userOp; + } + } + + return nullptr; +} + +template Operation *findUserOp(Operation *op) { + llvm::SetVector visitedOps; + return findUserOpImpl(op, visitedOps); +} + +template Operation *findDefOpBwd(const Value &val) { + if (!val || !val.getDefiningOp()) { + return nullptr; + } + auto op = val.getDefiningOp(); + if (op && isa(op)) { + return op; + } + for (auto operand : op->getOperands()) { + op = findDefOpBwd(operand); + if (op) { + return op; + } + } + return nullptr; +} + +} // namespace mlir + +#endif // TRITONXPU_ANALYSIS_UTILITY_H diff --git a/third_party/xpu/include/triton/CMakeLists.txt b/third_party/xpu/include/triton/CMakeLists.txt new file mode 100644 index 000000000..27c703b3c --- /dev/null +++ b/third_party/xpu/include/triton/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) diff --git a/third_party/xpu/include/triton/Conversion/CMakeLists.txt b/third_party/xpu/include/triton/Conversion/CMakeLists.txt new file mode 100644 index 000000000..f052ffdcf --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(TritonGPUToLLVM) +add_subdirectory(TritonToTritonGPU) +add_subdirectory(TritonXPUToLLVM) +add_subdirectory(TritonToTritonXPU) diff --git a/third_party/xpu/include/triton/Conversion/MLIRTypes.h b/third_party/xpu/include/triton/Conversion/MLIRTypes.h new file mode 100644 index 000000000..fadba413f --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/MLIRTypes.h @@ -0,0 +1,42 @@ +#ifndef TRITON_CONVERSION_MLIR_TYPES_H +#define TRITON_CONVERSION_MLIR_TYPES_H + +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// This file redefines some common MLIR types for easy usage. +namespace mlir { +namespace triton { +namespace type { + +// Integer types +inline Type i32Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 32); } +inline Type i16Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 16); } +inline Type i8Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 8); } +inline Type u32Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 32, IntegerType::Unsigned); +} +inline Type u1Ty(MLIRContext *ctx) { + return IntegerType::get(ctx, 1, IntegerType::Unsigned); +} + +// Float types +inline Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); } +inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); } +inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); } +inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); } + +inline bool isFloat(Type type) { + return type.isF32() || type.isF64() || type.isF16() || type.isF128() || + type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() || + type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() || + type.isFloat8E5M2FNUZ(); +} + +inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); } + +} // namespace type +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_MLIR_TYPES_H diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h new file mode 100644 index 000000000..00ec88089 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h @@ -0,0 +1,27 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +class ConversionPatternRewriter; +class Location; + +namespace triton { +using llvm::StringRef; + +inline std::string strJoin(llvm::ArrayRef strs, + llvm::StringRef delimiter) { + return llvm::join(strs.begin(), strs.end(), delimiter); +} + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_ diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..93f8374e5 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonGPUToLLVM) +add_public_tablegen_target(TritonGPUConversionPassIncGen) diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h new file mode 100644 index 000000000..b6ebfe9b2 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -0,0 +1,227 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H +#define TRITON_CONVERSION_TRITONGPU_TO_ELEMENTWISE_OP_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton { + +namespace gpu { + +SmallVector reorderValues(const SmallVector &values, Type inType, + Type ouType); + +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter); + +Type getElementType(Value value); + +class MultipleOperandsRange + : public iterator_range>::iterator> { + using ContainerT = SmallVector>; + +public: + using iterator_range::iterator_range; + ContainerT::reference operator[](ContainerT::size_type idx) { + return begin()[idx]; + } + ContainerT::const_reference operator[](ContainerT::size_type idx) const { + return begin()[idx]; + } + ContainerT::size_type size() const { return end() - begin(); } +}; + +// Base pattern for elementwise conversion using ConcreteT. Unpacks individual +// elements from a `!llvm.struct` via `llvm.extactvalue`, calls +// ConcreteT::createDestOps on each element, and packs them back into an +// `!llvm.struct` using `llvm.insertvalue`. +// +// Also supports processing the inputs in a vectorized form by consuming and +// producing multiple operand sets in ConcreteT::createDestOps. +template +class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit ElementwiseOpConversionBase( + LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit), + axisAnalysisPass(axisAnalysisPass) {} + + // Try to deduplicate the resultVals based on the + // constancy properties of the result discovered by + // the axis analysis pass. If possible, redundant + // computation is eliminated. + SmallVector maybeDeduplicate(SourceOp op, + SmallVector resultVals) const { + if (!isMemoryEffectFree(op)) + // the op has side effects: can't dedup + return resultVals; + SmallVector results = op->getResults(); + if (results.size() == 0 || results.size() > 1) + // there must be exactly 1 result + return resultVals; + Value result = results[0]; + Type type = result.getType(); + if (!type) + return resultVals; + RankedTensorType rtType = dyn_cast(type); + if (!rtType) + // the result must be a tensor + return resultVals; + Attribute encoding = rtType.getEncoding(); + if (!encoding) + // encoding not available + return resultVals; + if (!dyn_cast(encoding) && + !dyn_cast(encoding)) { + // TODO: constraining the ecndoing type here is necessary for avoiding + // crashes in the getElemsPerThread call below happening in the + // test_core::test_fp8_dot_acc + return resultVals; + } + + SmallVector elemsPerThread = getElemsPerThread(rtType); + int rank = elemsPerThread.size(); + if (product(elemsPerThread) != resultVals.size()) + return resultVals; + AxisInfo *axisInfo = axisAnalysisPass.getAxisInfo(result); + if (!axisInfo) + // axis info (e.g., constancy) not available + return resultVals; + SmallVector contigPerThread = getContigPerThread(encoding); + if (rank != contigPerThread.size()) + return resultVals; + + SmallVector constancy = axisInfo->getConstancy(); + if (rank != constancy.size()) + return resultVals; + bool hasConstancy = false; + for (int i = 0; i < rank; ++i) { + if (constancy[i] > contigPerThread[i]) { + if (constancy[i] % contigPerThread[i] != 0) + // constancy is not evenly covered by contigPerThread + return resultVals; + // can't move the values across different + // "contigPerThread"-sized blocks + constancy[i] = contigPerThread[i]; + } + if (elemsPerThread[i] < 1 || constancy[i] < 1) + return resultVals; + if (!(elemsPerThread[i] % constancy[i] == 0 || + constancy[i] % elemsPerThread[i] == 0)) + // either the constancy along each dimension must fit + // into the elemsPerThread or the other way around + return resultVals; + if (constancy[i] > 1) + hasConstancy = true; + } + if (!hasConstancy) + // nothing to deduplicate + return resultVals; + + if (rank > 1) { + // reorder the shape and constancy vectors by the axis order: + // from the fastest-changing to the smallest-changing axis + SmallVector order = getOrder(encoding); + if (rank != order.size()) + return resultVals; + elemsPerThread = applyPermutation(elemsPerThread, order); + constancy = applyPermutation(constancy, order); + } + + SmallVector strides(rank, 1); + for (int i = 1; i < rank; ++i) { + strides[i] = strides[i - 1] * elemsPerThread[i - 1]; + } + SmallVector dedupResultVals; + dedupResultVals.reserve(resultVals.size()); + for (int i = 0; i < resultVals.size(); ++i) { + // each coordinate of the orig_idx is "coarsened" using the + // constancy along this dimension: the resulting dedup_idx + // points to the reused value in the original resultsVal + int orig_idx = i; + int dedup_idx = 0; + for (int j = 0; j < rank; ++j) { + int coord_j = orig_idx % elemsPerThread[j]; + dedup_idx += (coord_j / constancy[j] * constancy[j]) * strides[j]; + orig_idx /= elemsPerThread[j]; + } + dedupResultVals.push_back(resultVals[dedup_idx]); + } + + return dedupResultVals; + } + LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTy = op.getType(); + Location loc = op->getLoc(); + // element type + auto resultElementTy = + isa(resultTy) + ? resultTy + : getElementTypeOrSelf(resultTy); // For TritonXPU Reduce Vectorize + Type elemTy = this->getTypeConverter()->convertType(resultElementTy); + SmallVector> allOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + subOperands = unpackI32(subOperands, argTy, rewriter, loc, + this->getTypeConverter()); + allOperands.resize(subOperands.size()); + for (auto v : llvm::enumerate(subOperands)) + allOperands[v.index()].push_back(v.value()); + } + if (allOperands.size() == 0) + allOperands.push_back({}); + + SmallVector resultVals; + for (auto it = allOperands.begin(), end = allOperands.end(); it != end;) { + auto curr = static_cast(this)->createDestOps( + op, adaptor, rewriter, elemTy, MultipleOperandsRange(it, end), loc); + if (curr.size() == 0) + return failure(); + for (auto v : curr) { + if (!static_cast(v)) + return failure(); + resultVals.push_back(v); + } + it += curr.size(); + } + if (op->getNumOperands() > 0) { + auto argTy = op->getOperand(0).getType(); + resultVals = reorderValues(resultVals, argTy, resultTy); + } + resultVals = maybeDeduplicate(op, resultVals); + resultVals = + packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); + Value view = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + rewriter.replaceOp(op, view); + + return success(); + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysisPass; +}; + +} // namespace gpu + +} // namespace mlir::triton +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Passes.h b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Passes.h new file mode 100644 index 000000000..b013f2628 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Passes.h @@ -0,0 +1,32 @@ +#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H +#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +#define GEN_PASS_DECL +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +namespace gpu { +std::unique_ptr> createAllocateSharedMemoryPass(); + +} // namespace gpu + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Passes.td new file mode 100644 index 000000000..700dcd6b4 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -0,0 +1,11 @@ +#ifndef TRITONCOMMONGPU_CONVERSION_PASSES +#define TRITONCOMMONGPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def AllocateSharedMemory : Pass<"allocate-shared-memory", "mlir::ModuleOp"> { + let summary = "Add metadata for shared memory allocation"; + let constructor = "mlir::triton::gpu::createAllocateSharedMemoryPass()"; +} + +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h new file mode 100644 index 000000000..d1494fd7e --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -0,0 +1,104 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_PATTERNS_TRITON_GPU_OP_TO_LLVM_H + +#include "TargetInfoBase.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::BlockedEncodingAttr; + +namespace SharedToDotOperandFMA { +Value convertLayout(int opIdx, Value val, Value llVal, + BlockedEncodingAttr dLayout, Value thread, Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); +} +LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); +namespace mlir { +namespace triton { + +constexpr int patternBenefitDefault = 1; +constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; +constexpr int patternBenefitClampOptimizedPattern = 20; +constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; + +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateViewOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateMinMaxFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool hwNanPropagationSupported, + PatternBenefit benefit); +void populateClampFOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); +void populateScanOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit); + +void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Patterns.h b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Patterns.h new file mode 100644 index 000000000..934501ad3 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Patterns.h @@ -0,0 +1,32 @@ +#ifndef TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PATTERNS_H +#define TRITONGPU_CONVERSION_TRITONGPUTOLLVM_PATTERNS_H + +#include + +namespace mlir { +class ModuleOp; +class RankedTensorType; + +namespace triton::gpu { + +/// Replaces `blocked -> dot_op` with `blocked -> shared -> dot_op` in the given +/// |module| op because the codegen doesn't handle `blocked -> dot_op` directly. +void decomposeBlockedToDotLayoutConversion(ModuleOp module); + +/// Replaces `splat -> shared` with `splat -> blocked -> shared` in the given +/// |module| op. +void decomposeSplatOpToSharedLayoutConversion(ModuleOp module); + +/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the +/// given |module| op, but bypass the decomposition if |shortcutFn| returns +/// true. +using ShortcutFn = std::function; +template +void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, + ShortcutFn shortcutFn); + +} // namespace triton::gpu + +} // namespace mlir + +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h new file mode 100644 index 000000000..d03f6b862 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -0,0 +1,66 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H + +#include "triton/Conversion/MLIRTypes.h" + +namespace mlir::triton { +class TargetInfoBase { +public: + virtual bool supportMaximumMinimum() const = 0; + + virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0; + + virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc, + Type type, Value cmp) const = 0; + + virtual void storeShared(ConversionPatternRewriter &rewriter, Location loc, + Value ptr, Value val, Value pred) const = 0; + virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *converter, Value ptr, + Type elemTy, Value pred) const = 0; + + virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const = 0; + virtual Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const = 0; + virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const = 0; + virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, Value i) const = 0; + + virtual Value programId(ConversionPatternRewriter &rewriter, Location loc, + ModuleOp moduleOp, int axis) const = 0; + + virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce) const = 0; + + virtual bool processReplicaUsingStMatrix( + ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + SmallVector &vals, RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, ArrayRef origRepShape, + ArrayRef outOrd, unsigned accumNumReplicates, + int swizzleByteWidth = 0) const = 0; + + virtual std::string getMulhiFuncName(Type resultElementTy) const = 0; + // Emits LLVM code with |rewriter| to print a message following the given + // format from the device. |formatStrStart| is the pointer to the start of + // the format string global variable; |args| are the arguments to fill + // placeholders in the format string. + virtual void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const = 0; + // Emits LLVM code with |rewriter| to perform assertion failure with the given + // |message| from the given |func| in |file|. + virtual void assertFail(ConversionPatternRewriter &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const = 0; + + // Whether to enable linear layout. This is a per-backend temporary escape + // hatch to disable linear layout while figuring out issues. Eventually we + // want to enable linear layout everywhere and delete this control. + virtual bool enableLinearLayout() const { return true; } + + virtual ~TargetInfoBase() {} +}; +} // namespace mlir::triton +#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h new file mode 100644 index 000000000..ab9d0ebf8 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h @@ -0,0 +1,26 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); + + Type getElementTypeForStruct(TensorOrMemDesc type); + Type convertTritonPointerType(triton::PointerType type); + Type convertTritonTensorType(RankedTensorType type); + Type convertMemDescType(MemDescType type); + Type convertAsyncToken(triton::gpu::AsyncTokenType type); +}; + +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Utility.h new file mode 100644 index 000000000..87851c6f5 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -0,0 +1,1647 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H + +#include + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Analysis/UtilityXPU.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "ttgpu_to_llvm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::triton; + +// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive +// Operators +#define inttofloat(...) rewriter.create(loc, __VA_ARGS__) +#define inttoptr(...) rewriter.create(loc, __VA_ARGS__) +#define ptrtoint(...) rewriter.create(loc, __VA_ARGS__) +#define zext(...) rewriter.create(loc, __VA_ARGS__) +#define sext(...) rewriter.create(loc, __VA_ARGS__) +#define fpext(...) rewriter.create(loc, __VA_ARGS__) +#define trunc(...) rewriter.create(loc, __VA_ARGS__) +#define udiv(...) rewriter.create(loc, __VA_ARGS__) +#define urem(...) rewriter.create(loc, __VA_ARGS__) +#define add(...) rewriter.create(loc, __VA_ARGS__) +#define sub(...) rewriter.create(loc, __VA_ARGS__) +#define fadd(...) rewriter.create(loc, __VA_ARGS__) +#define mul(...) rewriter.create(loc, __VA_ARGS__) +#define fmul(...) rewriter.create(loc, __VA_ARGS__) +#define smax(...) rewriter.create(loc, __VA_ARGS__) +#define umax(...) rewriter.create(loc, __VA_ARGS__) +#define fmax(...) rewriter.create(loc, __VA_ARGS__) +#define smin(...) rewriter.create(loc, __VA_ARGS__) +#define umin(...) rewriter.create(loc, __VA_ARGS__) +#define fmin(...) rewriter.create(loc, __VA_ARGS__) +#define shl(...) rewriter.create(loc, __VA_ARGS__) +#define lshr(...) rewriter.create(loc, __VA_ARGS__) +#define and_(...) rewriter.create(loc, __VA_ARGS__) +#define xor_(...) rewriter.create(loc, __VA_ARGS__) +#define or_(...) rewriter.create(loc, __VA_ARGS__) +#define bitcast(val__, type__) \ + rewriter.create(loc, type__, val__) +#define addrspacecast(...) \ + rewriter.create(loc, __VA_ARGS__) +#define gep(...) rewriter.create(loc, __VA_ARGS__) +#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) +#define insert_val(...) rewriter.create(loc, __VA_ARGS__) +#define extract_val(...) rewriter.create(loc, __VA_ARGS__) +#define insert_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define extract_element(...) \ + rewriter.create(loc, __VA_ARGS__) +#define load(...) rewriter.create(loc, __VA_ARGS__) +#define store(...) rewriter.create(loc, __VA_ARGS__) +#define fcmp_ogt(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::ogt, lhs, rhs) +#define fcmp_olt(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::olt, lhs, rhs) +#define fcmp_eq(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::oeq, lhs, rhs) +#define icmp_eq(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) +#define icmp_ne(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__) +#define icmp_slt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) +#define icmp_sle(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__) +#define icmp_sgt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__) +#define icmp_sge(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__) +#define icmp_ult(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__) +#define icmp_ule(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__) +#define icmp_ugt(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__) +#define icmp_uge(...) \ + rewriter.create(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__) +#define select(...) rewriter.create(loc, __VA_ARGS__) +#define address_of(...) rewriter.create(loc, __VA_ARGS__) +#define barrier() rewriter.create(loc) +#define undef(...) rewriter.create(loc, __VA_ARGS__) +#define null(...) rewriter.create(loc, __VA_ARGS__) +#define call(...) rewriter.create(loc, __VA_ARGS__) + +// Types +#define int_ty(width) rewriter.getIntegerType(width) +#define i64_ty rewriter.getIntegerType(64) +#define i32_ty rewriter.getIntegerType(32) +#define i16_ty rewriter.getIntegerType(16) +#define i32_ty rewriter.getIntegerType(32) +#define i64_ty rewriter.getIntegerType(64) +#define ui32_ty rewriter.getIntegerType(32, false) +#define ui64_ty rewriter.getIntegerType(64, false) +#define f16_ty rewriter.getF16Type() +#define bf16_ty rewriter.getBF16Type() +#define i8_ty rewriter.getIntegerType(8) +#define i1_ty rewriter.getI1Type() +#define f32_ty rewriter.getF32Type() +#define f64_ty rewriter.getF64Type() +#define vec_ty(type, num) VectorType::get(num, type) +#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) +#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) +#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) + +// Constants +#define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__) +#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) +#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) +#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) +#define i64_val(...) LLVM::createConstantI64(loc, rewriter, __VA_ARGS__) +#define int_val(width, val) \ + LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) +#define tid_val() getThreadId(rewriter, loc) + +// Attributes +#define i32_arr_attr(...) rewriter.getI32ArrayAttr({__VA_ARGS__}) +#define i64_arr_attr(...) rewriter.getI64ArrayAttr({__VA_ARGS__}) +#define str_attr(str) ::mlir::StringAttr::get(ctx, (str)) + +namespace mlir { +namespace triton { + +// Delinearize supposing order is [0, 1, .. , n] +template +llvm::SmallVector getMultiDimIndexImpl(T linearIndex, + llvm::ArrayRef shape) { + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearRemain = linearIndex; + llvm::SmallVector multiDimIndex(rank); + for (int i = rank - 1; i >= 0; --i) { + multiDimIndex[i] = linearRemain / accMul; + linearRemain = linearRemain % accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return multiDimIndex; +} + +template +llvm::SmallVector getMultiDimIndex(T linearIndex, llvm::ArrayRef shape, + llvm::ArrayRef order) { + size_t rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + auto reorderedMultiDim = getMultiDimIndexImpl(linearIndex, reordered); + llvm::SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +// Linearize supposing order is [0, 1, .. , n] +template +T getLinearIndexImpl(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape) { + assert(multiDimIndex.size() == shape.size()); + // shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c} + size_t rank = shape.size(); + T accMul = product(shape.drop_back()); + T linearIndex = 0; + for (int i = rank - 1; i >= 0; --i) { + linearIndex += multiDimIndex[i] * accMul; + if (i != 0) { + accMul = accMul / shape[i - 1]; + } + } + return linearIndex; +} + +template +T getLinearIndex(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape, + llvm::ArrayRef order) { + assert(shape.size() == order.size()); + return getLinearIndexImpl(applyPermutation(multiDimIndex, order), + applyPermutation(shape, order)); +} + +namespace gpu { +Type getFunctionType(Type resultType, ValueRange operands); + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, + Operation *op, StringRef funcName, + Type funcType, StringRef libname = "", + StringRef libpath = ""); +} // namespace gpu + +} // namespace triton + +namespace LLVM { +using namespace mlir::triton; + +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v); + +/// Create a 64-bit integer constant. +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v); + +/// Create a 16-bit float constant. +Value createConstantF16(Location loc, OpBuilder &rewriter, float v); + +/// Create a 32-bit float constant. +Value createConstantF32(Location loc, OpBuilder &rewriter, float v); + +/// Create a 64-bit float constant. +Value createConstantF64(Location loc, OpBuilder &rewriter, double v); + +/// Create NaN constant of specified type. +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type); + +/// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + const TypeConverter *converter, int64_t value); + +/// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value); + +/// Helper function to get strides from a given shape and its order +SmallVector getStridesFromShapeAndOrder(ArrayRef shape, + ArrayRef order, + Location loc, + RewriterBase &rewriter); +struct SharedMemoryObject { + Value base; // i32 ptr. The start address of the shared memory object after + // the initial allocation or the last slicing operation. + Type baseElemType; + // We need to store strides as Values, not integers, because the + // extract_slice instruction can take a slice at arbitrary offsets. + // Take $a[16:32, 16:32] as an example; though we know the stride of $a[0] is + // 32, we need to let the instruction that uses $a be aware of that. + // Otherwise, when we use $a, we only know that the shape of $a is 16x16. If + // we store strides into an attribute array of integers, the information + // cannot pass through block argument assignment because attributes are + // associated with operations, not Values. + // TODO(Keren): We may need to figure out a way to store strides as integers + // if we want to support more optimizations. + SmallVector + strides; // i32 int. The strides of the shared memory object. + SmallVector offsets; // i32 int. + // Offsets are applied at the last slicing operation. + // We can use offsets to recover the previous base. + // The offsets are zero at the initial allocation. + + SharedMemoryObject(Value base, Type baseElemType, ArrayRef strides, + ArrayRef offsets) + : base(base), baseElemType(baseElemType), + strides(strides.begin(), strides.end()), + offsets(offsets.begin(), offsets.end()) {} + + SharedMemoryObject(Value base, Type baseElemType, ArrayRef shape, + ArrayRef order, Location loc, + RewriterBase &rewriter) + : base(base), baseElemType(baseElemType) { + strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); + offsets.append(order.size(), i32_val(0)); + } + + SmallVector getStrides() const { return strides; } + SmallVector getOffsets() const { return offsets; } + Value getBase() const { return base; } + Type getBaseElemType() const { return baseElemType; } + + SmallVector getElems() const { + SmallVector elems; + elems.push_back(base); + elems.append(strides.begin(), strides.end()); + elems.append(offsets.begin(), offsets.end()); + return elems; + } + + SmallVector getTypes() const { + SmallVector types; + types.push_back(base.getType()); + types.append(strides.size(), IntegerType::get(base.getContext(), 32)); + types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); + return types; + } + + Value getCSwizzleOffset(int order) const { + assert(order >= 0 && order < strides.size()); + return offsets[order]; + } + + Value getBaseBeforeSlice(int order, Location loc, + ConversionPatternRewriter &rewriter) const { + Value cSwizzleOffset = getCSwizzleOffset(order); + Value offset = sub(i32_val(0), cSwizzleOffset); + Type type = base.getType(); + return gep(type, baseElemType, base, offset); + } +}; + +SharedMemoryObject +getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, + ConversionPatternRewriter &rewriter); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape); + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape); + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape, + ArrayRef order); + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape); + +Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, + StringRef key, StringRef content); + +// Given an elemId which represents the index of an element from the list of +// elements that are in the thread's registers (i.e. total of +// numel(sizePerThread)), it calculates the multi dim offset of the element in +// the smem buffer. Recall that the smem buffer will only store a single replica +// when converting distributed to distributed layout. Also, a replica is the +// smallest CTA tile that is common between input and output layouts. +SmallVector getMultiDimOffset(Attribute layout, Location loc, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + unsigned elemId, RankedTensorType type, + ArrayRef multiDimCTAInRepId, + ArrayRef shapePerCTATile); + +// Given a multiDimOffset, this function wraps around each dimension to be +// within shape. +SmallVector getWrappedMultiDimOffset( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDimOffset, ArrayRef shape, + SmallVector shapePerCTATile, SmallVector shapePerCTA); + +inline bool isKernel(FunctionOpInterface funcOp) { + return funcOp.getVisibility() == SymbolTable::Visibility::Public; +} + +inline Value getStackPointer(PatternRewriter &rewriter, + FunctionOpInterface funcOp) { + auto mod = funcOp->getParentOfType(); + LLVM::GlobalOp globalBase = nullptr; + mod.walk([&](LLVM::GlobalOp op) { + if (op.getSymName() == "global_smem") + globalBase = op; + }); + assert(globalBase); + if (isKernel(funcOp)) + return rewriter.create(funcOp.getLoc(), globalBase); + else + return funcOp.getArgument(funcOp.getNumArguments() - 1); +} + +inline Value getSharedMemoryBase(Location loc, + ConversionPatternRewriter &rewriter, + Operation *op) { + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + FunctionOpInterface func = + op->template getParentOfType(); + assert(op->hasAttr("allocation.offset")); + size_t offset = cast(op->getAttr("allocation.offset")) + .getValue() + .getZExtValue(); + Value offVal = i32_val(offset); + Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); + return base; +} +} // namespace LLVM + +namespace triton::xpu { + +inline SmallVector> +emitOffsetForClusterLayout(const triton::xpu::ClusterLayoutAttr &clusterLayout, + RankedTensorType type) { + auto ctx = type.getContext(); + auto shape = type.getShape(); + auto sizePerThread = clusterLayout.getSizePerCore(); + auto threadsPerWarp = clusterLayout.getCoresPerGroup(); + auto warpsPerCTA = clusterLayout.getGroupsPerCluster(); + auto order = clusterLayout.getOrder(); + auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(clusterLayout); + auto shapePerCTA = triton::gpu::getShapePerCTA(clusterLayout, shape); + + unsigned rank = shape.size(); + SmallVector tilesPerDim(rank); + for (unsigned k = 0; k < rank; ++k) + tilesPerDim[k] = ceil(shapePerCTA[k], shapePerCTATile[k]); + + unsigned elemsPerThread = triton::gpu::getTotalElemsPerThread(type); + unsigned totalSizePerThread = product(sizePerThread); + SmallVector> reorderedOffset(elemsPerThread); + for (unsigned n = 0; n < elemsPerThread; ++n) { + unsigned linearNanoTileId = n / totalSizePerThread; + unsigned linearNanoTileElemId = n % totalSizePerThread; + SmallVector multiDimNanoTileId = + getMultiDimIndex(linearNanoTileId, tilesPerDim, order); + SmallVector multiDimNanoTileElemId = + getMultiDimIndex(linearNanoTileElemId, sizePerThread, order); + for (unsigned k = 0; k < rank; ++k) { + unsigned reorderedMultiDimId = + (multiDimNanoTileId[k] * + (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + + multiDimNanoTileElemId[k]) % + shapePerCTA[k]; + + reorderedOffset[n].push_back(reorderedMultiDimId); + } + } + + return reorderedOffset; +} +} // namespace triton::xpu + +/* ------------------------------------ */ +// Returns CTA level thread idx +inline Value getThreadIdInCTA(RewriterBase &rewriter, Location loc) { + Value tid = + rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x); + return rewriter.create(loc, i32_ty, tid); +} + +// Returns CTA level thread idx. +inline Value getThreadId(RewriterBase &rewriter, Location loc) { + Value tid = getThreadIdInCTA(rewriter, loc); + auto mod = rewriter.getBlock()->getParent()->getParentOfType(); + return tid; +} + +// ----------------------------------------------------------------------- +// Shared memory utilities +// ----------------------------------------------------------------------- +using LLVM::getMultiDimIndex; +using LLVM::SharedMemoryObject; +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::SharedMemoryObject; +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::AMDWmmaEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::CTALayoutAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef offsets, + ArrayRef strides) { + assert(offsets.size() == strides.size()); + Value ret = i32_val(0); + for (auto [offset, stride] : llvm::zip(offsets, strides)) { + ret = add(ret, mul(offset, stride)); + } + return ret; +} + +// ----------------------------------------------------------------------- +// Blocked layout indices +// ----------------------------------------------------------------------- + +// "Applies" the given layout by computing layout(indices) and returning the +// resulting Values. +// +// In other words, this generates LLVM-dialect MLIR code to "run" the layout +// function. +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices); + +inline SmallVector +emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, + const BlockedEncodingAttr &blockedLayout, + RankedTensorType type) { + MLIRContext *ctx = rewriter.getContext(); + auto shape = type.getShape(); + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(blockedLayout)); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + auto sizePerThread = blockedLayout.getSizePerThread(); + auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + auto order = blockedLayout.getOrder(); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); + unsigned rank = shape.size(); + + // delinearize threadId to get the base index + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, order); + SmallVector multiDimThreadId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + + SmallVector multiDimBase(rank); + for (unsigned k = 0; k < rank; ++k) { + // Wrap around multiDimWarpId/multiDimThreadId in case + // shapePerCTATile[k] > shapePerCTA[k] + auto maxWarps = + ceil(shapePerCTA[k], sizePerThread[k] * threadsPerWarp[k]); + auto maxThreads = ceil(shapePerCTA[k], sizePerThread[k]); + multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps)); + multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads)); + // multiDimBase[k] = (multiDimThreadId[k] + + // multiDimWarpId[k] * threadsPerWarp[k]) * + // sizePerThread[k]; + Value threadsPerWarpK = i32_val(threadsPerWarp[k]); + Value sizePerThreadK = i32_val(sizePerThread[k]); + multiDimBase[k] = + mul(sizePerThreadK, + add(multiDimThreadId[k], mul(multiDimWarpId[k], threadsPerWarpK))); + } + + return multiDimBase; +} + +inline SmallVector> +emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout, + RankedTensorType type) { + auto ctx = type.getContext(); + auto shape = type.getShape(); + auto sizePerThread = blockedLayout.getSizePerThread(); + auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); + auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); + auto order = blockedLayout.getOrder(); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); + auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); + + unsigned rank = shape.size(); + SmallVector tilesPerDim(rank); + for (unsigned k = 0; k < rank; ++k) + tilesPerDim[k] = ceil(shapePerCTA[k], shapePerCTATile[k]); + + unsigned elemsPerThread = triton::gpu::getTotalElemsPerThread(type); + unsigned totalSizePerThread = product(sizePerThread); + SmallVector> reorderedOffset(elemsPerThread); + for (unsigned n = 0; n < elemsPerThread; ++n) { + unsigned linearNanoTileId = n / totalSizePerThread; + unsigned linearNanoTileElemId = n % totalSizePerThread; + SmallVector multiDimNanoTileId = + getMultiDimIndex(linearNanoTileId, tilesPerDim, order); + SmallVector multiDimNanoTileElemId = + getMultiDimIndex(linearNanoTileElemId, sizePerThread, order); + for (unsigned k = 0; k < rank; ++k) { + unsigned reorderedMultiDimId = + (multiDimNanoTileId[k] * + (sizePerThread[k] * threadsPerWarp[k] * warpsPerCTA[k]) + + multiDimNanoTileElemId[k]) % + shapePerCTA[k]; + + reorderedOffset[n].push_back(reorderedMultiDimId); + } + } + + return reorderedOffset; +} + +// ----------------------------------------------------------------------- +// Mma layout indices +// ----------------------------------------------------------------------- + +inline SmallVector +emitBaseIndexWithinCTAForMmaLayoutV1(Location loc, RewriterBase &rewriter, + const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto wpt = mmaLayout.getWarpsPerCTA(); + static constexpr std::array fpw{{2, 2, 1}}; + auto [isARow, isBRow, isAVec4, isBVec4, _] = + mmaLayout.decodeVoltaLayoutStates(); + + Value thread = getThreadId(rewriter, loc); + auto *ctx = thread.getContext(); + Value _1 = i32_val(1); + Value _2 = i32_val(2); + Value _4 = i32_val(4); + Value _16 = i32_val(16); + Value _32 = i32_val(32); + Value _fpw0 = i32_val(fpw[0]); + Value _fpw1 = i32_val(fpw[1]); + + // A info + auto aRep = mmaLayout.getMMAv1Rep(0); + auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); + // B info + auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); + auto bRep = mmaLayout.getMMAv1Rep(1); + + SmallVector rep({aRep[0], bRep[1]}); + SmallVector spw({aSpw[0], bSpw[1]}); + SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); + + Value lane = urem(thread, _32); + Value warp = udiv(thread, _32); + + Value warp0 = urem(warp, i32_val(wpt[0])); + Value warp12 = udiv(warp, i32_val(wpt[0])); + Value warp1 = urem(warp12, i32_val(wpt[1])); + + // warp offset + Value offWarpM = mul(warp0, i32_val(spw[0])); + Value offWarpN = mul(warp1, i32_val(spw[1])); + // quad offset + Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0); + Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1); + // pair offset + Value offPairM = udiv(urem(lane, _16), _4); + offPairM = urem(offPairM, _fpw0); + offPairM = mul(offPairM, _4); + Value offPairN = udiv(urem(lane, _16), _4); + offPairN = udiv(offPairN, _fpw0); + offPairN = urem(offPairN, _fpw1); + offPairN = mul(offPairN, _4); + offPairM = mul(offPairM, i32_val(rep[0] / 2)); + offQuadM = mul(offQuadM, i32_val(rep[0] / 2)); + offPairN = mul(offPairN, i32_val(rep[1] / 2)); + offQuadN = mul(offQuadN, i32_val(rep[1] / 2)); + // quad pair offset + Value offLaneM = add(offPairM, offQuadM); + Value offLaneN = add(offPairN, offQuadN); + // a, b offset + Value offsetAM = add(offWarpM, offLaneM); + Value offsetBN = add(offWarpN, offLaneN); + // m indices + Value offsetCM = add(and_(lane, _1), offsetAM); + // n indices + Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN))); + return {offsetCM, offsetCN}; +} + +inline SmallVector> +emitOffsetForMmaLayoutV1(const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + + auto [isARow, isBRow, isAVec4, isBVec4, _] = + mmaLayout.decodeVoltaLayoutStates(); + + // TODO: seems like the pattern below to get `rep`/`spw` appears quite often + // A info + auto aRep = mmaLayout.getMMAv1Rep(0); + auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); + // B info + auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); + auto bRep = mmaLayout.getMMAv1Rep(1); + + auto wpt = mmaLayout.getWarpsPerCTA(); + static constexpr std::array fpw{{2, 2, 1}}; + SmallVector rep({aRep[0], bRep[1]}); + SmallVector spw({aSpw[0], bSpw[1]}); + SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); + + SmallVector idxM; + for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0]) + for (unsigned mm = 0; mm < rep[0]; ++mm) + idxM.push_back(m + mm * 2); + + SmallVector idxN; + for (int n = 0; n < shape[1]; n += shapePerCTA[1]) { + for (int nn = 0; nn < rep[1]; ++nn) { + idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]); + idxN.push_back(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1); + } + } + + SmallVector> ret; + for (unsigned x1 : idxN) { // N + for (unsigned x0 : idxM) { // M + SmallVector idx(2); + idx[0] = x0; // M + idx[1] = x1; // N + ret.push_back(std::move(idx)); + } + } + return ret; +} + +inline SmallVector> +emitOffsetForMmaLayoutV2(const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + SmallVector> ret; + + auto rank = shape.size(); + for (unsigned i = 0; i < shapePerCTA[rank - 2]; + i += getShapePerCTATile(mmaLayout)[rank - 2]) { + for (unsigned j = 0; j < shapePerCTA[rank - 1]; + j += getShapePerCTATile(mmaLayout)[rank - 1]) { + if (rank == 3) { + ret.push_back({0, i, j}); + ret.push_back({0, i, j + 1}); + ret.push_back({0, i + 8, j}); + ret.push_back({0, i + 8, j + 1}); + } else { + ret.push_back({i, j}); + ret.push_back({i, j + 1}); + ret.push_back({i + 8, j}); + ret.push_back({i + 8, j + 1}); + } + } + } + return ret; +} + +// Note that this may return a null Value for one or more dimensions. This is +// valid only if you're going to slice off the relevant dimension. +inline SmallVector +emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter, + const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto _warpsPerCTA = mmaLayout.getWarpsPerCTA(); + auto rank = shape.size(); + assert(rank == 2 || rank == 3); + auto warpOrder = triton::gpu::getWarpOrder(mmaLayout); + ArrayRef instrShape = mmaLayout.getInstrShape(); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + + uint32_t repM = + (_warpsPerCTA[rank - 2] * instrShape[rank - 2]) / shapePerCTA[rank - 2]; + uint32_t repN = + (_warpsPerCTA[rank - 1] * instrShape[rank - 1]) / shapePerCTA[rank - 1]; + + uint32_t warpsM; + if (repM > 1) + warpsM = _warpsPerCTA[rank - 2] / repM; + else + warpsM = shape[rank - 2] / instrShape[rank - 2]; + + uint32_t warpsN; + if (repN > 1) + warpsN = _warpsPerCTA[rank - 1] / repN; + else + warpsN = shape[rank - 1] / instrShape[rank - 1]; + + SmallVector multiDimWarpId(rank); + multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, warpOrder); + Value warpIdM = urem(multiDimWarpId[rank - 2], i32_val(warpsM)); + Value warpIdN = urem(multiDimWarpId[rank - 1], i32_val(warpsN)); + + Value offWarpM = mul(warpIdM, i32_val(instrShape[rank - 2])); + Value offWarpN = mul(warpIdN, i32_val(instrShape[rank - 1])); + + SmallVector multiDimBase(rank); + if (rank == 3) + multiDimBase[0] = multiDimWarpId[0]; + + // warpsM/N may be 0, in which case warpIDM/N is poison (division by 0), which + // will cause LLVM to eliminate all ops that depend on the poison value. This + // *can* be okay, if the bad dimension is filtered out by a slice layout. So + // we rely on the caller to check. Worst case we crash, which is better than + // silently producing bad code. + if (warpsM != 0) + multiDimBase[rank - 2] = add(udiv(laneId, i32_val(4)), offWarpM); + if (warpsN != 0) + multiDimBase[rank - 1] = + add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarpN); + + return multiDimBase; +} + +inline SmallVector> +emitOffsetForMmaLayoutV3(const NvidiaMmaEncodingAttr &mmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + SmallVector> ret; + ArrayRef instrShape = mmaLayout.getInstrShape(); + + for (unsigned i = 0; i < shapePerCTA[0]; + i += getShapePerCTATile(mmaLayout)[0]) { + for (unsigned j = 0; j < shapePerCTA[1]; + j += getShapePerCTATile(mmaLayout)[1]) { + for (unsigned k = 0; k < instrShape[1]; k += 8) { + ret.push_back({i, j + k}); + ret.push_back({i, j + k + 1}); + ret.push_back({i + 8, j + k}); + ret.push_back({i + 8, j + k + 1}); + } + } + } + return ret; +} + +inline SmallVector +emitBaseIndexForMfmaLayout(Location loc, RewriterBase &rewriter, + const AMDMfmaEncodingAttr &mfmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto rank = shape.size(); + assert(rank == 2 || rank == 3); + auto _warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + unsigned mDim = mfmaLayout.getMDim(); + unsigned nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout)); + Value effectiveWarpSize = warpSize; + if (mDim == 4 && nDim == 4) { + const int uniqueValuesPerWarp = 4; + effectiveWarpSize = i32_val(uniqueValuesPerWarp); + } + Value laneId = urem(threadId, effectiveWarpSize); + Value warpId = udiv(threadId, warpSize); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, _warpsPerCTA, + triton::gpu::getWarpOrder(mfmaLayout)); + if (shape[rank - 2] >= mDim) { + assert(shape[rank - 2] % mDim == 0); + multiDimWarpId[rank - 2] = + urem(multiDimWarpId[rank - 2], + i32_val(ceil(shape[rank - 2], mDim))); + } + if (shape[rank - 1] >= nDim) { + assert(shape[rank - 1] % nDim == 0); + multiDimWarpId[rank - 1] = + urem(multiDimWarpId[rank - 1], + i32_val(ceil(shape[rank - 1], nDim))); + } + Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mDim)); + Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(nDim)); + + SmallVector multiDimBase(rank); + if (mfmaLayout.getIsTransposed()) { + multiDimBase[rank - 1] = + add(mul(i32_val(4), udiv(laneId, i32_val(mDim))), offWarp1); + multiDimBase[rank - 2] = add(urem(laneId, i32_val(mDim)), offWarp0); + } else { + multiDimBase[rank - 2] = + add(mul(i32_val(4), udiv(laneId, i32_val(nDim))), offWarp0); + multiDimBase[rank - 1] = add(urem(laneId, i32_val(nDim)), offWarp1); + } + // TODO(Lixun): It is assumed when rank = 3, warpsPerCTA is set to + // {numWarps, 1, 1}. We need to generalize the offset computation. + if (rank == 3) { + assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); + multiDimBase[0] = urem(warpId, i32_val(shape[0])); + } + return multiDimBase; +} + +inline void emitMfmaOffsetForCTA(const AMDMfmaEncodingAttr &mfmaLayout, + SmallVector> &offsets, + unsigned bOff, unsigned ctaOffsetX, + unsigned ctaOffsetY) { + auto mDim = mfmaLayout.getMDim(); + auto nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + // MFMA output tile consists of repeated "dot operand B" layout groups along + // row axis. This variable defines number of these groups. + DenseMap groups{{4, 1}, {16, 1}, {32, 4}}; + unsigned numGroups = groups.at(std::min(mDim, nDim)); + const unsigned elemsPerThreadPerGroup = 4; + auto warpSize = getWarpSize(mfmaLayout); + assert(warpSize == 64); + auto shapePerCta = getShapePerCTATile(mfmaLayout); + auto rank = shapePerCta.size(); + SmallVector elemOff(rank, 0); + for (unsigned block = 0; block < numGroups; block++) { + unsigned rowOrColOffset = + block * elemsPerThreadPerGroup * warpSize / std::min(mDim, nDim); + for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { + if (mfmaLayout.getIsTransposed()) { + elemOff[rank - 2] = ctaOffsetX * shapePerCta[rank - 2]; + elemOff[rank - 1] = + ctaOffsetY * shapePerCta[rank - 1] + elem + rowOrColOffset; + } else { + elemOff[rank - 2] = + ctaOffsetX * shapePerCta[rank - 2] + elem + rowOrColOffset; + elemOff[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; + } + if (rank == 3) + elemOff[0] = bOff; + offsets.push_back(elemOff); + } + } +} + +inline SmallVector> +emitOffsetForMfmaLayout(const AMDMfmaEncodingAttr &mfmaLayout, + RankedTensorType type) { + auto tensorShape = type.getShape(); + SmallVector> offsets; + auto shapePerCTA = getShapePerCTA(mfmaLayout, tensorShape); + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + auto rank = type.getRank(); + SmallVector numReps(rank); + unsigned mDim = mfmaLayout.getMDim(); + unsigned nDim = mfmaLayout.getNDim(); + assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 2] = mDim; + shapePerWarp[rank - 1] = nDim; + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); + unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); + numReps[d] = ceil(inPerWarp, shapePerWarp[d]); + } + + unsigned repBatch = rank == 3 ? numReps[0] : 1; + auto warpsPerBatch = + rank == 3 ? std::min(tensorShape[0], warpsPerCTA[0]) : 1; + + for (unsigned b = 0; b < repBatch; ++b) { + for (unsigned i = 0; i < numReps[rank - 2]; ++i) { + for (unsigned j = 0; j < numReps[rank - 1]; ++j) { + emitMfmaOffsetForCTA(mfmaLayout, offsets, b * warpsPerBatch, i, j); + } + } + } + return offsets; +} + +inline void emitWmmaOffsetForCTA(const AMDWmmaEncodingAttr &wmmaLayout, + SmallVector> &offsets, + unsigned ctaBatchOffset, unsigned ctaOffsetX, + unsigned ctaOffsetY) { + const unsigned elemsPerThreadPerGroup = 8; + auto warpSize = getWarpSize(wmmaLayout); + assert(warpSize == 32); + auto shapePerCta = getShapePerCTATile(wmmaLayout); + auto rank = shapePerCta.size(); + assert(rank == 2 || rank == 3); + SmallVector elemOffset(rank, 0); + if (rank == 3) + elemOffset[0] = ctaBatchOffset; + for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) { + elemOffset[rank - 2] = ctaOffsetX * shapePerCta[rank - 2] + 2 * elem; + elemOffset[rank - 1] = ctaOffsetY * shapePerCta[rank - 1]; + offsets.push_back(elemOffset); + } +} + +inline SmallVector +emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter, + const AMDWmmaEncodingAttr &wmmaLayout, + RankedTensorType type) { + auto shape = type.getShape(); + auto _warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + auto rank = _warpsPerCTA.size(); + assert(rank == 2 || rank == 3); + SmallVector warpsPerCTA; + for (unsigned i = 0; i < rank; ++i) + warpsPerCTA.push_back(i32_val(_warpsPerCTA[i])); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(wmmaLayout)); + Value laneId = + urem(threadId, i32_val(triton::gpu::getWarpSize(wmmaLayout) / 2)); + Value threadIdPerWarp = urem(threadId, warpSize); + + Value warpId = udiv(threadId, warpSize); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, _warpsPerCTA, + triton::gpu::getWarpOrder(wmmaLayout)); + if (shape[rank - 2] >= mnkDim[0]) { + assert(shape[rank - 2] % mnkDim[0] == 0); + multiDimWarpId[rank - 2] = + urem(multiDimWarpId[rank - 2], + i32_val(ceil(shape[rank - 2], mnkDim[0]))); + } + if (shape[rank - 1] >= mnkDim[1]) { + assert(shape[rank - 1] % mnkDim[1] == 0); + multiDimWarpId[rank - 1] = + urem(multiDimWarpId[rank - 1], + i32_val(ceil(shape[rank - 1], mnkDim[1]))); + } + Value offWarp0 = mul(multiDimWarpId[rank - 2], i32_val(mnkDim[0])); + Value offWarp1 = mul(multiDimWarpId[rank - 1], i32_val(mnkDim[1])); + + SmallVector multiDimBase(rank); + + multiDimBase[rank - 2] = + add(udiv(threadIdPerWarp, i32_val(mnkDim[2])), offWarp0); + multiDimBase[rank - 1] = add(laneId, offWarp1); + + // TODO: It is assumed when rank = 3, warpsPerCTA is set to + // {numWarps, 1, 1}. We need to generalize the offset computation. + if (rank == 3) { + assert(_warpsPerCTA[1] == 1 && _warpsPerCTA[2] == 1); + multiDimBase[0] = urem(warpId, i32_val(shape[0])); + } + return multiDimBase; +} + +inline SmallVector> +emitOffsetForWmmaLayout(const AMDWmmaEncodingAttr &wmmaLayout, + RankedTensorType type) { + auto tensorShape = type.getShape(); + SmallVector> offsets; + auto shapePerCTA = getShapePerCTA(wmmaLayout, tensorShape); + auto warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + + auto rank = tensorShape.size(); + assert(rank == 2 || rank == 3); + + SmallVector numWarpsPerDim(rank, 1); + auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr(); + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 2] = mnkDim[0]; + shapePerWarp[rank - 1] = mnkDim[1]; + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = std::min(tensorShape[d], shapePerCTA[d]); + unsigned inPerWarp = ceil(inPerCTA, warpsPerCTA[d]); + numWarpsPerDim[d] = ceil(inPerWarp, shapePerWarp[d]); + } + + unsigned repBatch = rank == 3 ? numWarpsPerDim[0] : 1; + unsigned repM = numWarpsPerDim[rank - 2]; + unsigned repN = numWarpsPerDim[rank - 1]; + auto warpsPerBatch = + rank == 3 ? std::min(tensorShape[0], warpsPerCTA[0]) : 1; + + for (unsigned b = 0; b < repBatch; ++b) { + for (unsigned i = 0; i < repM; ++i) { + for (unsigned j = 0; j < repN; ++j) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, b * warpsPerBatch, i, j); + } + } + } + return offsets; +} + +inline SmallVector> +emitOffsetForLayout(Attribute layout, RankedTensorType type); + +inline SmallVector> +emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout, + RankedTensorType type) { + auto parentEncoding = sliceLayout.getParent(); + unsigned dim = sliceLayout.getDim(); + auto parentShape = sliceLayout.paddedShape(type.getShape()); + RankedTensorType parentTy = + RankedTensorType::get(parentShape, type.getElementType(), parentEncoding); + auto parentOffsets = emitOffsetForLayout(parentEncoding, parentTy); + if (parentOffsets.empty()) + return {}; + + SmallVector> resultOffsets; + std::set> uniqueOffsets; + + for (unsigned i = 0; i < parentOffsets.size(); ++i) { + SmallVector offsets(parentOffsets[i].begin(), + parentOffsets[i].end()); + offsets.erase(offsets.begin() + dim); + if (auto [it, inserted] = uniqueOffsets.insert(offsets); inserted) { + resultOffsets.push_back(offsets); + } + } + + // It can happen that after deduplicating elements above, resultOffsets has + // fewer than getTotalElementsPerThread() elements. In that case repeat the + // sequence. + int elemsPerThread = triton::gpu::getTotalElemsPerThread(type); + assert(resultOffsets.size() > 0); + assert(elemsPerThread % resultOffsets.size() == 0); + int numRepeats = elemsPerThread / resultOffsets.size(); + SmallVector> ret; + for (int i = 0; i < numRepeats; ++i) { + for (unsigned j = 0; j < resultOffsets.size(); ++j) { + ret.push_back(SmallVector(resultOffsets[j])); + } + } + return ret; +} + +// ----------------------------------------------------------------------- +// Get offsets / indices for any layout +// ----------------------------------------------------------------------- + +inline SmallVector emitCTAOffsetForLayout(Location loc, + RewriterBase &rewriter, + const TargetInfoBase &target, + Attribute layout, + ArrayRef shape) { + unsigned rank = shape.size(); + SmallVector CTAsPerCGA = triton::gpu::getCTAsPerCGA(layout); + SmallVector CTASplitNum = triton::gpu::getCTASplitNum(layout); + SmallVector CTAOrder = triton::gpu::getCTAOrder(layout); + SmallVector shapePerCTA = + triton::gpu::getShapePerCTA(CTASplitNum, shape); + + // Delinearize clusterCTAId + Value clusterCTAId = target.getClusterCTAId(rewriter, loc); + SmallVector multiDimClusterCTAId = + delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); + + // CTA Wrapping + for (unsigned i = 0; i < rank; ++i) { + // This wrapping rule must be consistent with getShapePerCTA + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + multiDimClusterCTAId[i] = urem(multiDimClusterCTAId[i], i32_val(splitNum)); + } + + SmallVector CTAOffset(rank); + for (unsigned i = 0; i < rank; ++i) + CTAOffset[i] = mul(multiDimClusterCTAId[i], i32_val(shapePerCTA[i])); + + return CTAOffset; +} + +inline SmallVector +emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { + auto shape = type.getShape(); + + SmallVector baseIndex; + RewriterBase::InsertionGuard guard(rewriter); + SmallVector result; + if (auto blockedLayout = mlir::dyn_cast(layout)) { + result = emitBaseIndexWithinCTAForBlockedLayout(loc, rewriter, + blockedLayout, type); + } else if (auto mmaLayout = mlir::dyn_cast(layout)) { + if (mmaLayout.isVolta()) + result = + emitBaseIndexWithinCTAForMmaLayoutV1(loc, rewriter, mmaLayout, type); + if (mmaLayout.isAmpere() || mmaLayout.isHopper()) + result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, mmaLayout, + type); + } else if (auto mfmaLayout = mlir::dyn_cast(layout)) { + result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type); + } else if (auto wmmaLayout = mlir::dyn_cast(layout)) { + result = emitBaseIndexForWmmaLayout(loc, rewriter, wmmaLayout, type); + } else if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(type.getShape()); + RankedTensorType parentTy = + RankedTensorType::get(parentShape, type.getElementType(), parentLayout); + result = emitBaseIndexForLayoutImpl(loc, rewriter, target, parentLayout, + parentTy, withCTAOffset); + result.erase(result.begin() + sliceLayout.getDim()); + // CTAOffset has been added in emitBaseIndexForLayout of parentLayout + return result; + } else { + llvm_unreachable("unsupported emitBaseIndexForLayout"); + } + if (withCTAOffset) { + auto CTAOffset = + emitCTAOffsetForLayout(loc, rewriter, target, layout, shape); + assert(CTAOffset.size() == result.size() && "Rank mismatch"); + for (unsigned k = 0; k < result.size(); ++k) { + // Individual elements of `result` may be null. In the caller + // (emitBaseIndexForLayout), we assert that all such dimensions are sliced + // off. + if (!result[k]) + continue; + result[k] = add(result[k], CTAOffset[k]); + } + } + return result; +} + +inline SmallVector +emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { + SmallVector idx = emitBaseIndexForLayoutImpl( + loc, rewriter, target, layout, type, withCTAOffset); + + // Check that any null values were sliced out. + for (Value v : idx) { + if (!v) { + llvm::errs() << "Failed to generate indexing code, possibly due to bad " + "#mma layout. Please rerun your program with " + "MLIR_ENABLE_DUMP=1 and file a bug." + << "\nloc: " << loc << "\nlayout: " << layout + << "\ntype: " << type << "\nwithCTAOffset: " << withCTAOffset + << "\n"; + llvm::report_fatal_error("Failed to generate indexing code"); + } + } + + return idx; +} + +inline SmallVector> +emitOffsetForLayout(Attribute layout, RankedTensorType type) { + if (auto clusterLayout = dyn_cast(layout)) + return emitOffsetForClusterLayout(clusterLayout, type); + if (auto blockedLayout = dyn_cast(layout)) + return emitOffsetForBlockedLayout(blockedLayout, type); + if (auto mmaLayout = dyn_cast(layout)) { + if (mmaLayout.isVolta()) + return emitOffsetForMmaLayoutV1(mmaLayout, type); + if (mmaLayout.isAmpere()) + return emitOffsetForMmaLayoutV2(mmaLayout, type); + if (mmaLayout.isHopper()) + return emitOffsetForMmaLayoutV3(mmaLayout, type); + } + if (auto mfmaLayout = mlir::dyn_cast(layout)) { + return emitOffsetForMfmaLayout(mfmaLayout, type); + } + if (auto wmmaLayout = mlir::dyn_cast(layout)) { + return emitOffsetForWmmaLayout(wmmaLayout, type); + } + if (auto sliceLayout = mlir::dyn_cast(layout)) + return emitOffsetForSliceLayout(sliceLayout, type); + llvm_unreachable("unsupported emitOffsetForLayout"); +} + +// Eventually this will become the only emitIndices function. +std::optional>> +emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset); + +// Emit indices calculation within each ConversionPattern, and returns a +// [elemsPerThread X rank] index matrix. +inline SmallVector> +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset, + bool allowLL = true) { + // Eventually the LinearLayout path will be the only one. For now we allow + // both paths so we can test that they produce the same results. + if (allowLL && target.enableLinearLayout()) { + std::optional>> llOffsets = + emitIndicesUsingLinearLayouts(loc, rewriter, target, layout, type, + withCTAOffset); + if (llOffsets.has_value()) + return *llOffsets; + } + + // step 1, delinearize threadId to get the base index + auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, target, layout, + type, withCTAOffset); + // step 2, get offset of each element + auto offset = emitOffsetForLayout(layout, type); + // step 3, add offset to base, and reorder the sequence + // of indices to guarantee that elems in the same + // sizePerThread are adjacent in order + auto shape = type.getShape(); + unsigned rank = shape.size(); + unsigned elemsPerThread = offset.size(); + SmallVector> multiDimIdx(elemsPerThread, + SmallVector(rank)); + for (unsigned n = 0; n < elemsPerThread; ++n) + for (unsigned k = 0; k < rank; ++k) + multiDimIdx[n][k] = add(multiDimBase[k], i32_val(offset[n][k])); + + return multiDimIdx; +} + +/* ---------------- */ +/* ---------------- */ +inline DenseMap getSwizzledSharedPtrs( + Location loc, const TargetInfoBase &target, unsigned inVec, + RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout, + Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter, + SmallVectorImpl &offsetVals, SmallVectorImpl &srcStrides) { + // This utility computes the pointers for accessing the provided swizzled + // shared memory layout `resSharedLayout`. More specifically, it computes, + // for all indices (row, col) of `srcEncoding` such that idx % inVec = 0, + // the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] + + // colOff) where : + // phase = (row // perPhase) % maxPhase + // rowOff = row + // colOff = colOffSwizzled + colOffOrdered + // colOffSwizzled = ((col // outVec) ^ phase) * outVec + // colOffOrdered = (col % outVec) // minVec * minVec + // + // Note 1: + // ------- + // Because swizzling happens at a granularity of outVec, we need to + // decompose the offset into a swizzled factor and a non-swizzled + // (ordered) factor + // + // Note 2: + // ------- + // If we have x, y, z of the form: + // x = 0b00000xxxx + // y = 0byyyyy0000 + // z = 0b00000zzzz + // then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y + // This means that we can use some immediate offsets for shared memory + // operations. + auto dstPtrTy = ptr_ty(rewriter.getContext(), 3); + auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides); + Value dstPtrBase = gep(dstPtrTy, resElemTy, smemObj.base, dstOffset); + + auto srcEncoding = srcTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto srcShapePerCTA = triton::gpu::getShapePerCTA(srcTy); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + // swizzling params as described in TritonGPUAttrDefs.td + unsigned outVec = resSharedLayout.getVec(); + unsigned perPhase = resSharedLayout.getPerPhase(); + unsigned maxPhase = resSharedLayout.getMaxPhase(); + // Order + auto inOrder = triton::gpu::getOrder(srcEncoding); + auto outOrder = triton::gpu::getOrder(resSharedLayout); + assert(maxPhase == 1 || + outVec * maxPhase <= srcShape[outOrder[0]] && + "Swizzling would generate out of bounds memory accesses"); + // Tensor indices held by the current thread, as LLVM values + auto srcIndices = emitIndices(loc, rewriter, target, srcEncoding, srcTy, + /*withCTAOffset=*/false); + // Swizzling with leading offsets (e.g. Hopper GMMA) + unsigned swizzlingByteWidth = 0; + if (resSharedLayout.getHasLeadingOffset()) { + if (perPhase == 4 && maxPhase == 2) + swizzlingByteWidth = 32; + else if (perPhase == 2 && maxPhase == 4) + swizzlingByteWidth = 64; + else if (perPhase == 1 && maxPhase == 8) + swizzlingByteWidth = 128; + else + llvm::report_fatal_error("Unsupported shared layout."); + } + unsigned numElemsPerSwizzlingRow = + swizzlingByteWidth * 8 / resElemTy.getIntOrFloatBitWidth(); + Value numElemsPerSwizzlingRowVal = i32_val(numElemsPerSwizzlingRow); + unsigned leadingDimOffset; + if (outOrder.size() >= 2) { + leadingDimOffset = numElemsPerSwizzlingRow * srcShapePerCTA[outOrder[1]]; + } else { + leadingDimOffset = numElemsPerSwizzlingRow; + } + + Value leadingDimOffsetVal = i32_val(leadingDimOffset); + // Return values + DenseMap ret; + // cache for non-immediate offsets + DenseMap cacheCol, cacheRow; + unsigned minVec = std::min(outVec, inVec); + Value strideRow = outOrder.size() >= 2 ? srcStrides[outOrder[1]] : i32_val(0); + Value strideCol = srcStrides[outOrder[0]]; + LDBG("getSwizzledSharedPtrs: perPhase = " + << perPhase << " maxPhase = " << maxPhase << " minVec = " << minVec + << " inVec = " << inVec << " outVec = " << outVec << " strideRow " + << strideRow << " strideCol " << strideCol); + for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) { + Value offset = i32_val(0); + // Extract multi dimensional index for current element + auto idx = srcIndices[elemIdx]; + Value idxCol = idx[outOrder[0]]; // contiguous dimension + Value idxRow; + if (outOrder.size() >= 2) { + idxRow = idx[outOrder[1]]; // discontiguous dimension + } else { + idxRow = i32_val(0); + } + // compute phase = (row // perPhase) % maxPhase + Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase)); + // extract dynamic/static offset for immediate offsetting + unsigned immedateOffCol = 0; + unsigned immedateOffRow = 0; + if (leadingDimOffset) { + // hopper + offset = + mul(udiv(idxCol, numElemsPerSwizzlingRowVal), leadingDimOffsetVal); + // Shrink by swizzling blocks + idxCol = urem(idxCol, numElemsPerSwizzlingRowVal); + strideRow = numElemsPerSwizzlingRowVal; + } + if (auto add = dyn_cast_or_null(idxCol.getDefiningOp())) { + if (auto _cst = dyn_cast_or_null( + add.getRhs().getDefiningOp())) { + unsigned cst = + cast(_cst.getValue()).getValue().getSExtValue(); + unsigned key = cst % (outVec * maxPhase); + cacheCol.insert({key, idxCol}); + idxCol = cacheCol[key]; + immedateOffCol = cst / (outVec * maxPhase) * (outVec * maxPhase); + } + } + if (auto add = dyn_cast_or_null(idxRow.getDefiningOp())) { + if (auto _cst = dyn_cast_or_null( + add.getRhs().getDefiningOp())) { + unsigned cst = + mlir::cast(_cst.getValue()).getValue().getSExtValue(); + unsigned key = cst % (perPhase * maxPhase); + cacheRow.insert({key, idxRow}); + idxRow = cacheRow[key]; + immedateOffRow = cst / (perPhase * maxPhase) * (perPhase * maxPhase); + } + } + // row offset is simply row index + Value rowOff = mul(idxRow, strideRow); + // because swizzling happens at a granularity of outVec, we need to + // decompose the offset into a swizzled factor and a non-swizzled + // (ordered) factor: colOffSwizzled = ((col // outVec) ^ phase) * outVec + // colOffOrdered = (col % outVec) // minVec * minVec + Value colOffSwizzled = xor_(udiv(idxCol, i32_val(outVec)), phase); + colOffSwizzled = mul(colOffSwizzled, i32_val(outVec)); + Value colOffOrdered = urem(idxCol, i32_val(outVec)); + colOffOrdered = udiv(colOffOrdered, i32_val(minVec)); + colOffOrdered = mul(colOffOrdered, i32_val(minVec)); + Value colOff = add(colOffSwizzled, colOffOrdered); + // compute non-immediate offset + if (outOrder.size() == 3) + offset = add(offset, mul(idx[outOrder[2]], srcStrides[outOrder[2]])); + offset = add(offset, add(rowOff, mul(colOff, strideCol))); + Value currPtr = gep(dstPtrTy, resElemTy, dstPtrBase, offset); + // compute immediate offset + Value immediateOff; + if (outOrder.size() >= 2) { + immediateOff = + add(mul(i32_val(immedateOffRow), strideRow), i32_val(immedateOffCol)); + } else { + immediateOff = i32_val(immedateOffCol); + } + + ret[elemIdx] = gep(dstPtrTy, resElemTy, currPtr, immediateOff); + } + return ret; +} + +inline SmallVector loadSharedToDistributed( + Value dst, Value src, SharedMemoryObject smemObj, Type elemTy, Location loc, + ConversionPatternRewriter &rewriter, const TargetInfoBase &target) { + auto dstTy = cast(dst.getType()); + auto dstShape = dstTy.getShape(); + assert(dstShape.size() <= 2 && "Unexpected rank of loadSharedToDistributed"); + auto srcTy = cast(src.getType()); + auto dstDistributedLayout = dstTy.getEncoding(); + if (auto mmaLayout = dyn_cast(dstDistributedLayout)) { + assert((!mmaLayout.isVolta()) && + "ConvertLayout Shared->MMAv1 is not supported yet"); + } + auto srcSharedLayout = + cast(srcTy.getEncoding()); + auto srcElemTy = srcTy.getElementType(); + auto dstElemTy = dstTy.getElementType(); + LDBG("loadSharedToDistributed elemTy " << elemTy << " srcElemTy " << srcElemTy + << " dstElemTy " << dstElemTy); + auto inOrd = triton::gpu::getOrder(srcSharedLayout); + auto outOrd = triton::gpu::getOrder(dstDistributedLayout); + unsigned outVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + dstDistributedLayout, dstShape)[outOrd[0]] + : 1; + + // If the shmem layout is not swizzled, we can trivially vectorize loads + // across the whole width of the most-minor dimension of the shape, because + // Triton requires all the dims are powers of 2. + unsigned inVec = srcSharedLayout.getMaxPhase() == 1 + ? srcTy.getShape()[inOrd[0]] + : srcSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy); + SmallVector offsetVals = {smemObj.strides.size(), i32_val(0)}; + + DenseMap sharedPtrs = + getSwizzledSharedPtrs(loc, target, outVec, dstTy, srcSharedLayout, elemTy, + smemObj, rewriter, offsetVals, smemObj.strides); + assert(outElems % minVec == 0 && "Unexpected number of elements"); + unsigned numVecs = outElems / minVec; + auto wordTy = vec_ty(elemTy, minVec); + SmallVector outVals(outElems); + for (unsigned i = 0; i < numVecs; ++i) { + Value smemAddr = sharedPtrs[i * minVec]; + smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); + auto valVec = load(wordTy, smemAddr); + valVec.setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8); + for (unsigned v = 0; v < minVec; ++v) { + Value currVal = extract_element(elemTy, valVec, i32_val(v)); + outVals[i * minVec + v] = currVal; + } + } + return outVals; +} + +inline void storeDistributedToShared(Value src, ArrayRef inVals, + ArrayRef dstStrides, Value dst, + Value smemBase, Type elemTy, Location loc, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &target) { + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + auto rank = srcShape.size(); + assert(rank <= 3 && "Unexpected rank of storeDistributedToShared"); + auto dstTy = cast(dst.getType()); + auto srcDistributedLayout = srcTy.getEncoding(); + if (auto mmaLayout = dyn_cast(srcDistributedLayout)) { + assert((!mmaLayout.isVolta()) && + "ConvertLayout MMAv1->Shared is not supported yet"); + } + auto dstSharedLayout = + cast(dstTy.getEncoding()); + auto dstElemTy = dstTy.getElementType(); + auto inOrd = triton::gpu::getOrder(srcDistributedLayout); + auto outOrd = dstSharedLayout.getOrder(); + unsigned inVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + srcDistributedLayout, srcShape)[inOrd[0]] + : 1; + // If the shmem layout is not swizzled, we can trivially vectorize stores + // across the whole width of the most-minor dimension of the shape, because + // Triton requires all the dims are powers of 2. + unsigned outVec = dstSharedLayout.getMaxPhase() == 1 + ? dstTy.getShape()[inOrd[0]] + : dstSharedLayout.getVec(); + unsigned minVec = std::min(outVec, inVec); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + auto wordTy = vec_ty(elemTy, minVec); + Value word; + + SmallVector srcStrides(dstStrides); + SmallVector offsetVals(rank, i32_val(0)); + SharedMemoryObject smemObj(smemBase, elemTy, srcStrides, offsetVals); + + DenseMap sharedPtrs = + getSwizzledSharedPtrs(loc, target, inVec, srcTy, dstSharedLayout, elemTy, + smemObj, rewriter, offsetVals, srcStrides); + LDBG("storeDistributedToShared: numElems = " << numElems << " minVec = " + << minVec << " " << wordTy); + for (unsigned i = 0; i < numElems; ++i) { + if (i % minVec == 0) + word = undef(wordTy); + word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec)); + if (i % minVec == minVec - 1) { + Value smemAddr = sharedPtrs[i / minVec * minVec]; + smemAddr = bitcast(smemAddr, ptr_ty(rewriter.getContext(), 3)); + store(word, smemAddr) + .setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8); + } + } +} + +inline Value +getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, + ConversionPatternRewriter &rewriter) { + auto elems = smemObj.getElems(); + auto types = smemObj.getTypes(); + auto structTy = + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); + // pack into struct + Value llvmStruct = rewriter.create(loc, structTy); + for (const auto &v : llvm::enumerate(elems)) { + assert(v.value() && "can not insert null values"); + llvmStruct = insert_val(structTy, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +inline SmallVector +unpackLLElements(Location loc, Value llvmStruct, + ConversionPatternRewriter &rewriter) { + assert(bool(llvmStruct) && "can not unpack null values"); + if (llvmStruct.getType().isIntOrIndexOrFloat() || + isa(llvmStruct.getType()) || + isa(llvmStruct.getType()) || + isa(llvmStruct.getType())) + return {llvmStruct}; + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector results(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + results[i] = extract_val(type, llvmStruct, i); + } + return results; +} + +inline Value packLLElements(Location loc, + const LLVMTypeConverter *typeConverter, + ValueRange resultVals, + ConversionPatternRewriter &rewriter, Type type) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (!structType) { + assert(resultVals.size() == 1); + return *resultVals.begin(); + } + + auto elementTypes = structType.getBody(); + if (elementTypes.size() != resultVals.size()) { + emitError(loc) << " size mismatch when packing elements for LLVM struct" + << " expected " << elementTypes.size() << " but got " + << resultVals.size(); + } + Value llvmStruct = rewriter.create(loc, structType); + for (const auto &v : llvm::enumerate(resultVals)) { + if (!v.value()) { + emitError(loc) + << "cannot insert null values into struct, but tried to insert" + << v.value(); + } + if (v.value().getType() != elementTypes[v.index()]) { + LDBG("type " << type << " structType " << structType); + LDBG("value " << v.value()); + emitError(loc) << "invalid element type in packLLEElements. Expected " + << elementTypes[v.index()] << " but got " + << v.value().getType(); + } + llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index()); + } + return llvmStruct; +} + +inline bool isLayoutMmaV1(Attribute layout) { + bool isMmaV1 = false; + if (auto mmaLayout = dyn_cast(layout)) { + isMmaV1 = mmaLayout.isVolta(); + } + if (auto sliceLayout = dyn_cast(layout)) { + isMmaV1 = isa(sliceLayout.getParent()) && + cast(sliceLayout.getParent()).isVolta(); + } + return isMmaV1; +} + +} // namespace mlir + +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..99d90c4d7 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPU) +add_public_tablegen_target(TritonConversionPassIncGen) diff --git a/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/Passes.h b/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/Passes.h new file mode 100644 index 000000000..e159406b3 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_CONVERSION_PASSES_H +#define TRITON_CONVERSION_PASSES_H + +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/Passes.td b/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/Passes.td new file mode 100644 index 000000000..84150fe67 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/Passes.td @@ -0,0 +1,37 @@ +#ifndef TRITON_CONVERSION_PASSES +#define TRITON_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleOp"> { + let summary = "Convert Triton to TritonGPU"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonToTritonGPUPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + // TODO: Does this pass depend on SCF? + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect"]; + + let options = [ + Option<"numWarps", "num-warps", + "int32_t", /*default*/"4", + "number of warps">, + + Option<"threadsPerWarp", "threads-per-warp", + "int32_t", /*default*/"32", + "number of threads per warp">, + Option<"numCTAs", "num-ctas", + "int32_t", /*default*/"1", + "number of ctas in a cga">, + Option<"target", "target", + "std::string", /*default*/"\"\"", + "the GPU target, e.g., cuda:80, hip:gfx942"> + ]; +} + +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h new file mode 100644 index 000000000..d3da1394e --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -0,0 +1,31 @@ +#ifndef TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H +#define TRITON_CONVERSION_TRITONTOTRITONGPU_TRITONTOTRITONGPUPASS_H + +#include +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps"; +constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas"; +constexpr static char AttrTargetName[] = "triton_gpu.target"; + +constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp"; + +// Create the pass with numWarps passed from cl::opt. +std::unique_ptr> createConvertTritonToTritonGPUPass(); + +// Create the pass with numWarps set explicitly. +std::unique_ptr> +createConvertTritonToTritonGPUPass(const std::string &target, int numWarps, + int threadsPerWarp = 32, int numCTAs = 1); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/CMakeLists.txt b/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/CMakeLists.txt new file mode 100644 index 000000000..fb51818c8 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonXPU) +add_public_tablegen_target(TT2TTXConversionPassIncGen) diff --git a/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/Passes.h b/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/Passes.h new file mode 100644 index 000000000..10ba1dae1 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/Passes.h @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TT2TTX_CONVERSION_PASSES_H +#define TT2TTX_CONVERSION_PASSES_H + +#include "mlir/Pass/Pass.h" +#include "triton/Conversion/TritonToTritonXPU/TritonToTritonXPUPass.h" + +namespace mlir { + +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonToTritonXPU/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TT2TTX_CONVERSION_PASSES_H diff --git a/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/Passes.td b/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/Passes.td new file mode 100644 index 000000000..83f213c46 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/Passes.td @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TT2TTX_CONVERSION_PASSES +#define TT2TTX_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonToTritonXPU: Pass<"convert-triton-to-triton-xpu", "mlir::ModuleOp"> { + let summary = "Convert Triton to TritonXPU"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonToTritonXPUPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + // TODO: Does this pass depend on SCF? + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + // "mlir::triton::gpu::TritonGPUDialect", // Does this pass depend on TritonGPU in Triton 3.0? + "mlir::triton::xpu::TritonXPUDialect"]; + + let options = [ + Option<"xpu_arch", "xpu_arch", + "uint32_t", /*default*/"3", + "XPU Architecture">, + Option<"buffer_size", "buffer_size", + "uint32_t", /*default*/"512", + "bytes for local memory buffer">, + Option<"core_num", "core_num", + "uint32_t", /*default*/"64", + "xpu spec core_num"> + ]; +} + + +#endif // TT2TTX_CONVERSION_PASSES diff --git a/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/TritonToTritonXPUPass.h b/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/TritonToTritonXPUPass.h new file mode 100644 index 000000000..fb7d4a933 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonToTritonXPU/TritonToTritonXPUPass.h @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITON_CONVERSION_TT2TTX_TRITONTOTRITONXPUPASS_H +#define TRITON_CONVERSION_TT2TTX_TRITONTOTRITONXPUPASS_H + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +constexpr static char AttrXPUTargetName[] = "triton_xpu.target"; + +// Create the pass with buffer_size passed from cl::opt. +std::unique_ptr> createConvertTritonToTritonXPUPass(); + +// Create the pass with buffer_size set explicitly. +std::unique_ptr> +createConvertTritonToTritonXPUPass(uint32_t xpu_arch, uint32_t buffer_size, + uint32_t core_num); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TT2TTX_TRITONTOTRITONXPUPASS_H diff --git a/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/CMakeLists.txt b/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..abbabba58 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonXPUToLLVM) +add_public_tablegen_target(TTX2LLVMConversionPassIncGen) diff --git a/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/Passes.h b/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/Passes.h new file mode 100644 index 000000000..2e21b4d0b --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/Passes.h @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TTX2LLVM_CONVERSION_TRITONXPUTOLLVM_PASSES_H +#define TTX2LLVM_CONVERSION_TRITONXPUTOLLVM_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +#define GEN_PASS_DECL +#include "triton/Conversion/TritonXPUToLLVM/Passes.h.inc" + +namespace xpu { + +// TODO[dyq]: can be used ? +// std::unique_ptr> +// createDecomposeUnsupportedConversionsPass(uint32_t xpu_arch); + +} // namespace xpu + +std::unique_ptr> createConvertTritonXPUToLLVMPass(); +std::unique_ptr> +createConvertTritonXPUToLLVMPass(uint32_t xpu_arch, uint32_t buffer_size); + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonXPUToLLVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TTX2LLVM_CONVERSION_TRITONXPUTOLLVM_PASSES_H diff --git a/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/Passes.td b/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/Passes.td new file mode 100644 index 000000000..6bb6cc845 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/Passes.td @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TTX2LLVM_CONVERSION_PASSES +#define TTX2LLVM_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonXPUToLLVM: Pass<"convert-triton-xpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert TritonXPU to LLVM"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonXPUToLLVMPass()"; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; + + let options = [ + Option<"xpu_arch", "xpu_arch", + "uint32_t", /*default*/"3", + "XPU Architecture">, + Option<"buffer_size", "buffer_size", + "uint32_t", /*default*/"128", + "Buffer Size of LM">, + ]; +} + + +#endif // TTX2LLVM_CONVERSION_PASSES diff --git a/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/TypeConverter.h b/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/TypeConverter.h new file mode 100644 index 000000000..c6e17d0e0 --- /dev/null +++ b/third_party/xpu/include/triton/Conversion/TritonXPUToLLVM/TypeConverter.h @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITON_CONVERSION_TRITONXPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONXPU_TO_LLVM_TYPECONVERTER_H + +// clang-format off +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +// clang-format on + +using namespace mlir; +using namespace mlir::triton; + +class TritonXPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonXPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); + + Type getElementTypeForStruct(TensorOrMemDesc type); + Type convertTritonPointerType(triton::PointerType type); + Type convertTritonTensorType(RankedTensorType type); +}; + +#endif // TRITON_CONVERSION_TRITONXPU_TO_LLVM_TYPECONVERTER_H diff --git a/third_party/xpu/include/triton/Dialect/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/CMakeLists.txt new file mode 100644 index 000000000..b8ffa31f9 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) +add_subdirectory(TritonNvidiaGPU) +add_subdirectory(TritonXPU) +add_subdirectory(LLVMXPU) diff --git a/third_party/xpu/include/triton/Dialect/LLVMXPU/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/LLVMXPU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/LLVMXPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/xpu/include/triton/Dialect/LLVMXPU/IR/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/LLVMXPU/IR/CMakeLists.txt new file mode 100644 index 000000000..48c65efb2 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/LLVMXPU/IR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect(LLVMXPUOps llvm_xpu) +# TODO[dyq]: add_mlir_doc permission +# add_mlir_doc(LLVMXPUOps LLVMXPUDialect Dialects/ -gen-dialect-doc -dialect=llvm_xpu) +set(LLVM_TARGET_DEFINITIONS LLVMXPUOps.td) +mlir_tablegen(LLVMXPUConversions.inc -gen-llvmir-conversions) +mlir_tablegen(LLVMXPUFromLLVMIRConversions.inc -gen-intr-from-llvmir-conversions) +mlir_tablegen(LLVMXPUConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=llvm_xpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=llvm_xpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRXPUConversionsIncGen) diff --git a/third_party/xpu/include/triton/Dialect/LLVMXPU/IR/Dialect.h b/third_party/xpu/include/triton/Dialect/LLVMXPU/IR/Dialect.h new file mode 100644 index 000000000..c76b51ad9 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/LLVMXPU/IR/Dialect.h @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_LLVMXPU_IR_DIALECT_H_ +#define MLIR_DIALECT_LLVMXPU_IR_DIALECT_H_ + +// LLVMXPUDialect +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Dialect/LLVMXPU/IR/Dialect.h.inc" + +// LLVMXPUOps +#define GET_OP_CLASSES +#include "triton/Dialect/LLVMXPU/IR/Ops.h.inc" + +namespace mlir { +namespace LLVM { +namespace XPU {} // namespace XPU +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMXPU_IR_DIALECT_H_ diff --git a/third_party/xpu/include/triton/Dialect/LLVMXPU/IR/LLVMXPUOps.td b/third_party/xpu/include/triton/Dialect/LLVMXPU/IR/LLVMXPUOps.td new file mode 100644 index 000000000..564b5cca0 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/LLVMXPU/IR/LLVMXPUOps.td @@ -0,0 +1,408 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef LLVMXPUIR_OPS +#define LLVMXPUIR_OPS + +include "mlir/IR/OpBase.td" // Trait +include "mlir/IR/DialectBase.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" // LLVM_OpBase +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType + + +//===----------------------------------------------------------------------===// +// LLVMXPU dialect definitions +//===----------------------------------------------------------------------===// + +def LLVMXPU_Dialect : Dialect { + let name = "llvm_xpu"; + + let cppNamespace = "::mlir::LLVM::XPU"; + + let hasOperationAttrVerify = 1; + + let dependentDialects = ["LLVM::LLVMDialect"]; + + let extraClassDeclaration = [{ + /// Get the name of the attribute used to annotate external kernel + /// functions. + static StringRef getKernelFuncAttrName() { return "xpu.kernel"; } + }]; + + // let useDefaultAttributePrinterParser = 1; +} + + + +//===----------------------------------------------------------------------===// +// LLVMXPU op definitions +//===----------------------------------------------------------------------===// + +class XPU_Op traits = []> : + LLVM_OpBase { +} + + +//===----------------------------------------------------------------------===// +// LLVMXPU intrinsic operations +//===----------------------------------------------------------------------===// + +class XPU_IntrOp traits = [], + int numResults> + : LLVM_IntrOpBase overloadedResults=*/[], + /*list overloadedOperands=*/[], + traits, numResults>; + + +def XPU_GM2LMOp : XPU_IntrOp<"gm2lm", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src, // TODO[dyq]: Check i8? + LLVM_AnyPointer:$dst, + I32:$offset, + I32:$size); + string llvmBuilder = [{ + // XPU2, offset only has 7 bits, so we let LLVM fold the value + auto addr = builder.CreatePtrToInt($src, $offset->getType()); + auto faddr = builder.CreateAdd(addr, $offset); + auto srcptr = builder.CreateIntToPtr(faddr, $src->getType()); + auto zero = builder.getInt32(0); + createIntrinsicCall(builder, llvm::Intrinsic::xpu_gm2lm, {$dst, srcptr, zero, $size}); + }]; +} + +def XPU_LM2GMOp : XPU_IntrOp<"lm2gm", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src, + LLVM_AnyPointer:$dst, + I32:$offset, + I32:$size); + string llvmBuilder = [{ + // XPU2, offset only has 7 bits, so we let LLVM fold the value + auto addr = builder.CreatePtrToInt($dst, $offset->getType()); + auto faddr = builder.CreateAdd(addr, $offset); + auto dstptr = builder.CreateIntToPtr(faddr, $dst->getType()); + auto zero = builder.getInt32(0); + createIntrinsicCall(builder, llvm::Intrinsic::xpu_lm2gm, {dstptr, zero, $src, $size}); + }]; +} + +def XPU_SM2GMOp : XPU_IntrOp<"sm2gm", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src, + LLVM_AnyPointer:$dst, + I32:$offset, + I32:$size); + string llvmBuilder = [{ + // XPU2, offset only has 7 bits, so we let LLVM fold the value + auto addr = builder.CreatePtrToInt($dst, $offset->getType()); + auto faddr = builder.CreateAdd(addr, $offset); + auto dstptr = builder.CreateIntToPtr(faddr, $dst->getType()); + auto zero = builder.getInt32(0); + createIntrinsicCall(builder, llvm::Intrinsic::xpu_sm2gm, {dstptr, zero, $src, $size}); + }]; +} + + +def XPU_GM2LMOp_v3 : XPU_IntrOp<"gm2lm_v3", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src, + LLVM_AnyPointer:$dst, + I32:$offset, + I32:$size); + string llvmBuilder = [{ + auto isrc = builder.CreatePtrToInt($src, builder.getInt64Ty()); + auto ioft = builder.CreateZExtOrTrunc($offset, builder.getInt64Ty()); + auto asrc = builder.CreateAdd(isrc, ioft); + auto fsrc = builder.CreateIntToPtr(asrc, $src->getType()); + auto zero = builder.getInt32(0); + createIntrinsicCall(builder, llvm::Intrinsic::xpu_gm2lm_v3, {$dst, fsrc, zero, $size}); + }]; +} + +def XPU_LM2GMOp_v3 : XPU_IntrOp<"lm2gm_v3", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src, + LLVM_AnyPointer:$dst, + I32:$offset, + I32:$size); + string llvmBuilder = [{ + auto idst = builder.CreatePtrToInt($dst, builder.getInt64Ty()); + auto ioft = builder.CreateZExtOrTrunc($offset, builder.getInt64Ty()); + auto adst = builder.CreateAdd(idst, ioft); + auto fdst = builder.CreateIntToPtr(adst, $dst->getType()); + auto zero = builder.getInt32(0); + createIntrinsicCall(builder, llvm::Intrinsic::xpu_lm2gm_v3, {fdst, $src, zero, $size}); + }]; +} + +def XPU_SM2GMOp_v3 : XPU_IntrOp<"sm2gm_v3", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src, + LLVM_AnyPointer:$dst, + I32:$offset, + I32:$size); + string llvmBuilder = [{ + auto idst = builder.CreatePtrToInt($dst, builder.getInt64Ty()); + auto ioft = builder.CreateZExtOrTrunc($offset, builder.getInt64Ty()); + auto adst = builder.CreateAdd(idst, ioft); + auto fdst = builder.CreateIntToPtr(adst, $dst->getType()); + auto zero = builder.getInt32(0); + createIntrinsicCall(builder, llvm::Intrinsic::xpu_sm2gm_v3, {fdst, $src, zero, $size}); + }]; +} + +//===----------------------------------------------------------------------===// +// XPU special register op definitions +//===----------------------------------------------------------------------===// + +class XPU_SpecialRegisterOp traits = []> : + XPU_IntrOp { + let arguments = (ins); + let assemblyFormat = "attr-dict `:` type($res)"; +} + +//===----------------------------------------------------------------------===// +// Physiscal cluster index and range (0-7) +//===----------------------------------------------------------------------===// + +def XPU_ClusterIdOp : XPU_SpecialRegisterOp<"cluster_id">; + +//===----------------------------------------------------------------------===// +// Core index and range +//===----------------------------------------------------------------------===// + +def XPU_CoreIdOp : XPU_SpecialRegisterOp<"core_id">; + +//===----------------------------------------------------------------------===// +// XPU load parameters op definitions +//===----------------------------------------------------------------------===// + +def XPU_LoadParamOp : XPU_Op<"load_param"> { + let arguments = (ins I32:$num); + let results = (outs I32:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu_load_param, {$num}); + }]; + let assemblyFormat = "$num attr-dict"; +} + +//===----------------------------------------------------------------------===// +// XPU mfence v2 op definitions +//===----------------------------------------------------------------------===// + +def XPU_MfenceOp : XPU_Op<"mfence"> { + let arguments = (ins I32:$num); + string llvmBuilder = [{ + createIntrinsicCall(builder, llvm::Intrinsic::xpu2_mfence, {$num}); + }]; + let assemblyFormat = "$num attr-dict"; +} + +//===----------------------------------------------------------------------===// +// XPU barrier for inter-cluster +// mfence + sync_cluster +//===----------------------------------------------------------------------===// +def XPU_BarrierOp : XPU_Op<"barrier"> { + let arguments = (ins); + string llvmBuilder = [{ + auto five = builder.getInt32(7); + createIntrinsicCall(builder, llvm::Intrinsic::xpu2_mfence, {five}); + auto mask = builder.getInt32(65535); + createIntrinsicCall(builder, llvm::Intrinsic::xpu_csr_set_sync_group, {mask}); + }]; + let assemblyFormat = "attr-dict"; +} + +//===----------------------------------------------------------------------===// +// XPU set haddr op definitions +//===----------------------------------------------------------------------===// + +def XPU_SetHaddrOp : XPU_Op<"set_haddr"> { + let arguments = (ins I32:$num); + string llvmBuilder = [{ + createIntrinsicCall(builder, llvm::Intrinsic::xpu_set_haddr, {$num}); + }]; + let assemblyFormat = "$num attr-dict"; +} + +//===----------------------------------------------------------------------===// +// XPU log definitions +//===----------------------------------------------------------------------===// + +def XPU_LogOp : XPU_Op<"log", [Elementwise, SameOperandsAndResultType, SameOperandsAndResultShape]> { + let arguments = (ins F32:$operand); + let results = (outs F32:$result); + string llvmBuilder = [{ + $result = createIntrinsicCall(builder, llvm::Intrinsic::xpu_log2f, {$operand}); + }]; + let assemblyFormat = "$operand attr-dict `:` type($result)"; +} + +//===----------------------------------------------------------------------===// +// XPU min definitions +//===----------------------------------------------------------------------===// +def XPU_MinOp : XPU_IntrOp<"min", [], 0> { + let arguments = (ins I32:$lhs, + I32:$rhs); + let results = (outs I32:$res); + string llvmBuilder = [{ + // XTDK only supports llvm.xpu.min + $res = createIntrinsicCall(builder, llvm::Intrinsic::smin, {$lhs, $rhs}); + }]; +} + +//===----------------------------------------------------------------------===// +// XPU VGatherF definitions +//===----------------------------------------------------------------------===// +def XPU_VGatherFOp : XPU_IntrOp<"vgather_mask16_mr1", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src, + LLVM_AnyVector:$mask); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu2_vgather_mask16_mr1, {$src, $mask}); + }]; +} + +def XPU_VGatherHFOp : XPU_IntrOp<"vgather_mr1", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src, + LLVM_AnyVector:$mask); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu2_vgather_mr1, {$src, $mask}); + }]; +} + +//===----------------------------------------------------------------------===// +// XPU vvor_f_mh_rn definitions +//===----------------------------------------------------------------------===// +def XPU_VVOR_F_MHOp : XPU_IntrOp<"vvor_f_mh_rn", [], 0> { + let arguments = (ins LLVM_AnyVector:$mask, + LLVM_AnyVector:$a, + LLVM_AnyVector:$b, + LLVM_AnyVector:$c); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu2_vvor_f_mh_rn, {$mask, $a, $b, $c}); + }]; +} + +def XPU_VVOR_HF_MHOp : XPU_IntrOp<"vvor_hf_mh_rn", [], 0> { + let arguments = (ins LLVM_AnyVector:$mask, + LLVM_AnyVector:$a, + LLVM_AnyVector:$b, + LLVM_AnyVector:$c); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu3_vvor_hf_mh_rn, {$mask, $a, $b, $c}); + }]; +} + +def XPU_VVOR_S_MHOp : XPU_IntrOp<"vvor_s_mh", [], 0> { + let arguments = (ins LLVM_AnyVector:$mask, + LLVM_AnyVector:$a, + LLVM_AnyVector:$b, + LLVM_AnyVector:$c); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu2_vvor_s_mh, {$mask, $a, $b, $c}); + }]; +} + +//===----------------------------------------------------------------------===// +// XPU vload/vstore definitions +//===----------------------------------------------------------------------===// +def XPU_VLOAD_MZOp : XPU_IntrOp<"vload_mz", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src_ptr, + LLVM_AnyVector:$mask); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + auto zero = builder.getInt32(0); + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu2_vload_mz, {$mask, $src_ptr, zero}); + }]; +} + +def XPU_VLOAD_MHOp : XPU_IntrOp<"vload_mh", [], 0> { + let arguments = (ins LLVM_AnyPointer:$src_ptr, + LLVM_AnyVector:$dst_data, + LLVM_AnyVector:$mask); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + auto zero = builder.getInt32(0); + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu2_vload_mh, {$mask, $src_ptr, $dst_data, zero}); + }]; +} + +def XPU_VSTORE_MHOp : XPU_IntrOp<"vstore_mh", [], 0> { + let arguments = (ins LLVM_AnyVector:$src, + LLVM_AnyPointer:$dst_Ptr, + LLVM_AnyVector:$mask); + string llvmBuilder = [{ + auto zero = builder.getInt32(0); + createIntrinsicCall(builder, llvm::Intrinsic::xpu2_vstore_mh, {$src, $mask, $dst_Ptr, zero}); + }]; +} + +//===----------------------------------------------------------------------===// +// XPU svsllp/svsrlp definitions +//===----------------------------------------------------------------------===// +def XPU_SVSLLPOp : XPU_IntrOp<"svsllp", [], 0> { + let arguments = (ins I32:$offset, + LLVM_AnyVector:$src); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu2_svsllp_s, {$offset, $src}); + }]; +} + +def XPU_SVSRLPOp : XPU_IntrOp<"svsrlp", [], 0> { + let arguments = (ins I32:$offset, + LLVM_AnyVector:$src); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu2_svsrlp_s, {$offset, $src}); + }]; +} + +//===----------------------------------------------------------------------===// +// XPU vmerge definitions +//===----------------------------------------------------------------------===// +def XPU_VMERGE_L_HFOp : XPU_IntrOp<"vmerge_l_hf", [], 0> { + let arguments = (ins LLVM_AnyVector:$src1, + LLVM_AnyVector:$src2); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu3_vmerge_l_hf, {$src1, $src2}); + }]; +} + +def XPU_VMERGE_H_HFOp : XPU_IntrOp<"vmerge_h_hf", [], 0> { + let arguments = (ins LLVM_AnyVector:$src1, + LLVM_AnyVector:$src2); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu3_vmerge_h_hf, {$src1, $src2}); + }]; +} + +//===----------------------------------------------------------------------===// +// XPU vscatter definitions +//===----------------------------------------------------------------------===// +def XPU_SCATTER_MHOp : XPU_IntrOp<"vscatter_mh", [], 0> { + let arguments = (ins LLVM_AnyVector:$value, + LLVM_AnyVector:$mask, + LLVM_AnyPointer:$dst_ptr, + LLVM_AnyVector:$offset); + string llvmBuilder = [{ + createIntrinsicCall(builder, llvm::Intrinsic::xpu2_vscatter_mh, {$value, $mask, $dst_ptr, $offset}); + }]; +} + +//===----------------------------------------------------------------------===// +// XPU vshuffle definitions +//===----------------------------------------------------------------------===// +def XPU_VSHUFFLE2Op : XPU_IntrOp<"vshuffle2", [], 0> { + let arguments = (ins LLVM_AnyVector:$src); + let results = (outs LLVM_AnyVector:$res); + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::xpu3_vshuffle2_hf, {$src}); + }]; +} + +#endif // LLVMXPUIR_OPS diff --git a/third_party/xpu/include/triton/Dialect/Triton/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..f682f54a1 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,27 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS TritonDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs) + +set(LLVM_TARGET_DEFINITIONS TritonInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) + +set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td) +mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) + +add_public_tablegen_target(TritonTableGen) diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/Dialect.h b/third_party/xpu/include/triton/Dialect/Triton/IR/Dialect.h new file mode 100644 index 000000000..b1f1597c5 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/Dialect.h @@ -0,0 +1,83 @@ +#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITON_IR_DIALECT_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "triton/Dialect/Triton/IR/Dialect.h.inc" +#include "triton/Dialect/Triton/IR/OpsEnums.h.inc" +#include "triton/Dialect/Triton/IR/Traits.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { + +struct GlobalMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +class DialectInferLayoutInterface + : public DialectInterface::Base { +public: + DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + inferTransOpEncoding(Attribute operandEncoding, ArrayRef order, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const = 0; + + // Note: This function only verifies the operand encoding. It doesn't infer + // the result encoding. + virtual LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const = 0; + + // Tries to compute the encoding for the result of a reshape operation that + // makes the reshape a "nop", i.e. the same GPU threads contain the same + // elements as before the reshape. Note that this is not always possible (in + // which case you'd need to choose a different layout for the input to the + // reshape). + virtual LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + // Verify that the encoding are compatible to be used together in a dot + // operation + virtual LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const = 0; +}; + +} // namespace triton +} // namespace mlir + +#endif // TRITON_IR_DIALECT_H_ diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/Interfaces.h b/third_party/xpu/include/triton/Dialect/Triton/IR/Interfaces.h new file mode 100644 index 000000000..f8f3a6f74 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/Interfaces.h @@ -0,0 +1,9 @@ +#ifndef TRITON_IR_INTERFACES_H_ +#define TRITON_IR_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/Traits.h b/third_party/xpu/include/triton/Dialect/Triton/IR/Traits.h new file mode 100644 index 000000000..f1b3ba7b4 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/Traits.h @@ -0,0 +1,124 @@ +#ifndef TRITON_IR_TRAITS_H_ +#define TRITON_IR_TRAITS_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LogicalResult.h" + +#include + +namespace mlir { +namespace OpTrait { + +// These functions are out-of-line implementations of the methods in the +// corresponding trait classes. This avoids them being template +// instantiated/duplicated. +namespace impl { +// The rationale for this trait is to prevent users from creating programs +// that would have catastrophic register pressure and cause the compiler to +// hang. +// Since H100 has 256KB registers, we should allow users to create tensors +// of size up to 256K elements. It will spill for datatypes wider than 1B, +// but we probably should limit number of elements (rather than bytes) to +// keep specs simple + +//===-------------------- For Triton XPU -----------------------===// +// Triton XPU don't need the maxTensorNumElements (legalize pass) +int constexpr maxTensorNumElements = INT_MAX; +//===-----------------------------------------------------------===// + +LogicalResult verifyTensorSize(Operation *op); +LogicalResult verifyTensorLayouts(Operation *op); + +LogicalResult verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType = false); + +LogicalResult +verifySameOperandsAndResultEncoding(Operation *op, + bool allowTensorPointerType = false); + +LogicalResult verifySameLoadStoreOperandsShape(Operation *op); + +LogicalResult verifySameLoadStoreOperandsAndResultShape(Operation *op); + +} // namespace impl + +template +class TensorSizeTrait : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorSize(op); + } +}; + +// Trait applied to all Triton MLIR ops. Checks that the layouts of tensors are +// valid. +template +class VerifyTensorLayoutsTrait + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyTensorLayouts(op); + } +}; + +template +class SameOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding(op); + } +}; + +template +class SameOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op); + } +}; + +template +class SameLoadStoreOperandsShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsShape(op); + } +}; + +template +class SameLoadStoreOperandsAndResultShape + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameLoadStoreOperandsAndResultShape(op); + } +}; + +template +class SameLoadStoreOperandsEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsEncoding(op, + /*allowTensorPointerType=*/true); + } +}; + +template +class SameLoadStoreOperandsAndResultEncoding + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifySameOperandsAndResultEncoding( + op, /*allowTensorPointerType=*/true); + } +}; + +} // namespace OpTrait +} // namespace mlir + +#endif diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonAttrDefs.td new file mode 100644 index 000000000..adfeaff6f --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonAttrDefs.td @@ -0,0 +1,121 @@ +#ifndef TRITON_ATTR_DEFS +#define TRITON_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" + +// Attributes for LoadOp and StoreOp +def TT_CacheModifierAttr : I32EnumAttr< + "CacheModifier", "", + [ + I32EnumAttrCase<"NONE", 1, "none">, + I32EnumAttrCase<"CA", 2, "ca">, + I32EnumAttrCase<"CG", 3, "cg">, + I32EnumAttrCase<"WB", 4, "wb">, + I32EnumAttrCase<"CS", 5, "cs">, + I32EnumAttrCase<"WT", 6, "wt">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSemanticAttr : I32EnumAttr< + "MemSemantic", "", + [ + I32EnumAttrCase<"RELAXED", 1, "relaxed">, + I32EnumAttrCase<"ACQUIRE", 2, "acquire">, + I32EnumAttrCase<"RELEASE", 3, "release">, + I32EnumAttrCase<"ACQUIRE_RELEASE", 4, "acq_rel">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_EvictionPolicyAttr : I32EnumAttr< + "EvictionPolicy", "", + [ + I32EnumAttrCase<"NORMAL", 1, "evict_normal">, + I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">, + I32EnumAttrCase<"EVICT_LAST", 3, "evict_last"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_PaddingOptionAttr : I32EnumAttr< + "PaddingOption", "", + [ + I32EnumAttrCase<"PAD_ZERO", 1, "zero">, + // We can not set the string value to "NAN" because it is a keyword in C++ + I32EnumAttrCase<"PAD_NAN", 2, "nan"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +// atomic +def TT_AtomicRMWAttr : I32EnumAttr< + "RMWOp", "", + [ + I32EnumAttrCase<"AND", 1, "and">, + I32EnumAttrCase<"OR", 2, "or">, + I32EnumAttrCase<"XOR", 3, "xor">, + I32EnumAttrCase<"ADD", 4, "add">, + I32EnumAttrCase<"FADD", 5, "fadd">, + I32EnumAttrCase<"MAX", 6, "max">, + I32EnumAttrCase<"MIN", 7, "min">, + I32EnumAttrCase<"UMAX", 8, "umax">, + I32EnumAttrCase<"UMIN", 9, "umin">, + I32EnumAttrCase<"XCHG", 10, "exch"> + ]> { + let cppNamespace = "::mlir::triton"; +} + +def TT_MemSyncScopeAttr : I32EnumAttr< + "MemSyncScope", "", + [ + I32EnumAttrCase<"GPU", 1, "gpu">, + I32EnumAttrCase<"CTA", 2, "cta">, + I32EnumAttrCase<"SYSTEM", 3, "sys">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Program ID dimensions. +def TT_ProgramDim : I32EnumAttr< + "ProgramIDDim", "", + [ + I32EnumAttrCase<"X", 0, "x">, + I32EnumAttrCase<"Y", 1, "y">, + I32EnumAttrCase<"Z", 2, "z">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// Rounding mode. +def TT_RoundingModeAttr : I32EnumAttr< + "RoundingMode", "", + [ + I32EnumAttrCase<"RTZ", 0, "rtz">, + I32EnumAttrCase<"RTNE", 1, "rtne">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// PropagateNan. +def TT_PropagateNanAttr : I32EnumAttr< + "PropagateNan", "", + [ + I32EnumAttrCase<"NONE", 0, "none">, + I32EnumAttrCase<"ALL", 0xFFFF, "all">, + ]> { + let cppNamespace = "::mlir::triton"; +} + +// InputPrecision +def TT_InputPrecisionAttr : I32EnumAttr< + "InputPrecision", "", + [ + I32EnumAttrCase<"TF32", 0, "tf32">, + I32EnumAttrCase<"TF32x3", 1, "tf32x3">, + I32EnumAttrCase<"IEEE", 2, "ieee"> + ]>{ + let cppNamespace = "::mlir::triton"; +} + +#endif diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/TritonDialect.td b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonDialect.td new file mode 100644 index 000000000..c917538c7 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -0,0 +1,46 @@ +#ifndef TRITON_DIALECT +#define TRITON_DIALECT + +include "mlir/IR/OpBase.td" + +def Triton_Dialect : Dialect { + let name = "tt"; + + let cppNamespace = "::mlir::triton"; + + let summary = "The Triton IR in MLIR"; + + let description = [{ + Triton Dialect. + + Dependent Dialects: + * Arith: + * addf, addi, andi, cmpf, cmpi, divf, fptosi, ... + * Math: + * exp, sin, cos, log, ... + * StructuredControlFlow: + * for, if, while, yield, condition + * ControlFlow: + * br, cond_br + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "math::MathDialect", + "scf::SCFDialect", + "cf::ControlFlowDialect" + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +include "triton/Dialect/Triton/IR/TritonTypes.td" + + +#endif // TRITON_DIALECT diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonInterfaces.td new file mode 100644 index 000000000..cfc7d0032 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -0,0 +1,15 @@ +#ifndef TRITON_INTERFACES +#define TRITON_INTERFACES + +include "mlir/IR/OpBase.td" + +def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; +def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">; +def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">; +def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">; +def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">; +def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAndResultShape">; +def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">; +def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">; + +#endif // TRITON_INTERFACES diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonOps.td new file mode 100644 index 000000000..5797659dd --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonOps.td @@ -0,0 +1,1148 @@ +#ifndef TRITON_OPS +#define TRITON_OPS + +include "triton/Dialect/Triton/IR/TritonDialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface +include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface +include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" + + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +// +// Op Base +// +class TT_Op traits = []> : + Op { +} + +// +// Cast Ops +// +// Use cast ops in arith: +// bitcast +// fptoui, fptosi, uitofp, sitofp, +// extf, tructf, +// extui, extsi, tructi +def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast int64 to pointer"; + + let arguments = (ins TT_I64Like:$src); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast pointer to int64"; + + let arguments = (ins TT_PtrLike:$src); + + let results = (outs TT_I64Like:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +// arith.bitcast doesn't support pointers +def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Cast between types of the same bitwidth"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + // TODO: Add verifier +} + +def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + Pure, + /*DeclareOpInterfaceMethods*/]> { + let summary = "Floating point casting for custom types"; + + let description = [{ + Floating point casting for custom types (F8), and non-default rounding modes. + + F8 <-> FP16, BF16, FP32, FP64 + }]; + + let arguments = ( + ins TT_FloatTensor:$src, + OptionalAttr:$rounding + ); + + let results = (outs TT_FloatTensor:$result); + + let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)"; + + let hasVerifier = 1; +} + +// +// Arithmetic Ops +// + +def TT_ClampFOp : TT_Op<"clampf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Clamp operation for floating point types"; + + let description = [{ + Clamp operation for floating point types. + + The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max]. + }]; + + let arguments = ( + ins + TT_FloatLike:$x, + TT_FloatLike:$min, + TT_FloatLike:$max, + TT_PropagateNanAttr:$propagateNan + ); + + let results = (outs TT_FloatLike:$result); + + // List $propagateNan explicitly rather than relying on attr-dict to pick it + // up, because if it's inside attr-dict, its value will be printed as a + // number rather than as a meaningful string. + let assemblyFormat = "$x `,` $min `,` $max `,` `propagateNan` `=` $propagateNan attr-dict `:` type($result)"; +} + +// +// Math Ops +// + +def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise sqrt for floating point types"; + + let description = [{ + Precise sqrt for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x attr-dict `:` type($x)"; +} + +def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Precise div for floating point types"; + + let description = [{ + Precise div for floating point types. + }]; + + let arguments = (ins TT_FloatLike:$x, TT_FloatLike:$y); + + let results = (outs TT_FloatLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, + SameOperandsAndResultType, + Pure]> { + let summary = "Most significant N bits of the 2N-bit product of two integers"; + + let description = [{ + Most significant N bits of the 2N-bit product of two integers. + }]; + + let arguments = (ins TT_IntLike:$x, TT_IntLike:$y); + + let results = (outs TT_IntLike:$result); + + let assemblyFormat = "$x `,` $y attr-dict `:` type($x)"; +} + +// +// Pointer Arith Ops +// +def TT_AddPtrOp : TT_Op<"addptr", + [Pure, + Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; +} + +def TT_AdvanceOp : TT_Op<"advance", + [Pure, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { + let summary = "Advance a tensor pointer by offsets"; + + let arguments = (ins TT_TensorPtr:$ptr, Variadic:$offsets); + + let results = (outs TT_TensorPtr:$result); + + let assemblyFormat = "$ptr `,` `[` $offsets `]` attr-dict `:` type($result)"; +} + +// +// Load/Store Ops +// +def TT_LoadOp : TT_Op<"load", [ + SameLoadStoreOperandsAndResultShape, + SameLoadStoreOperandsAndResultEncoding, + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 1) || std::equal_to<>()">, + TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Load from a tensor of pointers or from a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + Optional:$mask, + Optional:$other, + + DefaultValuedAttr{}">:$boundaryCheck, + OptionalAttr:$padding, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let results = (outs TT_Type:$result); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor pointer with boundary check and padding + OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A tensor of pointers or a pointer to a scalar with mask and other + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + // A utility function to build the operation with all attributes + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)> + ]; + + // Specify `cacheModifier` and `evictionPolicy` explicitly in the + // assemblyFormat instead of as part of attr-dict so that they get printed + // as strings rather than opaque integers. + // + // Note there's no comma between `other` and `cacheModifier` and between + // `cacheModifier` and `evictionPolicy`. This is due to an apparent + // limitation in the MLIR custom-format parser. In oilist, the initial + // keywords of each clause have to be unique, so they can't be `,`. + // + // Even if we gave up on order-independence and used vanilla optional + // clauses, the format (`,` `foo` `=` $foo^)? (`,` `bar` `=` $bar^)? will + // not match the string ", bar = 0" because after the initial comma (first + // token of the first optional clause) we expect to see "foo". + let assemblyFormat = [{ + $ptr (`,` $mask^)? (`,` $other^)? + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +def TT_StoreOp : TT_Op<"store", [ + SameLoadStoreOperandsShape, + SameLoadStoreOperandsEncoding, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"value type matches ptr type", "ptr", "value", + "getPointeeType($_self)">, + TypesMatchWith<"mask type matches ptr type", "ptr", "mask", + "getI1SameShape(getPointeeType($_self))", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "Store by a tensor of pointers or by a tensor pointer"; + + let arguments = ( + ins + AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr, + TT_Type:$value, + Optional:$mask, + DefaultValuedAttr{}">:$boundaryCheck, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$value, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)>, + // A tensor pointer with boundary check + OpBuilder<(ins "Value":$ptr, "Value":$value, "ArrayRef":$boundaryCheck, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)> + ]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between mask, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $ptr `,` $value (`,` $mask^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($ptr) + }]; + + let hasCanonicalizer = 1; +} + +// +// Atomic Ops +// +def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + TypesMatchWith<"ptr type matches value type", "val", "ptr", + "getPointerTypeSameShape($_self)">, + TypesMatchWith<"mask type matches value type", + "val", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 2) || std::equal_to<>()"> +]> { + let summary = "atomic rmw"; + + let description = [{ + load data at $ptr, do $rmw_op with $val, and store result to $ptr. + + return old value at $ptr + }]; + + let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr, + TT_Type:$val, Optional:$mask, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $atomic_rmw_op, $sem, and $scope rather than relying on + // attr-dict so they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)? attr-dict `:` + functional-type(operands, $result) + }]; +} + +def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead]>, + MemoryEffects<[MemWrite]>, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding]> { + let summary = "atomic cas"; + + let description = [{ + compare $cmp with data $old at location $ptr, + + if $old == $cmp, store $val to $ptr, + + else store $old to $ptr, + + return $old + }]; + + let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val, + TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope); + + let results = (outs TT_Type:$result); + + // Explicitly list $sem and $scope rather than relying on attr-dict so + // they're printed as strings rather than opaque integers. + let assemblyFormat = [{ + $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:` + functional-type(operands, $result) + }]; +} + +// +// Shape Manipulation Ops +// +def TT_SplatOp : TT_Op<"splat", [Pure, + SameOperandsAndResultElementType, + SameOperandsAndResultEncoding]> { + let summary = "splat"; + + let arguments = (ins TT_Type:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; +} + +def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "expand_dims"; + + let arguments = (ins TT_Tensor:$src, I32Attr:$axis); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +def TT_ReshapeOp : TT_Op<"reshape", [Pure, + SameOperandsAndResultElementType]> { + let summary = "reinterpret a tensor to a different shape. It may change elements order if the attribute is set."; + let description = [{ + reinterpret a tensor to a different shape. + + If allow_reorder is set the compiler is free to change the order of + elements to generate more efficient code. + + If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. + The compiler is still free to change it for better performance. + }]; + let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr:$efficient_layout); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + let hasCanonicalizeMethod = 1; + let hasFolder = 1; + let hasVerifier = 1; + let builders = [ + OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder), + [{ + build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr()); + }]>]; +} + +def TT_BroadcastOp : TT_Op<"broadcast", [Pure, + SameOperandsAndResultElementType]> { + let summary = "broadcast a tensor"; + + let description = [{ + For a given tensor, broadcast changes one or more dimensions with size 1 + to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot + change the size of a non-1 dimension. + }]; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +// cat is not `pure` because it may reorder elements +def TT_CatOp : TT_Op<"cat", [NoMemoryEffect, + SameTypeOperands, + SameOperandsAndResultElementType]> { + let summary = "concatenate 2 tensors"; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_JoinOp : TT_Op<"join", [ + NoMemoryEffect, SameTypeOperands, + DeclareOpInterfaceMethods, +]> { + let summary = "join two tensors along a new, minor dimension"; + let description = [{ + For example, if the two input tensors are 4x8xf32, returns a tensor of + shape 4x8x2xf32. + + Because Triton tensors always have a power-of-two number of elements, + the two input tensors must have the same shape. + }]; + + let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs); + let results = (outs TT_Tensor:$result); + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)"; +} + +def TT_SplitOp : TT_Op<"split", [ + NoMemoryEffect, + DeclareOpInterfaceMethods, + TypesMatchWith<"outLHS and outRHS types match", + "outLHS", "outRHS", "$_self">, +]> { + let summary = "splits a tensor into two, along its last dimension"; + let description = [{ + The input must be a tensor whose last dimension has size 2. Returns two + tensors, src[..., 0] and src[..., 1]. + + For example, if the input shape is 4x8x2xf32, returns two tensors of + shape 4x8xf32. + }]; + + let arguments = (ins TT_Tensor:$src); + let results = (outs TT_Tensor:$outLHS, TT_Tensor:$outRHS); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($outLHS)"; +} + +def TT_TransOp : TT_Op<"trans", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + + let summary = "rearrange the dimensions of a tensor"; + let description = [{ + For example, given a tensor x with shape [1,2,4], transpose(x) with + order=[2,0,1] rearranges the tensor to have shape [4,1,2]. + + Although this op is called "trans", it implements both tl.trans() and + tl.permute(). ("permute" might be a better name, but it's called "trans" + because originally it only supported 2D tensors.) + + ## Implementation note on encodings: + + In the TritonGPU dialect (and probably others), an encoding is chosen for + this op's output so it's a nop from the perspective of code generation. + + For example, suppose tensor x has an encoding such that GPU thread [i,j,k] + has a register containing element [i,j,k] of the tensor. Now we transpose + x with order [2,1,0], i.e. we reverse the order of its dimensions. In + TritonGPU, we will choose a layout for the output of the transpose so that + GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the + same element it had before! All we've done is "rename" the element that + thread [i,j,k] has. + + The "real" transpose -- i.e. moving data between GPU threads -- occurs in + convertLayout ops that appear before and/or after the operation. + + We do this so that you can chain multiple data-movement ops (e.g. + transpose+reshape+concat) without going to shared memory after each one. + }]; + + let arguments = ( + ins TT_TensorOrMemDesc:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorOrMemDesc:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasFolder = 1; + let hasVerifier = 1; +} + +// +// SPMD Ops +// +def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> { + let arguments = (ins TT_ProgramDim:$axis); + + let results = (outs I32:$result); + + let assemblyFormat = "$axis attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + int32_t getAxisAsInt() { + return static_cast(getAxis()); + } + }]; +} + +// +// Dot Op +// +def TT_DotOp : TT_Op<"dot", [Pure, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC + when the inputs are f32. It can be one of: tf32, tf32x3, ieee. + tf32: use TC with tf32 ops. + tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp + ieee: don't use TC, implement dot in software. + If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. + }]; + + let arguments = ( + ins + TT_TensorOrMemDesc:$a, + TT_TensorOrMemDesc:$b, + TT_FpIntTensor:$c, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs TT_FpIntTensor:$d); + + // attr-dict prints enums as integers. To get inputPrecision printed as a + // string, we need to specify it explicitly. + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; + let hasVerifier = 1; +} + +// +// Reduce Op +// +def TT_ReduceOp: TT_Op<"reduce", + [Pure, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ReduceReturnOp: TT_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +// +// Scan Op +// +def TT_ScanOp: TT_Op<"scan", + [Pure, + SameOperandsAndResultEncoding, + SameOperandsAndResultShape, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Associative scan using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis, BoolAttr:$reverse); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "bool":$reverse)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TT_ScanReturnOp: TT_Op<"scan.return", + [HasParent<"ScanOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for scan operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + + +// +// External Elementwise op +// +def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; +} + +// +// Make Range Op +// +def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> { + let summary = "make range"; + + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + + // WARNING: MLIR generates getStart()/getEnd() functions which return + // uint32_t, even though these arguments are to be interpreted as *signed* + // int32 values. If this matters, use get{Start,End}Attr().getInt(), which + // return int64_t. + let arguments = (ins I32Attr:$start, I32Attr:$end); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = "attr-dict `:` type($result)"; + + // let hasFolder = 1; + let hasVerifier = 1; +} + +// +// ElementwiseInlineAsm Op +// +def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [ + Elementwise, + SameOperandsAndResultEncoding, + DeclareOpInterfaceMethods +]> { + let summary = "inline assembly applying an elementwise operation to a group of packed elements."; + let description = [{ + Runs an inline asm block to generate one or more tensors. + + The asm block is given `packed_element` elements at a time. Exactly which + elems it receives is unspecified. + }]; + + let arguments = (ins StrAttr:$asm_string, StrAttr:$constraints, BoolAttr:$pure, I32Attr:$packed_element, Variadic>:$args); + let results = (outs Variadic:$result); + + let assemblyFormat = [{ + $asm_string attr-dict ($args^ `:` type($args))? `->` type($result) + }]; + + let hasVerifier = 1; +} + +// +// Histogram Op +// +def TT_HistogramOp : TT_Op<"histogram", [Pure]> { + let summary = "return a histgram of the inputs."; + let description = [{ + Return the histogram of the input tensor. The number of bins is equal to + the dimension of the output tensor. Each bins has a width of 1 and bins + start at 0. + }]; + + let arguments = (ins TT_IntTensor:$src); + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = [{ + $src attr-dict `:` type($src) `->` type($result) + }]; +} + +// +// Print Op +// +def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>, + Arguments<(ins StrAttr:$prefix, BoolAttr:$hex, Variadic>:$args)> { + let summary = "Device-side print, as in CUDA for debugging"; + let description = [{ + `tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. + format are generated automatically from the arguments. + }]; + let assemblyFormat = [{ + $prefix attr-dict (`:` $args^ `:` type($args))? + }]; +} + +// +// Assert Op +// +def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "Device-side assert, as in CUDA for correctness checking"; + let description = [{ + `tt.assert` takes a condition tensor, a message string, a file string, a function string, and a line number. + If the condition is false, the message is printed, and the program is aborted. + }]; + let arguments = (ins TT_Tensor:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line); + let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)"; +} + +// +// Make Tensor Pointer Op +// +def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr", + [Pure, + SameVariadicOperandSize, + TypesMatchWith<"infer pointer type from the result type", + "result", "base", + "getPointerType(getElementTypeOfTensorPointerType($_self))">]> { + let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified"; + + let description = [{ + `tt.make_tensor_ptr` takes both meta information of the parent tensor and the block tensor, then it returns a + pointer to the block tensor, e.g. returns a type of `tt.ptr>`. + }]; + + // TODO(Chenggang): unify the integer types. Currently we cannot do that due to hardware constraints. + let arguments = (ins + TT_Ptr:$base, + Variadic:$shape, + Variadic:$strides, + Variadic:$offsets, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_TensorPtr:$result); + + // TODO(Keren): define a custom assembly format for this op because the result type cannot be printed correctly + // Add additional `[]` to increase readability and split variadic lists + let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)"; + + let builders = [ + OpBuilder<(ins + "Value":$base, + "ValueRange":$shape, + "ValueRange":$strides, + "ValueRange":$offsets, + "ArrayRef":$tensorShape, + "ArrayRef":$order + )> + ]; +} + +// The following ops, including `call`, `func`, and `return` are copied and modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +// We could revert it back once MLIR has a better inliner interface. +// +// Function Ops +// +def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `tt.call` operation represents a direct call to a function that is + within the same symbol scope as the call. The operands and result types of + the call must match the specified function type. The callee is encoded as a + symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32 + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]>]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); + } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + + // Required by CallOpInterface. + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> { + let summary = "An operation with a name containing a single `SSACFG` region"; + let description = [{ + Operations within the function cannot implicitly capture values defined + outside of the function, i.e. Functions are `IsolatedFromAbove`. All + external references must use function arguments or attributes that establish + a symbolic connection (e.g. symbols referenced by name via a string + attribute like SymbolRefAttr). An external function declaration (used when + referring to a function declared in some other module) has no body. While + the MLIR textual form provides a nice inline syntax for function arguments, + they are internally represented as “block arguments” to the first block in + the region. + + Only dialect attribute names may be specified in the attribute dictionaries + for function arguments, results, or the function itself. + + Example: + + ```mlir + // External function definitions. + tt.func @abort() + tt.func @scribble(i32, i64, memref) -> f64 + + // A function that returns its argument twice: + tt.func @count(%x: i64) -> (i64, i64) + attributes {fruit: "banana"} { + return %x, %x: i64, i64 + } + + // A function with an argument attribute + tt.func @example_fn_arg(%x: i32 {swift.self = unit}) + + // A function with a result attribute + tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64}) + + // A function with an attribute + tt.func @example_fn_attr() attributes {dialectName.attrName = false} + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); + let regions = (region AnyRegion:$body); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getFunctionType().getResults(); } + + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; +} + +def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable, */ReturnLike, Terminator]> { + let summary = "Function return operation"; + let description = [{ + The `tt.return` operation represents a return operation within a function. + The operation takes variable number of operands and produces no results. + The operand number and types must match the signature of the function + that contains the operation. + + Example: + + ```mlir + tt.func @foo() : (i32, f8) { + ... + tt.return %0, %1 : i32, f8 + } + ``` + }]; + + let arguments = (ins Variadic:$srcs); + + let builders = [OpBuilder<(ins), [{ + build($_builder, $_state, std::nullopt); + }]>]; + + let assemblyFormat = "attr-dict ($srcs^ `:` type($srcs))?"; + let hasVerifier = 1; +} + + +def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [ + MemoryEffects<[MemRead]>]> { + let summary = "Load from descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA load operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The destination tensor type and shape must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + Variadic:$indices, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict + ); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` + oilist( + `cacheModifier` `=` $cache | + `evictionPolicy` `=` $evict + ) + attr-dict `:` qualified(type($desc_ptr)) `->` type($result) + }]; +} + +def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [ + MemoryEffects<[MemWrite]>]> { + let summary = "store value based on descriptor"; + let description = [{ + This operation will be lowered to Nvidia TMA store operation on targets supporting it. + `desc_ptr` is a pointer to the TMA descriptor allocated in global memory. + The shape and types of `src` must match the descriptor otherwise the result is undefined. + + This is an escape hatch and is only there for testing/experimenting. + This op will be removed in the future. + }]; + let arguments = ( + ins + TT_PtrType:$desc_ptr, + TT_Tensor:$src, + Variadic:$indices + ); + + let assemblyFormat = [{ + $desc_ptr `[` $indices `]` `,` $src + attr-dict `:` qualified(type($desc_ptr)) `,` type($src) + }]; +} + +#endif // Triton_OPS diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td new file mode 100644 index 000000000..e3aed2262 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td @@ -0,0 +1,24 @@ +#ifndef TRITON_TYPE_INTERFACES +#define TRITON_TYPE_INTERFACES + +include "mlir/IR/OpBase.td" + +// Interface dynamically attached to RankedTensorType and MemDescType. +def TT_TensorOrMemDesc : TypeInterface<"TensorOrMemDesc"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"Returns the encoding of the tensor or memory descriptor", + "mlir::Attribute", "getEncoding", (ins)>, + InterfaceMethod<"Returns element type", + "mlir::Type", "getElementType", (ins)>, + InterfaceMethod<"Returns the type shape", + "llvm::ArrayRef", "getShape", (ins)>, + InterfaceMethod<"Returns the tensor or buffer rank", + "int64_t", "getRank", (ins)>, + InterfaceMethod<"Returns the element type bit width", + "int64_t", "getElementTypeBitWidth", (ins)>, + + ]; +} + +#endif // TRITON_TYPE_INTERFACES diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/TritonTypes.td b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonTypes.td new file mode 100644 index 000000000..7008d23c0 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -0,0 +1,146 @@ +#ifndef TRITON_TYPES +#define TRITON_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonDialect.td" + +// +// Types +// +class TritonTypeDef traits = []> + : TypeDef { + // Used by printer/parser + let mnemonic = _mnemonic; +} + +// Floating-point Type +def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">; +def TT_FloatTensor : RankedTensorOf<[TT_Float]>; +def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; + +// Boolean Type +// TT_Bool -> I1 +def TT_BoolTensor : RankedTensorOf<[I1]>; +def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; + +// Integer Type +def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">; +def TT_IntTensor : RankedTensorOf<[TT_Int]>; +def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>; + +// I32 Type +// TT_I32 -> I32 +// TT_I32Tensor -> I32Tensor +def TT_I32Like : AnyTypeOf<[I32, I32Tensor]>; + +// I64 Type +// TT_I64 -> I64 +// TT_I64Tensor -> I64Tensor +def TT_I64Like : AnyTypeOf<[I64, I64Tensor]>; + +// Pointer Type in TableGen +class TT_PtrOf pointeeTypes> : + DialectType($_self)">, + Concat<"[](::mlir::Type pointeeType) { return ", + SubstLeaves<"$_self", "pointeeType", AnyTypeOf.predicate>, + "; }(::mlir::cast<::mlir::triton::PointerType>($_self).getPointeeType())">]>, + "ptr", "::mlir::triton::PointerType">; + +// Pointer Type in C++ (corresponding to `TT_PtrOf`) +def TT_PtrType : TritonTypeDef<"Pointer", "ptr"> { + let summary = "Pointer type (`::mlir::triton::PointerType`) in Triton IR type system"; + + let description = [{ + Pointer type in Triton IR type system, which could be pointing to scalars or tensors. + }]; + + let parameters = (ins "Type":$pointeeType, "int":$addressSpace); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Type":$pointeeType, + "int":$addressSpace + ), [{ + return $_get(pointeeType.getContext(), pointeeType, addressSpace); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +// Scalar Pointer Type: `ptr<>` +def TT_Ptr : TT_PtrOf<[AnyType]>; + +// Tensor of Pointer Type: `tensor>` +def TT_PtrTensor : RankedTensorOf<[TT_Ptr]>; + +// Tensor of Pointer Type or Pointer type: `tensor>` or `ptr<>` +def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>; + +// Tensor Type +def TT_FpIntTensor : RankedTensorOf<[TT_Float, TT_Int]>; +def TT_Tensor : RankedTensorOf<[TT_Float, TT_Int, TT_Ptr]>; + +// Pointer Type to Tensor Type: `ptr>` +def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>; + +// Any Type in Triton IR +//===-------------------- For Triton XPU -----------------------===// +// For Vectorization +def TT_Vector : FixedVectorOf<[TT_FloatLike, TT_IntLike]>; +def TT_VectorTensor: TensorOf<[TT_Vector]>; +def TT_VectorLike: AnyTypeOf<[TT_Vector, TT_VectorTensor]>; +def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr, TT_VectorLike]>; +//===-----------------------------------------------------------===// + +// Memory descriptor type. +def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { + let summary = "memory descriptor type (`::mlir::triton::MemDescType`) in Triton IR type system"; + + let description = [{ + Memory descriptor contains a base pointer (scalar) and a descriptor of the memory. + If mutable memory is false that means the memory is constant and can only be allocated and stored once. + A constant memory allocation is different than a tensor as it can have multiple views and the descriptor + can be changed without changing the underlying memory. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType, + "Attribute":$encoding, + "bool":$mutable_memory + ); + let extraClassDeclaration = [{ + MemDescType cloneWith(std::optional> shape, + Type elementType) const { + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding()); + } + + bool hasRank() const { return true; } + }]; + let builders = [ + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, /*mutableMemory=*/false); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "bool":$mutableMemory + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, mutableMemory); + }]> + ]; + let hasCustomAssemblyFormat = 1; +} + + +#endif diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/Types.h b/third_party/xpu/include/triton/Dialect/Triton/IR/Types.h new file mode 100644 index 000000000..bf1967f1b --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/Types.h @@ -0,0 +1,39 @@ +#ifndef TRITON_IR_TYPES_H_ +#define TRITON_IR_TYPES_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.h.inc" + +#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc" + +namespace mlir { + +namespace triton { + +bool isTensorPointerType(Type type); + +bool isTensorOrTensorPointerType(Type type); + +unsigned getPointeeBitWidth(Type type); + +Type getPointeeType(Type type); + +Type getPointerType(Type type); + +Type getElementTypeOfTensorPointerType(Type type); + +Type getI1SameShape(Type type); + +Type getI32SameShape(Type type); + +Type getPointerTypeSameShape(Type type); + +} // namespace triton + +} // namespace mlir + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/xpu/include/triton/Dialect/Triton/IR/Utility.h b/third_party/xpu/include/triton/Dialect/Triton/IR/Utility.h new file mode 100644 index 000000000..0ef597147 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/IR/Utility.h @@ -0,0 +1,190 @@ +#ifndef TRITON_IR_UTILITY_H_ +#define TRITON_IR_UTILITY_H_ + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include +#include + +namespace mlir { + +template SmallVector convertType(ArrayRef in) { + SmallVector out; + for (const auto &i : in) + out.push_back(T(i)); + return out; +} + +template +SmallVector convertType(const VecU &in) { + return convertType(ArrayRef(in)); +} + +template Int product(llvm::ArrayRef arr) { + return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{}); +} +template auto product(const VecT &vec) { + return product(llvm::ArrayRef(vec)); +} + +// TODO(jlebar): Rename to ceilOfRatio. +template Int ceil(Int m, Int n) { return (m + n - 1) / n; } + +/// Get the highest power of 2 divisor of an integer. +template T highestPowOf2Divisor(T n) { + if (n == 0) { + return (static_cast(1) << (sizeof(T) * 8 - 2)); + } + return (n & (~(n - 1))); +} + +/// Get the next power of 2 for an integer (or the integer itself if it is a +/// power of 2). +template T nextPowOf2(T n) { + if (n == 0) { + return 1; + } + n--; + for (unsigned i = 1; i < sizeof(T) * 8; i <<= 1) { + n |= n >> i; + } + return n + 1; +} + +namespace triton { + +// Many functions here have two overloads, fn(ArrayRef) and fn(const VecT&). +// This is helpful because C++ won't both convert a vector to ArrayRef *and* +// infer the proper type T in one step. So without the second overload, we +// would have to explicitly convert most arguments to ArrayRef at the callsite. + +template +SmallVector applyPermutation(ArrayRef vec, ArrayRef permutation) { + static_assert(std::is_integral_v); + assert(vec.size() == permutation.size()); + + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (U i = 0; i < static_cast(sortedPerm.size()); i++) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret; + ret.reserve(vec.size()); + for (const U &i : permutation) { + ret.push_back(vec[i]); + } + return ret; +} + +template +auto applyPermutation(const VecT &vec, const PermT &permutation) { + return applyPermutation(ArrayRef(vec), ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector inversePermutation(ArrayRef permutation) { + // Check that `permutation` is actually a permutation. +#ifndef NDEBUG + SmallVector sortedPerm(permutation); + llvm::sort(sortedPerm); + for (int i = 0; i < sortedPerm.size(); ++i) { + assert(sortedPerm[i] == i); + } +#endif + + SmallVector ret(permutation.size()); + for (int i = 0; i < permutation.size(); ++i) { + ret[permutation[i]] = i; + } + return ret; +} + +template +[[nodiscard]] auto inversePermutation(const VecT &permutation) { + return inversePermutation(ArrayRef(permutation)); +} + +template +[[nodiscard]] SmallVector gather(ArrayRef elems, ArrayRef indices) { + SmallVector ret; + ret.reserve(indices.size()); + for (const U &i : indices) { + ret.push_back(elems[i]); + } + return ret; +} + +template +[[nodiscard]] auto gather(const VecT &elems, const IdxT &indices) { + return gather(ArrayRef(elems), ArrayRef(indices)); +} + +// Is `vec` [0, 1, ..., n]? Returns true on empty list. +template bool isIota(ArrayRef vec) { + static_assert(std::is_integral_v); + for (T i = 0; i < vec.size(); ++i) { + if (vec[i] != i) { + return false; + } + } + return true; +} + +template bool isIota(const VecT &vec) { + return isIota(ArrayRef(vec)); +} + +// Is `vals` some permutation of the numbers 0..(vals.size()-1)? +template bool isPermutationOfIota(ArrayRef vals) { + SmallVector sorted(vals); + llvm::sort(sorted); + return isIota(sorted); +} + +template bool IsPermutationOfIota(const VecT &vec) { + return isPermutationOfIota(ArrayRef(vec)); +} + +// Is `vec` [i, i+1, ..., i+n]? Returns true on empty list. +template bool isConsecutive(ArrayRef vec) { + static_assert(std::is_integral_v); + for (int i = 1; i < vec.size(); i++) { + if (vec[i] != vec[i - 1] + 1) { + return false; + } + } + return true; +} + +template bool isConsecutive(const VecT &vec) { + return isConsecutive(ArrayRef(vec)); +} + +// LLVM's STLExtras.h provides a bunch of functions that work over ranges, but +// it's missing min/max_element until +// https://github.com/llvm/llvm-project/commit/fab2bb8b makes it into Triton. +// TODO(jlebar): Remove this once we have the LLVM helpers. +template auto min_element(R &&Range) { + return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range)); +} +template +auto min_element(R &&Range, Compare &&C) { + return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range), + std::forward(C)); +} +template auto max_element(R &&Range) { + return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range)); +} +template +auto max_element(R &&Range, Compare &&C) { + return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range), + std::forward(C)); +} + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/xpu/include/triton/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..372a9ec11 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Triton) +add_public_tablegen_target(TritonTransformsIncGen) diff --git a/third_party/xpu/include/triton/Dialect/Triton/Transforms/Passes.h b/third_party/xpu/include/triton/Dialect/Triton/Transforms/Passes.h new file mode 100644 index 000000000..fde54fe17 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/Transforms/Passes.h @@ -0,0 +1,21 @@ +#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITON_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr createCombineOpsPass(); + +std::unique_ptr createReorderBroadcastPass(); +std::unique_ptr createRewriteTensorPointerPass(); + +} // namespace triton + +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +} // namespace mlir + +#endif diff --git a/third_party/xpu/include/triton/Dialect/Triton/Transforms/Passes.td b/third_party/xpu/include/triton/Dialect/Triton/Transforms/Passes.td new file mode 100644 index 000000000..4ebff63fa --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/Triton/Transforms/Passes.td @@ -0,0 +1,44 @@ +#ifndef TRITON_PASSES +#define TRITON_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonCombineOps : Pass { + let summary = "combine ops"; + let description = [{ + dot(a, b, 0) + c => dot(a, b, c) + + addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1)) + + select(cond, load(ptrs, broadcast(cond), ???), other) => + load(ptrs, broadcast(cond), other) + }]; + + let constructor = "mlir::triton::createCombineOpsPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect"]; +} + +def TritonReorderBroadcast : Pass { + let summary = "Moves broadcast and splat after elementwise operations"; + let description = [{ + elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + }]; + let constructor = "mlir::triton::createReorderBroadcastPass()"; + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +def TritonRewriteTensorPointer : Pass { + let summary = "Rewrite load/stores with tensor pointers into legacy load/stores"; + let description = [{ + This pass rewrites all load/store semantics initiated by a `tt.make_tensor_ptr` and `tt.advance` into legacy + semantics. After this pass, `tt.make_tensor_ptr` and `tt.advance` will disappear, and it generates logics to compute + the pointer/mask/other for each load/store. + }]; + + let constructor = "mlir::triton::createRewriteTensorPointerPass()"; + + let dependentDialects = ["mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/Attributes.h b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/Attributes.h new file mode 100644 index 000000000..a99ddfc17 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/Attributes.h @@ -0,0 +1,10 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ + +#include "mlir/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" + +#endif // TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..73c9401c1 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,21 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu) +add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) +mlir_tablegen(TritonGPUAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(TritonGPUAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(TritonGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonGPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonGPUAttrDefsIncGen) diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/Dialect.h b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/Dialect.h new file mode 100644 index 000000000..5ae7848a0 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -0,0 +1,127 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonGPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonGPU/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace gpu { + +struct SharedMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; + +unsigned getTotalElemsPerThread(Type type); + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, + Type eltTy); + +SmallVector getElemsPerThread(Type type); + +// Returns the number of threads per warp that may have access to replicated +// elements. If you want non-replicated threads, use +// getThreadsPerWarpWithUniqueData. +SmallVector getThreadsPerWarp(Attribute layout); + +unsigned getWarpSize(Attribute layout); + +// Returns the number of warps per CTA that may have access to replicated +// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData. +SmallVector getWarpsPerCTA(Attribute layout); + +SmallVector getSizePerThread(Attribute layout); + +// Returns the number of contiguous elements that each thread +// has access to, on each dimension of the tensor. E.g. +// for a blocked layout with sizePerThread = [1, 4], returns [1, 4], +// regardless of the shape of the tensor. +SmallVector getContigPerThread(Attribute layout); + +// Returns the number of non-replicated contiguous elements that each thread +// has access to, on each dimension of the tensor. For a blocked layout +// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements +// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1, +// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be +// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4]. +SmallVector getUniqueContigPerThread(Attribute layout, + ArrayRef tensorShape); + +// Returns the number of threads per warp that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17 +// have access to the full tensor, whereas the other threads have access to +// replicated elements, so this function returns [2, 2]. +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape); + +// Returns the number of warps per CTA that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2], +// returns [1, 1], since the first warp has access to the full tensor, whereas +// the other warps have access to replicated elements. +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); + +SmallVector getWarpOrder(Attribute layout); + +SmallVector getOrder(Attribute layout); + +CTALayoutAttr getCTALayout(Attribute layout); + +SmallVector getCTAsPerCGA(Attribute layout); + +SmallVector getCTASplitNum(Attribute layout); + +SmallVector getCTAOrder(Attribute layout); + +/* The difference between ShapePerCTATile and ShapePerCTA: + * (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp * + * WarpsPerCTA in each dimension and is independent from the tensor shape. + * (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension. + * (3) In the implementation of emitIndices, ShapePerCTATile will + * be replicated or wrapped to fit ShapePerCTA. + */ +SmallVector +getShapePerCTATile(Attribute layout, + ArrayRef tensorShape = ArrayRef()); + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape); +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape); +SmallVector getShapePerCTA(Type type); + +unsigned getNumWarpsPerCTA(Attribute layout); + +unsigned getNumCTAs(Attribute layout); + +bool isaDistributedLayout(Attribute layout); + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding); + +// Return true if a view between the two types cannot be implemented as a no-op. +bool isExpensiveView(Type srcType, Type dstType); + +// Return a blocked encoding where the shape is distributed contiguously amongst +// the threads, warps, CTAs with 1 element per threads. +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs); + +} // namespace gpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h new file mode 100644 index 000000000..d4f274742 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -0,0 +1,37 @@ +// Conversions from TritonGPU layouts (e.g. BlockedEncodingAttr) to +// LinearLayout. + +#include + +#include "triton/Tools/LinearLayout.h" + +namespace mlir::triton::gpu { + +// - BlockedEncodingAttrs have the following input dimensions. +// +// "register": elements in one thread +// "lane": threads in a warp +// "warp": warps in a block/CTA +// "block": blocks in a cluster +// +// - An n-dimensional SharedEncodingAttr has the following input dimensions. +// +// "offset": the n'th element in the allocation, within a particular block +// "block": blocks in a cluster +// +// All layouts have the following output dimensions. +// +// "dimi" for i in 0..n-1: the location in the n'th logical dimension of the +// output tensor. These also are not reordered according to the layout's +// `order`. +// +// You can flatten the input or output dimensions into a single dimension using +// LinearLayout::flattenIns/Outs(). +// +// Returns std::nullopt if the given layout can't be converted to an LL. +// TODO(jlebar): Remove the std::optional once all layouts are supported. +// +std::optional toLinearLayout(ArrayRef shape, + Attribute layout); + +} // namespace mlir::triton::gpu diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td new file mode 100644 index 000000000..ae23f9d13 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -0,0 +1,1301 @@ +#ifndef TRITONGPU_ATTRDEFS +#define TRITONGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +//===----------------------------------------------------------------------===// +// TritonGPU Attribute Definitions +//===----------------------------------------------------------------------===// +def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let methods = [ + InterfaceMethod<"Return total element size per thread.", + "unsigned", + "getTotalElemsPerThread", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy)>, + + InterfaceMethod<"Return element size per thread in each dimension.", + "SmallVector", + "getElemsPerThread", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy)>, + ]; +} + +class TritonGPU_Attr traits = [], + Dialect dialect = TritonGPU_Dialect, + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + + let description = [{ +TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines +how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function +\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding +to the indices of the CUDA threads allowed to access some data at index $i$. + +For example, let us consider the layout function: +\mathcal{L}(0, 0) = {0, 4} +\mathcal{L}(0, 1) = {1, 5} +\mathcal{L}(1, 0) = {2, 6} +\mathcal{L}(1, 1) = {3, 7} + +Then, attaching $\mathcal{L} to a tensor $T$ would mean that: +- T[0,0] is owned by both cuda thread 0 and 4 +- T[0,1] is owned by both cuda thread 1 and 5 +- T[1,0] is owned by both cuda thread 2 and 6 +- T[1,1] is owned by both cuda thread 3 and 7 + +Right now, Triton implements two main classes of layouts: shared, and distributed. + }]; + let attrName = "triton.gpu." # attrMnemonic; + + code extraBaseClassDeclaration = [{ + unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const; + SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const; + ::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const; + }]; +} + +//===----------------------------------------------------------------------===// +// CTA Layout +//===----------------------------------------------------------------------===// + +def CTALayoutAttr : TritonGPU_Attr<"CTALayout", "cta_layout"> { + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$CTAsPerCGA, + ArrayRefParameter<"unsigned">:$CTASplitNum, + ArrayRefParameter<"unsigned">:$CTAOrder + ); + + let description = [{ +Describes how blocks are distributed among the cooperate thread arrays (aka +CTAs, aka thread blocks) in a cooperate thread group (aka CTG, aka thread group +cluster). CGAs were introduced in Hopper (sm90). + +The tensor is divided up into CTASplitNum pieces, which are distributed among +the CTAsPerCGA thread blocks. Each CTA processes a subtensor of shape +`tensor_shape / CTASplitNum`. + +Example 0: The tensor shape is [64, 128] and, there are two CTAs, each +processing half the tensor [64, 64]. Then CTAsPerCGA = [1, 2] and +CTASplitNum = [1, 2]. + +Example 1: The tensor shape is [64, 128] and, there are two CTAs, both +processing the complete tensor [64, 128]. This happens when multicast is +enabled. In this case, CTAsPerCTA = [1, 2] but CTASplitNum = [1, 1]. + +Example 2: Consider a matmul AxB=C, where A=[M,K], B=[K,N], C=[M,N]. The +CTAsPerCGA for A, B, C are the same, [SplitM, SplitN], but the CTASplitNum are +different. CTASplitNum_A = [SplitM, 1], which means multicast on dim1, +CTASplitNum_B = [1, SplitN], which means multicast on dim0, CTASplitNum_C = +[SplitM, SplitN] which means no multicast. + +Currently programs with multiple CTAs per CGA are an experimental feature in +Triton, not enabled by default. + +You can leave off the CTALayout properties in the textual IR and Triton will +fill in the "default" CTALayout of CTAsPerCGA = CTASplitNum = [1...1]. In +addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to +[n-1,...,0] (it doesn't matter in this case). + }]; + + // CTALayout::get canonicalizes CTAOrder to [n,n-1,...,0] if CTAsPerCGA is + // [1...1]. The CTAOrder doesn't matter in this case. + // + // This is a little weird because if you write textual IR with a one order and + // then print it back out, you might get a different order. But it seems this + // is the best way to canonicalize an attribute in MLIR. + let builders = [ + AttrBuilder<(ins "ArrayRef":$CTAsPerCGA, + "ArrayRef":$CTASplitNum, + "ArrayRef":$CTAOrder), [{ + if (llvm::all_of(CTAsPerCGA, [](unsigned x) { return x == 1; })) { + SmallVector order; + for (int i = CTAsPerCGA.size() - 1; i >= 0; --i) + order.push_back(i); + return $_get(context, CTAsPerCGA, CTASplitNum, order); + } + return $_get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + }]>, + ]; + + let extraClassDeclaration = [{ + SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const { + llvm::report_fatal_error( + "Unsupported getElemsPerThread in CTALayoutAttr."); + } + unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { + llvm::report_fatal_error( + "Unsupported getTotalElemsPerThread in CTALayoutAttr."); + } + + static CTALayoutAttr getDefault(MLIRContext *context, int rank) { + SmallVector CTAsPerCGA(rank, 1); + SmallVector CTASplitNum(rank, 1); + SmallVector CTAOrder; + for (int i = rank - 1; i >= 0; --i) + CTAOrder.push_back(i); + return get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + } + }]; + + let genVerifyDecl = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// Shared Layout Encoding +//===----------------------------------------------------------------------===// + +def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> { + let mnemonic = "shared"; + + let description = [{ +An encoding for tensors whose elements may be simultaneously accessed by +different cuda threads in the programs, via shared memory. In other words, +for all indices i \in Z^d, \mathcal{L}(i) = {0, 1, ..., 32*num_warps - 1}. + +In order to avoid shared memory bank conflicts, elements may be swizzled. +Here are some examples. In all cases, the input tensor is [0, 1, ..., n-1]. + +1. Basic swizzling + + #shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // xor with 0 + [ 5, 4, 7, 6], // xor with 1 + [10, 11, 8, 9], // xor with 2 + [15, 14, 13, 12] // xor with 3 + +Here elements of row r are xor'ed with r (or more properly, in[r][c] -> +out[r][c^r]). + +2. Multiple rows per phase + + #shared<{vec=1, perPhase=2, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14] + +Elements of row r are xor'ed with r/2. In other words, perPhase=2 +means that pairs of 2 rows get the same swizzling. + +3. Max-phase applied + + $shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 5, 4, 7, 6], // phase 1 (xor with 1) + [ 8, 9, 10, 11], // phase 0 + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // ... + [21, 20, 23, 22], + [24, 25, 26, 27], + [29, 28, 31, 30] + +Elements of row r are xor'ed with (r/2) % 2. In other words, maxPhase=m has the +effect of limiting the maximum value of the xor to m-1. + +4. Max-phase and per-phase + + #shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}> + [ 0, 1, 2, 3], // phase 0 (xor with 0) + [ 4, 5, 6, 7], // phase 0 + [ 9, 8, 11, 10], // phase 1 (xor with 1) + [13, 12, 15, 14], // phase 1 + [16, 17, 18, 19], // phase 0 + [20, 21, 22, 23], // phase 0 + [25, 24, 27, 26], // phase 1 + [29, 28, 31, 30]] // phase 1 + +Here the xor value (the "phase", I guess?) changes every perPhase rows, up to a +maximum value of maxPhase-1. In other words, elements of row r are xor'ed with +(r/2) % 2. + +5. Adding vec + + #shared<{vec=2, perPhase=1, maxPhase=4, order=[1,0]}> + [ 0, 1, 2, 3, 4, 5, 6, 7], + [10, 11, 8, 9, 14, 15, 12, 13], + [20, 21, 22, 23, 16, 17, 18, 19], + [30, 31, 28, 29, 26, 27, 24, 25] + +When vec=2, elements are swizzled in pairs of 2. In other words, the element at +(r,c) has value + + ((c / 2) ^ r) * 2 + (c % 2). + +For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case, +when the matrix is stored in shared memory, there will be an offset not +only in the stride dimension, but also in the leading dimension. For example, +a matrix of size 16x128 and data type I8 is stored in the shared memory with +64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64, +compared to 1*64 when the hasLeadingOffset is false. + }]; + + // swizzle info: vec, perPhase, maxPhase + // order: the fastest-changing axis first + let parameters = ( + ins + "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + ArrayRefParameter<"unsigned">:$order, + "CTALayoutAttr":$CTALayout, + "bool":$hasLeadingOffset + ); + + let builders = [ + AttrBuilder<(ins "unsigned":$vec, + "unsigned":$perPhase, + "unsigned":$maxPhase, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout), [{ + bool hasLeadingOffset = false; // default value + return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit), [{ + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); + }]>, + + // TODO(jlebar): This should not be an overload of + // SharedEncodingAttr::get(). It's misleading, because it does a bunch of + // nontrivial work based on the given dotOpEnc. + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ + + // ---- begin GFX908/GFX90A ---- + if (auto mfmaEnc = mlir::dyn_cast(dotOpEnc.getParent())) { + int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + if (needTrans) + kDimNum = 1 - kDimNum; + bool isKDimInner = (order[0] == kDimNum); + if (isKDimInner) { + const int numBanks = 32; + const int bankBitWidth = 32; + const int SIMDWidth = 16; + + // number of inner dimension rows per one pattern repeat + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + // vecSize is set to kWidth of the dotop layout + int vecSize = dotOpEnc.getKWidth(); + int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); + + // TODO (zhanglx): figure out better parameters for mfma4 + if (mfmaEnc.getMDim() == 4) + maxPhase = 4; + + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return get(context, 1, 1, 1, order, CTALayout); + } + } + + // ---- begin GFX11 ---- + if (mlir::isa(dotOpEnc.getParent())) { + if (dotOpEnc.getOpIdx() == 0) { + const int numBanks = 32; + const int bankBitWidth = 32; + + // number of inner dimension rows per one pattern repeat + int innerDimLength = shape[order[0]]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit; + int maxPhase = 16 / perPhase; + + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } else { + // Do not swizzle in case k dimension is not innermost. + // In this case accesses will go in different banks even without swizzling. + return get(context, 1, 1, 1, order, CTALayout); + } + } + + + auto mmaEnc = mlir::dyn_cast(dotOpEnc.getParent()); + + if(!mmaEnc) + return get(context, 1, 1, 1, order, CTALayout); + + int opIdx = dotOpEnc.getOpIdx(); + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + // number of rows per phase + + // index of the inner dimension in `order` + unsigned inner = (opIdx == 0) ? 0 : 1; + + // ---- begin Volta ---- + if (mmaEnc.isVolta()) { + int perPhase = 128 / (shapePerCTA[order[0]] * (typeWidthInBit / 8)); + perPhase = std::max(perPhase, 1); + bool is_row = order[0] != 0; + bool is_vec4 = opIdx == 0 ? !is_row && (shapePerCTA[order[0]] <= 16) : + is_row && (shapePerCTA[order[0]] <= 16); + int pack_size = opIdx == 0 ? ((is_row || is_vec4) ? 1 : 2) : + ((is_row && !is_vec4) ? 2 : 1); + int rep = 2 * pack_size; + int maxPhase = (order[inner] == 1 ? 8 : 4) / perPhase; + int vec = 2 * rep; + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + // ---- begin Ampere ---- + if (mmaEnc.isAmpere()) { + int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); + perPhase = std::max(perPhase, 1); + std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; + int vecWidth = 32 / typeWidthInBit; + if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) { + perPhase = std::max(perPhase, 2 * vecWidth); + } + int rank = order.size(); + // --- handle A operand --- + if (opIdx == 0) { // compute swizzling for A operand + int m = (needTrans) ? matShape[2] : matShape[0]; + int k = (needTrans) ? matShape[0] : matShape[2]; + int vec = (order[0] == rank-1) ? k : m; + int mmaStride = (order[0] == rank-1) ? m : k; + int maxPhase = mmaStride / perPhase; + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + // --- handle B operand --- + if (opIdx == 1) { + // we compute vec and maxPhase m, n and k size of the mma + // instruction. when matmul operands is transposed, we should + // consider that to get m, n and k. + int n = needTrans ? matShape[2] : matShape[1]; + int k = needTrans ? matShape[1] : matShape[2]; + int vec = (order[0] == rank-1) ? n : k; + int mmaStride = (order[0] == rank-1) ? k : n; + int maxPhase = mmaStride / perPhase; + return get(context, vec, perPhase, maxPhase, order, CTALayout); + } + + llvm_unreachable("invalid operand index"); + } + + // ---- begin version 3 ---- + if (mmaEnc.isHopper()) { + llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr" + " is Hopper has not been implemented yet"); + return $_get(context, 1, 1, 1, order, CTALayout, true); + } + + // ---- not implemented ---- + llvm_unreachable("unsupported swizzling for provided MMA version"); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool":$needTrans), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy), [{ + auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth(); + int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1; + + // get proper shared memory swizzling mode from the contiguous dimension + // size of the origin blocked layout. + auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8; + if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) { + perPhase = 1; + maxPhase = 8; + } else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) { + perPhase = 2; + maxPhase = 4; + } else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) { + perPhase = 4; + maxPhase = 2; + } else { + llvm_unreachable("unsupported shared memory layout for MMAv3"); + } + + return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true); + }]> + ]; + + let extraClassDeclaration = extraBaseClassDeclaration; + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// Distributed Layout Encoding +//===----------------------------------------------------------------------===// +def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + + let description = [{ +The Distributed encoding describes the layout L with the 4-level compute hierarchy on GPU. +It is abstracted from the top to the bottom as CTAs Per CGA->Warps Per CTA->Threads Per Warp->Values Per Thread. + +For CTAs Per CGA and Warps Per CTA level, the linear id is distributed contiguously with the shape and order. +For example, for a shape/order pair defines a distribution layout +shape = [4, 4] +order = [0, 1] // The fastest-changing axis first +-> +layout = [0 4 8 12] + [1 5 9 13] + [2 6 10 14] + [3 7 11 15] + +For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + }]; + + let methods = [ + // Interface for the meta information about the multiple thread hierarchy. + InterfaceMethod<"Get the shape of the CTAs per CGA.", + "SmallVector", + "getCTAsPerCGA">, + + InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", + "SmallVector", + "getCTAOrder">, + + InterfaceMethod<"Get the shape of the warps per CTA.", + "SmallVector", + "getWarpsPerCTA">, + + InterfaceMethod<"Get the order of the warps per CTA. The fastest-changing axis first", + "SmallVector", + "getWarpOrder">, + + InterfaceMethod<"Get the shape of the threads per warp", + "SmallVector", + "getThreadsPerWarp">, + + InterfaceMethod<"Get the order of the threads per warp. The fastest-changing axis first", + "SmallVector", + "getThreadOrder">, + + InterfaceMethod<"Get the shape of the values per thread.", + "SmallVector", + "getSizePerThread">, + + InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", + "SmallVector", + "getCTASplitNum">, + + InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA", + "SmallVector", + "getShapePerCTATile", + (ins "ArrayRef":$tensorShape)>, + + InterfaceMethod<"Gets the number of contiguous elements per thread.", + "SmallVector", + "getContigPerThread">, + ]; +} + +class DistributedEncoding traits = [], + Dialect dialect = TritonGPU_Dialect> + : TritonGPU_Attr { + + let description = [{ +Distributed encodings have a layout function L that is entirely characterized +by a d-dimensional tensor T. Note that L doesn't need to have the same shape +(or even the same rank) as the tensor it is encoding. + +The layout function \mathcal{L} of this layout is then defined, for an +index `i` \in Z^d, as follows: + +\mathcal{L}(T)[i_d] = L[(i_d + k_d*T.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*T.shape[d] < L.shape[d] + +Intuitively, when the tensor dim size T.shape[d] is larger than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "wrapped around" manner, with +each thread owning multiple values. + +OTOH, when the tensor dim size T.shape[d] is smaller than the layout +dim size L.shape[d], on that particular dim, we distribute values from the +tensor to threads mapped in the layout in a "broadcasted" manner, with +each value owned by multiple threads. + +For example, for a tensor/layout pair +T = [x x x x x x x x] + [x x x x x x x x] +L = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] + +Then the data of T would be distributed as follow between the 16 CUDA threads: +L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, + {4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ] + }]; + + code extraDistributedDeclaration = extraBaseClassDeclaration # [{ + SmallVector getCTAsPerCGA() const; + SmallVector getCTAOrder() const; + SmallVector getCTASplitNum() const; + SmallVector getWarpsPerCTA() const; + SmallVector getWarpOrder() const; + SmallVector getThreadsPerWarp() const; + SmallVector getThreadOrder() const; + + SmallVector getSizePerThread() const; + SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; + }]; +} + +//===----------------------------------------------------------------------===// +// Blocked Layout Encoding +//===----------------------------------------------------------------------===// + +def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding", "blocked_encoding"> { + let mnemonic = "blocked"; + + let description = [{ +An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout +used to promote memory coalescing in LoadInst and StoreInst. +It is characterized by three tuples -- thread tile size, warp tile size, and block tile size -- which +specify the amount of elements owned by each CUDA thread, warp and CTA respectively. + +Example 1, a row-major coalesced layout may partition a 16x16 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows: + +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {1, 1} + CTASplitNum = {1, 1} +}> + +Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and +4 CTAs (taking 2x2 for example) as follows: + +CTA [0,0] CTA [0,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] + +CTA [1,0] CTA [1,1] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] [ 0 0 1 1 2 2 3 3 ; 32 32 33 33 34 34 35 35 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +[ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] [ 4 4 5 5 6 6 7 7 ; 36 36 37 37 38 38 39 39 ] +... ... +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +[ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] [ 28 28 29 29 30 30 31 31 ; 60 60 61 61 62 62 63 63 ] +for + +#triton_gpu.blocked_layout<{ + sizePerThread = {2, 2} + threadsPerWarp = {8, 4} + warpsPerCTA = {1, 2} + CTAsPerCGA = {2, 2} + CTASplitNum = {2, 2} +}> +}]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$sizePerThread__, + ArrayRefParameter<"unsigned">:$threadsPerWarp__, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first + + // CTALayout is optional in the textual IR. If omitted, we infer it to be a + // single CTA (so CTAsPerCGA = [1,...,1], CTASplitNum = [1,...,1], + // CTAOrder=[n,n-1,...,0]). + "CTALayoutAttr":$CTALayout + ); + let genVerifyDecl = 1; + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "CTALayoutAttr":$CTALayout), [{ + unsigned rank = sizePerThread.size(); + SmallVector threadsPerWarp(rank); + SmallVector warpsPerCTA(rank); + SmallVector shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); + + unsigned remainingLanes = numThreadsPerWarp; + unsigned remainingThreads = numWarps * numThreadsPerWarp; + unsigned remainingWarps = numWarps; + unsigned prevLanes = 1; + unsigned prevWarps = 1; + + // starting from the contiguous dimension + for (unsigned d = 0; d < rank - 1; ++d) { + unsigned i = order[d]; + unsigned threadsPerCTA = std::clamp(remainingThreads, 1, shapePerCTA[i] / sizePerThread[i]); + threadsPerWarp[i] = std::clamp(threadsPerCTA, 1, remainingLanes); + warpsPerCTA[i] = std::clamp(threadsPerCTA / threadsPerWarp[i], 1, remainingWarps); + remainingWarps /= warpsPerCTA[i]; + remainingLanes /= threadsPerWarp[i]; + remainingThreads /= threadsPerCTA; + prevLanes *= threadsPerWarp[i]; + prevWarps *= warpsPerCTA[i]; + } + + // Expand the last dimension to fill the remaining lanes and warps + threadsPerWarp[order[rank - 1]] = numThreadsPerWarp / prevLanes; + warpsPerCTA[order[rank - 1]] = numWarps / prevWarps; + + return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout); + }]>, + + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$sizePerThread, + "ArrayRef":$order, + "unsigned":$numWarps, + "unsigned":$numThreadsPerWarp, + "unsigned":$numCTAs), [{ + unsigned rank = sizePerThread.size(); + SmallVector CTAsPerCGA(rank); + SmallVector CTASplitNum(rank); + ArrayRef CTAOrder = order; + + unsigned remainingCTAs = numCTAs; + + // starting from the most strided dimension + for (int d = rank - 1; d >= 0; --d) { + unsigned i = order[d]; + CTAsPerCGA[i] = std::clamp(remainingCTAs, 1, shape[i] / sizePerThread[i]); + CTASplitNum[i] = CTAsPerCGA[i]; + remainingCTAs /= CTAsPerCGA[i]; + } + + CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level + + CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout); + }]> + ]; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SliceEncodingAttr squeeze(int axis); + + SmallVector getContigPerThread() { + // Block encoding is dense stride layout. The elements per thread are contiguous. + return getSizePerThread(); + }; + }]; + + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// MMA Layout Encoding +//===----------------------------------------------------------------------===// +// TODO: MMAv1 and MMAv2 should be two instances of the same class +def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let methods = [ + + InterfaceMethod<"Return whether the layout support reduction op.", + "bool", + "supportReduction">, + + InterfaceMethod<"Return shape per CTA.", + "SmallVector", + "getShapePerCTATileForDotOperands", + (ins "ArrayRef":$tensorShape, + "unsigned":$opIdx)>, + + InterfaceMethod<"Return total element size per thread for dot operands.", + "unsigned", + "getTotalElemsPerThreadForOperands", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy, + "unsigned":$kWidth, + "unsigned":$opIdx)>, + + InterfaceMethod<"Return size per thread for dot operands.", + "SmallVector", + "getSizePerThreadForOperands", + (ins "unsigned":$opIdx)>, + ]; +} + +def AMDMfmaEncodingAttr : DistributedEncoding<"AMDMfmaEncoding", "amd_mfma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_mfma"; + + let description = [{ +An encoding for tensors that have been produced by MFMA matrix core instructions, +available on AMD Instinct GPUs of CDNA architectures. + +It is characterized by the following parameters: +- `versionMajor` and `versionMinor` indicates the GPU architecture: + - 1.0: gfx908, i.e. MI100 + - 2.0: gfx90a: i.e. MI200, MI210, MI250 + - 3.0: gfx940, gfx941, gfx942: MI300 +- `warpsPerCTA` indicates the wave layout in the workgroup. +- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction. +- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout +without going to shared memory. This is used in the case of chained dot (E.g. Flash-Attention kernel). + +Example 1: +Suppose we have a tensor with a shape of [32, 64], warpsPerCTA set to [1, 2] and MDim=NDim=32. +The data will be distributed between threads as follows: + + wave 0 wave 1 +-----------------/\-------------- -----------------/\-------------- +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 0 1 2 3 ...... 30 31 ] [ 64 65 66 67 ...... 94 95 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] +[ 32 33 34 35 ...... 62 63 ] [ 96 97 98 99 ...... 126 127 ] + +Example 2: +Suppose we have a tensor with a shape of [16, 32], warpsPerCTA set to [1, 2] and MDim=NDim=16. +The data will be distributed between threads as follows: + + wave 0 wave 1 +-----------------/\------------- ------------------/\--------------- +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 0 1 2 3 ...... 14 15 ] [ 64 65 66 67 ...... 78 79 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 16 17 18 19 ...... 30 31 ] [ 80 81 82 83 ...... 94 95 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 32 33 34 35 ...... 46 47 ] [ 96 97 98 99 ...... 110 111 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] +[ 48 49 50 51 ...... 62 63 ] [ 112 113 114 115 ...... 126 127 ] + +Example 3: +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and nonKDim set to 4. +The data will be distributed between threads as follows(note that each element is duploicated in 16 threads): +Suppose we have a tensor with a shape of [8, 8], warpsPerCTA set to [2, 2] and MDim=NDim=4. +The data will be distributed between threads as follows(note that each element is duplicated in 16 threads): + +M N -> wave 0 wave 2 +| --------------------------/\-------------------------- ------------------------------/\------------------------------ +V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,133...189 130,134...190 131,135...191 ] + wave 1 wave 3 + --------------------------/\-------------------------- ------------------------------/\------------------------------ + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] + [ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ] +}]; + + let parameters = ( + ins + "unsigned": $versionMajor, + "unsigned": $versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "unsigned":$MDim, + "unsigned":$NDim, + "bool":$isTransposed, + "CTALayoutAttr":$CTALayout + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool supportReduction() const { + return true; + } + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getMFMAInstrShapeForOperands(int kWidth, int opIdx) const; + SmallVector getMFMARepForOperands(ArrayRef operandShape, int kWidth, int opIdx) const; + + SmallVector getContigPerThread() { + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + if (getIsTransposed()) + contigPerThread[rank - 1] = 4; + else + contigPerThread[rank - 2] = 4; + return contigPerThread; + }; + + }]; + + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + +def AMDWmmaEncodingAttr : DistributedEncoding<"AMDWmmaEncoding", "amd_wmma_encoding", [MmaEncodingTrait]> { + let mnemonic = "amd_wmma"; + + let description = [{ +An important limitation of WMMA for layout is a shape for tiles proccessed +by a single wave. It is [16, 16]. +This encoding assumes specific access to matrix elements by threads. + +Example: +Suppose we have a tensor with shape [32, 48], `warpsPerCTA` set to [2, 3]. + + wave 0 [16, 16] wave 1 [16, 16] wave 2 [16, 16] +-----------/\---------- -----------/\---------- -----------/\---------- +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +... ... ... +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + + wave 3 [16, 16] wave 4 [16, 16] wave 5 [16, 16] +-----------/\---------- -----------/\---------- -----------/\---------- +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] +... ... ... +[0 1 2 ... 14 15] [0 1 2 ... 14 15] [0 1 2 ... 14 15] +[16 17 18 ... 30 31] [16 17 18 ... 30 31] [16 17 18 ... 30 31] + }]; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "CTALayoutAttr":$CTALayout + ); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool supportReduction() const { + return true; + } + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getWMMAElemsPerInstrForOperands() const; + SmallVector getWMMARepForOperands(ArrayRef operandShape, + Type elemType, int kWidth, int opIdx) const; + static SmallVector getMNKDimPerWMMAInstr(); + + SmallVector getContigPerThread() { + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + return contigPerThread; + }; + }]; +} + +def NvidiaMmaEncodingAttr : DistributedEncoding<"NvidiaMmaEncoding", "nvidia_mma_encoding", [MmaEncodingTrait]> { + let mnemonic = "nvidia_mma"; + + let description = [{ +An encoding for tensors that have been produced by tensor cores. + +It is characterized by two parameters: +- A 'versionMajor' which specifies the generation the tensor cores + whose output is being partitioned: + - 1 for first-gen tensor cores (Volta), and + - 2 for second-gen tensor cores (Turing/Ampere). +- A 'versionMinor' which indicates the specific layout of a tensor core + generation, e.g. for Volta, there might be multiple kinds of layouts + annotated by 0,1,2 and so on. +- A `blockTileSize` to indicate how data should be partitioned between warps. + +// -------------------------------- version = 1 --------------------------- // + +For first-gen tensor cores, the implicit warpTileSize is [16, 16]. +Note: the layout is different from the recommended in PTX ISA +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.884 section, FP32 accumulator). + +For example, when versionMinor=1, the matrix L corresponding to +blockTileSize=[32,16] is: + + warp 0 +--------------------------------/\------------------------------- +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 0 0 2 2 8 8 10 10 0 0 2 2 8 8 10 10 ] +[ 1 1 3 3 9 9 11 11 1 1 3 3 9 9 11 11 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 4 4 6 6 12 12 14 14 4 4 6 6 12 12 14 14 ] +[ 5 5 7 7 13 13 15 15 5 5 7 7 13 13 15 15 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 16 16 18 18 20 20 22 22 16 16 18 18 20 20 22 22 ] +[ 17 17 19 19 21 21 23 23 17 17 19 19 21 21 23 23 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] +[ 24 24 26 26 28 28 30 30 24 24 26 26 28 28 30 30 ] +[ 25 25 27 27 29 29 31 31 25 25 27 27 29 29 31 31 ] + + warp 1 = warp0 + 32 +--------------------------------/\------------------------------- +[ 32 32 34 34 40 40 42 42 32 32 34 34 40 40 42 42 ] +[ 33 33 35 35 41 41 43 43 33 33 35 35 41 41 43 43 ] +[ ............................................................... ] + + +// -------------------------------- version = 2 --------------------------- // + +For second-gen tensor cores, the implicit warpTileSize is [16, 8]. +Information about this layout can be found in the official PTX documentation +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +(mma.16816 section, FP32 accumulator). + +For example, the matrix L corresponding to blockTileSize=[32,16] is: + warp 0 warp 2 +-----------------/\------------- ----------------/\------------- +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 +[ 0 0 1 1 2 2 3 3 32 32 33 33 34 34 35 35 +[ 4 4 5 5 6 6 7 7 36 36 37 37 38 38 39 39 +[ .............................. .............................. +[ 28 28 29 29 30 30 31 31 60 60 61 61 62 62 63 63 + + warp 1 warp 3 +----------------/\------------- ----------------/\------------- +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 +[ 64 64 65 65 66 66 67 67 96 96 97 97 98 98 99 99 +[ 68 68 69 69 70 70 71 71 100 100 101 101 102 102 103 103 +[ .............................. ............................... +[ 92 92 93 93 94 94 95 95 124 124 125 125 126 126 127 127 + +}]; + + let parameters = ( + ins + "unsigned":$versionMajor, + "unsigned":$versionMinor, + ArrayRefParameter<"unsigned">:$warpsPerCTA__, + "CTALayoutAttr":$CTALayout, + ArrayRefParameter<"unsigned">:$instrShape + ); + + let builders = [ + // Specially for MMAV1(Volta) + AttrBuilder<(ins "int":$versionMajor, + "int":$numWarps, + "CTALayoutAttr":$CTALayout, + "ArrayRef":$instrShape, + "ArrayRef":$shapeC, + "bool":$isARow, + "bool":$isBRow, + "bool":$isAVec4, + "bool":$isBVec4, + "int":$id), [{ + assert(versionMajor == 1 && "This builder is specially for versionMajor==1"); + // 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4] + int versionMinor = (isARow * (1<<0)) |\ + (isBRow * (1<<1)) |\ + (isAVec4 * (1<<2)) |\ + (isBVec4 * (1<<3)); + + // TODO: Share code with + // DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the + // rep,spw and fpw. + SmallVector wpt({1, 1}); + SmallVector wpt_nm1; + + SmallVector rep(2), spw(2); + std::array fpw{{2, 2, 1}}; + int packSize0 = (isARow || isAVec4) ? 1 : 2; + rep[0] = 2 * packSize0; + spw[0] = fpw[0] * 4 * rep[0]; + + int packSize1 = (isBRow && !isBVec4) ? 2 : 1; + rep[1] = 2 * packSize1; + spw[1] = fpw[1] * 4 * rep[1]; + + do { + wpt_nm1 = wpt; + if (wpt[0] * wpt[1] < numWarps) + wpt[0] = std::clamp(wpt[0] * 2, 1, shapeC[0] / spw[0]); + if (wpt[0] * wpt[1] < numWarps) + wpt[1] = std::clamp(wpt[1] * 2, 1, shapeC[1] / spw[1]); + } while (wpt_nm1 != wpt); + + return $_get(context, versionMajor, versionMinor, wpt, CTALayout, instrShape); + }]>, + + + AttrBuilder<(ins "int":$versionMajor, + "int":$numWarps, + "CTALayoutAttr":$CTALayout, + "ArrayRef":$instrShape, + "ArrayRef":$shapeA, + "ArrayRef":$shapeB, + "ArrayRef":$shapeC, + "bool":$isARow, + "bool":$isBRow, + "int":$id), [{ + assert(versionMajor == 1 && "This builder is specially for versionMajor==1"); + bool isAVec4 = !isARow && (shapeA[isARow] <= 16); + bool isBVec4 = isBRow && (shapeB[isBRow] <= 16); + return get(context, versionMajor, numWarps, CTALayout, instrShape, shapeC, isARow, isBRow, isAVec4, isBVec4, id); + }]> + ]; + + let extraClassDeclaration = extraDistributedDeclaration # [{ + bool isVolta() const; + bool isTuring() const; + bool isAmpere() const; + bool isHopper() const; + + unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef shape) const; + + // Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor + std::tuple decodeVoltaLayoutStates() const; + + // Number of bits in versionMinor to hold the ID of the MMA encoding instance. + // Here 5 bits can hold 32 IDs in a single module. + static constexpr int numBitsToHoldMmaV1ID{5}; + + // For MMA v1, method `getMMAv1IsRow` returns whether e.g. the a operand is used + // in the context of an mma.884.row.col or an mma.884.col.col operation. See the PTX ISA documentation + // section 9.7.13.4.1 for more details. + bool getMMAv1IsRow(int opIdx) const; + bool getMMAv1IsVec4(int opIdx) const; + int getMMAv1NumOuter(ArrayRef shape, int opIdx) const; + SmallVector getMMAv1Rep(int opIdx) const; + SmallVector getMMAv1ShapePerWarp(int opIdx) const; + int getMMAv1Vec(int opIdx) const; + SmallVector getMMAv2Rep(ArrayRef shape, + int bitwidth, int opIdx) const; + + bool supportReduction() const { + if (isAmpere() || isHopper()) { + return true; + } + return false; + }; + SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; + unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + + SmallVector getContigPerThread() { + assert(isVolta() || isAmpere() || isHopper()); + auto rank = getWarpsPerCTA().size(); + SmallVector contigPerThread(rank, 1); + contigPerThread[rank - 1] = 2; + return contigPerThread; + }; + + }]; + + let hasCustomAssemblyFormat = 1; +} + +def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> { + let mnemonic = "slice"; + + let description = [{ + Given a `parent` layout and a `dim`, squeezes the given `dim` in the `parent` + layout and distributes values in a tensor T according to the new layout. + + For example, given + + T = [x x x x x x x x] + L_parent = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] (with 16 CUDA threads) + + With dim = 0, squeezing out dim 0, we have + L = [{0,4,8,12}, {1,5,9,13}, {2,6,10,14}, {3,7,11,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L(T) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15}, {0,4,8,12} , ..., {3,7,11,15} ] + + With dim = 1, squeezing out dim 1, we have + L = [ {0,1,2,3}, {4,5,6,7}, {8,9,10,11}, {12,13,14,15} ] + + Then the data of T would be distributed as follow between the 16 CUDA threads: + L = [ {0,1,2,3}, {4,5,6,7}, ..., {12,13,14,15}, {0,1,2,3}, ..., {12,13,14,15} ] + + This is useful for constructing the inverse layout of an expand_dims operation + during some optimization passes. + }]; + + let parameters = ( + ins + "unsigned":$dim, + // TODO: constraint here to only take distributed encodings + "Attribute":$parent + ); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + template + SmallVector paddedShape(ArrayRef shape) const; + + SmallVector getContigPerThread() { + auto parentLayout = mlir::cast(getParent()); + auto parentContigPerThread = parentLayout.getContigPerThread(); + parentContigPerThread.erase(parentContigPerThread.begin() + getDim()); + return parentContigPerThread; + }; + }]; + + let hasCustomAssemblyFormat = 1; +} + +def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> { + let mnemonic = "dot_op"; + + let description = [{ +In the TritonGPU dialect, given `d = tt.dot a, b, c` tt.dot's operands a and b +must be of DotOperandEncodingAttr layout, if the dot is MMA v1 or v2 (i.e. +pre-Hopper). For MMA v3, the operands are *almost always* in a regular shared +encoding, but sometimes the LHS is also a dot-operand encoding. + +a's opIdx is 0, b's opIdx is 1. + +The parent field is the layout of d. + +kWidth defines number of consecutive elements stored by one thread along k dimension. +Some layouts do not use this parameter, either because they have a fixed number of +elements along the K dim, or they use all elements of the tensor along the K dim. + }]; + + let parameters = ( + ins + "unsigned":$opIdx, + "Attribute":$parent, + DefaultValuedParameter<"unsigned", "0">:$kWidth + ); + + let builders = [ + // Specially for MMAV1(Volta) + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "Type":$eltTy), [{ + NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); + if (!parentAttr || !parentAttr.isAmpere()) + return $_get(context, opIdx, parent, 0); + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + unsigned MMAv2kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, MMAv2kWidth); + }]> + ]; + + let assemblyFormat = "`<` `{` struct(params) `}` `>`"; + let genVerifyDecl = 1; + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getContigPerThread() { + return getSizePerThread(); + }; + }]; +} + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td new file mode 100644 index 000000000..10f2c8c68 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -0,0 +1,54 @@ +#ifndef TRITONGPU_DIALECT +#define TRITONGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonGPU_Dialect : Dialect { + let name = "triton_gpu"; + + let cppNamespace = "::mlir::triton::gpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton GPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "mlir::gpu::GPUDialect", + "tensor::TensorDialect", + ]; + + let extraClassDeclaration = [{ + static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static int getNumWarps(ModuleOp mod) { + if (!mod->hasAttr("triton_gpu.num-warps")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-warps attribute"); + return cast(mod->getAttr("triton_gpu.num-warps")).getInt(); + } + static int getNumCTAs(ModuleOp mod) { + if (!mod->hasAttr("triton_gpu.num-ctas")) + return 1; + return cast(mod->getAttr("triton_gpu.num-ctas")).getInt(); + } + void registerTypes(); + + static std::string getThreadsPerWarpAttrName() { return "triton_gpu.threads-per-warp"; } + + static int getThreadsPerWarp(ModuleOp mod) { + Attribute threadsPerWarp = mod->getDiscardableAttr("triton_gpu.threads-per-warp"); + if(!threadsPerWarp) { + return 32; + } + return cast(threadsPerWarp).getInt(); + } + }]; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h new file mode 100644 index 000000000..0ee2cfeca --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h @@ -0,0 +1,6 @@ +#ifndef TRITON_GPU_DIALECT_INTERFACES_H +#define TRITON_GPU_DIALECT_INTERFACES_H + +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc" + +#endif // TRITON_GPU_DIALECT_INTERFACES_H diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td new file mode 100644 index 000000000..2530009cb --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -0,0 +1,235 @@ +#ifndef TRITONGPU_OPS +#define TRITONGPU_OPS + +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +// +// Interfaces +// +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +class TTG_Op traits = []> : + Op { +} + +def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", + [SameOperandsAndResultShape, + SameOperandsAndResultElementType, + Pure]> { + let summary = "convert layout"; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { + let summary = "async wait"; + + let arguments = (ins Variadic:$asyncToken, I32Attr:$num); + + let results = (outs TTG_AsyncToken:$retToken); + + let assemblyFormat = "$asyncToken attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { + let summary = "async commit group"; + + let results = (outs TTG_AsyncToken:$asyncToken); + let arguments = (ins Variadic:$inputTokens); + + let assemblyFormat = [{ + $inputTokens attr-dict + }]; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 80; + } + }]; +} + +def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + TypesMatchWith<"infer mask type from src type", + "src", "mask", "getI1SameShape($_self)", + "($_op.getOperands().size() <= 3) || std::equal_to<>()">, + TypesMatchWith<"infer other type from src type", + "src", "other", "getPointeeType($_self)", + "($_op.getOperands().size() <= 4) || std::equal_to<>()"> +]> { + let summary = "copy data from global memory to local memory asynchronously"; + + let description = [{ + This operation copies data from global memory to local memory asynchronously. + This is analogue to tt.load except the data are copied to local memory pointed + by by the memory descriptor instread of a distributed tensor. The rest of the + operands are the same as tt.load. + }]; + + let arguments = ( + ins TT_PtrTensor:$src, + TT_MemDescType:$result, + Optional:$mask, + Optional:$other, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let builders = [ + OpBuilder<(ins "Value":$src, "Value":$result, + "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, + ]; + + let results = (outs TTG_AsyncToken:$token); + + let extraClassDeclaration = [{ + static DenseSet getEligibleLoadByteWidth(int computeCapability) { + DenseSet validLoadBytes; + if (computeCapability >= 80) { + validLoadBytes = {4, 8, 16}; + } + return validLoadBytes; + } + }]; + + // Specify cacheModifier and evictionPolicy explicitly, instead of leaving + // them in attr-dict, because this way their values get printed as strings, + // rather than as opaque integers. + // + // Note there are no commas between other, cacheModifier, and evictionPolicy, + // due to limitations in MLIR's asm parser. + let assemblyFormat = [{ + $src `,` $result (`mask` $mask^)? (`other` $other^)? + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($src) `->` type($result) + }]; +} + + +// Allocate shared memory +def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods]> { + let summary = "allocate tensor"; + let description = [{ + This operation allocates buffer in shared memory and return a descriptor + containing the address and a view of the buffer. + + Explicitly deallocating a buffer is optional; see local_dealloc. + }]; + let arguments = (ins Optional:$src); + + let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}]; + + let results = (outs TT_MemDescType:$result); +} + +// Deallocate shared memory +def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree]>]> { + let summary = "dealloc buffer"; + + let description = [{ + This operation deallocates a buffer explicitly. Using the buffer after this + operation is undefined. + + This operation is optional. If you don't explicitly dealloc a buffer, the + compiler assumes it's deallocated at the first point that post-dominates all + uses of the alloc. + + Because we assume a memdesc is dead at the first point that post-dominates + its uses, ops that wait for an async operation on a memdesc to complete + (such as triton_nvidia_gpu.dot_wait) should also take the memdesc as an + operand. + }]; + + let arguments = (ins TT_MemDescType:$src); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}]; +} + +def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> { + let summary = "take a subview of the descriptor."; + + let description = [{ + This operation returns a new descriptor representing a subview of the buffer. + It doesn't affect the underlying memory. The subview can be rank-reduced. + + For example, suppose that + - the input shape is 2x4x16xf16, + - the output shape is 4x4xf16, and + - offsets = [1, 0, 4]. + + Then in Python syntax, the subview covers input[1][0:4][4:8]. + }]; + let arguments = ( + ins TT_MemDescType:$src, Variadic:$offsets); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}]; + + let results = (outs TT_MemDescType:$result); + + let hasVerifier = 1; +} + +def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods]> { + let summary = "Load a buffer from local memory into a distributed tensor"; + + let description = [{ + Load a tensor from the local memory descriptor into a distributed tensor. + }]; + let arguments = (ins TT_MemDescType:$src, Optional :$token); + + let builders = [ + OpBuilder<(ins "Type":$retType, "Value":$src), + [{ + build($_builder, $_state, retType, src, /*token=*/static_cast(nullptr)); + }]>]; + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}]; + + let results = (outs TT_Tensor:$result); +} + +def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods]> { + let summary = "Store a distributed tensor into a buffer in local memory"; + + let description = [{ + Store a distributed tensor into a buffer in local memory. + }]; + let arguments = (ins TT_Tensor:$src, TT_MemDescType:$dst); + + // Use qualified() otherwise "!tt.memdesc" is printed as "". + let assemblyFormat = [{ + $src `,` $dst attr-dict `:` type($src) `->` qualified(type($dst)) + }]; +} + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td new file mode 100644 index 000000000..6765ac40c --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td @@ -0,0 +1,36 @@ +#ifndef TRITONGPU_TYPES +#define TRITONGPU_TYPES + +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class TTG_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTG_TokenType : TTG_TypeDef<"Token", "token"> { + let parameters = (ins "int32_t":$type); + + let builders = [ + TypeBuilder<(ins "unsigned":$type), [{ + return $_get($_ctxt, type); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +def TTG_AsyncToken : TTG_TypeDef<"AsyncToken", + "async.token", []> { + let summary = "async token type"; + let description = [{ + `ttg.async.token` is a type returned by an asynchronous operation. + It is used to establish an SSA-based link between async operations + and operations that group or synchronize the async operations. + }]; +} + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/IR/Types.h b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/Types.h new file mode 100644 index 000000000..edf37fef6 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/IR/Types.h @@ -0,0 +1,10 @@ +#ifndef TRITONGPU_IR_TYPES_H_ +#define TRITONGPU_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..6be94d1a8 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonGPU) +add_public_tablegen_target(TritonGPUTransformsIncGen) diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/Passes.h new file mode 100644 index 000000000..c50d24a08 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -0,0 +1,22 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace gpu { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +} // namespace gpu +} // namespace triton +} // namespace mlir +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/Passes.td new file mode 100644 index 000000000..fdceb2cfe --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -0,0 +1,148 @@ +#ifndef TRITONGPU_PASSES +#define TRITONGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> { + let summary = "pipeline"; + + let description = [{ + Applies software pipelining to loops in the module based on number of stages. + This may convert some load into asynchronous loads, and multi-buffer the data. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; + + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} + +def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> { + let summary = "3xTF32 trick"; + + let description = [{ + Decompose fp32 `DotOp` instructions into 4 pointwise ops and 3 fp16 `DotOp`s + to allow using TensorCores. See https://github.com/NVIDIA/cutlass/discussions/385 + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; +} + +def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> { + let summary = "prefetch"; + + let description = [{ + Decompose `DotOp` instructions in loops into several finer-grained `DotOp` + that may have their operands constructed at the end of the previous iteration + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect"]; +} + +def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> { + let summary = "accelerate matmul"; + + let description = [{ + Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators + (e.g., Nvidia tensor cores) + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> { + let summary = "fuse transpositions"; + + let description = [{ + Re-arranged layouts of tensors used as matrix multiplication operands so as to promote the use of + hardware-accelerated transpositions. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + Option<"hoistLayoutConversion", "hoist-layout-conversion", + "bool", /*default*/"true", + "whether to move conver to dot operand earlier pass elementwise ops"> + ]; +} + +def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> { + let summary = "coalesce"; + + let description = [{ + TODO + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + + +def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> { + let summary = "remove superfluous layout conversions"; + + let description = [{ + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + +} + +def TritonGPUOptimizeThreadLocality : Pass<"tritongpu-optimize-thread-locality", "mlir::ModuleOp"> { + let summary = "Reduce the cost of synchronization between threads in an SM"; + + let description = [{ + Today, this optimizes reduction yielded by loop to be thread-local until after the loop completes. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReorderInstructions: Pass<"tritongpu-reorder-instructions", "mlir::ModuleOp"> { + let summary = "Reorder instructions"; + + let description = "This pass reorder instructions so as to (1) decrease register pressure (e.g., by moving " + "conversions from shared memory before their first use) and (2) promote LLVM instruction " + "order more friendly to `ptxas`."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUReduceDataDuplication: Pass<"tritongpu-reduce-data-duplication", "mlir::ModuleOp"> { + let summary = "Reduce data duplication in register by decomposing convert[distributed -> dotOperand] " + "into convert[distributed -> shared -> dotOperand]"; + + let description = "Decomposing conversions this way makes it possible to use CSE and reuse #shared tensors"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and-if", "mlir::ModuleOp"> { + let summary = "Combine tensor select and if"; + + let description = "For select instruction that uses the same condidtion as the if instruction in the same block " + "this pass combines the select into the if instruction, making the select operands returned by the " + "then/else yields."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h new file mode 100644 index 000000000..fbfa235fc --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while converting to the TritonGPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class TritonGPUTypeConverter : public TypeConverter { +public: + TritonGPUTypeConverter(MLIRContext *context, int numWarps, int threadsPerWarp, + int numCTAs); + int getNumWarps() const { return numWarps; } + int getThreadsPerWarp() const { return threadsPerWarp; } + int getNumCTAs() const { return numCTAs; } + +private: + MLIRContext *context; + int numWarps; + int threadsPerWarp; + int numCTAs; +}; + +class TritonGPUConversionTarget : public ConversionTarget { + +public: + explicit TritonGPUConversionTarget(MLIRContext &ctx, + TritonGPUTypeConverter &typeConverter); +}; + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_TRITONGPUCONVERSION_H_ diff --git a/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/Utility.h new file mode 100644 index 000000000..114c18142 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -0,0 +1,177 @@ +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include +#include + +namespace mlir { + +namespace triton { +class ModuleAxisInfoAnalysis; +class LoadOp; +class StoreOp; +class FuncOp; +namespace gpu { +class SharedEncodingAttr; +} +} // namespace triton + +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + TensorOrMemDesc type, + int numWarps); + +/// Returns true if the Load uses block pointer. +bool isLoadFromTensorPtr(triton::LoadOp op); + +// Return an array of indices enumerating the elements of 'arr' in descending +// order (so that result[i] is the index of the i-th largest element of 'arr') +SmallVector argSort(const SmallVector &arr); + +// Return the operand used to access the memory in the operation +Value getMemAccessPtr(Operation *op); + +// Return bitwidth of tensor element +unsigned getElementBitWidth(RankedTensorType type); + +// Calculate the optimal number of elements per thread for a given operation +// along an axis with greatest continuity. +unsigned +getNumElementsPerThread(Operation *op, SmallVector order, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis); + +/* Dump Triton IR in graphviz dot format. + * + * You can override `onValue` and `onOperation` in a subclass to mark + * specific Values and Operations. The below subclass + * GraphLayoutMarker is an example. + * + * Default NodeInfo for Value nodes: + * {{"shape": "box"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", shapeStr}} + * + * Default NodeInfo for Operation nodes: + * {{"shape": "ellipse"}, + * {"style", "filled"}, + * {"fillcolor", "white"}, + * {"label", operationName}} + * + * If the key "label" is not set by `onValue` or `onOperation`, default labels + * will be generated. For Value node, the default label is the shape string and + * for Operation node, it is the operation name. + * + * Reference: + * https://graphviz.org/doc/info/shapes.html + * https://graphviz.org/doc/info/colors.html + * + * Usage: + * C++: GraphDumper().dumpToFile(func, "func.dot"); + * Shell: dot -Tjpg func.dot -o func.jpg + */ +class GraphDumper { +public: + using NodeInfo = std::map; + + // Override this function to mark specific Values + virtual NodeInfo onValue(Value value) const; + // Override this function to mark specific Operations + virtual NodeInfo onOperation(Operation *op) const; + + std::string dump(triton::FuncOp func) const; + void dumpToFile(triton::FuncOp func, const std::string &filename) const; + +protected: + std::string getShapeStr(const Type &type) const; + + std::string getUniqueId(Value value) const; + std::string getUniqueId(Operation *op) const; + + std::string emitNode(const std::string &id, const NodeInfo style) const; + std::string emitEdge(const std::string &srcId, + const std::string &destId) const; + + std::string emitValueNode(Value value) const; + std::string emitOperationNode(Operation *op) const; +}; + +/* A subclass of GraphDumper that marks different layout kinds in different + * colors.*/ +class GraphLayoutMarker : public GraphDumper { +public: + NodeInfo onValue(Value value) const override; + +protected: + std::string getColor(const Type &type) const; +}; + +// Infers the encoding of the result of op given the source encoding. +std::optional inferDstEncoding(Operation *op, Attribute encoding); + +// Infers the encoding of the source of op given the result encoding. +std::optional inferSrcEncoding(Operation *op, Attribute encoding); + +bool isExpensiveLoadOrStore(Operation *op); + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding); + +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separately for the loop to be correct. +scf::ForOp replaceForOpWithNewSignature( + RewriterBase &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements); +scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, + ValueRange newIterOperands); + +// Replace IfOp with a new IfOp with extra results operands. The YieldOp is not +// updated and needs to be updated separately for the bodies to be correct. +scf::IfOp replaceIfOpWithNewSignature( + RewriterBase &rewriter, scf::IfOp loop, TypeRange newResultTypes, + SmallVectorImpl> &replacements); + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping); + +// Get backward slice of tensor values starting from the root node along with +// encoding propagation. +LogicalResult getConvertBackwardSlice( + Value root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation = nullptr); + +// Populate pattern to remove dead cycles in ForOp. +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns); + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order); + +SmallVector delinearize(OpBuilder &b, Location loc, unsigned linear, + ArrayRef shape); + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape); +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape); + +// Return true if the op is a pure elementwise_inline_asm op with a single +// operand and single result. +bool isPureUnaryInlineAsm(Operation *op); + +// read the compute capability from the module attributes +int getNVIDIAComputeCapability(Operation *module); + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..b7ce83fe7 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -0,0 +1,19 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu) +add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonNvidiaGPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td) +mlir_tablegen(TritonNvidiaGPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonNvidiaGPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonNvidiaGPUAttrDefsIncGen) diff --git a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.h b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h similarity index 63% rename from unittest/Conversion/TritonGPUToLLVM/DumpLayout.h rename to third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h index e0014ad3d..279faf9a4 100644 --- a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.h +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -21,26 +21,24 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#ifndef TRITON_UNITTEST_CONVERSION_TRITONGPU_TO_LLVM_DUMP_LAYOUT_H -#define TRITON_UNITTEST_CONVERSION_TRITONGPU_TO_LLVM_DUMP_LAYOUT_H +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ -#include "triton/Dialect/TritonGPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -namespace gpu { +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" -// Try to "interpret" the MLIR Value into an integer. -int evalValue(Value value, int ctaid, int tid); - -std::string dumpDistributedLayout(Attribute layout, - llvm::ArrayRef shape, bool multiCTA); +// TritonNvidiaGPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" -std::string dumpSharedLayout(Attribute layout, llvm::ArrayRef shape, - Type elemTy, bool multiCTA); +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc" -} // namespace gpu -} // namespace triton -} // namespace mlir +#define GET_OP_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc" -#endif +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td new file mode 100644 index 000000000..936535bb0 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td @@ -0,0 +1,29 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_ATTRDEFS +#define TRITONNVIDIAGPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td new file mode 100644 index 000000000..67ece715d --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -0,0 +1,67 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_DIALECT +#define TRITONNVIDIAGPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonNvidiaGPU_Dialect : Dialect { + let name = "triton_nvidia_gpu"; + + let cppNamespace = "::mlir::triton::nvidia_gpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton Nvidia GPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "triton::gpu::TritonGPUDialect", + "mlir::gpu::GPUDialect", + "tensor::TensorDialect", + ]; + + let extraClassDeclaration = [{ + static std::string getNumWarpsAttrName() { return "triton_gpu.num-warps"; } + static int getNumWarps(ModuleOp mod) { + if(!mod->hasAttr("triton_gpu.num-warps")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-warps attribute"); + return cast(mod->getAttr("triton_gpu.num-warps")).getInt(); + } + static int getNumCTAs(ModuleOp mod) { + if(!mod->hasAttr("triton_gpu.num-ctas")) + llvm::report_fatal_error( + "TritonGPU module should contain a triton_gpu.num-ctas attribute"); + return cast(mod->getAttr("triton_gpu.num-ctas")).getInt(); + } + void registerTypes(); + }]; + + let useDefaultTypePrinterParser = 1; +} + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td" + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td new file mode 100644 index 000000000..486bbf553 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -0,0 +1,246 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_OPS +#define TRITONNVIDIAGPU_OPS + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td" +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; + +class TTNG_Op traits = []> : + Op { +} + +def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { + let arguments = (ins BoolAttr:$bCluster); + + let summary = "fence proxy async"; + + let assemblyFormat = "attr-dict"; + + let extraClassDeclaration = [{ + static bool isSupported(int computeCapability) { + return computeCapability >= 90; + } + }]; +} + +def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> { + let arguments = (ins I1Attr:$relaxed); + let assemblyFormat = "attr-dict"; +} + +def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> { + let assemblyFormat = "attr-dict"; +} + +// +// DotAsync Op +// +def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot async"; + + let description = [{ + $d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp + }]; + + let arguments = (ins TT_TensorOrMemDesc:$a, + TT_TensorOrMemDesc:$b, + TT_FpIntTensor:$c, + TT_InputPrecisionAttr:$inputPrecision, + I32Attr:$maxNumImpreciseAcc); + + let results = (outs TT_FpIntTensor:$d); + + let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; +} + +def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods, + AllTypesMatch<["inputs", "outputs"]>]> { + let summary = "dot wait"; + let arguments = (ins Variadic:$inputs, I32Attr:$pendings); + let results = (outs Variadic:$outputs); + let description = [{ + Waits until there are $pendings or fewer outstanding async dot operations. + + $inputs must be the tensors corresponding to the async dot ops that we're + waiting on. For example, if there are N pending async dot ops and we call + `dot_wait 1`, then $inputs must be the result of the first dot op. + }]; + + let assemblyFormat = "$inputs attr-dict `:` type($inputs)"; +} + +def TTNG_InitBarrierOp : TTNG_Op<"init_barrier", [DeclareOpInterfaceMethods]> { + let summary = "Initialize a barrier in the given shared memory allocation."; + + let description = [{ + Initializes a shared memory allocation with mbarrier information. + `alloc` is a descriptor to the shared memory allocation. `count` is the + number of arrives expected by the barrier. + + This lowers to PTX mbarrier.init.shared::cta.b64. + }]; + + let hasVerifier = 1; + let arguments = (ins TT_MemDescType:$alloc, + I32Attr:$count); + let assemblyFormat = "$alloc `,` $count attr-dict `:` type($alloc)"; +} + +def TTNG_InvalBarrierOp : TTNG_Op<"inval_barrier", [DeclareOpInterfaceMethods]> { + let summary = "Invalidate a barrier allocation."; + + let description = [{ + Invalidate a barrier allocation so that it can be re-used. According to PTX + spec this has to be done before any re-use of the memory used by mbarrier. + + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval + }]; + + let hasVerifier = 1; + let arguments = (ins TT_MemDescType:$alloc); + let assemblyFormat = "$alloc attr-dict `:` type($alloc)"; +} + +def TTNG_BarrierExpectOp : TTNG_Op<"barrier_expect", [DeclareOpInterfaceMethods]> { + let summary = "Signal a barrier of an expected number of bytes to be copied."; + + let description = [{ + This signal the barrier that `size` bytes are expected to be copied. The + associated barrier wait will block until the expected number of bytes are copied. + }]; + + let hasVerifier = 1; + let arguments = ( + ins TT_MemDescType:$alloc, + I32Attr:$size, + I1:$pred + ); + + let assemblyFormat = [{ + $alloc `,` $size attr-dict `,` $pred `:` type($alloc) + }]; +} + +def TTNG_WaitBarrierOp : TTNG_Op<"wait_barrier", [DeclareOpInterfaceMethods]> { + let summary = "wait until the mbarrier phase completes."; + + let description = [{ + Blocks the program progress until the mbarrier object in `alloc` completes + its current phase. + + This lowers a waitloop using PTX instruction + mbarrier.try_wait.parity.shared.b64. + + The barrier behavior is described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms + }]; + + let hasVerifier = 1; + let arguments = (ins TT_MemDescType:$alloc, + I32:$phase); + let assemblyFormat = "$alloc `,` $phase attr-dict `:` type($alloc)"; +} + + +def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local", [DeclareOpInterfaceMethods]> { + let summary = "copy data based on descriptor from global memory to local memory asynchronously"; + + let description = [{ + This operation copies data from global memory to local memory + asynchronously. This is analogue to tt.load except the data are copied to + local memory pointed by the memory descriptor instread of a distributed + tensor. The data copied depends on the global memory descriptor pointed to + by `desc_ptr`. + }]; + + let hasVerifier = 1; + let arguments = ( + ins TT_PtrType:$desc_ptr, + Variadic:$coord, + TT_MemDescType:$barrier, + TT_MemDescType:$result, + I1:$pred, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict, + DefaultValuedAttr:$isVolatile + ); + + let assemblyFormat = [{ + $desc_ptr `[` $coord `]` $result `,` $barrier `,` $pred + oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict) + attr-dict `:` type($desc_ptr) `,` type($barrier) `->` type($result) + }]; +} + +def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global", [DeclareOpInterfaceMethods]> { + let summary = "copy data based on descriptor from local memory to global memory asynchronously"; + + let description = [{ + This operation copies data from local memory to global memory + asynchronously. This is analogue to tt.store except the data are copied from + local memory pointed by the memory descriptor instread of a distributed + tensor. The data copied depends on the global memory descriptor pointed to + by `desc_ptr`. + }]; + + let arguments = ( + ins TT_PtrType:$desc_ptr, + Variadic:$coord, + TT_MemDescType:$src); + + let assemblyFormat = [{ + $desc_ptr `[` $coord `]` $src + attr-dict `:` type($desc_ptr) `,` type($src) + }]; +} + +def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> { + let summary = "wait until all the inputs are read."; + let arguments = (ins I32Attr:$pendings); + let description = [{ + Wait until all the read operations are done from the associated store operations. + This is needed before the shared memory can be written to. + }]; + + let assemblyFormat = "attr-dict"; +} + + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td new file mode 100644 index 000000000..d3126f8a0 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td @@ -0,0 +1,37 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_TYPES +#define TRITONNVIDIAGPU_TYPES + +include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class TTNG_TypeDef + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTNG_TokenType : TTNG_TypeDef<"Token", "token">; + +def TTNG_MutexType : TTNG_TypeDef<"Mutex", "mutex">; + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/Types.h b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/Types.h new file mode 100644 index 000000000..63c7a091a --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/IR/Types.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITONNVIDIAGPU_IR_TYPES_H_ +#define TRITONNVIDIAGPU_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..d4b5c097f --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonNvidiaGPU) +add_public_tablegen_target(TritonNvidiaGPUTransformsIncGen) diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h new file mode 100644 index 000000000..174403138 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +// Used by Triton runtime +struct ClusterInfo { + ClusterInfo() : clusterDimX(1), clusterDimY(1), clusterDimZ(1) {} + int clusterDimX; + int clusterDimY; + int clusterDimZ; +}; + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir + +namespace mlir { + +std::unique_ptr createTritonNvidiaGPUPlanCTAPass( + mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr); + +std::unique_ptr +createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90); + +std::unique_ptr createTritonNvidiaGPUTMALoweringPass(); + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +} // namespace mlir +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_ diff --git a/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td new file mode 100644 index 000000000..6fe71ade2 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td @@ -0,0 +1,77 @@ +// Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files +// (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of the Software, +// and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#ifndef TRITONNVIDIAGPU_PASSES +#define TRITONNVIDIAGPU_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp"> { + let summary = "plan CTA"; + + let description = [{ + Plan CTAs in CGA + }]; + + let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> { + let summary = "Insert fences across generic and async proxy"; + + let description = [{ + }]; + + let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()"; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"computeCapability", "compute-capability", + "int32_t", /*default*/"90", + "device compute capability"> + ]; +} + + +def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::ModuleOp"> { + let summary = "lower to TMA load/store operations"; + + let description = [{ + Lower Triton experimental descriptor load to TMA load/store operations in TritonNvidiaGPUDialect. + }]; + + let constructor = "mlir::createTritonNvidiaGPUTMALoweringPass()"; + + let dependentDialects = [ + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; +} + +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/TritonXPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/IR/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/CMakeLists.txt new file mode 100644 index 000000000..829dd31ea --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/CMakeLists.txt @@ -0,0 +1,28 @@ +set(LLVM_TARGET_DEFINITIONS TritonXPUOps.td) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +# add_mlir_doc(TritonXPUOps TritonXPUOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS TritonXPUDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_xpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_xpu) +# add_mlir_doc(TritonXPUDialect TritonXPUDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS TritonXPUTypes.td) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_xpu) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_xpu) + +set(LLVM_TARGET_DEFINITIONS TritonXPUInterfaces.td) +mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) +add_public_tablegen_target(TritonXPUTableGen) + + +set(LLVM_TARGET_DEFINITIONS TritonXPUAttrDefs.td) +mlir_tablegen(TritonXPUAttrInterfaces.h.inc -gen-attr-interface-decls) # TritonXPU_AttrTrait +mlir_tablegen(TritonXPUAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(TritonXPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonXPUAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(TritonXPUAttrDefsIncGen) diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/IR/Dialect.h b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/Dialect.h new file mode 100644 index 000000000..1bf7059ac --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/Dialect.h @@ -0,0 +1,63 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITON_DIALECT_TRITONXPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONXPU_IR_DIALECT_H_ + +// TritonXPUDialect +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // cf +#include "triton/Dialect/Triton/IR/Dialect.h" // arith/scf/math/triton +#include "triton/Dialect/TritonGPU/IR/Dialect.h" // SliceEncodingAttr + +#include "triton/Dialect/TritonXPU/IR/Dialect.h.inc" // TritonXPUDialect + +// TritonXPUAttr +#include "mlir/IR/Attributes.h" +#include "triton/Dialect/TritonXPU/IR/TritonXPUAttrInterfaces.h.inc" +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonXPU/IR/TritonXPUAttrDefs.h.inc" + +// TritonXPUOps +#define GET_OP_CLASSES +#include "triton/Dialect/TritonXPU/IR/Ops.h.inc" + +// TritonXPUTypes +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonXPU/IR/Types.h.inc" + +namespace mlir { +namespace triton { +namespace xpu { + +unsigned getTotalElemsPerThread(Type eltTy); + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, + Type eltTy); + +unsigned getGroupSize(Attribute layout); + +// Return a blocked encoding where the shape is distributed contiguously amongst +// the threads, warps, CTAs with 1 element per threads. +triton::xpu::ClusterLayoutAttr +getDefaultClusterEncoding(MLIRContext *context, ArrayRef shape, + uint32_t buffer_size, uint32_t core_num); + +SmallVector +getCoresPerClusterWithUniqueData(Attribute layout, + ArrayRef tensorShape); + +SmallVector +getCoresPerGroupWithUniqueData(Attribute layout, ArrayRef tensorShape); + +SmallVector getUniqueContigPerCore(Attribute layout, + ArrayRef shape); + +} // namespace xpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONXPU_IR_DIALECT_H_ diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUAttrDefs.td b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUAttrDefs.td new file mode 100644 index 000000000..b54924e97 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUAttrDefs.td @@ -0,0 +1,200 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITONXPU_ATTRDEFS +#define TRITONXPU_ATTRDEFS + +include "mlir/IR/Interfaces.td" // AttrInterface +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonXPU/IR/TritonXPUDialect.td" + +//===----------------------------------------------------------------------===// +// TritonXPU Attribute Definitions +//===----------------------------------------------------------------------===// +def TritonXPU_AttrTrait : AttrInterface<"TritonXPU_AttrTrait"> { + let cppNamespace = "::mlir::triton::xpu"; + + let methods = [ + InterfaceMethod<"Get the shape of the values per core.", + "SmallVector", + "getSizePerCoreInterface">, + + InterfaceMethod<"Get the shape of the core per group", + "SmallVector", + "getCoresPerGroupInterface">, + + InterfaceMethod<"Get the shape of the groups per cluster.", + "SmallVector", + "getGroupsPerClusterInterface">, + + InterfaceMethod<"Get the shape of the cores per cluster.", + "SmallVector", + "getCoresPerClusterInterface">, + + InterfaceMethod<"Return total element size per thread.", + "unsigned", + "getTotalElemsPerThread", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy)>, + + InterfaceMethod<"Return element size per thread in each dimension.", + "SmallVector", + "getElemsPerThread", + (ins "ArrayRef":$tensorShape, + "Type":$eltTy)>, + ]; +} + +class TritonXPU_Attr traits = [], + Dialect dialect = TritonXPU_Dialect, + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + + let description = [{ + The base class of TritonXPU Encoding Attribute. + }]; + let attrName = "triton.xpu." # attrMnemonic; + + code extraBaseClassDeclaration = [{ + SmallVector getSizePerCoreInterface() const; + SmallVector getCoresPerGroupInterface() const; + SmallVector getGroupsPerClusterInterface() const; + + SmallVector getCoresPerClusterInterface() const; + + unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const; + SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const; + ::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const; + }]; +} + + +//===----------------------------------------------------------------------===// +// Cluster Layout +//===----------------------------------------------------------------------===// + + +def ClusterLayoutAttr : TritonXPU_Attr<"ClusterLayout", "cluster_layout"> { + let mnemonic = "cluster"; + + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$sizePerCore, + ArrayRefParameter<"unsigned">:$coresPerGroup, + ArrayRefParameter<"unsigned">:$groupsPerCluster, + ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first + "bool":$isReduceOpt + ); + + let description = [{ +For XPU hardware, the number of cores is the number of threads used. +We need try to make one thread to cover a continuous memory as long as +possible. + +Example 1, a row-major coalesced layout may partition a 32x16 tensor over 64 threads as follows. + +[ 0 0 0 0 0 0 0 0 ; 1 1 1 1 1 1 1 1 ] +[ 2 2 2 2 2 2 2 2 ; 3 3 3 3 3 3 3 3 ] +... +[ 60 60 60 60 60 60 60 60 ; 61 61 61 61 61 61 ] +[ 62 62 62 62 62 62 62 62 ; 63 63 63 63 63 63 ] + +for + +#triton_xpu.cluster_layout<{ + sizePerCore = {8} + coresPerGroup = {4} + groupsPerCluster = {16} +}> + + + +Example 2, a row-major coalesced layout may partition a 32x16 1D tensor over 64 threads as follows. + +[ 0 0 0 0 0 0 0 0 ; 0 0 0 0 0 0 0 0 ] +[ 1 1 1 1 1 1 1 1 ; 1 1 1 1 1 1 1 1 ] +... +[ 30 30 30 30 30 30 30 30 ; 30 30 30 30 30 30 30 30 ] +[ 31 31 31 31 31 31 31 31 ; 31 31 31 31 31 31 31 31 ] + +for + +#triton_xpu.cluster_layout<{ + sizePerCore = {16} + coresPerGroup = {4} + groupsPerCluster = {16} +}> + +core_32-core_63 will be idle + + + +Example 3, a row-major coalesced layout may partition a [32, 16] 2D tensor over 64 threads as follows. + +[ 0 0 0 0 0 0 0 0 ; 1 1 1 1 1 1 1 1 ] +[ 2 2 2 2 2 2 2 2 ; 3 3 3 3 3 3 3 3 ] +... +[ 60 60 60 60 60 60 60 60 ; 61 61 61 62 62 62 ] +[ 62 62 62 62 62 62 62 62 ; 63 63 63 63 63 63 ] + +for + +#triton_xpu.cluster_layout<{ + sizePerCore = {1, 8} + coresPerGroup = {1, 4} + groupsPerCluster = {1, 16} +}> + + + +Example 4, a row-major coalesced layout may partition a [32, 16] 2D tensor over 64 threads as follows. + +[ 0 0 0 0 0 0 0 0 ; 0 0 0 0 0 0 0 0 ] +[ 1 1 1 1 1 1 1 1 ; 1 1 1 1 1 1 1 1 ] +... +[ 30 30 30 30 30 30 30 30 ; 30 30 30 30 30 30 30 30 ] +[ 31 31 31 31 31 31 31 31 ; 31 31 31 31 31 31 31 31 ] + +for + +#triton_xpu.cluster_layout<{ + sizePerCore = {1, 16} + coresPerGroup = {4, 1} + groupsPerCluster = {16, 1} +}> + +core_32-core_63 will be idle + + + }]; + + let genVerifyDecl = 1; + + let builders = [ + AttrBuilder<(ins "ArrayRef":$shape, + "ArrayRef":$order, + "unsigned":$buffer_size, + "unsigned":$core_num), [{ + int rank = shape.size(); + SmallVector sizePerCore(rank, 1u); + SmallVector coresPerGroup(rank, 1u); + SmallVector groupsPerCluster(rank, 1u); + + coresPerGroup[rank-1] = core_num; + groupsPerCluster[rank-1] = 1; + sizePerCore[rank-1] = std::min(buffer_size, static_cast( + std::ceil(static_cast(shape[rank-1]) / core_num))); + + return $_get(context, sizePerCore, coresPerGroup, groupsPerCluster, order, /*isReduceOpt*/false); + }]>, + ]; + + let extraClassDeclaration = extraBaseClassDeclaration; + + // let skipDefaultBuilders = 1; // will skip get method(use parameters directly) + let hasCustomAssemblyFormat = 1; +} + +#endif // TRITONXPU_ATTRDEFS diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUDialect.td b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUDialect.td new file mode 100644 index 000000000..50f438577 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUDialect.td @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITONXPU_DIALECT +#define TRITONXPU_DIALECT + +include "mlir/IR/DialectBase.td" + + +//===----------------------------------------------------------------------===// +// TRITONXPU dialect definitions +//===----------------------------------------------------------------------===// + +def TritonXPU_Dialect : Dialect { + let name = "triton_xpu"; + + let cppNamespace = "::mlir::triton::xpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + TRITON XPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + ]; + + let useDefaultAttributePrinterParser = 1; +} + +#endif // TRITONXPU_DIALECT diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUInterfaces.td b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUInterfaces.td new file mode 100644 index 000000000..15891cc54 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUInterfaces.td @@ -0,0 +1,16 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITONXPU_INTERFACES +#define TRITONXPU_INTERFACES + + +include "mlir/Interfaces/SideEffectInterfaces.td" // MemoryEffects +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/ControlFlowInterfaces.td" // ReturnLike + +def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">; + +#endif // TRITONXPU_INTERFACES diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUOps.td b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUOps.td new file mode 100644 index 000000000..f2f9d3285 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUOps.td @@ -0,0 +1,414 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITONXPU_OPS +#define TRITONXPU_OPS + +include "mlir/IR/OpBase.td" // Trait +include "mlir/Dialect/Arith/IR/ArithBase.td" // Arith_CmpFPredicateAttr +include "triton/Dialect/TritonXPU/IR/TritonXPUDialect.td" // For TritonXPU_Dialect +include "triton/Dialect/TritonXPU/IR/TritonXPUAttrDefs.td" // For Attr +include "triton/Dialect/TritonXPU/IR/TritonXPUTypes.td" // For types +include "triton/Dialect/TritonXPU/IR/TritonXPUInterfaces.td" // For types + + +//===----------------------------------------------------------------------===// +// TRITONXPU op definitions +//===----------------------------------------------------------------------===// + + +class TTX_Op traits = []> : + Op; + + +//===----------------------------------------------------------------------===// +// Memory Ops +//===----------------------------------------------------------------------===// + +def TTX_AllocaOp : TTX_Op<"alloca"> { + let summary = "alloca"; + let arguments = (ins SI64Attr:$size); + let results = (outs TT_PtrLike:$result); +} + +def TTX_GM2LMOp : TTX_Op<"gm2lm", [AttrSizedOperandSegments]> { + let summary = "gm2lm"; + let arguments = (ins TT_PtrLike:$ptr, + Optional:$len, + Optional:$bufPtr, + SI32Attr:$offsetState, + SI32Attr:$fixedStride, + SI64Attr:$rowLen, + SI64Attr:$rowStride, + SI32Attr:$lrie, + SI32Attr:$tensorColSize, + BoolAttr:$SVOpt, + BoolAttr:$async, + BoolAttr:$atomicSim); + let results = (outs TT_PtrLike:$result); +} + +def TTX_LM2GMOp : TTX_Op<"lm2gm", [AttrSizedOperandSegments]> { + let summary = "lm2gm"; + let arguments = (ins TT_PtrLike:$ptr, + TTX_Type:$value, + Optional:$len, + Optional:$bufPtr, + SI32Attr:$tensorColSize, + SI32Attr:$offsetState, + SI64Attr:$rowLen, + SI64Attr:$rowStride, + BoolAttr:$atomicSim); +} + +def TTX_LoadOp : TTX_Op<"load", [AttrSizedOperandSegments, + MemoryEffects<[MemRead]>]> { + let summary = "load"; + let arguments = (ins TT_PtrLike:$ptr, + Optional:$mask, + Optional:$other, + Optional:$index, + SI32Attr:$stride, + SI32Attr:$tensorColSize, + BoolAttr:$isDiscrete, + BoolAttr:$SVOpt, + BoolAttr:$bf16Tofp32Unordered); + let results = (outs TT_Type:$result); +} + +def TTX_StoreOp : TTX_Op<"store", [AttrSizedOperandSegments, + MemoryEffects<[MemWrite]>]> { + let summary = "store"; + let arguments = (ins TT_PtrLike:$ptr, + TT_Type:$value, + Optional:$mask, + Optional:$index, + SI32Attr:$tensorColSize, + BoolAttr:$bf16Tofp32Unordered); +} + +def TTX_SM2GMOp : TTX_Op<"sm2gm", [AttrSizedOperandSegments]> { + let summary = "sm2gm"; + let arguments = (ins TT_PtrLike:$ptr, + Optional:$value, + Optional:$len, + Optional:$bufPtr, + SI32Attr:$offsetState); +} + +def TTX_StoreSMOp : TTX_Op<"storeSM", [MemoryEffects<[MemWrite]>, AttrSizedOperandSegments]> { + let summary = "store to SM"; + let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional:$mask, Optional:$index); +} + + +//===----------------------------------------------------------------------===// +// Context Ops +//===----------------------------------------------------------------------===// + +def TTX_GetCoreIdOp : TTX_Op<"get_core_id"> { + let summary = "get_core_id"; + let results = (outs TTX_Type:$result); +} + +def TTX_GetThreadIdOp : TTX_Op<"get_thread_id"> { + let summary = "get_thread_id"; + let description = [{ + threadType(0): tid = core_id() * cluster_num() + cluster_id() + threadType(1): tid = core_num() * cluster_id() + core_id() + }]; + let arguments = (ins SI32Attr:$threadType); + let results = (outs TTX_Type:$result); +} + +def TTX_GetClusterIdOp : TTX_Op<"get_cluster_id", [NoMemoryEffect]> { + let summary = "Get the ID of the current Cluster"; + let results = (outs I32:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTX_GetNumClusterOp : TTX_Op<"get_num_cluster", [NoMemoryEffect]> { + let summary = "Get the number of launched Cluster"; + let results = (outs I32:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + +//===----------------------------------------------------------------------===// +// MakeRange-Liks Ops +//===----------------------------------------------------------------------===// +// TODO[dyq]: combine TTX_MakeRangeOp && TTX_InterleaveOp && TTX_OutRangeOp + +// borrowed from triton/include/triton/Dialect/Triton/IR/TritonOps.td +def TTX_MakeRangeOp : TTX_Op<"make_range", [AttrSizedOperandSegments, + Pure]> { + let summary = "make range"; + + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + + // WARNING: MLIR generates getStart()/getEnd() functions which return + // uint32_t, even though these arguments are to be interpreted as *signed* + // int32 values. If this matters, use get{Start,End}Attr().getInt(), which + // return int64_t. + let arguments = (ins I32Attr:$start, + I32Attr:$end, + I32Attr:$realSize, + Optional:$loopIndex, + Optional:$unrollIndex); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = "$loopIndex $unrollIndex attr-dict `:` type($loopIndex) type($unrollIndex) `->` type($result)"; + + // let hasFolder = 1; + let hasVerifier = 1; +} + +def TTX_InterleaveOp : TTX_Op<"interleave", [AttrSizedOperandSegments, + Pure]> { + let summary = "interleave"; + + let description = [{ + Returns an 1D int32 tensor. + + Values span from $start to $end (exclusive), with step = 1 + }]; + + // WARNING: MLIR generates getStart()/getEnd() functions which return + // uint32_t, even though these arguments are to be interpreted as *signed* + // int32 values. If this matters, use get{Start,End}Attr().getInt(), which + // return int64_t. + let arguments = (ins I32Attr:$start, + I32Attr:$end, + Optional:$loopIndex, + Optional:$unrollIndex); + + let results = (outs TT_IntTensor:$result); + + let assemblyFormat = "$loopIndex $unrollIndex attr-dict `:` type($loopIndex) type($unrollIndex) `->` type($result)"; + + // let hasFolder = 1; + let hasVerifier = 1; +} + +def TTX_OutRangeOp : TTX_Op<"out_range"> { + let summary = "out_range"; + let description = [{ + Returns idx. + }]; + let arguments = (ins I32Attr:$groupsize, I32Attr:$rowspercore, TT_Int:$index); + let results = (outs TT_IntTensor:$result); +} + +//===----------------------------------------------------------------------===// +// Vectorization Ops +//===----------------------------------------------------------------------===// + +class TTX_VUnaryFOp traits = []> : + TTX_Op, + Arguments<(ins TTX_VectorLike:$value)>, + Results<(outs TTX_VectorLike:$result)> { + let assemblyFormat = "$value `,` attr-dict `:` type($result)"; +} + +class TTX_BinaryOp traits = []> : + TTX_Op { + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; +} + +class TTX_VVIntBinaryOp traits = []> : + TTX_BinaryOp, + Arguments<(ins TTX_VectorLike:$lhs, TTX_VectorLike:$rhs)>, + Results<(outs TTX_VectorLike:$result)>; + +class TTX_VVFloatBinaryOp traits = []> : + TTX_BinaryOp, + Arguments<(ins TTX_VectorLike:$lhs, TTX_VectorLike:$rhs)>, + Results<(outs TTX_VectorLike:$result)>; + +class TTX_SVFloatBinaryOp traits = []> : + TTX_Op, + Arguments<(ins TTX_Type:$lhs, TTX_Type:$rhs, I32Attr:$elemState)>, + Results<(outs TTX_VectorLike:$result)> { + let summary = mnemonic; +} + +def TTX_VExpFOp : TTX_VUnaryFOp<"vexpf">; +def TTX_VLogFOp : TTX_VUnaryFOp<"vlogf">; +def TTX_VSinFOp : TTX_VUnaryFOp<"vsinf">; +def TTX_VCosFOp : TTX_VUnaryFOp<"vcosf">; +def TTX_VSqrtFOp : TTX_VUnaryFOp<"vsqrtf">; +def TTX_VAbsFOp : TTX_VUnaryFOp<"vabsf">; +def TTX_VSIToFPOp : TTX_Op<"vsitofp"> { + let summary = "vsitofp"; + let arguments = (ins TTX_VectorLike:$value); + let results = (outs TTX_VectorLike:$result); +} + +def TTX_VvaddFOp : TTX_VVFloatBinaryOp<"vvaddf">; +def TTX_VvsubFOp : TTX_VVFloatBinaryOp<"vvsubf">; +def TTX_VvmulFOp : TTX_VVFloatBinaryOp<"vvmulf">; +def TTX_VvdivFOp : TTX_VVFloatBinaryOp<"vvdivf">; +def TTX_VvmaxFOp : TTX_VVFloatBinaryOp<"vvmaxf">; +def TTX_VvminFOp : TTX_VVFloatBinaryOp<"vvminf">; +def TTX_VvmaxNumFOp : TTX_VVFloatBinaryOp<"vvmaxnumf">; +def TTX_VvminNumFOp : TTX_VVFloatBinaryOp<"vvminnumf">; + +def TTX_SvaddFOp : TTX_SVFloatBinaryOp<"svaddf">; +def TTX_SvmulFOp : TTX_SVFloatBinaryOp<"svmulf">; +def TTX_SvsubFOp : TTX_SVFloatBinaryOp<"svsubf">; +def TTX_SvmaxFOp : TTX_SVFloatBinaryOp<"svmaxf">; + +def TTX_VvorIOp : TTX_VVIntBinaryOp<"vvori">; +def TTX_VvxorIOp : TTX_VVIntBinaryOp<"vvxori">; +def TTX_VvandIOp : TTX_VVIntBinaryOp<"vvandi">; +def TTX_VvaddIOp : TTX_VVIntBinaryOp<"vvaddi">; +def TTX_VvsubIOp : TTX_VVIntBinaryOp<"vvsubi">; +def TTX_VvmulIOp : TTX_VVIntBinaryOp<"vvmuli">; + +def TTX_VMacFOp : TTX_Op<"vmacf", [SameOperandsAndResultType]> { + let summary = "vmacf"; + let arguments = (ins TTX_VectorLike:$value, TTX_VectorLike:$mulData, TTX_VectorLike:$addData); + let results = (outs TTX_VectorLike:$result); + let assemblyFormat = "$value `,` $mulData `,` $addData attr-dict `:` type($result)"; +} + +def TTX_VConstOp : TTX_Op<"vconst"> { + let summary = "vconst"; + let arguments = (ins AnyAttr:$value); + let results = (outs TTX_VectorLike:$result); +} + +def TTX_VSplatOp : TTX_Op<"vsplat"> { + let summary = "vsplat"; + let arguments = (ins TTX_Type:$src); + let results = (outs TTX_VectorLike:$result); +} + +def TTX_VSelectOp : TTX_Op<"vselect"> { + let summary = "vselect"; + let arguments = (ins TTX_Type:$condition, + TTX_Type:$true_value, + TTX_Type:$false_value); + let results = (outs TTX_VectorLike:$result); +} + +def TTX_VExtFOp : TTX_Op<"vextf"> { + let summary = "vextf"; + let arguments = (ins TTX_VectorLike:$value); + let results = (outs TTX_VectorLike:$result); +} + +def TTX_VTruncFOp : TTX_Op<"vtruncf"> { + let summary = "vtruncf"; + let arguments = (ins TTX_VectorLike:$value); + let results = (outs TTX_VectorLike:$result); +} + +def TTX_VCmpFOp : TTX_Op<"vcmpf"> { + let summary = "vcmpf"; + let arguments = (ins Arith_CmpFPredicateAttr:$predicate, + TTX_VectorLike:$lhs, + TTX_VectorLike:$rhs); + let results = (outs TTX_VectorLike:$result); +} + +//===----------------------------------------------------------------------===// +// Other Ops +//===----------------------------------------------------------------------===// + +def TTX_MinOp : TTX_Op<"min"> { + let summary = "min"; + let arguments = (ins TTX_Type:$lhs, TTX_Type:$rhs); + let results = (outs TTX_Type:$res); +} + +def TTX_ExtractOp : TTX_Op<"extract"> { + let summary = "extract the element from the tensor"; + let description = [{ + Returns an 1D int32. + Extract the element from the tensor according to index. + }]; + let arguments = (ins I32Attr:$index, TTX_Type:$tensor); + let results = (outs TTX_Type:$result); +} + +def TTX_ExtractSliceOp : TTX_Op<"extract_slice"> { + let summary = "extract slice a part of elements from the tensor"; + let description = [{ + Returns an 1D int32. + Truncate a part of elements from the tensor. + }]; + let arguments = (ins TTX_Type:$tensor); + let results = (outs TTX_Type:$result); +} + +// borrowed from triton/include/triton/Dialect/Triton/IR/TritonOps.td +def TTX_ConvertLayoutOp : TTX_Op<"convert_layout", + [SameOperandsAndResultShape, + SameOperandsAndResultElementType, + Pure]> { + let summary = "convert layout"; + + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +// borrowed from triton/include/triton/Dialect/Triton/IR/TritonOps.td +def TTX_ReduceOp: TTX_Op<"reduce", + [Pure, + SameOperandsEncoding, + SingleBlock, + DeclareOpInterfaceMethods]> { + let summary = "Reduction using generic combination algorithm"; + let arguments = (ins Variadic:$srcs, I32Attr:$axis, TT_Int:$loopIndex); + let results = (outs Variadic:$result); + let regions = (region SizedRegion<1>:$combineOp); + let builders = [ + OpBuilder<(ins "ValueRange":$srcs, "int":$axis, "Value":$loopIndex)>, + ]; + let hasVerifier = 1; + let hasRegionVerifier = 1; + let extraClassDeclaration = [{ + llvm::SmallVector getInputTypes(); + llvm::SmallVector getElementTypes(); + unsigned getNumOperands(); + }]; +} + +def TTX_ReduceReturnOp: TTX_Op<"reduce.return", + [HasParent<"ReduceOp">, Pure, Terminator, ReturnLike]> { + let summary = "terminator for reduce operator"; + let arguments = (ins Variadic:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + +// borrowed from triton/include/triton/Dialect/Triton/IR/TritonOps.td +def TTX_BroadcastOp : TTX_Op<"broadcast", [Pure]> { + let summary = "broadcast a tensor"; + + let description = [{ + For a given tensor, broadcast changes one or more dimensions with size 1 + to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot + change the size of a non-1 dimension. + }]; + + let arguments = (ins TTX_TensorVector:$src); + + let results = (outs TTX_TensorVector:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + + let hasCanonicalizeMethod = 1; + let hasFolder = 1; +} + +#endif // TRITONXPU_OPS diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUTypes.td b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUTypes.td new file mode 100644 index 000000000..4056d70a1 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/IR/TritonXPUTypes.td @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITONXPU_TYPES +#define TRITONXPU_TYPES + +include "triton/Dialect/Triton/IR/TritonTypes.td" // For TT_FloatLike + + +def TTX_Vector : FixedVectorOf<[TT_FloatLike, TT_IntLike]>; +def TTX_TensorVector: AnyTypeOf<[TT_Tensor, TensorOf<[TTX_Vector]>]>; +def TTX_VectorLike: AnyTypeOf<[TTX_Vector, TensorOf<[TTX_Vector]>]>; +def TTX_Type: AnyTypeOf<[TT_Type, TTX_VectorLike]>; + +#endif // TRITONXPU_TYPES diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/CMakeLists.txt b/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..3ed451055 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonXPU) +add_public_tablegen_target(TritonXPUTransformsIncGen) diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/Passes.h b/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/Passes.h new file mode 100644 index 000000000..9c00c6369 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/Passes.h @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITON_DIALECT_TRITONXPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONXPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/UtilityXPU.h" // helper +#include "triton/Dialect/TritonXPU/IR/Dialect.h" // dependentDialects +#include "llvm/ADT/TypeSwitch.h" // TypeSwitch +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" // llvm_unreachable + +namespace mlir { +namespace triton { +namespace xpu { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +} // namespace xpu +} // namespace triton +} // namespace mlir +#endif diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/Passes.td b/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/Passes.td new file mode 100644 index 000000000..3beffb78b --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/Passes.td @@ -0,0 +1,207 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITONXPU_PASSES +#define TRITONXPU_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// Functionality Pass +//===----------------------------------------------------------------------===// + +def TritonXPUCreateGM2LM : Pass<"tritonxpu-create-gm2lm", "mlir::ModuleOp"> { + let summary = "Create GM2LM for XPU."; + + let description = [{ + tt.load => triton_xpu.gm2lm + triton_xpu.load; + tt.store => triton_xpu.store + triton_xpu.lm2gm. + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; + + let options = [ + Option<"xpuArch", "xpu-arch", + "uint32_t", /*default*/"3", + "XPU architecture">, + Option<"atomicSim", "atomic-sim", + "bool", /*default*/"1", + "Atomic Simulation"> + ]; +} + +def TritonXPULegalize : Pass<"tritonxpu-legalize", "mlir::ModuleOp"> { + let summary = "Legalize for XPU."; + + let description = [{ + Insert scf.for when local memory is smaller than block size. + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; + + let options = [ + Option<"bufferSize", "lm-buflen", + "uint32_t", /*default*/"128", + "buffer size for local memory">, + Option<"coreNum", "core-num", + "uint32_t", /*default*/"64", + "core num"> + ]; +} + +def TritonXPUMask : Pass<"tritonxpu-mask", "mlir::ModuleOp"> { + let summary = "Mask for Calculation."; + + let description = [{ + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; +} + +def TritonXPUAlloca : Pass<"tritonxpu-alloca", "mlir::ModuleOp"> { + let summary = "Alloca buffer for gm2lm and lm2gm."; + + let description = [{ + triton_xpu.gm2lm => triton_xpu.alloca + triton_xpu.gm2lm + triton_xpu.lm2gm => triton_xpu.alloca + triton_xpu.lm2gm + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; +} + + +def TritonXPUDtypeConvert : Pass<"tritonxpu-dtype-convert", "mlir::ModuleOp"> { + let summary = "Dtype Convert for XPU."; + + let description = [{ + XPU2: FP16 => FP32 + XPU3: BF16 => FP32 + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; + let options = [ + Option<"xpuArch", "xpu-arch", + "uint32_t", /*default*/"3", + "XPU architecture"> + ]; +} + +def TritonXPULoopGrid : Pass<"tritonxpu-loop-grid", "mlir::ModuleOp"> { + let summary = "Create loop on triton programs for out-of-bounds grid_size."; + + let description = [{ + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; +} + +def TritonXPUUnrollControl : Pass<"tritonxpu-unroll-control", "mlir::ModuleOp"> { + let summary = "Control the unroll size."; + + let description = [{ + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; +} + +//===----------------------------------------------------------------------===// +// Optimization Pass +//===----------------------------------------------------------------------===// + +def TritonXPUOffsetAnalysis : Pass<"tritonxpu-offset-analysis", "mlir::ModuleOp"> { + let summary = "Analysis Ptr's Offset State."; + + let description = [{ + Given buffer_size = 8 + + offsets = [0, 1, 2, 3, 4, 5, 6, 7] -> Continuous + offsets = [0, 0, 0, 0, 0, 0, 0, 0] -> DiscreteSame + offsets = [4, 5, 6, 7, 0, 1, 2 ,3] -> Discrete + offsets = [0, 1, 1001, 1002, 3, 4, 2001, 2002] -> Unknown + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; + + let options = [ + Option<"dumpFlag", "dump-flag", + "bool", /*default*/"0", + "detail dump flag"> + ]; +} + +def TritonXPUCoreTiling : Pass<"tritonxpu-core-tiling", "mlir::ModuleOp"> { + let summary = "Core Tiling Optimization."; + + let description = [{ + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; + + let options = [ + Option<"dumpFlag", "dump-flag", + "bool", /*default*/"0", + "detail dump flag">, + Option<"bufferSize", "lm-buflen", + "uint32_t", /*default*/"128", + "buffer size for local memory"> + ]; +} + + +def TritonXPUVectorize : Pass<"tritonxpu-vectorize", "mlir::ModuleOp"> { + let summary = "Vectorize Calculation."; + + let description = [{ + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; + + let options = [ + Option<"dumpFlag", "dump-flag", + "bool", /*default*/"0", + "detail dump flag"> + ]; +} + +def TritonXPUMemoryAsync : Pass<"tritonxpu-memory-async", "mlir::ModuleOp"> { + let summary = "Memory Async Optimization."; + + let description = [{ + }]; + + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; + + let options = [ + Option<"dumpFlag", "dump-flag", + "bool", /*default*/"0", + "detail dump flag">, + ]; +} + +def TritonXPUInterleave : Pass<"tritonxpu-interleave", "mlir::ModuleOp"> { + let summary = "Interleave for XPU."; + let description = [{ + Convert triton.make_range triton_xpu.interleave. + }]; + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; +} + +def TritonXPUStoreControl : Pass<"tritonxpu-store-control", "mlir::ModuleOp"> { + let summary = "Store Control for XPU."; + let description = [{ + Only Store isCoreId0InsideGroup=0 for ReduceOp + }]; + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; +} + +def TritonXPUOtherSim : Pass<"tritonxpu-other-sim", "mlir::ModuleOp"> { + let summary = "Simulate Other for XPU."; + let description = [{ + Simulate Other in LoadOp/StoreOp. + }]; + let dependentDialects = ["mlir::triton::xpu::TritonXPUDialect"]; +} + +#endif // TRITONXPU_PASSES diff --git a/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/TritonXPUConversion.h b/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/TritonXPUConversion.h new file mode 100644 index 000000000..74cc308a4 --- /dev/null +++ b/third_party/xpu/include/triton/Dialect/TritonXPU/Transforms/TritonXPUConversion.h @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while converting to the TritonXPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITONXPU_TRANSFORMS_TRITONGPUCONVERSION_H_ +#define TRITON_DIALECT_TRITONXPU_TRANSFORMS_TRITONGPUCONVERSION_H_ + +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/DialectConversion.h" // TypeConverter + +namespace mlir { + +class TritonXPUTypeConverter : public TypeConverter { +public: + TritonXPUTypeConverter(MLIRContext *context, uint32_t buffer_size, + uint32_t core_num); + uint32_t getBufferSize() const { return buffer_size; } + uint32_t getCoreNum() const { return core_num; } + +private: + MLIRContext *context; + uint32_t buffer_size; + uint32_t core_num; +}; + +class TritonXPUConversionTarget : public ConversionTarget { +public: + explicit TritonXPUConversionTarget(MLIRContext &ctx, + TritonXPUTypeConverter &typeConverter); +}; + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONXPU_TRANSFORMS_TRITONGPUCONVERSION_H_ diff --git a/third_party/xpu/include/triton/Target/CMakeLists.txt b/third_party/xpu/include/triton/Target/CMakeLists.txt new file mode 100644 index 000000000..39d31dc9b --- /dev/null +++ b/third_party/xpu/include/triton/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(LLVMIR) diff --git a/third_party/xpu/include/triton/Target/LLVMIR/CMakeLists.txt b/third_party/xpu/include/triton/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..1f6c1b351 --- /dev/null +++ b/third_party/xpu/include/triton/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVMIR) +add_public_tablegen_target(LLVMIRIncGen) diff --git a/third_party/xpu/include/triton/Target/LLVMIR/Passes.h b/third_party/xpu/include/triton/Target/LLVMIR/Passes.h new file mode 100644 index 000000000..27ecb5c3d --- /dev/null +++ b/third_party/xpu/include/triton/Target/LLVMIR/Passes.h @@ -0,0 +1,17 @@ +#ifndef TRITON_TARGET_LLVM_IR_PASSES_H +#define TRITON_TARGET_LLVM_IR_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Create a pass to add DIScope +std::unique_ptr createLLVMDIScopePass(); + +/// Generate the code for registering conversion passes. +#define GEN_PASS_REGISTRATION +#include "triton/Target/LLVMIR/Passes.h.inc" + +} // namespace mlir + +#endif // TRITON_TARGET_LLVM_IR_PASSES_H diff --git a/third_party/xpu/include/triton/Target/LLVMIR/Passes.td b/third_party/xpu/include/triton/Target/LLVMIR/Passes.td new file mode 100644 index 000000000..999b0b889 --- /dev/null +++ b/third_party/xpu/include/triton/Target/LLVMIR/Passes.td @@ -0,0 +1,15 @@ +#ifndef TRITON_TARGET_LLVMIR_PASSES +#define TRITON_TARGET_LLVMIR_PASSES + +include "mlir/Pass/PassBase.td" + +def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> { + let summary = "Materialize LLVM line info"; + let description = [{ + This pass materializes line mapping information for LLVM IR dialect operations. + }]; + + let constructor = "mlir::createLLVMDIScopePass()"; +} + +#endif diff --git a/third_party/xpu/include/triton/Target/LLVMXPU/LLVMXPUToLLVMIRTranslation.h b/third_party/xpu/include/triton/Target/LLVMXPU/LLVMXPUToLLVMIRTranslation.h new file mode 100644 index 000000000..b97ab8d92 --- /dev/null +++ b/third_party/xpu/include/triton/Target/LLVMXPU/LLVMXPUToLLVMIRTranslation.h @@ -0,0 +1,29 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for LLVMXPU dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_XPU_XPUTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_XPU_XPUTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the LLVMXPU dialect and the translation from it to the LLVM IR in +/// the given registry; +void registerLLVMXPUDialectTranslation(DialectRegistry ®istry); + +/// Register the LLVMXPU dialect and the translation from it in the registry +/// associated with the given context. +void registerLLVMXPUDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_XPU_XPUTOLLVMIRTRANSLATION_H diff --git a/third_party/xpu/include/triton/Tools/LinearLayout.h b/third_party/xpu/include/triton/Tools/LinearLayout.h new file mode 100644 index 000000000..fb2680241 --- /dev/null +++ b/third_party/xpu/include/triton/Tools/LinearLayout.h @@ -0,0 +1,532 @@ +#ifndef TRITON_TOOLS_LINEARLAYOUT_H +#define TRITON_TOOLS_LINEARLAYOUT_H + +#include +#include +#include +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::triton { + +// # High-level overview of linear layouts +// +// The idea for linear layouts is due to Adam P. Goucher. +// +// In Triton, a linear layout (LL) is a function that maps from a "hardware +// location" to a "logical tensor index". +// +// For example, suppose we have a 2D tensor T stored in GPU registers. T's +// layout is the function that, given a "hardware location" tuple of (thread-id, +// warp-id), returns an index (x,y) into T. In other words, if L(t,w) = (x,y) +// is our linear layout func, then a register in thread t in warp w contains the +// value T[x,y]. +// +// The key fact about LLs is, the mapping from (t,w) to (x,y) is not arbitrary. +// We only need to specify the value of L(t,w) at certain special points +// (namely, the values L(t,0) and L(0,w) where t and w are powers of 2), and +// from those we can compute all the other values of L. +// +// Here's an example LL where we have 4 warps and 4 threads per warp, and the +// tensor T has shape 4x4. We define the function L by choosing the values of +// L(0,1), L(0,2), L(1,0), and L(2,0). Our choices are shown below. +// +// t/w 0 1 2 3 +// 0 ? (0,1) (0,2) ? +// L(t,w) = 1 (1,1) ? ? ? +// 2 (2,2) ? ? ? +// 3 ? ? ? ? +// +// You only need to specify these four values to define the whole linear layout. +// These special values are called the "basis vectors" or "bases" of the layout. +// We complete the table by xor'ing together the bases, according to the +// following rule. (I write "⊕" for xor.) +// +// L(t1 ⊕ t2, w1 ⊕ w2) = L(t1, w1) ⊕ L(t2, w2) (linearity rule). +// +// The linearity rule plus our four choices allows us to fill in the whole +// table. Here's how we might compute some of the values. +// +// L(0,0) = L(1 ⊕ 1, 0 ⊕ 0) = L(1,0) ⊕ L(1,0) = (1,1) ⊕ (1,1) = (0,0) +// L(0,3) = L(0 ⊕ 0, 2 ⊕ 1) = L(0,2) ⊕ L(0,1) = (0,2) ⊕ (0,1) = (0,3) +// L(3,0) = L(2 ⊕ 1, 0 ⊕ 0) = L(2,0) ⊕ L(1,0) = (2,2) ⊕ (1,1) = (3,3) +// L(3,3) = L(3 ⊕ 0, 0 ⊕ 3) = L(3,0) ⊕ L(0,3) = (3,3) ⊕ (0,3) = (3,0). +// +// (Notice it's a consequence of the linearity rule that L(0,0) = (0,0), no +// matter what values we chose for the table.) +// +// The whole table looks like this. +// +// t/w 0 1 2 3 +// 0 (0,0) (0,1) (0,2) (0,3) +// L(t,w) = 1 (1,1) (1,0) (1,3) (1,2) +// 2 (2,2) (2,3) (2,0) (2,1) +// 3 (3,3) (3,2) (3,1) (3,0). +// +// Careful readers will recognize this as a classic "swizzled" layout where +// (t, w) -> (t, w ⊕ t). To go from this formula to an LL, you only need to +// compute the results at input points (0,1), (0,2), (1,0), and (2,0). + +// Indeed the whole point of LLs is that they allow us to specify transposed and +// swizzled layouts as a "general case". Instead of a layout class for +// registers in a thread, and another layout for registers in a thread but in +// MMAv2 order, and so on, all of these can be represented by different LLs. +// This gets rid of special cases and lets us write more general code. +// +// In this example, L was a 2D -> 2D function, but LLs are general MD -> ND +// functions. In practice, a GPU register layout usually has input dims (reg, +// thread-id, warp-id, block-id), where reg represents the fact that one thread +// may store values for the tensor in multiple registers. +// +// To summarize, a linear layout is a function from tuples of integers to tuples +// of integers. We specify some key values of the function, and then we can +// compute all the other values using the linearity rule. +// +// Here are the key things you can do with linear layout objects. +// +// 1. Given an LL, construct a new LL by modifying it or combining it with +// another LL. +// +// 2. "Apply" an LL, i.e. use it to map an input index to an output index. +// A function for this that uses LLVM-dialect MLIR as its input and output +// lives in TritonGPUToLLVM.h. +// +// 3. Convert an existing Triton layout (e.g. BlockedLayoutAttr) to an LL. +// These functions live in TritonGPU/LinearLayoutConversions.h. During +// TTGIR -> LLVM codegen, we convert Triton layouts to linear layouts and +// then apply them. In the future, we intend to remove the Triton layouts +// entirely. +// +// # Examples of linear layouts +// +// 1. The 1D identity layout. This maps L(x) = x. +// +// Recall that our bases are the values of L(x) where x is a power of two. +// So for e.g. an 8-element layout, we have L(1) = 1, L(2) = 2, L(4) = 4, and +// therefore our bases are [1, 2, 4]. +// +// 2. The 1D zeros layout. This maps L(x) = 0. +// +// For an 8-element layout, we have L(1) = L(2) = L(4) = 0, so our bases are +// [0, 0, 0]. +// +// 3. A 2D -> 2D identity layout. Our basis vectors are the values of L(x,0) +// and L(0,y) where x and y are powers of two. The bases are +// +// - L(0,1) = (0,1) +// - L(0,2) = (0,2) +// - L(1,0) = (1,0) +// - L(2,0) = (2,0). +// +// 4. A 2D -> 2D transpose layout. For a 4x4 layout, we have: +// +// - L(0,1) = (1,0) +// - L(0,2) = (2,0) +// - L(1,0) = (0,1) +// - L(2,0) = (0,2). +// +// 5. A 1D -> 1D "transpose" layout. Consider the 16-element layout that maps +// +// x = 0 1 2 3 4 5 6 7 8 9 A B C D E F +// L(x) = 0 4 8 C 1 5 9 D 2 6 A E 3 7 B F. +// +// The bases are [L(1), L(2), L(4), L(8)] = [4, 8, 1, 2]. You can also think +// of this as a rearrangement of the 1D identity layout [1, 2, 4, 8]. +// +// 6. A 2D -> 1D broadcasted layout. L(x,y) = x. For a 4x4 -> 4 layout, our +// bases are +// +// - L(0,1) = 0 +// - L(0,2) = 0 +// - L(1,0) = 1 +// - L(2,0) = 2. +// +// # Implementation notes +// +// ## Dimension order +// +// An LL's input and output dimensions have an order. This order only affects +// the reshapeIns/Outs operations, where the layout is logically flattened +// according to the dimension order and then chopped up again. +// +// ## Surjectivity +// +// We require that all output values are covered by some input value, i.e. the +// function L is surjective. But multiple input values can map to the same +// output value. This represents the idea that the same logical tensor element +// can be stored in multiple places in the hardware. +// +// ## Why map hardware loc -> tensor index and not the other way around? +// +// In Triton, a linear layout usually tells us which logical tensor value is +// stored at a particular place in the hardware. For example, an LL might map +// the tuple (thread-id, warp-id, block-id) to a 2D index into a tensor, (x,y), +// meaning that the register at (t,w,b) has value tensor[x,y]. Or it might map +// from a shared memory (offset, block) to a tensor index. +// +// It might seem more natural to go the other way around, from tensor index to +// place in the hardware. But a particular tensor[x,y] value might be stored in +// more than one place in the hardware, so if we went in this direction, the +// layout would no longer be a proper function. This would complicate +// everything else. +// +// # Optional mathematical background: Linear functions over GF(2) +// +// (You shouldn't need to understand this math to use linear layouts, but it +// helps with the implementation.) +// +// One way to define a linear function is to say it's any function F that can be +// written as +// +// L(a) = a1 * B1 + a2 * B2 + ... + aM * BM, +// +// where +// +// - a is a vector [a1...aM], and ai is a scalar in some field 𝔽 (for +// example, ai might be a real number), and +// - each Bj is a vector [b1j, b1j, ..., bNj] of N scalars in 𝔽. +// +// We can also write this as a matrix-vector product Ba, where +// +// - a is the column vector [a1, ..., aM] and +// +// - B is the matrix formed by concatenating the column vectors B1, ..., BM: +// +// | ↑ ↑ ↑ | +// B = | B1, B2, ..., BM| +// | ↓ ↓ ↓ | +// +// |b11, b12, ..., b1M| +// |b21, b22, ..., b2M| +// = | ↓ ↓ ↓ | +// |bN1, bN2, ..., bNM|. +// +// Usually when we do linear algebra, the field 𝔽 from which `ai` and `bij` are +// drawn is the real or complex numbers. But in linear layouts, we let 𝔽 be a +// different field: GF(2). +// +// GF(2) is the two-element field of bits. To define a field, I need to give +// you the set of elements and also addition and multiplication operations. For +// GF(2) the elements are simply {0,1}. We define addition as xor, and +// multiplication as binary `and`. +// +// Here's an example of a 4x4 matrix-vector multiply where the elements are in +// GF(2). I'm using ⊕ to represent GF(2)'s addition operation (i.e xor) and × +// to represent multiplication (i.e. binary `and`). +// +// | 1 0 0 0 | | 0 | | 1 | | 0 | | 0 | | 0 | +// | 0 1 1 0 | | 1 | = | 0 | × 0 ⊕ | 1 | × 1 ⊕ | 1 | × 1 ⊕ | 0 | × 0 +// | 0 0 1 1 | | 1 | | 0 | | 0 | | 1 | | 1 | +// | 0 0 1 1 | | 0 | | 0 | | 0 | | 1 | | 1 | +// +// | 0 | | 0 | +// = | 1 | ⊕ | 1 | +// | 0 | | 1 | +// | 0 | | 1 | +// +// | 0 | +// = | 0 |. +// | 1 | +// | 1 | +// +// This works, but it's cumbersome. It's more compact to think of the vector +// `a` as an M-bit integer, and each column Bi of the matrix B as an N-bit +// integer. Here's the same matrix-vector product written this way. +// +// = | 1 2 14 12 | × 6 +// = | 1 2 14 12 | × 0b0110 +// = (1 × 0) ⊕ (2 × 1) ⊕ (14 × 1) ⊕ (12 × 0) +// = 2 ⊕ 14 +// = 12. +// +// And we confirm that our answer of 12 is equal to the binary value 0b1100 we +// got before. +// +// Notice that the function F(a) is fully specified by the matrix B, and that +// the four columns of B tell us the values of F at power-of-two values for `a`, +// namely F(1), F(2), F(4), and F(8). In other words, we specify four results +// of F(x) (we call these the function's "basis vectors" or its "bases") and we +// can then compute any other value by xor'ing together subsets of the bases. +// +// In the case of a 1D -> 1D layout, the implementation of an LL is +// straightforward from the mathematical description. If the LL is +// higher-dimensional, we can "stack" the bit vectors to create 1D vectors. +// For example, if we have a 2D LL and we're given input tuple (0b0011, 0b1100), +// we can treat this like a 1D input 0b0011'1100 and then do the regular 1D LL +// computation. Similarly we can "unstack" the output from 1D to ND. +// +// The linearity rule presented earlier is perhaps misleading at this point. In +// the 1D view of things, we really only need +// +// L(x ⊕ y) = L(x) ⊕ L(y) (1D linearity rule), +// +// which is part of the definition of L being a linear function. The new 1D +// linearity rule plus stacking/unstacking is equivalent to the earlier +// N-dimensional linearity rule. +// +// That's all we need in order to define linear layouts mathematically! +// +// # Comaprison to Nvidia CuTe +// +// (Note, I'm not an expert on CuTe; this is my best understanding.) +// +// CuTe is a programmatic layout system that's part of Nvidia CUTLASS; see +// https://github.com/NVIDIA/cutlass/blob/629f465/media/docs/cute/00_quickstart.md +// +// LLs and CuTe solve similar problems. Before CuTe, CUTLASS v2 had many +// handcrafted layouts, "RowMajor", "VoltaTensorOpMultiplicandCongruous", etc, +// see https://www.youtube.com/watch?v=QLdUML5MCfE&t=574s. Each of these was a +// special case. CUTLASS v3 introduced CuTe layouts, which are programmable and +// subsume all of these special cases. The CUTLASS folks say this simplified +// CUTLASS, in the same way that we hope LLs will simplify Triton. +// +// Like CuTe layouts, LLs are also programmable and composible. But there are +// also some differences. +// +// - Dimensions in LLs are named; CuTe dimensions are numbered. +// - CuTe layouts can be nested; LLs cannot be. (Nesting doesn't give CuTe +// layouts additional power; any nested layout can be flattened.) +// - CuTe layouts support non-power-of-two shapes; LLs do not. In particular +// this means that LLs cannot represent padded layouts. +// - In CuTe, swizzling is a separate step applied after specifying a layout. +// In LLs, swizzling is part of the layout itself. +// - The structure of LLs allows us to programmatically search for layouts that +// satisfy certain requirements, for example a shared layout that doesn't +// have bank conflicts when read into a particular register layout. CuTe +// expects a human to choose the layout using their brain. +// - CuTe emits code that is in the critical path of your CPU and GPU programs, +// therefore it needs to be fast. It uses C++ template magic to specialize +// on known-sized dimensions, and so on. LLs themselves do not need to be +// fast; only the emitted `apply` code is on the critical path. +// - CuTe requires a CUDA compiler such as nvcc; LLs do not. +// +class LinearLayout { +private: + // bases[inDim][i] = L(0, ..., inDim=2^i, ..., 0). All other values of L are + // computed by xor'ing bases together, using the linearity rule. In addition: + // + // - Each inDim has the same set of outDims, in the same order. + // - The order of dims is minor-to-major, although this only affects reshape. + llvm::MapVector /*size=getNumOutDims()*/> + /*size=getInDimSizeLog2(inDim)*/> + bases; + + llvm::SetVector outDimNames; + +public: + using BasesT = decltype(bases); + + // The 0-dimensional layout that maps everything to 0. This is useful as a + // starting point when doing something like + // + // LinearLayout ret = LinearLayout::empty(); + // for (...) ret *= ...; + // return ret; + static LinearLayout empty() { return LinearLayout(BasesT{}, {}); } + + // Creates a 1D -> 1D layout that's the identity function, i.e. L(x) = x + // for x in [0, size). + static LinearLayout identity1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0 + // for x in [0, size). + static LinearLayout zeros1D(int32_t size, StringAttr inDim, + StringAttr outDim); + + // Creates a LinearLayout from a list of bases. These are interpreted + // according to the rules written for the member variable `bases`. + explicit LinearLayout(BasesT bases, ArrayRef outDimNames); + + // Construct a LinearLayout from an explicit list of bases. (This constructor + // is needed because llvm::MapVector does not have a constructor that accepts + // an initializer_list.) + // + // For example, given these bases + // + // L(in1=1, in2=0) = (out1=0, out2=1) + // L(in1=2, in2=0) = (out1=0, out2=2) + // L(in1=0, in2=1) = (out1=0, out2=4) + // L(in1=0, in2=2) = (out1=0, out2=8) + // L(in1=0, in2=4) = (out1=1, out2=1) + // + // we can use this constructor to build an equivalent LL: + // + // LinearLayout({ + // {"in1", {/*L(in1=1)=*/{0,1}, /*L(in1=2)=*/{0,2}}}, + // {"in2", {/*L(in2=1)=*/{0,4}, /*L(in2=2)=*/{0,8}, /*L(in2=4)=*/{1,1}}}, + // }, + // {"out1", "out2"}) + explicit LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames); + + const BasesT &getBases() const { return bases; } + + // Get the pos'th basis vector for the inDim -> outDim mapping. + // getBasis(inDim, pos) = L(0, ..., inDim = 2^pos, ..., 0). + ArrayRef getBasis(StringAttr inDim, int32_t pos) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + assert(pos < it->second.size()); + return it->second[pos]; + } + + int32_t getBasis(StringAttr inDim, int32_t pos, StringAttr outDim) const { + return getBasis(inDim, pos)[getOutDimIndex(outDim)]; + ; + } + + // These are in minor-to-major order, although if you don't flatten the dims + // (e.g. by reshaping) then the order doesn't really affect anything. + auto getInDimNames() const { return llvm::make_first_range(bases); } + ArrayRef getOutDimNames() const { + return outDimNames.getArrayRef(); + } + + // Gets the position that this outDim occupies in getOutDimNames(). Asserts + // if the dim is not present. + int32_t getOutDimIndex(StringAttr outDim) const; + + bool hasInDim(StringAttr inDim) const { return bases.contains(inDim); } + bool hasOutDim(StringAttr outDim) const { + return outDimNames.contains(outDim); + } + + int32_t getNumInDims() const { return bases.size(); } + int32_t getNumOutDims() const { return outDimNames.size(); } + + // Asserts if the dimension is not present. + int32_t getInDimSizeLog2(StringAttr inDim) const; + int32_t getInDimSize(StringAttr inDim) const { + return 1 << getInDimSizeLog2(inDim); + } + + // getOutDimSize(dim) == s means that there exists an input value that will + // produce each output value in [0,s). + // + // For example, if our bases are + // + // L(in0=1) = 1 + // L(in0=2) = 4 + // L(in1=1) = 2 + // L(in1=2) = 8 + // + // then the largest value we can produce is L(3,3) = 1 ⊕ 4 ⊕ 2 ⊕ 8 = 15 (and + // indeed we can produce all values in [0,16) by xor'ing subsets of the bases + // 1,2,4,8), so getOutDimSize(out_dim0) == 16. + // + // Asserts if the dimension is not present. + int32_t getOutDimSizeLog2(StringAttr outDim) const; + int32_t getOutDimSize(StringAttr outDim) const { + return 1 << getOutDimSizeLog2(outDim); + } + + // Reorders the in/out dimensions of the layout. This is mostly cosmetic + // (affecting e.g. the order of getIn/OutDimNames), but it also affects the + // behavior of reshape. + [[nodiscard]] LinearLayout + transposeIns(ArrayRef newInDimOrder) const; + [[nodiscard]] LinearLayout + transposeOuts(ArrayRef newOutDimOrder) const; + + // Creates a new layout which, roughly speaking, is equivalent to one where + // every element of the `outer` layout is replaced by a full instance of the + // `inner` layout. + // + // Examples: + // + // - empty() is the multiplicative identity: + // + // L * empty() == empty() * L == L. + // + // - Multiplying two identity1D layouts with disjoint in/out dimensions gives + // a 2D identity layout: + // + // identity1D(4, "i1", "o1") * identity1D(8, "i2", "o2") => + // L(i1,i2) = (i1,i2), + // + // with in-dims ("i1", "i2") and out-dims ("o1", "o2"), in that order. + // + // - If out-dims overlap, they are combined, as in the following examples. + // + // - identity1D(4, "i", "o") * identity1D(2, "i", "o") == + // identity1D(8, "i", "o") + // + // - identity1D(4, "i", "o") * zeros1D(2, "i", "o") => L(x) = x % 4 + // for x in [0,8). + // + // - zeros1D(2, "i", "o") * identity1D(4, "i", "o") => L(x) = x / 2 + // for x in [0,8). + // + // - identity1D(4, "i", "o1") * identity1D(8, "i", "o2") => + // L(x) = (x % 4, x / 4) for x in [0,32). + // + // Notice that this operation is not commutative. It's also not associative. + // TODO(jlebar): Can I modify the definition to make it associative? Pretty + // confusing if not. If I can't, add an example. + // + // Requires: Any in/out dimensions which are in both outer and inner appear in + // the same relative order. + friend LinearLayout operator*(LinearLayout inner, LinearLayout outer); + LinearLayout &operator*=(LinearLayout outer) { + *this = *this * outer; + return *this; + } + + // Computes and returns L(x, y, z). + // + // If you want to apply the layout to mlir Values instead of integers, that + // function lives in TritonGPUToLLVM/Utility.h. + SmallVector> + apply(ArrayRef> ins) const; + + // Creates a new layout which is equivalent to running this layout, then + // running `outer`. That is, + // + // - let this layout be L(x), and + // - let `outer` be O(x). + // - Then compose(outer) returns the layout (O∘L)(x), aka O(L(x)). + // + // Requires: The output dimensions of this layout equal the input dimensions + // of outer (order doesn't matter). + [[nodiscard]] LinearLayout compose(const LinearLayout &outer) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeIns( + // std::vector> + // newInDims) const; + + // TODO(jlebar): Not yet implemented. + // [[nodiscard]] LinearLayout reshapeOuts( + // std::vector> + // newOutDims) const; + + std::string toString() const; + + friend bool operator==(LinearLayout lhs, LinearLayout rhs); + friend bool operator!=(LinearLayout lhs, LinearLayout rhs) { + return !(lhs == rhs); + } +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) { + os << layout.toString(); + return os; +} + +} // namespace mlir::triton + +#endif diff --git a/third_party/xpu/include/triton/Tools/StrUtil.h b/third_party/xpu/include/triton/Tools/StrUtil.h new file mode 100644 index 000000000..8b59f7d2b --- /dev/null +++ b/third_party/xpu/include/triton/Tools/StrUtil.h @@ -0,0 +1,54 @@ +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::triton { + +// Better version of llvm::join. This one works when T is an integer or any +// other type which defines operator<<(raw_ostream). +template +std::string join(C &&container, llvm::StringRef sep = ", ") { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + s << elem; + } + return ret; +} + +// Joins a container of elements into a string, using `sep` as a separator. +// +// fn is called to transform each element of the container before it's added to +// the string. fn must have one of the following two signatures. +// +// - void fn(llvm::raw_ostream&, E), where E is the element type of the +// container, or +// - T fn(E), where T is a type which can be passed to +// raw_ostream::operator<<. +// +template +std::string join(C &&container, llvm::StringRef sep, Fn &&fn) { + std::string ret; + llvm::raw_string_ostream s(ret); + for (const auto &elem : container) { + if (!ret.empty()) + s << sep; + + if constexpr (std::is_invocable_v) { + static_assert( + std::is_void_v< + std::invoke_result_t>); + fn(s, elem); + } else { + s << fn(elem); + } + } + return ret; +} + +} // namespace mlir::triton diff --git a/third_party/xpu/include/triton/Tools/Sys/GetEnv.hpp b/third_party/xpu/include/triton/Tools/Sys/GetEnv.hpp new file mode 100644 index 000000000..b4c36278a --- /dev/null +++ b/third_party/xpu/include/triton/Tools/Sys/GetEnv.hpp @@ -0,0 +1,84 @@ +#ifndef TRITON_TOOLS_SYS_GETENV_HPP +#define TRITON_TOOLS_SYS_GETENV_HPP + +#include +#include +#include +#include +#include +#include + +namespace mlir::triton { + +inline const std::set CACHE_INVALIDATING_ENV_VARS = { + // clang-format off + "AMDGCN_ENABLE_DUMP", + "DISABLE_FAST_REDUCTION", + "DISABLE_LLVM_OPT", + "DISABLE_MMA_V3", + "DISABLE_PTXAS_OPT", + "LLVM_IR_ENABLE_DUMP", + "LLVM_ENABLE_TIMING", + "MLIR_ENABLE_DIAGNOSTICS", + "MLIR_ENABLE_DUMP", + "MLIR_ENABLE_TIMING", + "TRITON_DISABLE_LINE_INFO", + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE", + "TRITON_ENABLE_LLVM_DEBUG", + "TRITON_LLVM_DEBUG_ONLY", + "USE_TTGIR_LOC", + "NVPTX_ENABLE_DUMP", + "TRITONXPU_BF16_ROUND_MID", + "LLVM_ERROR_LM_SIZE", + "TRITON_TUNE_BUFFER_LM_SIZE" + // clang-format on +}; + +inline const std::set CACHE_NEUTRAL_ENV_VARS = { + "TRITON_REPRODUCER_PATH", +}; + +namespace tools { + +inline void assertIsRecognized(const std::string &env) { + bool is_invalidating = CACHE_INVALIDATING_ENV_VARS.find(env.c_str()) != + CACHE_INVALIDATING_ENV_VARS.end(); + bool is_neutral = + CACHE_NEUTRAL_ENV_VARS.find(env.c_str()) != CACHE_NEUTRAL_ENV_VARS.end(); + std::string errmsg = env + "is not recognized. " + "Please add it to triton/tools/sys/getenv.hpp"; + assert((is_invalidating || is_neutral) && errmsg.c_str()); +} + +inline std::string getStrEnv(const std::string &env) { + assertIsRecognized(env); + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return ""; + std::string result(cstr); + return result; +} + +// return value of a cache-invalidating boolean environment variable +inline bool getBoolEnv(const std::string &env) { + assertIsRecognized(env); + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return str == "on" || str == "true" || str == "1"; +} + +inline std::optional isEnvValueBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (str == "on" || str == "true" || str == "1") + return true; + if (str == "off" || str == "false" || str == "0") + return false; + return std::nullopt; +} +} // namespace tools +} // namespace mlir::triton + +#endif diff --git a/third_party/xpu/lib/Analysis/Alias.cpp b/third_party/xpu/lib/Analysis/Alias.cpp new file mode 100644 index 000000000..5b3910013 --- /dev/null +++ b/third_party/xpu/lib/Analysis/Alias.cpp @@ -0,0 +1,65 @@ +#include "triton/Analysis/Alias.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +namespace mlir { + +AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { + if (lhs == rhs) + return lhs; + AliasInfo ret; + for (auto value : lhs.allocs) { + ret.insert(value); + } + for (auto value : rhs.allocs) { + ret.insert(value); + } + return ret; +} + +void SharedMemoryAliasAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + AliasInfo aliasInfo; + bool pessimistic = true; + // These ops may allocate a new shared memory buffer. + auto result = op->getResult(0); + + // Only LocalAllocOp creates a new buffer. + if (isa(op)) { + aliasInfo.insert(result); + pessimistic = false; + } else if (isa(op)) { + // extract_slice %src + // trans %src + aliasInfo = AliasInfo(operands[0]->getValue()); + pessimistic = false; + } else { + assert(!isa(result.getType()) && + "unknown operation creating memory descriptor"); + } + + if (pessimistic) { + return setAllToEntryStates(results); + } + // Join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(aliasInfo)); +} + +AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { + // TODO: implement + return AliasResult::MayAlias; +} + +ModRefResult SharedMemoryAliasAnalysis::getModRef(Operation *op, + Value location) { + // TODO: implement + return ModRefResult::getModAndRef(); +} + +} // namespace mlir diff --git a/third_party/xpu/lib/Analysis/Allocation.cpp b/third_party/xpu/lib/Analysis/Allocation.cpp new file mode 100644 index 000000000..aafc30eb3 --- /dev/null +++ b/third_party/xpu/lib/Analysis/Allocation.cpp @@ -0,0 +1,651 @@ +#include "triton/Analysis/Allocation.h" + +#include +#include +#include + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" + +using ::mlir::triton::gpu::AMDMfmaEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getUniqueContigPerThread; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +namespace mlir { + +//===----------------------------------------------------------------------===// +// Shared Memory Allocation Analysis +//===----------------------------------------------------------------------===// +namespace triton { + +// Bitwidth of pointers +constexpr int kPtrBitWidth = 64; + +static std::pair, SmallVector> +getCvtOrder(Attribute srcLayout, Attribute dstLayout) { + auto srcMmaLayout = mlir::dyn_cast(srcLayout); + auto srcDotLayout = mlir::dyn_cast(srcLayout); + auto dstMmaLayout = mlir::dyn_cast(dstLayout); + auto dstDotLayout = mlir::dyn_cast(dstLayout); + + assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && + "mma -> mma layout conversion is only supported on Ampere"); + + // mma or dot layout does not have an order, so the order depends on the + // layout of the other operand. + auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout) + : getOrder(srcLayout); + auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout) + : getOrder(dstLayout); + + return {inOrd, outOrd}; +} + +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (shouldUseDistSmem(srcLayout, dstLayout)) { + // TODO: padding to avoid bank conflicts + return convertType(getShapePerCTA(srcTy)); + } + + if (isMfmaToDotShortcut(srcTy, dstTy)) + return {}; + + // MmaToDotShortcut and MmaToMmaShortcut doesn't use shared mem + if (auto srcMmaLayout = mlir::dyn_cast(srcLayout)) { + if (mlir::isa(dstLayout)) { + if (isMmaToDotShortcut(srcTy, dstTy)) { + return {}; + } + } else if (auto dstMmaLayout = + mlir::dyn_cast(dstLayout)) { + if (isMmaToMmaShortcut(srcTy, dstTy)) { + return {}; + } + } + } + + assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()"); + + auto srcShapePerCTA = getShapePerCTA(srcTy); + auto dstShapePerCTA = getShapePerCTA(dstTy); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); + + unsigned rank = dstTy.getRank(); + SmallVector repShape(rank); + for (unsigned d = 0; d < rank; ++d) { + repShape[d] = + std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), + std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); + } + return repShape; +} + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec) { + auto repShape = getRepShapeForCvtLayout(op); + if (repShape.empty()) + return repShape; + auto rank = repShape.size(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + assert(!isMfmaToDotShortcut(srcTy, dstTy)); + + auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); + unsigned srcContigPerThread = + getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; + unsigned dstContigPerThread = + getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means + // that we cannot do vectorization. + unsigned innerDim = rank - 1; + inVec = outOrd[0] != innerDim ? 1 + : inOrd[0] != innerDim ? 1 + : srcContigPerThread; + outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; + + // For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the + // codegen. + if (auto mma = mlir::dyn_cast(srcLayout)) { + if (mma.getVersionMajor() == 1) { + inVec = srcContigPerThread; + } else if (mlir::isa(dstLayout)) { + // when storing from mma layout and loading in blocked layout vectorizing + // the load back gives better performance even if there is a + // transposition. + outVec = dstContigPerThread; + } + } + + if (rank <= 1) + return repShape; + // pad the last dimension + unsigned paddedDim = rank - 1; + if (auto dstBlockedLayout = mlir::dyn_cast(dstLayout)) { + paddedDim = dstBlockedLayout.getOrder()[0]; + } + unsigned pad = std::max(inVec, outVec); + repShape[paddedDim] += pad; + return repShape; +} + +// TODO: extend beyond scalars +SmallVector getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) { + SmallVector smemShape; + if (isa(op.getPtr().getType())) { + // do nothing or just assert because shared memory is not used in tensor up + // to now + } else { + // need only bytes for scalar + // always vec = 1 and elemsPerThread = 1 for scalar? + smemShape.push_back(1); + } + return smemShape; +} + +SmallVector getScratchConfigForAtomicCAS(triton::AtomicCASOp op) { + return SmallVector{1}; +} + +class AllocationAnalysis { +public: + AllocationAnalysis(Operation *operation, + Allocation::FuncAllocMapT *funcAllocMap, + Allocation *allocation) + : operation(operation), funcAllocMap(funcAllocMap), + allocation(allocation) { + run(); + } + +private: + using BufferT = Allocation::BufferT; + + /// Value -> Liveness Range + /// Use MapVector to ensure determinism. + using BufferRangeMapT = llvm::MapVector>; + /// Nodes -> Nodes + using GraphT = DenseMap>; + + void run() { + getValuesAndSizes(); + resolveLiveness(); + computeOffsets(); + } + + /// Initializes explicitly defined shared memory values for a given operation. + void getExplicitValueSize(Operation *op) { + // Values returned from scf.yield will not be allocated even though they + // have the shared encoding. + // For example: %a = scf.if -> yield + // %a must be allocated elsewhere by other operations. + // FIXME(Keren): extract and insert are always alias for now + if (!maybeSharedAllocationOp(op)) + return; + + // XXX(Keren): Why this hard-coded alignment? + size_t kAlignment = 8; + for (Value result : op->getResults()) { + if (auto alloc = result.getDefiningOp()) { + // Bytes could be a different value once we support padding or other + // allocation policies. + auto allocType = alloc.getType(); + auto shapePerCTA = triton::gpu::getShapePerCTA(allocType); + auto bytes = product(shapePerCTA) * + allocType.getElementTypeBitWidth() / 8; + + // XXX(Keren): magic numbers 256 and 1024 + // benzh@maybe alignment should be passed in. + // Software swizzling calculates phase based on offset, while hardware + // swizzling do that based on physical address. Thus only by setting the + // alignment to 1024 can ensure the correctness.  + if (bytes > 256) + kAlignment = 1024; + allocation->addBuffer(result, bytes, + kAlignment); + } + } + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes, + unsigned alignment) { + if (bytes > 0) + allocation->addBuffer(op, bytes, alignment); + } + + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes) { + if (bytes > 0) + allocation->addBuffer(op, bytes); + } + + /// Initializes temporary shared memory for a given operation. + void getScratchValueSize(Operation *op) { + const size_t scratchAlignment = 128; + // TODO[dyq]: can we use an alternative method to bypass allocation? + if (auto xpuReduceOp = dyn_cast(op)) { + ReduceOpHelper helper(xpuReduceOp); + unsigned bytes = helper.getXPUScratchSizeInBytes(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + unsigned bytes = helper.getScratchSizeInBytes(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto scanOp = dyn_cast(op)) { + ScanLoweringHelper helper(scanOp); + unsigned bytes = helper.getScratchSizeInBytes(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto histogram = dyn_cast(op)) { + auto dstTy = histogram.getType(); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + auto bytes = std::max(dstTy.getNumElements(), threadsPerWarp) * + std::max(8, dstTy.getElementTypeBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.getSrc().getType(); + auto dstTy = cvtLayout.getType(); + auto srcEncoding = srcTy.getEncoding(); + auto dstEncoding = dstTy.getEncoding(); + if (mlir::isa(srcEncoding) || + mlir::isa(dstEncoding)) { + // Conversions from/to shared memory do not need scratch memory. + return; + } + // ConvertLayoutOp with both input/output non-shared_layout + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's + // also possible to realize it with other approaches in restricted + // conditions, such as warp-shuffle + unsigned inVec = 0; + unsigned outVec = 0; + auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto bytes = + isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } else if (auto atomicRMWOp = dyn_cast(op)) { + auto value = op->getOperand(0); + // only scalar requires scratch memory + // make it explicit for readability + if (dyn_cast(value.getType())) { + // nothing to do + } else { + auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto elemTy = + cast(value.getType()).getPointeeType(); + auto bytes = + isa(elemTy) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } else if (auto atomicCASOp = dyn_cast(op)) { + // only scalar requires scratch memory + // make it explicit for readability + auto value = op->getOperand(0); + if (dyn_cast(value.getType())) { + // nothing to do + } else { + auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp); + unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, + std::multiplies{}); + auto elemTy = + cast(value.getType()).getPointeeType(); + auto bytes = isa(elemTy) + ? elems * kPtrBitWidth / 8 + : elems * elemTy.getIntOrFloatBitWidth() / 8; + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } else if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto funcOp = dyn_cast(callable); + auto *funcAlloc = &(*funcAllocMap)[funcOp]; + auto bytes = funcAlloc->getSharedMemorySize(); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); + } + } + + void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { + dataflow::Lattice *latticeElement = + analysis.getLatticeElement(value); + if (latticeElement) { + AliasInfo &info = latticeElement->getValue(); + if (!info.getAllocs().empty()) { + for (auto alloc : info.getAllocs()) { + allocation->addAlias(value, alloc); + } + } + } + } + + /// Extract all shared memory values and their sizes + void getValuesAndSizes() { + // Get the alloc values + operation->walk([&](Operation *op) { + getExplicitValueSize(op); + getScratchValueSize(op); + }); + // Get the alias values + std::unique_ptr solver = createDataFlowSolver(); + SharedMemoryAliasAnalysis *aliasAnalysis = + solver->load(); + if (failed(solver->initializeAndRun(operation))) { + // TODO: return error instead of bailing out.. + llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); + } + operation->walk([&](Operation *op) { + for (auto operand : op->getOperands()) { + getValueAlias(operand, *aliasAnalysis); + } + for (auto value : op->getResults()) { + getValueAlias(value, *aliasAnalysis); + } + }); + } + + /// Computes the liveness range of the allocated value. + /// Each buffer is allocated only once. + void resolveExplicitBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto valueBufferIter : allocation->valueBuffer) { + auto value = valueBufferIter.first; + auto *buffer = valueBufferIter.second; + bufferRange[buffer] = getLiveness(value); + } + } + + /// Extends the liveness range by unionizing the liveness range of the aliased + /// values because each allocated buffer could be an alias of others, if block + /// arguments are involved. + void resolveAliasBufferLiveness( + function_ref(Value value)> getLiveness) { + for (auto aliasBufferIter : allocation->aliasBuffer) { + auto value = aliasBufferIter.first; + auto buffers = aliasBufferIter.second; + auto range = getLiveness(value); + for (auto *buffer : buffers) { + auto minId = range.start(); + auto maxId = range.end(); + if (bufferRange.count(buffer)) { + // Extend the allocated buffer's range + minId = std::min(minId, bufferRange[buffer].start()); + maxId = std::max(maxId, bufferRange[buffer].end()); + } + bufferRange[buffer] = Interval(minId, maxId); + } + } + } + + /// Computes the liveness range of scratched buffers. + /// Some operations may have a temporary buffer that is not explicitly + /// allocated, but is used to store intermediate results. + void resolveScratchBufferLiveness( + const DenseMap &operationId) { + // Analyze liveness of scratch buffers and virtual buffers. + auto processScratchMemory = [&](const auto &container) { + for (auto opScratchIter : container) { + // Any scratch memory's live range is the current operation's live + // range. + auto *op = opScratchIter.first; + auto *buffer = opScratchIter.second; + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } + }; + processScratchMemory(allocation->opScratch); + processScratchMemory(allocation->opVirtual); + } + + /// Resolves liveness of all values involved under the root operation. + void resolveLiveness() { + // Assign an ID to each operation using post-order traversal. + // To achieve the correct liveness range, the parent operation's ID + // should be greater than each of its child operation's ID . + // Example: + // ... + // %5 = triton.convert_layout %4 + // %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) { + // %2 = triton.convert_layout %5 + // ... + // scf.yield %arg0 + // } + // For example, %5 is defined in the parent region and used in + // the child region, and is not passed as a block argument. + // %6 should should have an ID greater than its child operations, + // otherwise %5 liveness range ends before the child operation's liveness + // range ends. + DenseMap operationId; + operation->walk( + [&](Operation *op) { operationId[op] = operationId.size(); }); + + // Analyze liveness of explicit buffers + Liveness liveness(operation); + auto getValueLivenessRange = [&](Value value) { + auto liveOperations = liveness.resolveLiveness(value); + auto minId = std::numeric_limits::max(); + auto maxId = std::numeric_limits::min(); + std::for_each(liveOperations.begin(), liveOperations.end(), + [&](Operation *liveOp) { + if (operationId[liveOp] < minId) { + minId = operationId[liveOp]; + } + if ((operationId[liveOp] + 1) > maxId) { + maxId = operationId[liveOp] + 1; + } + }); + return Interval(minId, maxId); + }; + + resolveExplicitBufferLiveness(getValueLivenessRange); + resolveAliasBufferLiveness(getValueLivenessRange); + resolveScratchBufferLiveness(operationId); + } + + /// Computes the shared memory offsets for all related values. + /// Paper: Algorithms for Compile-Time Memory Optimization + /// (https://dl.acm.org/doi/pdf/10.5555/314500.315082) + void computeOffsets() { + SmallVector buffers; + for (auto bufferIter : bufferRange) { + buffers.emplace_back(bufferIter.first); + } + + calculateStarts(buffers); + + // NOTE: The original paper doesn't consider interference between + // the bumped ranges. Buffers that previously do not interfere with + // could interfere after offset bumping if their liveness ranges overlap. + // Therefore, we rerun the interference graph algorithm after bumping so + // that we regroup the buffers and color them again. Since we always + // increase the buffer offset and keep reducing conflicts, we will + // eventually reach a fixed point. + GraphT interference; + buildInterferenceGraph(buffers, interference); + do { + allocate(buffers, interference); + buildInterferenceGraph(buffers, interference); + } while (!interference.empty()); + } + + /// Computes the initial shared memory offsets. + void calculateStarts(const SmallVector &buffers) { + // v = values in shared memory + // t = triplet of (size, start, end) + // shared memory space + // - + // | *******t4 + // | /|\ v2 inserts t4, t5, and t6 + // | | + // | ******t5 ************t6 + // | ^^^^^v2^^^^^^ + // | | *********************t2 + // | \|/ v2 erases t1 + // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 + // |---------------------------------------------| liveness range + // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... + // If the available triple's range is less than a given buffer range, + // we won't know if there has been an overlap without using graph coloring. + // Start -> Liveness Range + using TripleMapT = std::multimap>; + TripleMapT tripleMap; + tripleMap.insert(std::make_pair(0, Interval())); + SmallVector xBuffers = buffers; + while (!xBuffers.empty()) { + auto tripleIt = tripleMap.begin(); + auto offset = tripleIt->first; + auto range = tripleIt->second; + tripleMap.erase(tripleIt); + auto bufferIt = + std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { + auto xRange = bufferRange[buffer]; + bool res = xRange.intersects(range); + for (auto val : tripleMap) + res = res && + !val.second.intersects(xRange); // only one buffer intersect + return res; + }); + if (bufferIt != xBuffers.end()) { + auto buffer = *bufferIt; + auto xSize = buffer->size; + auto xRange = bufferRange.lookup(buffer); + // TODO(Keren): A buffer's size shouldn't be determined here, have to + // clean it up + size_t alignOffset = buffer->setOffsetAligned(offset); + tripleMap.insert({alignOffset + xSize, + Interval{std::max(range.start(), xRange.start()), + std::min(range.end(), xRange.end())}}); + // We could either insert (range.start, xRange.start) or (range.start, + // xRange.end), both are correct and determine the potential buffer + // offset, and the graph coloring algorithm will solve the interference, + // if any + if (range.start() < xRange.start()) + tripleMap.insert({offset, Interval{range.start(), xRange.end()}}); + if (xRange.end() < range.end()) + tripleMap.insert({offset, Interval{xRange.start(), range.end()}}); + xBuffers.erase(bufferIt); + } + } + } + + /// Builds a graph of all shared memory values. Edges are created between + /// shared memory values that are overlapping. + void buildInterferenceGraph(const SmallVector &buffers, + GraphT &interference) { + // Reset interference graph + interference.clear(); + for (auto x : buffers) { + for (auto y : buffers) { + if (x == y) + continue; + auto xStart = x->offset; + auto yStart = y->offset; + auto xSize = x->size; + auto ySize = y->size; + Interval xSizeRange = {xStart, xStart + xSize}; + Interval ySizeRange = {yStart, yStart + ySize}; + auto xOpRange = bufferRange.lookup(x); + auto yOpRange = bufferRange.lookup(y); + if (xOpRange.intersects(yOpRange) && + xSizeRange.intersects(ySizeRange)) { + interference[x].insert(y); + } + } + } + } + + /// Finalizes shared memory offsets considering interference. + void allocate(const SmallVector &buffers, + const GraphT &interference) { + // Reset shared memory size + allocation->sharedMemorySize = 0; + // First-fit graph coloring + // Neighbors are nodes that interfere with each other. + // We color a node by finding the index of the first available + // non-neighboring node or the first neighboring node without any color. + // Nodes with the same color do not interfere with each other. + DenseMap colors; + for (auto value : buffers) { + colors[value] = (value == buffers[0]) ? 0 : -1; + } + SmallVector available(buffers.size()); + for (auto x : buffers) { + std::fill(available.begin(), available.end(), true); + for (auto y : interference.lookup(x)) { + int color = colors[y]; + if (color >= 0) { + available[color] = false; + } + } + auto it = std::find(available.begin(), available.end(), true); + colors[x] = std::distance(available.begin(), it); + } + // Finalize allocation + // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) + // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) + // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) + // TODO(Keren): We are wasting memory here. + // Nodes with color2 can actually start with 24. + for (auto x : buffers) { + size_t newOffset = 0; + for (auto y : interference.lookup(x)) { + newOffset = std::max(newOffset, y->offset + y->size); + } + if (colors.lookup(x) != 0) + x->setOffsetAligned(newOffset); + allocation->sharedMemorySize = + std::max(allocation->sharedMemorySize, x->offset + x->size); + } + } + +private: + Operation *operation; + Allocation::FuncAllocMapT *funcAllocMap; + Allocation *allocation; + BufferRangeMapT bufferRange; +}; + +} // namespace triton + +void Allocation::run(FuncAllocMapT &funcAllocMap) { + triton::AllocationAnalysis(getOperation(), &funcAllocMap, this); +} + +} // namespace mlir diff --git a/third_party/xpu/lib/Analysis/AxisInfo.cpp b/third_party/xpu/lib/Analysis/AxisInfo.cpp new file mode 100644 index 000000000..49d559618 --- /dev/null +++ b/third_party/xpu/lib/Analysis/AxisInfo.cpp @@ -0,0 +1,1316 @@ +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define DEBUG_TYPE "axis-info" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir::triton { +namespace { + +int64_t gcdImpl(int64_t a, int64_t b, int64_t *x, int64_t *y) { + // Base Case + if (a == 0) { + *x = 0; + *y = 1; + return b; + } + int64_t x1, y1; // To store results of recursive call + int64_t gcd = gcdImpl(b % a, a, &x1, &y1); + // Update x and y using results of + // recursive call + *x = y1 - (b / a) * x1; + *y = x1; + return gcd; +} + +int64_t gcd(int64_t a, int64_t b) { + if (a == 0) + return b; + if (b == 0) + return a; + int64_t x, y; + return gcdImpl(a, b, &x, &y); +} + +constexpr int log2Int(int64_t num) { + return (num > 1) ? 1 + log2Int(num / 2) : 0; +} + +// If lhs * rhs overflows, return max value possible value for the type +int64_t multiplyDivisor(int64_t lhs, int64_t rhs) { + int64_t maxDivisor = highestPowOf2Divisor(0); + if (lhs > maxDivisor / rhs) + return maxDivisor; + return lhs * rhs; +} + +class AxisInfoVisitor { +public: + AxisInfoVisitor() = default; + virtual ~AxisInfoVisitor() = default; + + static bool isContiguousDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getContiguity(dim) == shape[dim]; + } + + static bool isConstantDim(const AxisInfo &info, ArrayRef shape, + int dim) { + return info.getConstancy(dim) == shape[dim]; + } + + virtual AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) = 0; + + virtual bool match(Operation *op) = 0; +}; + +// Base class for all operations +template class AxisInfoVisitorImpl : public AxisInfoVisitor { +public: + using AxisInfoVisitor::AxisInfoVisitor; + + AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) final { + return getAxisInfo(cast(op), operands); + } + + bool match(Operation *op) final { return isa(op); } + + virtual AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) = 0; +}; + +// Binary operations +template +class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + assert(operands.size() == 2 && "Expected two operands"); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); + for (auto d = 0; d < rank; ++d) { + if (constantValue.has_value()) { + contiguity.push_back(1); + constancy.push_back( + std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + highestPowOf2Divisor(constantValue.value())); + } else { + contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); + constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); + divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); + } + } + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +protected: + virtual int64_t getContiguity(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getDivisibility(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) { + return 1; + } + + virtual std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) { + return {}; + } +}; + +class AxisInfoVisitorList { +public: + template > + void append() { + (visitors.emplace_back(std::make_unique()), ...); + } + + AxisInfo apply(Operation *op, + ArrayRef *> operands) { + for (auto &visitor : visitors) + if (visitor->match(op)) + return visitor->getAxisInfo(op, operands); + return AxisInfo(); + } + +private: + std::vector> visitors; +}; + +class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + AxisInfoVisitorList visitors; + + void setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, + lattice->join(AxisInfo::getPessimisticValueState(lattice->getPoint()))); + } + + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef *> argLattices, + unsigned firstIndex) override { + if (auto forOp = dyn_cast(op)) { + visitForOpInductionVar(forOp, argLattices); + } else { + setAllToEntryStates(argLattices.take_front(firstIndex)); + setAllToEntryStates(argLattices.drop_front( + firstIndex + successor.getSuccessorInputs().size())); + } + } + +public: + AxisInfoAnalysis(DataFlowSolver &solver); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + using FuncAxisInfoMapT = DenseMap; + + void visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; + void + visitForOpInductionVar(scf::ForOp op, + ArrayRef *> argLattices); +}; + +template +class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + return operands[0]->getValue(); + } +}; + +class MakeRangeOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::MakeRangeOp op, + ArrayRef *> operands) override { + auto start = op.getStart(); + auto end = op.getEnd(); + return AxisInfo(/*contiguity=*/{end - start}, + /*divisibility=*/{highestPowOf2Divisor(start)}, + /*constancy=*/{1}); + } +}; + +template +class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto intAttr = dyn_cast(op.getValue()); + auto boolAttr = dyn_cast(op.getValue()); + if (intAttr || boolAttr) { + int64_t value{}; + if (intAttr) + value = intAttr.getValue().getZExtValue(); + else + value = boolAttr.getValue() ? 1 : 0; + return AxisInfo(/*contiguity=*/{1}, + /*divisibility=*/{highestPowOf2Divisor(value)}, + /*constancy=*/{1}, + /*knownConstantValue=*/{value}); + } + // TODO: generalize to dense attr + auto splatAttr = dyn_cast(op.getValue()); + if (splatAttr && splatAttr.getElementType().isIntOrIndex()) { + int64_t value = splatAttr.template getSplatValue().getZExtValue(); + TensorType ty = cast(splatAttr.getType()); + return AxisInfo( + /*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1), + /*divisibility=*/ + AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)), + /*constancy=*/ + AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end()), + /*knownConstantValue=*/{value}); + } + return AxisInfo(); + } +}; + +template +class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)), + gcd(lhs.getContiguity(dim), rhs.getConstancy(dim))); + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs = k * d_lhs = k * k' * gcd(d_lhs, d_rhs) + // rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs) + // lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) * + // gcd(d_lhs, d_rhs) + auto rhsDivisibility = rhs.getDivisibility(dim); + if constexpr (std::is_same_v) { + // %ptr = addptr %lhs, %rhs + // is equivalent to + // %0 = mul %rhs, %elemSize + // %ptr = add %lhs, %0 + // The result will still be contiguous in terms of elements but not bytes + // For example: + // addptr [16] : !ptr, [0, 1, 2, 3] : i32 -> !ptr + // returns: + // [16, 20, 24, 28] : !ptr + // with element locations: + // [4, 5, 6, 7] + // It is "strided contiguous" with a divisilibity of 16 bytes + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + rhsDivisibility = multiplyDivisor(rhs.getDivisibility(dim), elemSize); + } + return gcd(lhs.getDivisibility(dim), rhsDivisibility); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + return {lhs.getConstantValue().value() + + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() - + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + auto rank = lhs.getRank(); + auto elemSize = std::max( + 1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8); + auto rhsValue = rhs.getConstantValue().value() * elemSize; + return {lhs.getConstantValue().value() + rhsValue}; + } + } + return {}; + } +}; + +class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + // lhs * 1 = lhs + auto lhsContiguity = + rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1 + ? lhs.getContiguity(dim) + : 1; + // 1 * rhs = rhs + auto rhsContiguity = + lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1 + ? rhs.getContiguity(dim) + : 1; + return std::max(lhsContiguity, rhsContiguity); + } + + int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && + !(rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto rhsDivisibility = rhs.getDivisibility(dim); + if (rhs.getContiguity(dim) > 1 && + !(lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + rhsDivisibility = 1; + } + return multiplyDivisor(lhsDivisibility, rhsDivisibility); + } + + std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() * rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs / 1 = lhs + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? lhs.getContiguity(dim) + : 1; + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // Case 1: both lhs and rhs are constants. + auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + // Case 2: lhs contiguous, rhs constant. + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p), + // ..., (d_lhs * k + n) / (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // the minimal constancy is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual constancy. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + constancy = std::max(constancy, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return constancy; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // Case 1: lhs is 0 + if (lhs.getConstantValue().has_value() && + lhs.getConstantValue().value() == 0) + return lhs.getDivisibility(dim); + // Case 2: rhs is 1 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return lhs.getDivisibility(dim); + // otherwise: return 1 + return 1; + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() / rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getContiguity(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + int64_t contiguity = 1; + // lhs contiguous, rhs constant + // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n + // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p + // lhs % rhs = d_lhs * k % (d_rhs * p), (d_lhs * k + 1) % (d_rhs * p), + // ..., (d_lhs * k + n) % (d_rhs * p) + // Because d_lhs % d_rhs = 0 || d_rhs % d_lhs = 0, + // The minimal contiguity is gcd(d_lhs, d_rhs). + // Since gcd(d_lhs, d_rhs) maybe > len(lhs), + // we need to use another gcd to get the actual contiguity. + if (AxisInfoVisitor::isContiguousDim(lhs, shape, dim) && + AxisInfoVisitor::isConstantDim(rhs, shape, dim)) { + contiguity = std::max(contiguity, gcd(lhs.getContiguity(dim), + gcd(lhs.getDivisibility(dim), + rhs.getDivisibility(dim)))); + } + return contiguity; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + // lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k'' + // rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p'' + // lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r + // r must be divisible by gcd(d_lhs, d_rhs) + return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim)); + }; + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + auto shape = resTy.getShape(); + // lhs % 1 = 0 + return rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1 + ? shape[dim] + : gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() % rhs.getConstantValue().value()}; + else if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return {0}; + return {}; + } +}; + +class SplatOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::SplatOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + TensorType retTy = cast(_retTy); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(opInfo.getDivisibility(0)); + constancy.push_back(retTy.getShape()[d]); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::LoadOp op, + ArrayRef *> operands) override { + // If pointers and mask both have constancy properties, those properties + // will also extend to output. + AxisInfo ptrInfo = operands[0]->getValue(); + std::optional maskInfo; + if (operands.size() > 1) { + maskInfo = operands[1]->getValue(); + } + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + + for (int d = 0; d < ptrInfo.getRank(); ++d) { + contiguity.push_back(1); + divisibility.push_back(1); + constancy.push_back( + gcd(ptrInfo.getConstancy(d), + maskInfo.has_value() ? maskInfo->getConstancy(d) : 0)); + } + + return AxisInfo(contiguity, divisibility, constancy); + } +}; + +class ExpandDimsOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::ExpandDimsOp op, + ArrayRef *> operands) override { + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); + AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); + AxisInfo::DimVectorT constancy = opInfo.getConstancy(); + int64_t newDivisibility = 1; + if (opInfo.getConstantValue().has_value()) { + // The tensor is constant, same as ConstantOpAxisInfoVisitor + newDivisibility = highestPowOf2Divisor(opInfo.getConstantValue().value()); + } else if (opInfo.getRank()) { + // Otherwise, calculate the GCD as the new divisibility + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + newDivisibility = + opInfo.getContiguity(0) > 1 ? 1 : opInfo.getDivisibility(0); + for (int d = 1; d < opInfo.getRank(); ++d) { + newDivisibility = + gcd(newDivisibility, + opInfo.getContiguity(d) > 1 ? 1 : opInfo.getDivisibility(d)); + } + } + contiguity.insert(contiguity.begin() + op.getAxis(), 1); + divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility); + constancy.insert(constancy.begin() + op.getAxis(), 1); + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +class BroadcastOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(triton::BroadcastOp op, + ArrayRef *> operands) override { + Type _retTy = *op->result_type_begin(); + Type _opTy = *op->operand_type_begin(); + TensorType retTy = cast(_retTy); + TensorType opTy = cast(_opTy); + ArrayRef retShape = retTy.getShape(); + ArrayRef opShape = opTy.getShape(); + AxisInfo opInfo = operands[0]->getValue(); + AxisInfo::DimVectorT contiguity; + AxisInfo::DimVectorT divisibility; + AxisInfo::DimVectorT constancy; + for (int d = 0; d < retTy.getRank(); ++d) { + contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); + divisibility.push_back(opInfo.getDivisibility(d)); + constancy.push_back(opShape[d] == 1 ? retShape[d] + : opInfo.getConstancy(d)); + } + return AxisInfo(contiguity, divisibility, constancy, + operands[0]->getValue().getConstantValue()); + } +}; + +template +class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto resTy = dyn_cast(op.getType()); + if (!resTy) + return AxisInfo(); + auto shape = resTy.getShape(); + short rank = resTy.getRank(); + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + for (short d = 0; d < rank; ++d) { + int64_t constHint = 1; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + constHint = lhsInfo.getConstancy(d); + constantValue = + compare(getPredicate(op), lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value()) + ? 1 + : 0; + } else { + // Case 1: lhs and rhs are both partial constants + constHint = gcd(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)); + if ((gtPredicate(getPredicate(op)) || lePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(lhsInfo, shape, d)) { + // Case 2: lhs all constant, rhs all contiguous + // NOTE: + // lhs: 4 4 4 4 + // rhs: 4 5 6 7 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs lt rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 1, 1, 1 + // lhs ge rhs: 1, 0, 0, 0 + // lhs gt rhs: 0, 0, 0, 0 + constHint = std::max(constHint, gcd(rhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } else if ((ltPredicate(getPredicate(op)) || + gePredicate(getPredicate(op))) && + AxisInfoVisitor::isConstantDim(rhsInfo, shape, d)) { + // Case 3: lhs all contiguous, rhs all constant + // NOTE + // lhs: 4 5 6 7 + // rhs: 4 4 4 4 + // lhs eq rhs: 1, 0, 0, 0 + // lhs ne rhs: 0, 1, 1, 1 + // lhs le rhs: 1, 0, 0, 0 + // lhs lt rhs: 0, 0, 0, 0 + // lhs gt rhs: 0, 1, 1, 1 + // lhs ge rhs: 1, 1, 1, 1 + constHint = std::max(constHint, gcd(lhsInfo.getContiguity(d), + gcd(lhsInfo.getDivisibility(d), + rhsInfo.getDivisibility(d)))); + } + } + + constancy.push_back(constHint); + divisibility.push_back(1); + contiguity.push_back(1); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } + +private: + static arith::CmpIPredicate getPredicate(arith::CmpIOp op) { + return op.getPredicate(); + } + + static bool gtPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sgt || + predicate == arith::CmpIPredicate::ugt; + } + + static bool gePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sge || + predicate == arith::CmpIPredicate::uge; + } + + static bool ltPredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::slt || + predicate == arith::CmpIPredicate::ult; + } + + static bool lePredicate(arith::CmpIPredicate predicate) { + return predicate == arith::CmpIPredicate::sle || + predicate == arith::CmpIPredicate::ule; + } + + static bool compare(arith::CmpIPredicate predicate, int64_t lhs, + int64_t rhs) { + switch (predicate) { + case arith::CmpIPredicate::eq: + return lhs == rhs; + case arith::CmpIPredicate::ne: + return lhs != rhs; + case arith::CmpIPredicate::slt: + return lhs < rhs; + case arith::CmpIPredicate::sle: + return lhs <= rhs; + case arith::CmpIPredicate::sgt: + return lhs > rhs; + case arith::CmpIPredicate::sge: + return lhs >= rhs; + case arith::CmpIPredicate::ult: + return (uint64_t)lhs < (uint64_t)rhs; + case arith::CmpIPredicate::ule: + return (uint64_t)lhs <= (uint64_t)rhs; + case arith::CmpIPredicate::ugt: + return (uint64_t)lhs > (uint64_t)rhs; + case arith::CmpIPredicate::uge: + return (uint64_t)lhs >= (uint64_t)rhs; + default: + break; + } + llvm_unreachable("unknown comparison predicate"); + } +}; + +template +class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto condConstancy = operands[0]->getValue().getConstancy(); + auto lhsInfo = operands[1]->getValue(); + auto rhsInfo = operands[2]->getValue(); + auto rank = lhsInfo.getRank(); + + AxisInfo::DimVectorT contiguity, divisibility, constancy; + std::optional constantValue; + if (operands[0]->getValue().getConstantValue().has_value()) { + if (operands[0]->getValue().getConstantValue() == 0) { + contiguity = rhsInfo.getContiguity(); + divisibility = rhsInfo.getDivisibility(); + constancy = rhsInfo.getConstancy(); + constantValue = rhsInfo.getConstantValue(); + } else { + contiguity = lhsInfo.getContiguity(); + divisibility = lhsInfo.getDivisibility(); + constancy = lhsInfo.getConstancy(); + constantValue = lhsInfo.getConstantValue(); + } + } else { + // The condition can be either a tensor or i1. + // If i1 is used as the condition, the entire tensor of either + // lhs or rhs is selected. + bool i1Cond = isa(op.getOperand(0).getType()); + for (auto d = 0; d < rank; ++d) { + if (i1Cond) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } else { + constancy.push_back( + std::min(gcd(lhsInfo.getConstancy(d), condConstancy[d]), + gcd(rhsInfo.getConstancy(d), condConstancy[d]))); + contiguity.push_back( + std::min(gcd(lhsInfo.getContiguity(d), condConstancy[d]), + gcd(rhsInfo.getContiguity(d), condConstancy[d]))); + if (contiguity.back() == lhsInfo.getContiguity(d) && + contiguity.back() == rhsInfo.getContiguity(d)) { + // Contiguity not changed + divisibility.push_back( + gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + } else { + // Contiguity changed, we cannot use only divisibility. + // For example, the following example should have contiguity 2 and + // divisibility 2 + // [[0, 1], [4, 5]] + // [[16, 17, 18, 19]] + divisibility.push_back( + std::min(gcd(lhsInfo.getDivisibility(d), contiguity.back()), + gcd(rhsInfo.getDivisibility(d), contiguity.back()))); + } + } + } + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value() && + lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) + constantValue = lhsInfo.getConstantValue(); + } + + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } +}; + +template +class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) { + if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() & + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() | + rhs.getConstantValue().value()}; + } else if constexpr (std::is_same_v) { + return {lhs.getConstantValue().value() ^ + rhs.getConstantValue().value()}; + } + } + return {}; + } +}; + +class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + auto shift = rhs.getConstantValue().has_value() + ? rhs.getConstantValue().value() + : rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto numBits = log2Int(lhsDivisibility); + return multiplyDivisor(lhsDivisibility, 1 << shift); + } + + int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs, int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(arith::ShLIOp op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() << rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { +public: + using BinaryOpVisitorImpl::BinaryOpVisitorImpl; + +private: + int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 0) + return lhs.getContiguity(dim); + else + return 1; + } + + int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + auto shift = rhs.getConstantValue().has_value() + ? rhs.getConstantValue().value() + : rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + return std::max(1, lhsDivisibility / (1 << shift)); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, + int dim) override { + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + } + + std::optional getConstantValue(OpTy op, const AxisInfo &lhs, + const AxisInfo &rhs) override { + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value()) + return {lhs.getConstantValue().value() >> rhs.getConstantValue().value()}; + return {}; + } +}; + +template +class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { + auto lhsInfo = operands[0]->getValue(); + auto rhsInfo = operands[1]->getValue(); + auto rank = lhsInfo.getRank(); + std::optional constantValue; + if (lhsInfo.getConstantValue().has_value() && + rhsInfo.getConstantValue().has_value()) { + if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::max(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } else if constexpr (std::is_same_v || + std::is_same_v) { + constantValue = {std::min(lhsInfo.getConstantValue().value(), + rhsInfo.getConstantValue().value())}; + } + return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1), + /*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1), + /*knownConstancy=*/AxisInfo::DimVectorT(rank, 1), + /*constantValue=*/constantValue); + } else { + AxisInfo::DimVectorT contiguity, divisibility, constancy; + for (auto d = 0; d < rank; ++d) { + constancy.push_back( + std::min(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); + divisibility.push_back( + std::min(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); + contiguity.push_back( + std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); + } + return AxisInfo(contiguity, divisibility, constancy, std::nullopt); + } + } +}; + +//===----------------------------------------------------------------------===// +// AxisInfoAnalysis +//===----------------------------------------------------------------------===// + +AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) + : dataflow::SparseForwardDataFlowAnalysis>( + solver) { + // UnrealizedConversionCast: + // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is + // in the process of a PartialConversion, where UnrealizedConversionCast + // may exist + visitors.append, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor>(); + // TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp + // when scf.for supports integer induction variables + visitors.append(); + visitors.append, + ConstantOpAxisInfoVisitor>(); + visitors.append, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor, + AddSubOpAxisInfoVisitor>(); + visitors.append(); + visitors.append, + DivOpAxisInfoVisitor>(); + visitors.append, + RemOpAxisInfoVisitor>(); + visitors.append(); + visitors.append(); + visitors.append(); + visitors.append>(); + visitors.append, + LogicalOpAxisInfoVisitor, + LogicalOpAxisInfoVisitor>(); + visitors.append>(); + visitors.append, + ShROpAxisInfoVisitor>(); + visitors.append, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor, + MaxMinOpAxisInfoVisitor>(); + visitors.append(); +} + +void AxisInfoAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + // TODO: For sure not the right way to do this + // but why is scf.if not initialized otherwise? + for (auto op : operands) + if (op->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op); + AxisInfo curr = visitors.apply(op, operands); + if (curr.getRank() == 0) + return setAllToEntryStates(results); + // override with hint + auto newContiguity = curr.getContiguity(); + auto newDivisibility = curr.getDivisibility(); + auto newConstancy = curr.getConstancy(); + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + curr = AxisInfo(newContiguity, newDivisibility, newConstancy, + curr.getConstantValue()); + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(curr)); +} + +void AxisInfoAnalysis::visitForOpInductionVar( + scf::ForOp op, ArrayRef *> argLattices) { + auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue(); + auto step = getLatticeElementFor(op, op.getStep())->getValue(); + + AxisInfo::DimVectorT knownContiguity(1, 1); + AxisInfo::DimVectorT knownDivisibility(1, 1); + AxisInfo::DimVectorT knownConstancy(1, 1); + knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0)); + auto inductionVar = + AxisInfo(knownContiguity, knownDivisibility, knownConstancy); + (void)argLattices[0]->join(inductionVar); +} + +} // anonymous namespace + +template +void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp, + DimVectorT *contiguity, + DimVectorT *divisibility, + DimVectorT *constancy) { + // liast of attributes that we care about + SmallVector> retVecs; + retVecs.push_back({contiguity, "tt.contiguity"}); + retVecs.push_back({divisibility, "tt.divisibility"}); + retVecs.push_back({constancy, "tt.constancy"}); + // initialize attributes one by one + for (auto [vec, attrName] : retVecs) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto int_attr = dyn_cast_or_null(attr)) + *vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue()); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + *vec = DimVectorT(vals.begin(), vals.end()); + } + } +} + +/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) { + auto rank = 1; + if (TensorType ty = dyn_cast(value.getType())) + rank = ty.getRank(); + if (triton::PointerType ty = dyn_cast(value.getType())) + if (TensorType elemTy = dyn_cast(ty.getPointeeType())) + rank = elemTy.getRank(); + + DimVectorT knownContiguity(rank, 1); + DimVectorT knownDivisibility(rank, 1); + DimVectorT knownConstancy(rank, 1); + + BlockArgument blockArg = dyn_cast(value); + + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + // llvm codegen check alignment to generate vector load/store + // would be nice if this wasn't the case + else if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, + &knownContiguity, &knownDivisibility, + &knownConstancy); + } else if (Operation *op = value.getDefiningOp()) { + if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state: + // the maximum possible divisibility, contiguity, and constancy. + knownDivisibility = DimVectorT(rank, highestPowOf2Divisor(0)); + knownConstancy = DimVectorT(rank, highestPowOf2Divisor(0)); + knownContiguity = DimVectorT(rank, highestPowOf2Divisor(0)); + } + // Other operations are conservatively initialized with the lowest possible + // divisibility, contiguity, and constancy unless they have specified. + if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) { + auto vals = cast(attr).getValues(); + knownDivisibility = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) { + auto vals = cast(attr).getValues(); + knownContiguity = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getDiscardableAttr("tt.constancy")) { + auto vals = cast(attr).getValues(); + knownConstancy = DimVectorT(vals.begin(), vals.end()); + } + } + + return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); +} + +/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { + // If one argument is not initialized, return the other. + if (lhs.getRank() == 0) + return rhs; + if (rhs.getRank() == 0) + return lhs; + DimVectorT contiguity; + DimVectorT divisibility; + DimVectorT constancy; + for (auto d = 0; d < lhs.getRank(); ++d) { + contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); + divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); + constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); + } + std::optional constantValue; + if (lhs.getConstantValue().has_value() && + rhs.getConstantValue().has_value() && + lhs.getConstantValue() == rhs.getConstantValue()) + constantValue = lhs.getConstantValue(); + return AxisInfo(contiguity, divisibility, constancy, constantValue); +} + +unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto layout = tensorTy.getEncoding(); + + // Here order should be ordered by contiguous first, so the first element + // should have the largest contiguous. + auto order = triton::gpu::getOrder(layout); + unsigned align = getPtrAlignment(ptr); + + auto uniqueContigPerThread = + triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape()); + assert(order[0] < uniqueContigPerThread.size() && + "Unexpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + LDBG("getPtrContiguity uniqueContigPerThread = " << contiguity); + contiguity = std::min(align, contiguity); + + return contiguity; +} + +unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(ptr); + if (!axisInfo) + return 1; + auto layout = tensorTy.getEncoding(); + auto order = triton::gpu::getOrder(layout); + auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); + auto maxContig = axisInfo->getContiguity(order[0]); + auto elemNumBits = triton::getPointeeBitWidth(tensorTy); + auto elemNumBytes = std::max(elemNumBits / 8, 1); + auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); + unsigned alignment = std::min(maxMultiple, maxContig); + LDBG("getPtrAlignment order[0] " + << order[0] << " maxMultipleBytes = " << maxMultipleBytes + << " maxContig = " << maxContig << " elemNumBits = " << elemNumBits + << " maxMultiple = " << maxMultiple << " alignment " << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { + auto tensorTy = dyn_cast(mask.getType()); + if (!tensorTy) + return 1; + auto *axisInfo = getAxisInfo(mask); + if (!axisInfo) + return 1; + auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); + auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); + LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " + << alignment); + LLVM_DEBUG({ + std::string axisStr; + llvm::raw_string_ostream os(axisStr); + axisInfo->print(os); + LDBG("-- " << axisStr); + }); + return alignment; +} + +void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(funcOp))) + return; + auto *axisInfoMap = getFuncData(funcOp); + auto updateAxisInfoMap = [&](Value value) { + auto axisInfo = analysis->getLatticeElement(value)->getValue(); + AxisInfo curAxisInfo; + if (axisInfoMap->count(value)) { + curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value)); + } else { + curAxisInfo = axisInfo; + } + (*axisInfoMap)[value] = curAxisInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateAxisInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateAxisInfoMap(value); + } + }); +} + +void ModuleAxisInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *axisInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, int64_t prevValue) { + auto curValue = highestPowOf2Divisor(0); + if (callee.getArgAttrOfType(index, attrName)) { + curValue = + callee.getArgAttrOfType(index, attrName).getInt(); + } + auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64), + gcd(prevValue, curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto axisInfo = axisInfoMap->lookup(value); + assert(axisInfo.getRank() == 1 && "only scalar arguments are supported"); + setAttrFn("tt.contiguity", axisInfo.getContiguity(0)); + setAttrFn("tt.divisibility", axisInfo.getDivisibility(0)); + setAttrFn("tt.constancy", axisInfo.getConstancy(0)); + } +} + +} // namespace mlir::triton diff --git a/third_party/xpu/lib/Analysis/CMakeLists.txt b/third_party/xpu/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..09ce99dfa --- /dev/null +++ b/third_party/xpu/lib/Analysis/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(TritonAnalysis + AxisInfo.cpp + Allocation.cpp + Membar.cpp + Alias.cpp + Utility.cpp + UtilityXPU.cpp + + DEPENDS + TritonTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRLLVMDialect + TritonIR + TritonGPUIR + TritonNvidiaGPUIR +) diff --git a/third_party/xpu/lib/Analysis/Membar.cpp b/third_party/xpu/lib/Analysis/Membar.cpp new file mode 100644 index 000000000..407a5ae15 --- /dev/null +++ b/third_party/xpu/lib/Analysis/Membar.cpp @@ -0,0 +1,178 @@ +#include "triton/Analysis/Membar.h" +#include "triton/Analysis/Alias.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include + +namespace mlir { + +void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) { + FunctionOpInterface funcOp = + dyn_cast(allocation->getOperation()); + OpBuilder builder(funcOp.getContext()); + resolve(funcOp, &funcBlockInfoMap, &builder); +} + +void MembarAnalysis::resolve(FunctionOpInterface funcOp, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + // Initialize the blockList + DenseMap inputBlockInfoMap; + DenseMap outputBlockInfoMap; + std::deque blockList; + funcOp.walk([&](Block *block) { + for (auto &op : block->getOperations()) { + // Check if the operation belongs to scf dialect, if so, we need to + // throw an error + if (op.getDialect()->getNamespace() == "scf") { + llvm::report_fatal_error( + "scf dialect is not supported in membar. Please lower it " + "to cf dialect first."); + return; + } + } + if (block->isEntryBlock()) + blockList.emplace_back(block); + }); + + // A fixed point algorithm + while (!blockList.empty()) { + auto *block = blockList.front(); + blockList.pop_front(); + // Make a copy of the inputblockInfo but not update + auto inputBlockInfo = inputBlockInfoMap[block]; + SmallVector successors; + for (auto &op : block->getOperations()) { + if (op.hasTrait()) { + visitTerminator(&op, successors); + } else { + update(&op, &inputBlockInfo, funcBlockInfoMap, builder); + } + } + // Get the reference because we want to update if it changed + if (outputBlockInfoMap.count(block) && + inputBlockInfo == outputBlockInfoMap[block]) { + // If we have seen the block before and the inputBlockInfo is the same as + // the outputBlockInfo, we skip the successors + continue; + } + // Update the current block + outputBlockInfoMap[block].join(inputBlockInfo); + // Update the successors + for (auto *successor : successors) { + inputBlockInfoMap[successor].join(outputBlockInfoMap[block]); + blockList.emplace_back(successor); + } + } + + // Update the final dangling buffers that haven't been synced + auto &funcBlockInfo = (*funcBlockInfoMap)[funcOp]; + funcOp.walk([&](Block *block) { + block->walk([&](triton::ReturnOp returnOp) { + funcBlockInfo.join(outputBlockInfoMap[block]); + }); + }); +} + +void MembarAnalysis::visitTerminator(Operation *op, + SmallVector &successors) { + if (auto branchInterface = dyn_cast(op)) { + Block *parentBlock = branchInterface->getBlock(); + successors.append(std::begin(parentBlock->getSuccessors()), + std::end(parentBlock->getSuccessors())); + return; + } + // Otherwise, it could be a return op + if (op->hasTrait()) + return; + llvm_unreachable("Unknown terminator encountered in membar analysis"); +} + +void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { + OpBuilder::InsertionGuard g(*builder); + auto barrierOp = builder->create(op->getLoc()); +} + +void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { + if (isa(op)) { + // If the current op is a barrier, we sync previous reads and writes + blockInfo->sync(); + return; + } + + if (isa(op) && + !isa(op->getNextNode())) { + // If the current op is an async wait and the next op is not a barrier we + // insert a barrier op and sync + builder->setInsertionPointAfter(op); + insertBarrier(op, builder); + blockInfo->sync(); + return; + } + + BlockInfo curBlockInfo; + if (isa(op)) { + // Inter-function dependencies + auto callOpInterface = dyn_cast(op); + if (auto callee = + dyn_cast(callOpInterface.resolveCallable())) + curBlockInfo = funcBlockInfoMap->lookup(callee); + } else { + // Intra-function dependencies + if (auto memoryEffectOpInterface = dyn_cast(op)) { + // Explicit buffer + SmallVector> + effectInstances; + memoryEffectOpInterface.getEffects(effectInstances); + for (auto effectInstance : effectInstances) { + if (auto value = effectInstance.getValue()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) { + if (isa(effectInstance.getEffect())) + curBlockInfo.syncWriteIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + else if (isa(effectInstance.getEffect())) + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + } + } + } + // XXX(Keren): This is a hack as we cannot set side effects for dot ops, but + // on hopper they do have side effects. Need to clean it up + if (auto dotOp = dyn_cast(op)) { + for (auto value : dotOp.getOperands()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + } + // Scratch buffer is considered as both shared memory write & read + auto bufferId = allocation->getBufferId(op); + if (bufferId != Allocation::InvalidBufferId) { + curBlockInfo.syncWriteIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + + if (blockInfo->isIntersected(curBlockInfo)) { + builder->setInsertionPoint(op); + insertBarrier(op, builder); + blockInfo->sync(); + } + // Update the region info, even if barrier is inserted, we have to maintain + // the current op's read/write buffers. + blockInfo->join(curBlockInfo); +} +} // namespace mlir diff --git a/third_party/xpu/lib/Analysis/Utility.cpp b/third_party/xpu/lib/Analysis/Utility.cpp new file mode 100644 index 000000000..c4f32032c --- /dev/null +++ b/third_party/xpu/lib/Analysis/Utility.cpp @@ -0,0 +1,989 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace mlir { +namespace { + +using namespace triton; +using namespace triton::gpu; + +int getParentAxis(Attribute layout, int axis) { + if (auto sliceEncoding = dyn_cast(layout)) { + axis = axis < sliceEncoding.getDim() ? axis : axis + 1; + return getParentAxis(sliceEncoding.getParent(), axis); + } + return axis; +} + +SmallVector getParentOrder(Attribute layout) { + if (auto sliceEncoding = mlir::dyn_cast(layout)) { + return getParentOrder(sliceEncoding.getParent()); + } + return getOrder(layout); +} + +} // namespace + +// TODO(jlebar): Move this class into namespace triton. +bool ReduceOpHelper::isReductionOnLayoutFastAxis() { + return getParentAxis(getSrcLayout(), axis) == + getParentOrder(getSrcLayout())[0]; +} + +SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { + auto srcLayout = getSrcLayout(); + auto order = getOrder(srcLayout); + auto it = std::find(order.begin(), order.end(), axis); + // delete the axis from order + order.erase(it); + // insert axis at the beginning of order + order.insert(order.begin(), axis); + return order; +} + +// Thread offset is the thread index offset of two adjacent threads on the +// reduction axis within the warp. +unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { + auto srcLayout = getSrcLayout(); + + // If the reduction axis is the fast axis of the parent layout + if (isReductionOnLayoutFastAxis()) { + return 1; + } + + unsigned threadOffset = 1; + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto parentLayout = sliceLayout.getParent(); + auto threadsPerWarp = getThreadsPerWarp(parentLayout); + threadOffset = threadsPerWarp[sliceLayout.getDim()]; + } else { + auto threadsPerWarp = getThreadsPerWarp(srcLayout); + auto order = getOrder(srcLayout); + for (unsigned i = 0; i < order.size(); i++) { + if (order[i] == axis) + break; + threadOffset *= threadsPerWarp[order[i]]; + } + } + return threadOffset; +} + +// Cases where distributed shared memory is not required in ConvertLayout: +// (1) numCTAs == 1 +// (2) numCTAs > 1 but srcCTALayout == dstCTALayout +// TODO: Case with SliceLayout as srcLayout and numCTAs > 1 is to be implemented +// in the future +bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) { + unsigned numCTAs = getNumCTAs(srcLayout); + assert(numCTAs == getNumCTAs(dstLayout) && + "Invalid layout conversion: the numbers of CTAs of src and dst " + "layouts are different"); + + // Case (1): Never use dsmem when numCTAs == 1 + if (numCTAs == 1) + return false; + + // Case where CTAsPerCGA of srcLayout in the sliced dim is not 1 is not + // implemented yet + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + llvm::report_fatal_error("Layout conversion to be implemented"); + } + + // Case where CTAsPerCGA of dstLayout in the sliced dim is not 1 is supported + if (auto sliceLayout = mlir::dyn_cast(dstLayout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] != 1) + return true; + } + + // The above two branches make sure that it is legal to call getCTALayout of + // srcLayout and dstLayout + + // Case (2): Do not use dsmem when srcCTALayout == dstCTALayout + auto srcCTALayout = getCTALayout(srcLayout); + auto dstCTALayout = getCTALayout(dstLayout); + if (srcCTALayout == dstCTALayout) + return false; + + // Dsmem access is required when srcCTALayout != dstCTALayout + return true; +} + +unsigned ReduceOpHelper::getInterWarpSize() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned sizeIntraWarps = getIntraWarpSize(); + return std::min(srcReduceDimSize / sizeIntraWarps, + getWarpsPerCTA(getSrcLayout())[axis]); +} + +unsigned ReduceOpHelper::getIntraWarpSize() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + return std::min(srcReduceDimSize, getThreadsPerWarp(getSrcLayout())[axis]); +} + +unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData(); + return std::min( + srcReduceDimSize / sizeIntraWarps, + getWarpsPerCTAWithUniqueData(getSrcLayout(), getSrcShape())[axis]); +} + +unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned elementPerThreads = + getUniqueContigPerThread(getSrcLayout(), getSrcShape())[axis]; + return std::min( + srcReduceDimSize / elementPerThreads, + getThreadsPerWarpWithUniqueData(getSrcLayout(), getSrcShape())[axis]); +} + +unsigned ReduceOpHelper::getThreadsReductionAxis() { + auto srcLayout = getSrcLayout(); + auto srcShape = getSrcShape(); + return getThreadsPerWarpWithUniqueData(srcLayout, srcShape)[axis] * + getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; +} + +bool ReduceOpHelper::isWarpSynchronous() { + auto srcLayout = getSrcLayout(); + auto srcShape = getSrcShape(); + return getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == 1; +} + +SmallVector ReduceOpHelper::getScratchConfig() { + SmallVector smemShape; + // that case doesn't need inter-warp communication + if (isWarpSynchronous()) + return {0, 0}; + + smemShape = convertType(getSrcShape()); + smemShape[axis] = getInterWarpSizeWithUniqueData(); + + return smemShape; +} + +unsigned ReduceOpHelper::getScratchSizeInBytes() { + auto smemShape = getScratchConfig(); + auto elems = product(smemShape); + + unsigned bytesPerElem = 0; + for (const auto &ty : srcElementTypes) { + bytesPerElem += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return bytesPerElem * elems; +} + +bool ReduceOpHelper::isReduceWithinCTA() { + auto axis = getAxis(); + auto srcLayout = getSrcLayout(); + auto CTASplitNum = getCTASplitNum(srcLayout); + assert(axis < CTASplitNum.size()); + return CTASplitNum[axis] == 1; +} + +bool ReduceOpHelper::isSupportedLayout() { + // Layout optimization passes such as PlanCTAPass and + // RemoveLayoutConversionPass should avoid cross-CTA reduction + if (!isReduceWithinCTA()) { + return false; + } + + auto srcLayout = getSrcLayout(); + if (isa(srcLayout)) { + return true; + } + if (auto mmaLayout = dyn_cast(srcLayout)) { + return mmaLayout.supportReduction(); + } + if (auto sliceLayout = dyn_cast(srcLayout)) { + return true; + } + return false; +} + +unsigned ScanLoweringHelper::getAxisNumElementsPerThread() { + return getEncoding().getSizePerThread()[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { + SmallVector sizePerThreads = getContigPerThread(getEncoding()); + sizePerThreads[getAxis()] = 1; + return product(sizePerThreads); +} + +Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); } + +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() { + return getThreadsPerWarp(getEncoding())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() { + return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()]; +} + +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + threadsPerWarp[getAxis()] = 1; + return product(threadsPerWarp); +} + +// Return the flat numbers of threads computing independent scan results. +unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { + unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp(); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + warpsPerCTA[getAxis()] = 1; + unsigned numParallelWarpsPerCTA = product(warpsPerCTA); + return numParallelThreadsPerWarp * numParallelWarpsPerCTA; +} + +unsigned ScanLoweringHelper::getAxisNumWarps() { + return getWarpsPerCTA(getEncoding())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() { + return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumBlocks() { + auto sizePerThreads = getSizePerThread(getEncoding()); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + unsigned axis = getAxis(); + return ceil( + getShape()[axis], + (sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis])); +} + +unsigned ScanLoweringHelper::getNonAxisNumBlocks() { + auto sizePerThreads = getSizePerThread(getEncoding()); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + unsigned axis = getAxis(); + unsigned numBlocks = 1; + for (unsigned i = 0; i < sizePerThreads.size(); i++) { + if (i == axis) + continue; + numBlocks *= + ceil(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] * + warpsPerCTA[i])); + } + return numBlocks; +} + +bool ScanLoweringHelper::isSupported() { + // TODO: Support the following cases: + // 1. Scan on non-blocking encodings + if (!isa(getEncoding())) + return false; + return true; +} + +unsigned ScanLoweringHelper::getScratchSizeInElems() { + auto mod = scanOp->getParentOfType(); + unsigned numWarps = TritonGPUDialect::getNumWarps(mod); + unsigned numNonAxisElementsPerWarp = + getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread(); + unsigned numElements = numWarps * numNonAxisElementsPerWarp * + getAxisNumBlocks() * getNonAxisNumBlocks(); + return numElements; +} + +unsigned ScanLoweringHelper::getScratchSizeInBytes() { + unsigned axisNumWarps = getAxisNumWarpsWithUniqueData(); + if (axisNumWarps == 1) + return 0; + unsigned elementSizeInBytes = 0; + for (const auto &ty : srcElementTypes) { + elementSizeInBytes += ceil(ty.getIntOrFloatBitWidth(), 8); + } + return elementSizeInBytes * getScratchSizeInElems(); +} + +SmallVector, SmallVector>> +getReshapeDecomposition(ArrayRef srcShape, + ArrayRef dstShape) { + SmallVector, SmallVector>> ret; + + if (srcShape.empty()) { + assert(dstShape.empty()); + return ret; + } + ret.push_back({}); + + int srcIdx = 0; + int dstIdx = 0; + int srcNElems = 1; + int dstNElems = 1; + while (srcIdx < srcShape.size() || dstIdx < dstShape.size()) { + if (srcNElems < dstNElems || // + (srcIdx < srcShape.size() && srcNElems == 1) || + (srcIdx < srcShape.size() && srcShape[srcIdx] == 1)) { + assert(srcIdx < srcShape.size()); + srcNElems *= srcShape[srcIdx]; + ret.back().first.push_back(srcIdx); + srcIdx++; + } else if (dstNElems < srcNElems || + (dstIdx < dstShape.size() && dstShape[dstIdx] == 1)) { + assert(dstIdx < dstShape.size()); + dstNElems *= dstShape[dstIdx]; + ret.back().second.push_back(dstIdx); + dstIdx++; + } else { + ret.push_back({}); + srcNElems = 1; + dstNElems = 1; + } + } + return ret; +} + +BlockedEncodingAttr ScanLoweringHelper::getEncoding() { + return cast(srcEncoding); +} + +unsigned ScanLoweringHelper::getAxisElementStride() { + auto order = getOrder(getEncoding()); + unsigned stride = 1; + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= getContigPerThread(getEncoding())[dim]; + } + llvm_unreachable("Axis not found in order"); +} + +unsigned ScanLoweringHelper::getAxisThreadStride() { + auto order = getOrder(getEncoding()); + unsigned stride = 1; + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= getEncoding().getThreadsPerWarp()[dim]; + } + llvm_unreachable("Axis not found in order"); +} + +unsigned ScanLoweringHelper::getAxisBlockStride() { + auto order = getOrder(getEncoding()); + unsigned stride = 1; + auto sizePerThreads = getSizePerThread(getEncoding()); + auto threadsPerWarp = getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = getWarpsPerCTA(getEncoding()); + for (unsigned dim : order) { + if (dim == getAxis()) + return stride; + stride *= ceil(getShape()[dim], sizePerThreads[dim] * + threadsPerWarp[dim] * + warpsPerCTA[dim]); + } + llvm_unreachable("Axis not found in order"); +} + +bool maybeSharedAllocationOp(Operation *op) { + // TODO(Keren): This function can be replaced by adding + // MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to + // query the memory effects of the op. + auto *dialect = op->getDialect(); + return dialect && + (dialect->getTypeID() == TypeID::get() || + dialect->getTypeID() == + TypeID::get() || + dialect->getTypeID() == TypeID::get() || + dialect->getTypeID() == TypeID::get() || + dialect->getTypeID() == TypeID::get()); +} + +static bool supportMFMAGranularity(int m, int n, int k) { + // these limitations are dtype dependent, in future we may relax them + const static std::pair mfmaTypes[2] = {{32, 8}, {16, 16}}; + for (const auto &mfmaType : mfmaTypes) { + auto [granularityMN, granularityK] = mfmaType; + if (m % granularityMN != 0 || n % granularityMN != 0) + continue; + if (k % granularityK != 0) + continue; + return true; + } + return false; +} + +bool supportMFMATypes(Type a, Type b) { + if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth()) + return false; + + auto F8E4M3FNUZ = TypeID::get(); + auto F8E5M2FNUZ = TypeID::get(); + auto F16 = TypeID::get(); + auto BF16 = TypeID::get(); + auto F32 = TypeID::get(); + auto Int = TypeID::get(); + DenseSet> supportedTypes = { + {F32, F32}, + {F16, F16}, + {BF16, BF16}, + {F8E4M3FNUZ, F8E4M3FNUZ}, + {F8E4M3FNUZ, F8E5M2FNUZ}, + {F8E5M2FNUZ, F8E4M3FNUZ}, + {F8E5M2FNUZ, F8E5M2FNUZ}, + {Int, Int}}; + + if (!supportedTypes.contains({a.getTypeID(), b.getTypeID()})) + return false; + + if (a.isIntOrIndex() && a.getIntOrFloatBitWidth() != 8) + return false; + return true; +} + +bool supportMFMA(triton::DotOp op) { + auto aTy = cast(op.getA().getType()); + auto bTy = cast(op.getB().getType()); + + auto aElemTy = aTy.getElementType(); + auto bElemTy = bTy.getElementType(); + + if (!supportMFMATypes(aElemTy, bElemTy)) + return false; + + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + + auto rank = aShape.size(); + assert(bShape.size() == rank); + auto M = aShape[rank - 2]; + auto N = bShape[rank - 1]; + auto K = aShape[rank - 1]; + assert(K == bShape[rank - 2]); + if (!supportMFMAGranularity(M, N, K)) + return false; + + return true; +} + +static bool supportWMMAGranularity(int m, int n, int k) { + return m % 16 == 0 && n % 16 == 0 && k % 16 == 0; +} + +static bool supportWMMATypes(Type a, Type b, Type c, Type d) { + if (a != b || c != d) + return false; + auto aWidth = a.getIntOrFloatBitWidth(); + auto cWidth = c.getIntOrFloatBitWidth(); + if (a.isIntOrIndex()) { + if (!c.isIntOrIndex()) + return false; + bool aValid = aWidth <= 8; + bool cValid = cWidth <= 32; + return aValid && cValid; + } else if (isa(a) && isa(c)) { + if (a.isBF16()) + return c.isBF16() || c.isF32(); + if (a.isF16()) + return c.isF16() || c.isF32(); + return aWidth <= cWidth && aWidth <= 16; + } + return false; +} + +bool supportWMMA(triton::DotOp op) { + auto aTy = cast(op.getA().getType()); + auto bTy = cast(op.getB().getType()); + auto cTy = cast(op.getC().getType()); + auto dTy = cast(op.getResult().getType()); + + auto aElemTy = aTy.getElementType(); + auto bElemTy = bTy.getElementType(); + auto cElemTy = cTy.getElementType(); + auto dElemTy = dTy.getElementType(); + + if (!supportWMMATypes(aElemTy, bElemTy, cElemTy, dElemTy)) + return false; + + auto aShape = aTy.getShape(); + auto bShape = bTy.getShape(); + + auto rank = aShape.size(); + assert(bShape.size() == rank); + assert(aShape[rank - 1] == bShape[rank - 2]); + if (!supportWMMAGranularity(aShape[rank - 2], bShape[rank - 1], + aShape[rank - 1])) + return false; + + return true; +} + +bool supportMMA(triton::DotOp op, int version) { + // Refer to mma section for the data type supported by Volta and Hopper + // Tensor Core in + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + auto aElemTy = op.getA().getType().getElementType(); + auto bElemTy = op.getB().getType().getElementType(); + if (version == 3) { + if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) + return false; + auto retType = op.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + auto mod = op->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && + retShapePerCTA[rank - 1] % 8 == 0 && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ() || + aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || + aElemTy.isF32()))) { + return false; + } + // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. + if (op.getMaxNumImpreciseAcc() < 32 && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && + cast(op.getType()).getElementType().isF32()) { + return false; + } + } + if (aElemTy.isF32() && bElemTy.isF32()) { + return op.getInputPrecision() == InputPrecision::TF32 && version >= 2; + } + return supportMMA(op.getA(), version) && supportMMA(op.getB(), version); +} + +bool supportMMA(Value value, int version) { + // Tell whether a DotOp support MMA by the operand type(either $a or $b). + // We cannot get both the operand types(in TypeConverter), here we assume the + // types of both the operands are identical here. + assert((version == 1 || version == 2 || version == 3) && + "Unexpected MMA layout version found"); + auto elemTy = cast(value.getType()).getElementType(); + // FP8 is not natively supported on all mma versions but it can always be + // promoted to fp16 therefore we can always support it. + bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || + elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + return isFP8 || elemTy.isF16() || elemTy.isBF16() || + (elemTy.isF32() && version >= 2) || + (elemTy.isInteger(8) && version >= 2); +} + +bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mfmaLayout = dyn_cast(srcLayout); + auto dotOperandLayout = dyn_cast(dstLayout); + if (mfmaLayout == nullptr || dotOperandLayout == nullptr) + return false; + // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is + // improved. In addition, we can enable this shortcut for regular MFMA + // layout when opIdx == 1. + return mfmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] && + dotOperandLayout.getParent() == mfmaLayout && + (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && + (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); +} + +static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) { + auto src = dyn_cast(srcEncoding); + auto dst = dyn_cast(dstEncoding); + if (!src || !dst) + return false; + // when #mma = MmaEncoding + return src && dst && src.getVersionMajor() == 3 && + src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 && + dst.getWarpsPerCTA()[1] == 1; +} + +bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding()); +} + +// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mmaLayout = cast(srcLayout); + auto dotOperandLayout = cast(dstLayout); + int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth(); + auto ans = mmaLayout.getVersionMajor() == 3 && + dotOperandLayout.getOpIdx() == 0 && + isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) && + (elementTypeSize == 16 || elementTypeSize == 8); + return ans; +} + +bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) + return true; + // dot_op = #mma + // when #mma = MmaEncoding + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mmaLayout = mlir::cast(srcLayout); + auto dotOperandLayout = mlir::cast(dstLayout); + return mmaLayout.getVersionMajor() == 2 && + mmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && + dotOperandLayout.getParent() == mmaLayout && + !srcTy.getElementType().isF32(); +} + +namespace { + +/// A data structure similar to SetVector but maintains +/// a deque instead of a vector to allow for efficient +/// push_back and pop_front operations. +/// Using SetVector doesn't suffice our needs because +/// it only pushes and pops from the back. +/// For example, if we have a queue like this: +/// 0->4 1->2->3 +/// ^-------- +/// where 3 depends on 4, once we pop 3, we found +/// 4 is not ready, so we check 2 and push 3 back +/// to the queue. +struct DFSSubgraphState { + DFSSubgraphState() : set(), deque() {} + DenseSet set; + std::deque deque; + + bool push_back(Operation *op) { + if (set.insert(op).second) { + deque.push_back(op); + return true; + } + return false; + } + + Operation *pop_front() { + Operation *op = deque.front(); + deque.pop_front(); + set.erase(op); + return op; + } + + bool empty() { return deque.empty(); } +}; + +/// DFS post-order implementation that maintains a global count to work across +/// multiple invocations, to help implement topological sort on multi-root DAGs. +/// We traverse all operations but only record the ones that appear in +/// `toSort` for the final result. +struct DFSState { + DFSState(const SetVector &set) : toSort(set), seen() {} + const SetVector &toSort; + SmallVector topologicalCounts; + DenseSet seen; + + /// We mark each op as ready if all its operands and parents ops are seen. If + /// an op is ready, we add it to the queue. Otherwise, we keep adding its + /// operands to the ancestors set. + /// We always want an op to be scheduled after all its parents to handle + /// correctly cases with scf operations. + void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph, + SmallVector &readyQueue) { + bool ready = true; + for (Value operand : op->getOperands()) { + auto def = operand.getDefiningOp(); + if (def && !seen.count(def)) { + subGraph.push_back(def); + ready = false; + } + } + Operation *parent = op->getParentOp(); + while (parent) { + if (!seen.count(parent)) { + subGraph.push_back(parent); + ready = false; + } + parent = parent->getParentOp(); + } + if (ready) + readyQueue.push_back(op); + } +}; + +void dfsPostorder(Operation *root, DFSState *state) { + DFSSubgraphState subGraph; + subGraph.push_back(root); + SmallVector ops; + while (!subGraph.empty()) { + // Nodes in the ready queue are ready to be processed. + // Meaning that either their operands are all seen or they have null + // operands. + SmallVector readyQueue; + auto *current = subGraph.pop_front(); + state->addToReadyQueue(current, subGraph, readyQueue); + while (!readyQueue.empty()) { + Operation *current = readyQueue.pop_back_val(); + if (!state->seen.insert(current).second) + continue; + ops.push_back(current); + for (Value result : current->getResults()) { + for (Operation *op : result.getUsers()) + state->addToReadyQueue(op, subGraph, readyQueue); + } + for (Region ®ion : current->getRegions()) { + for (Operation &op : region.getOps()) + state->addToReadyQueue(&op, subGraph, readyQueue); + } + } + } + + for (Operation *op : llvm::reverse(ops)) { + if (state->toSort.count(op) > 0) + state->topologicalCounts.push_back(op); + } +} + +} // namespace + +SetVector +multiRootTopologicalSort(const SetVector &toSort) { + if (toSort.empty()) { + return toSort; + } + + // Run from each root with global count and `seen` set. + DFSState state(toSort); + for (auto *s : toSort) { + assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); + dfsPostorder(s, &state); + } + + // Reorder and return. + SetVector res; + for (auto it = state.topologicalCounts.rbegin(), + eit = state.topologicalCounts.rend(); + it != eit; ++it) { + res.insert(*it); + } + return res; +} + +SetVector multiRootGetSlice(Operation *op, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter) { + SetVector slice; + slice.insert(op); + + unsigned currentIndex = 0; + SetVector backwardSlice; + SetVector forwardSlice; + while (currentIndex != slice.size()) { + auto *currentOp = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentOp. + backwardSlice.clear(); + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = backwardFilter; + getBackwardSlice(currentOp, &backwardSlice, opt); + slice.insert(backwardSlice.begin(), backwardSlice.end()); + + // Compute and insert the forwardSlice starting from currentOp. + forwardSlice.clear(); + getForwardSlice(currentOp, &forwardSlice, forwardFilter); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + ++currentIndex; + } + return multiRootTopologicalSort(slice); +} + +namespace { +// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis +// interacts with constant propagation, but SparseConstantPropagation +// doesn't seem to be sufficient. +class ConstantAnalysis : public DataFlowAnalysis { +public: + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override { + WalkResult result = top->walk([&](Operation *op) { + if (failed(visit(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + + LogicalResult visit(ProgramPoint point) override { + Operation *op = point.get(); + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + auto *constant = getOrCreate>( + op->getResult(0)); + propagateIfChanged(constant, constant->join(dataflow::ConstantValue( + value, op->getDialect()))); + return success(); + } + // Dead code analysis requires every operands has initialized ConstantValue + // state before it is visited. + // https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322 + // That's why we need to set all operands to unknown constants. + setAllToUnknownConstants(op->getResults()); + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) + setAllToUnknownConstants(block.getArguments()); + } + return success(); + } + +private: + /// Set all given values as not constants. + void setAllToUnknownConstants(ValueRange values) { + dataflow::ConstantValue unknownConstant(nullptr, nullptr); + for (Value value : values) { + auto *constant = + getOrCreate>(value); + propagateIfChanged(constant, constant->join(unknownConstant)); + } + } +}; +} // namespace + +std::unique_ptr createDataFlowSolver() { + auto solver = std::make_unique(); + solver->load(); + solver->load(); + return solver; +} + +static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) { + + if (auto makeTensorPtrOp = dyn_cast(op)) { + return makeTensorPtrOp; + } + + if (auto advanceOp = dyn_cast(op)) { + return getMakeTensorPtrOp(advanceOp.getPtr()); + } + + if (auto branch = dyn_cast(op)) { + auto idx = cast(v).getResultNumber(); + llvm::SmallVector yieldOps; + op->walk([&](Operation *op) { + if (auto yieldOp = dyn_cast(op)) + yieldOps.push_back(yieldOp); + }); + + // benzh@ if multi yields, all yields operand should come from same arg. + Value newValue = yieldOps[0].getOperands()[idx]; + return getMakeTensorPtrOp(newValue); + } + + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +MakeTensorPtrOp getMakeTensorPtrOp(Value v) { + using BranchOps = llvm::SetVector>; + llvm::DenseMap blockToCFOps; + auto moduleOp = + v.getParentBlock()->getParentOp()->getParentOfType(); + + moduleOp.walk([&](Operation *op) { + if (auto br = dyn_cast(op)) { + Block *block = br.getDest(); + blockToCFOps[block].insert({op, -1}); + } + if (auto condBr = dyn_cast(op)) { + Block *blockT = condBr.getTrueDest(); + Block *blockF = condBr.getFalseDest(); + blockToCFOps[blockT].insert({condBr, 1}); + blockToCFOps[blockF].insert({condBr, 0}); + } + }); + + if (Operation *definingOp = v.getDefiningOp()) + return getMakeTensorPtrOpImpl(definingOp, v); + + // If there is no defining op, v must be a BlockArgument. + BlockArgument arg = cast(v); + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + + if (auto forOp = dyn_cast(argOwner)) + return getMakeTensorPtrOp( + forOp.getOperand(argNum + forOp.getNumControlOperands() - 1)); + if (auto funcOp = dyn_cast(argOwner)) { + Block *block = arg.getOwner(); + Operation *op; + int tOrF; + std::tie(op, tOrF) = blockToCFOps[block][0]; + if (auto br = dyn_cast(op)) + return getMakeTensorPtrOp(br.getDestOperands()[argNum]); + if (auto condBr = dyn_cast(op)) + return getMakeTensorPtrOp(tOrF ? condBr.getTrueDestOperands()[argNum] + : condBr.getFalseDestOperands()[argNum]); + return getMakeTensorPtrOp(argOwner->getOperand(argNum)); + } + llvm_unreachable("Unable to getMakeTensorPtr()"); +} + +} // namespace mlir + +//===-------------------- For Triton XPU -----------------------===// + +namespace mlir { +std::map ReduceOpHelper::reduceIdMap; +unsigned ReduceOpHelper::reduceNum = 0; +std::map> + ReduceOpHelper::reduceSMOffsetMap; + +bool ReduceOpHelper::isCoreSynchronous() { + auto srcLayout = getSrcLayout(); + auto srcShape = getSrcShape(); + return triton::xpu::getCoresPerClusterWithUniqueData(srcLayout, + srcShape)[axis] == 1; +} + +unsigned ReduceOpHelper::getIntraGroupSizeWithUniqueData() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned elementPerThreads = + triton::xpu::getUniqueContigPerCore(getSrcLayout(), getSrcShape())[axis]; + // getUniqueContigPerThread(getSrcLayout(), getSrcShape())[axis]; + return std::min(srcReduceDimSize / elementPerThreads, + triton::xpu::getCoresPerGroupWithUniqueData( + getSrcLayout(), getSrcShape())[axis]); +} + +SmallVector ReduceOpHelper::getXPUScratchConfig() { + SmallVector smemShape; + + smemShape = convertType(getSrcShape()); + smemShape[axis] = 64; // max_group_size = core_num + + return smemShape; +} + +unsigned ReduceOpHelper::getXPUScratchSizeInBytes() { + auto smemShape = getXPUScratchConfig(); + auto elems = product(smemShape); + + unsigned bytesPerElem = 0; + for (const auto &ty : srcElementTypes) { + bytesPerElem += + ceil(getElementTypeOrSelf(ty).getIntOrFloatBitWidth(), 8); + } + return bytesPerElem * elems; +} + +} // namespace mlir + +//===-----------------------------------------------------------===// diff --git a/third_party/xpu/lib/Analysis/UtilityXPU.cpp b/third_party/xpu/lib/Analysis/UtilityXPU.cpp new file mode 100644 index 000000000..d3408c44a --- /dev/null +++ b/third_party/xpu/lib/Analysis/UtilityXPU.cpp @@ -0,0 +1,269 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Analysis/UtilityXPU.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { + +std::map SMHelper::smOffsetMap; + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const OffsetState &state) { + switch (state) { + case OffsetState::Unknown: + os << "Unknown"; + break; + case OffsetState::DiscreteSame: + os << "Discrete Same"; + break; + case OffsetState::Continuous: + os << "Continuous"; + break; + case OffsetState::Discrete: + os << "Discrete"; + break; + case OffsetState::LocallyContinuous: + os << "Locally Continuous"; + break; + default: + os << "Invalid State"; + break; + } + return os; +} + +Type addrspaceCast(Type type, int addressSpace) { + if (auto tensorType = mlir::dyn_cast(type)) { + auto elemTy = tensorType.getElementType(); + auto ptrTy = mlir::cast(elemTy); + auto valTy = ptrTy.getPointeeType(); + auto ptrTyWithNewAS = triton::PointerType::get(valTy, addressSpace); + return RankedTensorType::get(tensorType.getShape(), ptrTyWithNewAS, + tensorType.getEncoding()); + } else if (auto pointerType = mlir::dyn_cast(type)) { + auto valTy = pointerType.getPointeeType(); + auto ptrTyWithNewAS = triton::PointerType::get(valTy, addressSpace); + return ptrTyWithNewAS; + } else { + llvm_unreachable("`type` must be a PointerType or RankedTensorType whose " + "element type is PointerType."); + } +} + +bool inOpChain(llvm::SetVector &opChain, Operation *op) { + if (!op || opChain.empty()) + return false; + for (int i = 0; i < opChain.size(); ++i) { + if (op == opChain[i]) { + return true; + } + } + return false; +} + +void getOpChainBwd(llvm::SetVector &opChain, Operation *op) { + if (!op) { + return; + } + opChain.insert(op); + + int noDefCnt = 0; + for (auto operand : op->getOperands()) { + if (!operand.getDefiningOp()) { + noDefCnt++; + } + } + + if (isa(op) || isa(op) || + noDefCnt == op->getNumOperands()) { + return; + } + + for (auto operand : op->getOperands()) { + getOpChainBwd(opChain, operand.getDefiningOp()); + } +} + +void getOpChainFwd(llvm::SetVector &opChain, Operation *op) { + opChain.insert(op); + + if (isa(op)) { + return; + } + + for (auto userOp : op->getUsers()) { + if (!opChain.contains(userOp)) { + getOpChainFwd(opChain, userOp); + } + } +} + +void getOpTreeBwd(llvm::SetVector &opTree, + llvm::SetVector &visitedOps, Operation *op) { + if (!op) { + return; + } + visitedOps.insert(op); + opTree.insert(op); + + if (isa(op)) { + // Do nothing + } else { + for (auto operand : op->getOperands()) { + if (!visitedOps.contains(operand.getDefiningOp())) { + getOpTreeBwd(opTree, visitedOps, operand.getDefiningOp()); + } + } + } + + for (auto userOp : op->getUsers()) { + if (!visitedOps.contains(userOp)) { + getOpTreeBwd(opTree, visitedOps, userOp); + } + } + + return; +} + +void getOpTreeBwd(llvm::SetVector &opTree, + llvm::SetVector &visitedOps, Operation *op, + Block *block) { + if (!op) { + return; + } + visitedOps.insert(op); + + if (isa(op)) { + // Do nothing + } else { + if (auto forOp = dyn_cast(op)) { + forOp->walk([&](Operation *innerOp) { + if (!visitedOps.contains(innerOp)) { + getOpTreeBwd(opTree, visitedOps, innerOp, block); + } + }); + } + for (auto operand : op->getOperands()) { + if (!visitedOps.contains(operand.getDefiningOp())) { + getOpTreeBwd(opTree, visitedOps, operand.getDefiningOp(), block); + } + } + } + + for (auto userOp : op->getUsers()) { + if (!visitedOps.contains(userOp)) { + getOpTreeBwd(opTree, visitedOps, userOp, block); + } + } + if (op->getBlock() == block) { + opTree.insert(op); + } + return; +} + +llvm::SmallVector +sortOpTreeBwd(llvm::SmallVector &opTree) { + auto compareOps = [](Operation *op1, Operation *op2) { + auto *block1 = op1->getBlock(); + auto *block2 = op2->getBlock(); + + if (block1 == block2) { + return op2->isBeforeInBlock(op1); + } + + auto *region = block1->getParent(); + assert(region == block2->getParent() && + "Operations are in different regions!"); + return std::distance(region->begin(), Region::iterator(block1)) > + std::distance(region->begin(), Region::iterator(block2)); + }; + auto sortedOpTree = opTree; + llvm::stable_sort(sortedOpTree, compareOps); + return sortedOpTree; +} + +llvm::SetVector +sortOpTreeBwd(llvm::SetVector &opTree) { + auto compareOps = [](Operation *op1, Operation *op2) { + auto *block1 = op1->getBlock(); + auto *block2 = op2->getBlock(); + + if (block1 == block2) { + return op2->isBeforeInBlock(op1); + } + + auto *region = block1->getParent(); + assert(region == block2->getParent() && + "Operations are in different regions!"); + return std::distance(region->begin(), Region::iterator(block1)) > + std::distance(region->begin(), Region::iterator(block2)); + }; + llvm::SmallVector opTreeVec; + for (auto op : opTree) { + opTreeVec.emplace_back(op); + } + llvm::stable_sort(opTreeVec, compareOps); + llvm::SetVector sortedOpTree; + for (auto op : opTreeVec) { + sortedOpTree.insert(op); + } + return sortedOpTree; +} + +llvm::SetVector sortOpTree(llvm::SetVector &opTree) { + auto compareOps = [](Operation *op1, Operation *op2) { + auto *parentOp1 = op1; + auto *parentOp2 = op2; + for (mlir::Block *block1 = parentOp1->getBlock(); block1 != nullptr;) { + for (mlir::Block *block2 = parentOp2->getBlock(); block2 != nullptr;) { + if (block1 == block2) { + return parentOp1->isBeforeInBlock(parentOp2); + } + parentOp2 = block2->getParentOp(); + if (parentOp2 == nullptr) { + break; + } + block2 = parentOp2->getBlock(); + } + parentOp2 = op2; // reset for next iteration + parentOp1 = block1->getParentOp(); + if (parentOp1 == nullptr) { + break; + } + block1 = parentOp1->getBlock(); + } + assert(0 && "Sort Op Tree Failed!"); + return false; + }; + llvm::SmallVector opTreeVec; + for (auto op : opTree) { + opTreeVec.emplace_back(op); + } + llvm::stable_sort(opTreeVec, compareOps); + llvm::SetVector sortedOpTree; + for (auto op : opTreeVec) { + sortedOpTree.insert(op); + } + return sortedOpTree; +} + +// Only Create Loop Once If StoreOp in SCF.IF +bool inSameSCFIfBlock(llvm::SetVector &storeOps, + Operation *storeOp) { + auto block1 = storeOp->getBlock(); + Operation *parentOp1 = block1->getParentOp(); + + for (auto otherStoreOp : storeOps) { + auto block2 = otherStoreOp->getBlock(); + Operation *parentOp2 = block2->getParentOp(); + if (parentOp1 == parentOp2 && dyn_cast(parentOp1) && + dyn_cast(parentOp2)) { + return true; + } + } + return false; +} + +} // namespace mlir diff --git a/third_party/xpu/lib/CMakeLists.txt b/third_party/xpu/lib/CMakeLists.txt new file mode 100644 index 000000000..c58b7fa0a --- /dev/null +++ b/third_party/xpu/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(Analysis) +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(Target) +add_subdirectory(Tools) diff --git a/third_party/xpu/lib/Conversion/CMakeLists.txt b/third_party/xpu/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..28175de8d --- /dev/null +++ b/third_party/xpu/lib/Conversion/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(TritonToTritonGPU) +add_subdirectory(TritonGPUToLLVM) +add_subdirectory(TritonToTritonXPU) +add_subdirectory(TritonXPUToLLVM) diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp new file mode 100644 index 000000000..aae9faf0e --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -0,0 +1,69 @@ +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_ALLOCATESHAREDMEMORY +#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace { + +struct AllocateSharedMemory + : public mlir::triton::impl::AllocateSharedMemoryBase< + AllocateSharedMemory> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + ModuleAllocation allocation(mod); + + mod.walk([&](FunctionOpInterface funcOp) { + funcOp.walk([&](Operation *op) { + auto *funcAllocation = allocation.getFuncData(funcOp); + auto oBufferId = funcAllocation->getBufferId(op); + int offset = -1; + if (oBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(oBufferId); + else if (op->getNumResults() == 1) { + Value value = op->getResult(0); + auto vBufferId = funcAllocation->getBufferId(value); + if (vBufferId != Allocation::InvalidBufferId) + offset = funcAllocation->getOffset(vBufferId); + } + if (offset == -1) + return; + op->setAttr("allocation.offset", + IntegerAttr::get(IntegerType::get(ctx, 32), offset)); + }); + }); + mod->setAttr("triton_gpu.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + allocation.getSharedMemorySize())); + } +}; + +} // namespace + +namespace mlir { + +namespace triton { + +namespace gpu { + +std::unique_ptr> createAllocateSharedMemoryPass() { + return std::make_unique(); +} + +} // namespace gpu + +} // namespace triton + +} // namespace mlir diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp new file mode 100644 index 000000000..a3f55f1e7 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -0,0 +1,80 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; + +struct AssertOpConversion : public ConvertOpToLLVMPattern { + explicit AssertOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter); + auto elemTy = elems[0].getType(); + Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0); + for (auto elem : elems) { + if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) { + condition = + or_(condition, + icmp_eq(elem, rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)))); + } else { + assert(false && "Unsupported type for assert"); + return failure(); + } + } + llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(), + adaptor.getFunc(), adaptor.getLine(), rewriter); + rewriter.eraseOp(op); + return success(); + } + // op: the op at which the assert is inserted. Unlike printf, we need to + // know about the op to split the block. + void llAssert(Operation *op, Value condition, StringRef message, + StringRef file, StringRef func, int line, + ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + auto ctx = rewriter.getContext(); + auto loc = op->getLoc(); + // #block1 + // if (condition) { + // #block2 + // __assertfail(message); + // } + // #block3 + Block *prevBlock = op->getBlock(); + + Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToStart(ifBlock); + targetInfo.assertFail(rewriter, loc, message, file, func, line); + + // Split a block after the call. + Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(ifBlock); + rewriter.create(loc, thenBlock); + rewriter.setInsertionPointToEnd(prevBlock); + rewriter.create(loc, condition, ifBlock, thenBlock); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateAssertOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..4d57131d0 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -0,0 +1,37 @@ +add_triton_library(TritonGPUToLLVM + ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp + DotOpToLLVM/FMA.cpp + TypeConverter.cpp + Utility.cpp + ElementwiseOpToLLVM.cpp + MemoryOpToLLVM.cpp + AssertOpToLLVM.cpp + ViewOpToLLVM.cpp + MakeRangeOpToLLVM.cpp + HistogramOpToLLVM.cpp + AllocateSharedMemory.cpp + ReduceOpToLLVM.cpp + ScanOpToLLVM.cpp + ConvertLayoutOpToLLVM.cpp + ControlFlowOpToLLVM.cpp + FuncOpToLLVM.cpp + SPMDOpToLLVM.cpp + DecomposeUnsupportedConversions.cpp + PrintOpToLLVM.cpp + + DEPENDS + TritonGPUConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRGPUDialect + MLIRGPUToNVVMTransforms + MLIRGPUToROCDLTransforms + MLIRGPUTransforms + TritonAnalysis + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUTransforms +) diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp new file mode 100644 index 000000000..9765d7bf0 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp @@ -0,0 +1,141 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + if (funcOp->hasAttr("nvvm.kernel")) { + // A GPU kernel + if (op.getNumOperands() > 0) { + return rewriter.notifyMatchFailure( + op, "Kernel functions do not support return with operands"); + } + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + } else { + // A device function + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + auto loc = op.getLoc(); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = insert_val(packedResultsTy, packedResults, it.value(), + it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + } + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the last argument of the caller, which is the current stack pointer + // of shared memory and append it to the operands of the callOp. + auto loc = callOp.getLoc(); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + if (!caller->hasAttr("allocation.offset")) { + auto base = LLVM::getStackPointer(rewriter, caller); + promotedOperands.push_back(base); + return promotedOperands; + } + promotedOperands.push_back( + LLVM::getSharedMemoryBase(callOp->getLoc(), rewriter, callOp)); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } +}; + +} // namespace + +void mlir::triton::populateControlFlowOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 000000000..94894ceb1 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,324 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" + +using mlir::isLayoutMmaV1; +using mlir::LLVM::getMultiDimOffset; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getStridesFromShapeAndOrder; +using mlir::LLVM::getWrappedMultiDimOffset; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::isaDistributedLayout; +using ::mlir::triton::gpu::SharedEncodingAttr; + +namespace { + +struct LocalLoadOpConversion + : public ConvertOpToLLVMPattern { +public: + LocalLoadOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemDescType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + // TODO: do we need to check if src is shared ? + if (isa(srcLayout) && isaDistributedLayout(dstLayout)) { + return lowerSharedToDistributed(op, adaptor, getTypeConverter(), + rewriter); + } + if (isa(dstLayout) && + isa( + cast(dstLayout).getParent())) { + return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter); + } + return failure(); + } + +private: + LogicalResult + lowerSharedToDotOpFMA(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + RankedTensorType dstTy = op.getType(); + Attribute dstLayout = dstTy.getEncoding(); + auto dotLayout = cast(dstLayout); + auto blockedLayout = cast( + cast(dstLayout).getParent()); + auto thread = getThreadId(rewriter, loc); + Value res = SharedToDotOperandFMA::convertLayout( + dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout, + thread, loc, getTypeConverter(), rewriter); + rewriter.replaceOp(op, res); + return success(); + } + LogicalResult + lowerSharedToDistributed(triton::gpu::LocalLoadOp op, + triton::gpu::LocalLoadOpAdaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType(); + auto dstTy = op.getResult().getType(); + auto dstShape = dstTy.getShape(); + assert(dstShape.size() <= 2 && + "Unexpected rank of ConvertLayout(shared->blocked)"); + auto srcSharedLayout = cast(srcTy.getEncoding()); + auto dstLayout = dstTy.getEncoding(); + auto inOrd = getOrder(srcSharedLayout); + + auto smemObj = getSharedMemoryObjectFromStruct( + loc, adaptor.getSrc(), + typeConverter->convertType(srcTy.getElementType()), rewriter); + auto elemTy = typeConverter->convertType(dstTy.getElementType()); + + auto srcStrides = + getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter); + + SmallVector outVals = + loadSharedToDistributed(op.getResult(), op.getSrc(), smemObj, elemTy, + loc, rewriter, targetInfo); + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct ConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { +public: + ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + if (isSupported(srcLayout, dstLayout)) { + return lowerDistributedToDistributed(op, adaptor, rewriter); + } + return failure(); + } + +private: + bool isSupported(Attribute srcLayout, Attribute dstLayout) const { + return isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout) && + !isLayoutMmaV1(srcLayout) && !isLayoutMmaV1(dstLayout); + } + // shared memory rd/st for blocked or mma layout with data padding + void processReplica(Location loc, ConversionPatternRewriter &rewriter, + bool stNotRd, RankedTensorType type, + ArrayRef numCTAsEachRep, + ArrayRef multiDimRepId, unsigned vec, + ArrayRef paddedRepShape, + ArrayRef origRepShape, + ArrayRef outOrd, SmallVector &vals, + Value smemBase) const { + auto accumNumCTAsEachRep = product(numCTAsEachRep); + auto layout = type.getEncoding(); + auto rank = type.getRank(); + auto sizePerThread = getSizePerThread(layout); + auto accumSizePerThread = product(sizePerThread); + SmallVector numCTATiles(rank); + auto shapePerCTATile = getShapePerCTATile(layout); + auto shapePerCTA = getShapePerCTA(layout, type.getShape()); + auto order = getOrder(layout); + for (unsigned d = 0; d < rank; ++d) { + numCTATiles[d] = ceil(shapePerCTA[d], shapePerCTATile[d]); + } + auto elemTy = type.getElementType(); + bool isInt1 = elemTy.isInteger(1); + bool isPtr = isa(elemTy); + auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy); + if (isInt1) + elemTy = IntegerType::get(elemTy.getContext(), 8); + else if (isPtr) + elemTy = IntegerType::get(elemTy.getContext(), 64); + + auto llvmElemTy = getTypeConverter()->convertType(elemTy); + + for (unsigned ctaId = 0; ctaId < accumNumCTAsEachRep; ++ctaId) { + auto multiDimCTAInRepId = + getMultiDimIndex(ctaId, numCTAsEachRep, order); + SmallVector multiDimCTAId(rank); + for (const auto &it : llvm::enumerate(multiDimCTAInRepId)) { + auto d = it.index(); + multiDimCTAId[d] = multiDimRepId[d] * numCTAsEachRep[d] + it.value(); + } + + auto linearCTAId = + getLinearIndex(multiDimCTAId, numCTATiles, order); + // TODO: This is actually redundant index calculation, we should + // consider of caching the index calculation result in case + // of performance issue observed. + for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { + SmallVector multiDimOffset = + getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, + multiDimCTAInRepId, shapePerCTATile); + SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( + rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, + shapePerCTA); + Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, + paddedRepShape, outOrd); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset); + auto vecTy = vec_ty(llvmElemTy, vec); + ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3)); + if (stNotRd) { + Value valVec = undef(vecTy); + for (unsigned v = 0; v < vec; ++v) { + auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v]; + if (isInt1) + currVal = zext(llvmElemTy, currVal); + else if (isPtr) + currVal = ptrtoint(llvmElemTy, currVal); + valVec = insert_element(vecTy, valVec, currVal, i32_val(v)); + } + store(valVec, ptr); + } else { + Value valVec = load(vecTy, ptr); + for (unsigned v = 0; v < vec; ++v) { + Value currVal = extract_element(llvmElemTy, valVec, i32_val(v)); + if (isInt1) + currVal = icmp_ne(currVal, + rewriter.create( + loc, i8_ty, rewriter.getI8IntegerAttr(0))); + else if (isPtr) + currVal = inttoptr(llvmElemTyOrig, currVal); + vals[elemId + linearCTAId * accumSizePerThread + v] = currVal; + } + } + } + } + } + // blocked/mma -> blocked/mma. + // Data padding in shared memory to avoid bank conflict. + LogicalResult + lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto typeConverter = getTypeConverter(); + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (product(srcTy.getShape()) == 1) { + auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + SmallVector outVals(getTotalElemsPerThread(dstTy), inVals[0]); + Value result = + packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + return success(); + } + + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + smemBase = bitcast(smemBase, elemPtrTy); + auto shape = dstTy.getShape(); + unsigned rank = dstTy.getRank(); + SmallVector numReplicates(rank); + SmallVector inNumCTAsEachRep(rank); + SmallVector outNumCTAsEachRep(rank); + SmallVector inNumCTAs(rank); + SmallVector outNumCTAs(rank); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); + auto shapePerCTA = getShapePerCTA(srcLayout, shape); + + for (unsigned d = 0; d < rank; ++d) { + unsigned inPerCTA = + std::min(shapePerCTA[d], srcShapePerCTATile[d]); + unsigned outPerCTA = + std::min(shapePerCTA[d], dstShapePerCTATile[d]); + unsigned maxPerCTA = std::max(inPerCTA, outPerCTA); + numReplicates[d] = ceil(shapePerCTA[d], maxPerCTA); + inNumCTAsEachRep[d] = maxPerCTA / inPerCTA; + outNumCTAsEachRep[d] = maxPerCTA / outPerCTA; + assert(maxPerCTA % inPerCTA == 0 && maxPerCTA % outPerCTA == 0); + inNumCTAs[d] = ceil(shapePerCTA[d], inPerCTA); + outNumCTAs[d] = ceil(shapePerCTA[d], outPerCTA); + } + // Potentially we need to store for multiple CTAs in this replication + auto accumNumReplicates = product(numReplicates); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + unsigned inVec = 0; + unsigned outVec = 0; + auto origRepShape = getRepShapeForCvtLayout(op); + auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); + + unsigned outElems = getTotalElemsPerThread(dstTy); + auto outOrd = getOrder(dstLayout); + SmallVector outVals(outElems); + + for (unsigned repId = 0; repId < accumNumReplicates; ++repId) { + auto multiDimRepId = + getMultiDimIndex(repId, numReplicates, outOrd); + if (repId != 0) { + barrier(); + } + auto successful = targetInfo.processReplicaUsingStMatrix( + rewriter, loc, smemBase, vals, srcTy, + getTypeConverter()->convertType(srcTy.getElementType()), + paddedRepShape, origRepShape, outOrd, accumNumReplicates); + if (!successful) { + processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, + multiDimRepId, inVec, paddedRepShape, origRepShape, + outOrd, vals, smemBase); + } + barrier(); + processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, + multiDimRepId, outVec, paddedRepShape, origRepShape, + outOrd, outVals, smemBase); + } + + Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); + rewriter.replaceOp(op, result); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp new file mode 100644 index 000000000..b7bd5fbc3 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -0,0 +1,234 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using ValueTable = std::map, Value>; +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::LLVM::getStridesFromShapeAndOrder; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::isaDistributedLayout; +using ::mlir::triton::gpu::SharedEncodingAttr; + +SmallVector +getThreadIds(Value threadId, ArrayRef shapePerCTATile, + ArrayRef sizePerThread, ArrayRef order, + ConversionPatternRewriter &rewriter, Location loc) { + int dim = order.size(); + SmallVector threadIds(dim); + for (unsigned k = 0; k < dim - 1; k++) { + Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]); + Value rem = urem(threadId, dimK); + threadId = udiv(threadId, dimK); + threadIds[order[k]] = rem; + } + Value dimK = i32_val(shapePerCTATile[order[dim - 1]]); + threadIds[order[dim - 1]] = urem(threadId, dimK); + return threadIds; +} + +// Get shapePerCTATile for M or N axis. +int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto shapePerCTATile = getShapePerCTATile(layout); + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + return isM ? mShapePerCTATile : nShapePerCTATile; +} + +// Get sizePerThread for M or N axis. +int getSizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { + auto order = layout.getOrder(); + auto sizePerThread = getSizePerThread(layout); + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + return isM ? mSizePerThread : nSizePerThread; +} + +Value getStructFromValueTable(ArrayRef vals, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, + Type elemTy) { + SmallVector elemTypes(vals.size(), elemTy); + SmallVector elems; + elems.reserve(vals.size()); + for (auto &val : vals) { + elems.push_back(val); + } + MLIRContext *ctx = elemTy.getContext(); + Type structTy = struct_ty(elemTypes); + return packLLElements(loc, typeConverter, elems, rewriter, structTy); +} + +ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, + int sizePerThread, + ConversionPatternRewriter &rewriter, + Location loc, + const LLVMTypeConverter *typeConverter, + Type type) { + ValueTable res; + auto elems = unpackLLElements(loc, val, rewriter); + int index = 0; + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTA) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + res[{m + mm, k}] = elems[index++]; + } + } + return res; +} + +Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, + Location loc, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto aTensorTy = cast(A.getType()); + auto aLayout = cast(aTensorTy.getEncoding()); + auto aShapePerCTA = getShapePerCTA(aTensorTy); + + auto aOrder = aLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isARow = aOrder[0] == 1; + + auto aSmem = getSharedMemoryObjectFromStruct( + loc, llA, typeConverter->convertType(aTensorTy.getElementType()), + rewriter); + Value strideAM = aSmem.strides[0]; + Value strideAK = aSmem.strides[1]; + Value strideA0 = isARow ? strideAK : strideAM; + Value strideA1 = isARow ? strideAM : strideAK; + int aNumPtr = 8; + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; + + auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value mContig = i32_val(sizePerThread[order[1]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); + Value threadIdM = threadIds[0]; + + Value offA0 = isARow ? _0 : mul(threadIdM, mContig); + Value offA1 = isARow ? mul(threadIdM, mContig) : _0; + SmallVector aOff(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) { + aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1)); + } + auto elemTy = typeConverter->convertType(aTensorTy.getElementType()); + + Type ptrTy = ptr_ty(rewriter.getContext(), 3); + SmallVector aPtrs(aNumPtr); + for (int i = 0; i < aNumPtr; ++i) + aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); + + SmallVector vas; + + int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/); + int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/); + + for (unsigned k = 0; k < K; ++k) + for (unsigned m = 0; m < M; m += mShapePerCTATile) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) { + Value offset = + add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); + Value pa = gep(ptrTy, elemTy, aPtrs[0], offset); + Value va = load(elemTy, pa); + vas.emplace_back(va); + } + + return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy); +} + +Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, + Location loc, const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto bTensorTy = cast(B.getType()); + auto bLayout = cast(bTensorTy.getEncoding()); + auto bShapePerCTA = getShapePerCTA(bTensorTy); + + auto bOrder = bLayout.getOrder(); + auto order = dLayout.getOrder(); + + bool isBRow = bOrder[0] == 1; + + auto bSmem = getSharedMemoryObjectFromStruct( + loc, llB, typeConverter->convertType(bTensorTy.getElementType()), + rewriter); + Value strideBN = bSmem.strides[1]; + Value strideBK = bSmem.strides[0]; + Value strideB0 = isBRow ? strideBN : strideBK; + Value strideB1 = isBRow ? strideBK : strideBN; + int bNumPtr = 8; + int K = bShapePerCTA[0]; + int N = bShapePerCTA[1]; + + auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = getSizePerThread(dLayout); + + Value _0 = i32_val(0); + + Value nContig = i32_val(sizePerThread[order[0]]); + + // threadId in blocked layout + auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, + rewriter, loc); + Value threadIdN = threadIds[1]; + + Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; + Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); + SmallVector bOff(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) { + bOff[i] = add(mul(offB0, strideB0), mul(offB1, strideB1)); + } + auto elemTy = typeConverter->convertType(bTensorTy.getElementType()); + + Type ptrTy = ptr_ty(rewriter.getContext(), 3); + SmallVector bPtrs(bNumPtr); + for (int i = 0; i < bNumPtr; ++i) + bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); + + SmallVector vbs; + + int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/); + int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/); + + for (unsigned k = 0; k < K; ++k) + for (unsigned n = 0; n < N; n += nShapePerCTATile) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + Value offset = + add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); + Value pb = gep(ptrTy, elemTy, bPtrs[0], offset); + Value vb = load(elemTy, pb); + vbs.emplace_back(vb); + } + + return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy); +} + +namespace SharedToDotOperandFMA { +Value convertLayout(int opIdx, Value val, Value llVal, + BlockedEncodingAttr dLayout, Value thread, Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + if (opIdx == 0) + return loadAFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); + else + return loadBFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); +} +} // namespace SharedToDotOperandFMA diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp new file mode 100644 index 000000000..690155ee5 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -0,0 +1,116 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/Patterns.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +static void addAttrs(Operation *op, ArrayRef attrs) { + for (const NamedAttribute attr : attrs) + op->setAttr(attr.getName(), attr.getValue()); +} + +} // namespace + +namespace mlir::triton::gpu { + +void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) { + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + module.walk([&](triton::SplatOp splatOp) -> void { + auto dstType = cast(splatOp.getType()); + auto shared = + dyn_cast(dstType.getEncoding()); + if (shared) { + OpBuilder builder(splatOp); + SmallVector sizePerThread(dstType.getRank(), 1); + auto newType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::BlockedEncodingAttr::get( + module.getContext(), dstType.getShape(), sizePerThread, + getOrder(shared), numWarps, threadsPerWarp, numCTAs)); + auto newSplat = builder.create(splatOp.getLoc(), newType, + splatOp.getSrc()); + auto newConvert = builder.create( + splatOp.getLoc(), dstType, newSplat.getResult()); + splatOp.replaceAllUsesWith(newConvert.getResult()); + splatOp.erase(); + } + }); +} + +template +void decomposeTensorCoreToDotLayoutConversion(ModuleOp module, + ShortcutFn shortcutFn) { + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + + module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcMma = dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (srcMma && dstDotOp && !shortcutFn(srcType, dstType)) { + auto tmpType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::BlockedEncodingAttr::get( + module.getContext(), srcType.getShape(), getSizePerThread(srcMma), + getOrder(srcMma), numWarps, threadsPerWarp, numCTAs)); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + addAttrs(tmp, cvtOp->getAttrs()); + auto newConvert = builder.create( + cvtOp.getLoc(), dstType, tmp); + addAttrs(newConvert, cvtOp->getAttrs()); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); +} + +template void decomposeTensorCoreToDotLayoutConversion< + triton::gpu::NvidiaMmaEncodingAttr>(ModuleOp, ShortcutFn); +template void + decomposeTensorCoreToDotLayoutConversion( + ModuleOp, ShortcutFn); + +void decomposeBlockedToDotLayoutConversion(ModuleOp module) { + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); + module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcBlocked = + dyn_cast(srcType.getEncoding()); + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (srcBlocked && dstDotOp) { + auto tmpType = MemDescType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get( + module.getContext(), dstDotOp, srcType.getShape(), + srcBlocked.getOrder(), srcBlocked.getCTALayout(), + srcType.getElementType())); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + addAttrs(tmp, cvtOp->getAttrs()); + auto newConvert = builder.create(cvtOp.getLoc(), + dstType, tmp); + addAttrs(newConvert, cvtOp->getAttrs()); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); +} + +} // namespace mlir::triton::gpu diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 000000000..afb5bf01d --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,102 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; + +using ValueTableFMA = std::map, Value>; + +static ValueTableFMA +getValueTableFromStructFMA(Value val, int K, int n0, int shapePerCTATile, + int sizePerThread, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Type type) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + int index = 0; + for (unsigned k = 0; k < K; ++k) { + for (unsigned m = 0; m < n0; m += shapePerCTATile) + for (unsigned mm = 0; mm < sizePerThread; ++mm) { + res[{m + mm, k}] = elems[index++]; + } + } + return res; +} + +LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto B = op.getB(); + auto C = op.getC(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto bTensorTy = cast(B.getType()); + auto dTensorTy = cast(D.getType()); + + auto aShapePerCTA = getShapePerCTA(aTensorTy); + auto bShapePerCTA = getShapePerCTA(bTensorTy); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + auto order = dLayout.getOrder(); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = getSizePerThread(dLayout); + auto shapePerCTATile = getShapePerCTATile(dLayout); + + int K = aShapePerCTA[1]; + int M = aShapePerCTA[0]; + int N = bShapePerCTA[1]; + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + + auto has = + getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread, + rewriter, loc, typeConverter, aTensorTy); + auto hbs = + getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread, + rewriter, loc, typeConverter, bTensorTy); + + SmallVector ret = cc; + bool isCRow = order[0] == 1; + + for (unsigned k = 0; k < K; k++) { + for (unsigned m = 0; m < M; m += mShapePerCTATile) + for (unsigned n = 0; n < N; n += nShapePerCTATile) + for (unsigned mm = 0; mm < mSizePerThread; ++mm) + for (unsigned nn = 0; nn < nSizePerThread; ++nn) { + int mIdx = m / mShapePerCTATile * mSizePerThread + mm; + int nIdx = n / nShapePerCTATile * nSizePerThread + nn; + + int z = isCRow + ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx + : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; + ret[z] = rewriter.create(loc, has[{m + mm, k}], + hbs[{n + nn, k}], ret[z]); + } + } + + auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 000000000..0287207be --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,839 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir::triton::gpu; + +namespace mlir::triton::gpu { + +Type getElementType(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) + return tensorType.getElementType(); + return type; +} +// MMA encoding has a different order depending on the element's bit width; +// reorder if we're in this case. +SmallVector reorderValues(const SmallVector &values, Type inType, + Type ouType) { + auto inTensorTy = dyn_cast(inType); + auto ouTensorTy = dyn_cast(ouType); + if (!inTensorTy || !ouTensorTy) + return values; + auto inEncoding = dyn_cast(inTensorTy.getEncoding()); + auto ouEncoding = dyn_cast(ouTensorTy.getEncoding()); + assert(inEncoding == ouEncoding); + if (!inEncoding) + return values; + // If the parent of the dot operand is in block encoding, we don't need to + // reorder elements + auto parentEncoding = dyn_cast(ouEncoding.getParent()); + if (!parentEncoding) + return values; + size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); + size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); + auto ouEltTy = ouTensorTy.getElementType(); + if (inBitWidth == ouBitWidth) + return values; + if (inBitWidth == 16 && ouBitWidth == 32) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 8) { + ret.push_back(values[i]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + } + return ret; + } + if (inBitWidth == 8 && ouBitWidth == 16) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 16) { + ret.push_back(values[i + 0]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 8]); + ret.push_back(values[i + 9]); + ret.push_back(values[i + 10]); + ret.push_back(values[i + 11]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + ret.push_back(values[i + 12]); + ret.push_back(values[i + 13]); + ret.push_back(values[i + 14]); + ret.push_back(values[i + 15]); + } + return ret; + } + llvm_unreachable("unimplemented code path"); +} + +SmallVector unpackI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + +SmallVector packI32(const SmallVector &inValues, Type srcTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + auto tensorTy = dyn_cast(srcTy); + if (!tensorTy) + return inValues; + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return inValues; + SmallVector outValues; + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); + auto vecType = vec_ty(eltType, vecWidth); + for (int i = 0; i < inValues.size(); i += vecWidth) { + Value vec = undef(vecType); + for (int j = 0; j < vecWidth; j++) { + vec = insert_element(vec, inValues[i + j], i32_val(j)); + } + outValues.push_back(bitcast(vec, i32_ty)); + } + return outValues; +} + +int getNumElementsPerThreads(Type type, + const LLVMTypeConverter *typeConverter) { + int numElemsPerThread = 1; + auto tensorTy = dyn_cast(type); + if (!tensorTy) + return numElemsPerThread; + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (structType) { + numElemsPerThread = structType.getBody().size(); + } + auto encoding = dyn_cast(tensorTy.getEncoding()); + if (!(encoding && isa(encoding.getParent()))) + return numElemsPerThread; + auto eltType = tensorTy.getElementType(); + assert(eltType.getIntOrFloatBitWidth() <= 32 && + "Only support element type with bit width <= 32 in dot operand mma " + "layout"); + // dot operand data are packed into i32 elements so use the following formula + // to get the number of elements per thread. + return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread; +} + +} // namespace mlir::triton::gpu + +namespace { +struct AddPtrOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = op.getType(); + auto typeConverter = getTypeConverter(); + auto resultTensorTy = dyn_cast(resultTy); + if (resultTensorTy) { + unsigned elems = getTotalElemsPerThread(resultTy); + Type elemTy = typeConverter->convertType( + cast(resultTensorTy.getElementType()).getPointeeType()); + Type ptrTy = typeConverter->convertType(resultTensorTy.getElementType()); + auto ptrs = unpackLLElements(loc, adaptor.getPtr(), rewriter); + auto offsets = unpackLLElements(loc, adaptor.getOffset(), rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = gep(ptrTy, elemTy, ptrs[i], offsets[i]); + } + Value view = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, view); + } else { + assert(isa(resultTy)); + auto resultPtrTy = typeConverter->convertType(resultTy); + auto resultElemTy = typeConverter->convertType( + cast(resultTy).getPointeeType()); + Value result = + gep(resultPtrTy, resultElemTy, adaptor.getPtr(), adaptor.getOffset()); + rewriter.replaceOp(op, result); + } + return success(); + } +}; + +struct CmpIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(arith::CmpIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, + MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create( + loc, elemTy, ArithCmpIPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::ICmpPredicate + ArithCmpIPredicateToLLVM(arith::CmpIPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__) \ + case arith::CmpIPredicate::item__: \ + return LLVM::ICmpPredicate::item__ + + __PRED_ENUM(eq); + __PRED_ENUM(ne); + __PRED_ENUM(sgt); + __PRED_ENUM(sge); + __PRED_ENUM(slt); + __PRED_ENUM(sle); + __PRED_ENUM(ugt); + __PRED_ENUM(uge); + __PRED_ENUM(ult); + __PRED_ENUM(ule); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpIPredicate"); + } +}; + +struct CmpFOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + static SmallVector + createDestOps(arith::CmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + MultipleOperandsRange operands, Location loc) { + return {rewriter.create( + loc, elemTy, ArithCmpFPredicateToLLVM(op.getPredicate()), + operands[0][0], operands[0][1])}; + } + + static LLVM::FCmpPredicate + ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__, item1__) \ + case arith::CmpFPredicate::item__: \ + return LLVM::FCmpPredicate::item1__ + + __PRED_ENUM(OEQ, oeq); + __PRED_ENUM(ONE, one); + __PRED_ENUM(OGT, ogt); + __PRED_ENUM(OGE, oge); + __PRED_ENUM(OLT, olt); + __PRED_ENUM(OLE, ole); + __PRED_ENUM(ORD, ord); + __PRED_ENUM(UEQ, ueq); + __PRED_ENUM(UGT, ugt); + __PRED_ENUM(UGE, uge); + __PRED_ENUM(ULT, ult); + __PRED_ENUM(ULE, ule); + __PRED_ENUM(UNE, une); + __PRED_ENUM(UNO, uno); + __PRED_ENUM(AlwaysTrue, _true); + __PRED_ENUM(AlwaysFalse, _false); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpFPredicate"); + } +}; + +struct MulhiUIOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + explicit MulhiUIOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(MulhiUIOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + + Type resultElementTy = getElementTypeOrSelf(op.getResult().getType()); + assert(resultElementTy.isInteger(32) || resultElementTy.isInteger(64)); + + auto funcName = targetInfo.getMulhiFuncName(resultElementTy); + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +struct ExternElementwiseOpConversion + : public ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + SmallVector createDestOps(ExternElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + StringRef funcName = op.getSymbol(); + if (funcName.empty()) + llvm::errs() << "ExternElementwiseOpConversion"; + + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath()); + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } +}; + +template +struct ElementwiseOpConversion + : public ElementwiseOpConversionBase< + SourceOp, ElementwiseOpConversion> { + using Base = + ElementwiseOpConversionBase>; + using Base::Base; + using OpAdaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector createDestOps(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0], + adaptor.getAttributes().getValue())}; + } +}; + +struct ElementwiseInlineAsmOpConversion + : public ConvertOpToLLVMPattern { + using Base = ConvertOpToLLVMPattern; + + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + typedef typename Base::OpAdaptor OpAdaptor; + + // If operand size is smaller than 32 bits, pack in groups of 32 bits. + SmallVector packOperands(ElementwiseInlineAsmOp op, + MultipleOperandsRange operands, + ConversionPatternRewriter &rewriter, + Location loc) const { + SmallVector packedOperands; + unsigned numPackedElements = op.getPackedElement(); + for (int i = 0, e = op.getNumOperands(); i < e; i++) { + Type elemTy = getElementType(op.getOperand(i)); + unsigned bitWidth = + elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : 64; + unsigned numElementPerReg = bitWidth < 32 ? 32 / bitWidth : 1; + numElementPerReg = std::min(numElementPerReg, numPackedElements); + for (int j = 0; j < numPackedElements; j += numElementPerReg) { + if (numElementPerReg == 1) { + packedOperands.push_back(operands[j][i]); + continue; + } + Type t = + vec_ty(getTypeConverter()->convertType(elemTy), numElementPerReg); + Value packed = undef(t); + for (int k = 0; k < numElementPerReg; k++) { + packed = insert_element(packed, operands[j + k][i], i32_val(k)); + } + packedOperands.push_back(packed); + } + } + return packedOperands; + } + + SmallVector> + createDestOps(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + MultipleOperandsRange operands, Location loc) const { + auto ctx = op->getContext(); + + if (operands.size() % op.getPackedElement() != 0) + llvm::report_fatal_error("Inline asm op has more packed elements than " + "number of elements per thread."); + + // Pack elems smaller than 32 bits into 32-bit registers. + SmallVector packedOperands = + packOperands(op, operands, rewriter, loc); + + // Types returned by the LLVM asm op. If there's more than one, they'll be + // wrapped in a struct. + SmallVector asmRetTypes; + for (auto result : op.getResult()) { + auto ty = getTypeConverter()->convertType(getElementType(result)); + + // Pack return elements into 32-bits. + unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64; + unsigned numElemsPerReg = + std::min(bitWidth < 32 ? 32 / bitWidth : 1, op.getPackedElement()); + assert(op.getPackedElement() % numElemsPerReg == 0); + if (numElemsPerReg > 1) { + ty = vec_ty(ty, numElemsPerReg); + } + for (unsigned i = 0; i < op.getPackedElement() / numElemsPerReg; i++) { + asmRetTypes.push_back(ty); + } + } + Type asmRetType = + asmRetTypes.size() > 1 ? struct_ty(asmRetTypes) : asmRetTypes[0]; + + Value asmResults = + rewriter + .create( + loc, asmRetType, + /*operands=*/packedOperands, + /*asm_string=*/op.getAsmString(), + /*constraints=*/op.getConstraints(), + /*has_side_effects=*/!op.getPure(), + /*is_align_stack=*/false, + /*asm_dialect=*/ + LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT), + /*operand_attrs=*/ArrayAttr()) + ->getResult(0); + + // asmResults is a flat struct; pack its values into + // [return_value][op.getPackedElement()]. + SmallVector> ret(op->getNumResults()); + for (int i = 0; i < op->getNumResults(); i++) { + int structIdx = 0; + for (int j = 0; j < op.getPackedElement(); j++) { + Value val; + if (asmRetTypes.size() > 1) { + val = + extract_val(asmResults, i * op.getPackedElement() + structIdx++); + } else { + val = asmResults; + } + if (auto vectorTy = dyn_cast(val.getType())) { + for (int k = 0; k < vectorTy.getNumElements(); k++) { + ret[i].push_back(extract_element(val, i32_val(k))); + } + j += vectorTy.getNumElements() - 1; + } else { + ret[i].push_back(val); + } + } + } + return ret; + } + + LogicalResult + matchAndRewrite(ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // Layout is unpackedOperands[operand][elem]. + SmallVector> unpackedOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto subOperands = unpackLLElements(loc, operand, rewriter); + unpackedOperands.push_back( + unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter())); + } + + int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), + getTypeConverter()); + + // These are checked by the verifier, so we don't need to raise a nice + // error. + assert(all_of(unpackedOperands, [&](auto &operands) { + return operands.size() == numElemsPerThread; + })); + if (numElemsPerThread % op.getPackedElement() != 0) { + // Pad with the undef for each operand to have a multiple of + // op.getPackedElement() elements. + int numPaddedValue = + op.getPackedElement() - numElemsPerThread % op.getPackedElement(); + for (auto &operands : unpackedOperands) { + for (int i = 0; i < numPaddedValue; i++) { + operands.push_back(undef(operands[0].getType())); + } + } + } + + // Run the inline asm op on each block of elements. + // + // Layout is unpackedResults[result_idx][elem]. + // + // This loop always runs at least once, even when the asm has no input + // elements. + SmallVector> unpackedResults(op->getNumResults()); + for (unsigned i = 0; i < numElemsPerThread; i += op.getPackedElement()) { + // Block of elements to process with one call to the inline asm. This is + // ordered opposite `unpackedResults`: The outer dim is + // op.getPackedElement(), and the inner dim is the operand. + SmallVector> block(op.getPackedElement()); + for (auto &os : unpackedOperands) { + for (int j = 0; j < op.getPackedElement(); j++) { + block[j].push_back(os[i + j]); + } + } + auto cur = createDestOps(op, adaptor, rewriter, block, loc); + assert(cur.size() == unpackedResults.size()); + for (unsigned j = 0; j < cur.size(); j++) { + unpackedResults[j].insert(unpackedResults[j].end(), cur[j].begin(), + cur[j].end()); + } + } + for (auto &results : unpackedResults) { + results.resize(numElemsPerThread); + } + // Reorder and pack the results. + SmallVector outs; + for (int i = 0; i < unpackedResults.size(); i++) { + // We reordered all the inputs so they match operand 0. Reorder the + // outputs accordingly. + if (op->getNumOperands() > 0) { + unpackedResults[i] = reorderValues( + unpackedResults[i], /*inType=*/op->getOperand(0).getType(), + /*ouType=*/op->getResult(i).getType()); + } + auto packed = packI32(unpackedResults[i], op->getResult(i).getType(), + rewriter, loc, getTypeConverter()); + outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter, + op->getResult(i).getType())); + } + + rewriter.replaceOp(op, outs); + return success(); + } +}; + +struct AbsIOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0][0], + /*is_int_min_poison=*/false)}; + } +}; + +struct AbsFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(math::AbsFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (llvm::isa(elemTy)) { + // Mask out the sign bit + auto num_bits = + getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); + assert(num_bits <= 16); + auto mask = (1u << (num_bits - 1u)) - 1u; + auto maskAttr = rewriter.getIntegerAttr(elemTy, mask); + auto maskConst = rewriter.create(loc, maskAttr); + return {and_(operands[0][0], maskConst)}; + } + + return {rewriter.create(loc, elemTy, operands[0][0])}; + } +}; +/// The lowering of index_cast becomes an integer conversion since index +/// becomes an integer. If the bit width of the source and target integer +/// types is the same, just erase the cast. If the target type is wider, +/// sign-extend the value, otherwise truncate it. +struct IndexCastOpLowering + : public ElementwiseOpConversionBase { + using Base = + ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::IndexCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + auto inElemTy = + this->getTypeConverter()->convertType(getElementType(op.getIn())); + unsigned targetBits = elemTy.getIntOrFloatBitWidth(); + unsigned sourceBits = inElemTy.getIntOrFloatBitWidth(); + + if (targetBits == sourceBits) + return {operands[0][0]}; + if (targetBits < sourceBits) + return {rewriter.replaceOpWithNewOp(op, elemTy, + operands[0][0])}; + return { + rewriter.replaceOpWithNewOp(op, elemTy, operands[0][0])}; + } +}; + +struct SelectOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + SmallVector createDestOps(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + std::array llvmOperands; + if (operands[0].size() == 2) { + // Case of scalar condition with tensor operands. + assert(op.getCondition().getType().isInteger(1)); + llvmOperands = {adaptor.getCondition(), operands[0][0], operands[0][1]}; + } else { + llvmOperands = {operands[0][0], operands[0][1], operands[0][2]}; + } + return {rewriter.create( + loc, llvmOperands[1].getType(), llvmOperands, + adaptor.getAttributes().getValue())}; + } +}; +template +struct MinMaxFOpConversion + : ElementwiseOpConversionBase> { + using Base = ElementwiseOpConversionBase>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + static_assert(std::is_same::value || + std::is_same::value, + "OpTy must be arith::MinimumFOp or arith::MaximumFOp"); + + // Choose the destination op based on the OpTy. + using DestOpNanProp = + typename std::conditional::value, + LLVM::MinimumOp, LLVM::MaximumOp>::type; + using DestOpNoNanProp = + typename std::conditional::value, + LLVM::MinNumOp, LLVM::MaxNumOp>::type; + + explicit MinMaxFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + bool hwNanPropagationSupported, + PatternBenefit benefit = 1) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + hwNanPropagationSupported(hwNanPropagationSupported) {} + + SmallVector createDestOps(OpTy op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + if (hwNanPropagationSupported) { + return {rewriter.create(loc, elemTy, operands[0][0], + operands[0][1])}; + } + // Handle workaround for NaN propagation, i.e. software emulation of NaN + // propagation. If any of the operands is NaN, return NaN. + auto lhs = operands[0][0]; + auto rhs = operands[0][1]; + auto lhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, lhs, lhs); + auto rhsIsNan = + rewriter.create(loc, LLVM::FCmpPredicate::une, rhs, rhs); + auto isNan = rewriter.create(loc, lhsIsNan, rhsIsNan); + auto nonNanRes = rewriter.create(loc, elemTy, lhs, rhs); + + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + +private: + bool hwNanPropagationSupported; +}; + +struct ClampFOpConversion + : ElementwiseOpConversionBase { + using Base = ElementwiseOpConversionBase; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit ClampFOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + targetInfo(targetInfo) {} + + SmallVector createDestOps(ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // Clip pattern not found, use min/max. + if (op.getPropagateNan() == PropagateNan::ALL) { + if (targetInfo.supportMaximumMinimum()) { + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + // On pre-80 compute capability, we need to handle NaN propagation + // manually. We need to check only the first operand for clamp. + auto lhs = operands[0][0]; + auto isNan = rewriter.create(loc, LLVM::FCmpPredicate::une, + lhs, lhs); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + auto nonNanRes = rewriter.create(loc, v, operands[0][2]); + auto nan = LLVM::createNaNConstant(loc, rewriter, elemTy); + // Select the result based on the isNan flag. + return {rewriter.create(loc, isNan, nan, nonNanRes)}; + } + + // No NaN propagation. + assert(op.getPropagateNan() == PropagateNan::NONE); + auto v = rewriter.create(loc, elemTy, operands[0][0], + operands[0][1]); + return {rewriter.create(loc, v, operands[0][2])}; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMinMaxFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool hwNanPropagationSupported, + PatternBenefit benefit) { + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, hwNanPropagationSupported, benefit); +} + +void mlir::triton::populateClampFOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); +} + +void mlir::triton::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, + PatternBenefit benefit) { +#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp) + POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp) + POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp) + POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp) + POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp) + POPULATE_UNARY_OP(math::FloorOp, math::FloorOp) + POPULATE_UNARY_OP(math::CeilOp, math::CeilOp) + POPULATE_UNARY_OP(math::LogOp, math::LogOp) + POPULATE_UNARY_OP(math::Log2Op, math::Log2Op) + POPULATE_UNARY_OP(math::CosOp, math::CosOp) + POPULATE_UNARY_OP(math::SinOp, math::SinOp) + POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp) + POPULATE_UNARY_OP(math::RsqrtOp, math::RsqrtOp) + POPULATE_UNARY_OP(math::ExpOp, math::ExpOp) + POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op) + POPULATE_UNARY_OP(math::ErfOp, math::ErfOp) + POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp) + POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp) + POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp) +#undef POPULATE_UNARY_OP + +#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + + POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // - + POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // + + POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // * + POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp) + POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp) + POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // % + POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp) + POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp) + POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // & + POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // | + POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^ + POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // << + POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >> + POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >> + // fmin (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MinNumFOp, LLVM::MinNumOp) + // fmax (return non-NaN if either op is non-NaN) + POPULATE_BINARY_OP(arith::MaxNumFOp, LLVM::MaxNumOp) + POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin + POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax + POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin + POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax +#undef POPULATE_BINARY_OP + + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); + + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, targetInfo, + benefit); + patterns.add(typeConverter, axisInfoAnalysis, + benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, axisInfoAnalysis, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000..47f40ebec --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,118 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, int numWarps, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); + // 1. Modify the function type to add the new argument. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(ptrTy); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add the new argument. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + // 3. Add a new argument to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(ptrTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = funcOp; + if (!LLVM::isKernel(funcOp)) + amendedFuncOp = amendFuncOp(funcOp, rewriter); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + amendedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) { + return failure(); + } + + auto ctx = funcOp->getContext(); + + if (LLVM::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr("nvvm.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + newFuncOp.setLinkage(LLVM::Linkage::External); + } else { + // The noinline attribute will be used by the LLVM codegen to prevent + // inlining. + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + newFuncOp.setPassthroughAttr( + ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + rewriter.eraseOp(amendedFuncOp); + newFuncOp.setLinkage(LLVM::Linkage::Internal); + } + // Set an attribute for maxntidx, it could be used in latter LLVM codegen + // for `nvvm.annotation` metadata. + newFuncOp->setAttr("nvvm.maxntid", + rewriter.getDenseI32ArrayAttr(32 * numWarps)); + rewriter.eraseOp(funcOp); + return success(); + } + +private: + int numWarps{0}; +}; + +} // namespace + +void mlir::triton::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, + PatternBenefit benefit) { + patterns.add(typeConverter, numWarps, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp new file mode 100644 index 000000000..acf940b3e --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp @@ -0,0 +1,212 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +static int log2Int(int64_t num) { return (num > 1) ? 1 + log2Int(num / 2) : 0; } + +// Compute a histogram within a warp. This uses an algorithm by @apgoucher +// that does the following: +// Create a ballot for each bit of the bin index (there +// are only log2(num_bins) of these) and then apply bitwise operations to get +// the indicator functions for the bins owned by this particular thread, and +// only popcount those. +static SmallVector computeWarpLevelHistogram( + Location loc, RankedTensorType srcType, SmallVector &srcValues, + int numBins, int numThreadPerWarp, Value threadId, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) { + assert(numBins % numThreadPerWarp == 0 && + "numBins must be divisible by numThreadPerWarp"); + Value zero = i32_val(0); + int numBits = log2Int(numBins); + int numBitsLaneId = log2Int(numThreadPerWarp); + unsigned numElementsPerThreads = triton::gpu::getTotalElemsPerThread(srcType); + unsigned numThreadWithUniqueData = + triton::gpu::getThreadsPerWarpWithUniqueData(srcType.getEncoding(), + srcType.getShape())[0]; + // The histogram is distributed across threads, each thread owns `numBins / + // numThreadPerWarp` bins. + SmallVector warpLevelHistogram(numBins / numThreadPerWarp, zero); + for (int i = 0; i < numElementsPerThreads; ++i) { + Value value = srcValues[i]; + SmallVector ballotBits; + for (int j = 0; j < numBits; ++j) { + Value bitSet = and_(value, i32_val(1 << j)); + Value cmp = icmp_ne(bitSet, zero); + Value bit = + targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), cmp); + ballotBits.push_back(bit); + } + uint64_t fullMaskValue = + numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF; + Value fullMask = int_val(numThreadPerWarp, fullMaskValue); + Value mask = fullMask; + // If not all threads have unique data, mask out the redundant ones. + if (numThreadWithUniqueData < numThreadPerWarp) { + mask = int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1); + } + for (int i = 0; i < numBitsLaneId; i++) { + Value updateMask = select(icmp_ne(and_(threadId, i32_val(1 << i)), zero), + int_val(numThreadPerWarp, 0), fullMask); + mask = + and_(mask, xor_(ballotBits[i + numBits - numBitsLaneId], updateMask)); + } + // at this point, 'mask' tells you which elements are in a bin owned by this + // thread. + for (int k = 0; k < warpLevelHistogram.size(); k++) { + Value binMask = mask; + for (int j = 0; j < numBits - numBitsLaneId; j++) { + Value updateMask = + int_val(numThreadPerWarp, ((k & (1 << j)) ? 0 : fullMaskValue)); + binMask = and_(binMask, xor_(ballotBits[j], updateMask)); + } + // at this point, 'bin_mask' tells you which elements are in the kth bin + // owned by this thread. + Value bitCount = rewriter.create( + loc, int_ty(numThreadPerWarp), binMask); + if (numThreadPerWarp > 32) + bitCount = trunc(i32_ty, bitCount); + warpLevelHistogram[k] = add(warpLevelHistogram[k], bitCount); + } + } + return warpLevelHistogram; +} + +static void atomicAdd(Value ptr, Value val, Location loc, + ConversionPatternRewriter &rewriter) { + rewriter.create(loc, LLVM::AtomicBinOp::add, ptr, val, + LLVM::AtomicOrdering::monotonic); +} + +static SmallVector computeCrossWarpHistogram( + Location loc, ConversionPatternRewriter &rewriter, RankedTensorType srcType, + Value baseSharedMemPtr, const SmallVector &warpLevelHistogram, + int numBins, int numThreadPerWarp, const SmallVector &indices, + Value threadId, int numWarps) { + SmallVector histogramValues; + unsigned numWarpsWithUniqueData = + mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(), + srcType.getShape())[0]; + Value laneId = and_(threadId, i32_val(numThreadPerWarp - 1)); + // Initialize the shared memory with zeros. + int64_t numElementPerThread = + ceil(numBins, numThreadPerWarp * numWarps); + for (int i = 0; i < numElementPerThread; ++i) { + Value offset = add(threadId, i32_val((i * numWarps * numThreadPerWarp))); + offset = urem(offset, i32_val(numBins)); + Value sharedMemPtr = + gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + store(i32_val(0), sharedMemPtr); + } + barrier(); + Block *afterAtomics = nullptr; + // If some warps have replicated data we need to skip those warps when + // accumulating. + if (numWarpsWithUniqueData < numWarps) { + Block *currentBlock = rewriter.getInsertionBlock(); + afterAtomics = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + Block *atomicBlock = rewriter.createBlock(afterAtomics); + rewriter.setInsertionPointToEnd(currentBlock); + Value cond = + icmp_ult(threadId, i32_val(numWarpsWithUniqueData * numThreadPerWarp)); + rewriter.create(loc, cond, atomicBlock, afterAtomics); + rewriter.setInsertionPointToStart(atomicBlock); + } + // Apply atomic add to update the histogram in shared memory. + for (int i = 0; i < warpLevelHistogram.size(); ++i) { + Value warpLevelHistogramValue = warpLevelHistogram[i]; + Value offset = + add(mul(laneId, i32_val(warpLevelHistogram.size())), i32_val(i)); + Value sharedMemPtr = + gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset); + atomicAdd(sharedMemPtr, warpLevelHistogramValue, loc, rewriter); + } + if (afterAtomics) { + rewriter.create(loc, afterAtomics); + rewriter.setInsertionPointToStart(afterAtomics); + } + barrier(); + // load the histogram to register with the right layout. + for (Value index : indices) { + Value sharedMemPtr = + gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index); + Value val = load(i32_ty, sharedMemPtr); + histogramValues.push_back(val); + } + return histogramValues; +} + +namespace { +struct HistogramOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + explicit HistogramOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + SmallVector srcValues = unpackLLElements(loc, input, rewriter); + int numBins = op.getType().getDimSize(0); + auto mod = op->getParentOfType(); + int numThreadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + assert(numThreadsPerWarp == 32 || + numThreadsPerWarp == 64 && + "Only supports 32 or 64 threads per warp"); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + // Pad out the bins so that we have at least one bin per thread within a + // warp. + numBins = std::max(numBins, numThreadsPerWarp); + Value threadId = getThreadId(rewriter, loc); + auto srcType = op.getSrc().getType(); + // First compute a warp local histogram based on values owned by each warps. + SmallVector warpLevelHistogram = computeWarpLevelHistogram( + loc, srcType, srcValues, numBins, numThreadsPerWarp, threadId, rewriter, + targetInfo); + + // Then use atomic to update the histogram in shared memory. + // TODO: we could skip this for cases with num_warps=1 as long as we can + // generate the right layout. Currently the warp level histogram generates + // data in the default blocked layout. + Value baseSharedMemPtr = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto dstType = op.getType(); + Attribute dstEncoding = dstType.getEncoding(); + auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding, + dstType, true); + SmallVector innerDimIndices; + for (int i = 0; i < indices.size(); ++i) + innerDimIndices.push_back(indices[i][0]); + SmallVector histogramValue = computeCrossWarpHistogram( + loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins, + numThreadsPerWarp, innerDimIndices, threadId, numWarps); + + Value results = packLLElements(loc, typeConverter, histogramValue, rewriter, + op.getType()); + rewriter.replaceOp(op, results); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; +} // namespace + +void mlir::triton::populateHistogramOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp new file mode 100644 index 000000000..43120c791 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp @@ -0,0 +1,53 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +struct MakeRangeOpConversion + : public ConvertOpToLLVMPattern { + MakeRangeOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + RankedTensorType ty = op.getType(); + auto shape = ty.getShape(); + auto layout = ty.getEncoding(); + auto elemTy = ty.getElementType(); + assert(elemTy.isInteger(32)); + Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart()); + auto idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, true); + unsigned elems = idxs.size(); + SmallVector retVals(elems); + // TODO: slice layout has more elements than expected. + // Unexpected behavior for make range, but generally OK when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. + for (const auto &multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = add(multiDim.value()[0], start); + } + auto typeConverter = getTypeConverter(); + Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); + rewriter.replaceOp(op, result); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMakeRangeOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000..12ab6684c --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,145 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// blocked -> shared. +// Swizzling in shared memory to avoid bank conflict. Normally used for +// A/B operands of dots. +void lowerDistributedToShared(Location loc, Value src, Value dst, + Value adaptorSrc, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) { + auto srcTy = cast(src.getType()); + auto dstTy = cast(dst.getType()); + auto dstShapePerCTA = triton::gpu::getShapePerCTA(dstTy); + auto srcLayout = srcTy.getEncoding(); + auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); + assert(srcTy.getShape().size() <= 2 || + (srcTy.getShape().size() == 3 && outOrd[2] == 0) && + "Unexpected rank of ConvertLayout(blocked->shared)"); + auto elemTy = typeConverter->convertType(srcTy.getElementType()); + + auto smemBase = smemObj.getBase(); + int32_t elemSize = elemTy.getIntOrFloatBitWidth(); + unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); + auto dstStrides = smemObj.getStrides(); + auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); + storeDistributedToShared(src, inVals, dstStrides, dst, smemBase, elemTy, loc, + rewriter, targetInfo); +} + +struct LocalAllocOpConversion + : public ConvertOpToLLVMPattern { + LocalAllocOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value smemBase = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + auto sharedLayout = + cast(resultTy.getEncoding()); + auto order = sharedLayout.getOrder(); + // Workaround for 3D tensors + // TODO: we need to modify the pipeline pass to give a proper shared + // encoding to 3D tensors + SmallVector newOrder; + if (resultTy.getShape().size() != order.size()) { + for (auto i = 0; i < order.size(); ++i) + newOrder.push_back(order[i] + 1); + newOrder.push_back(0); + } else { + newOrder = SmallVector(order.begin(), order.end()); + } + + auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); + auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); + auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, + newOrder, loc, rewriter); + // If there is an initial tensor, store it into the shared memory. + if (op.getSrc()) { + lowerDistributedToShared(loc, op.getSrc(), op.getResult(), + adaptor.getSrc(), smemObj, typeConverter, + rewriter, targetInfo); + } + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +struct LocalDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::LocalDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::LocalDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct LocalStoreOpConversion + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern; + + LocalStoreOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value memDescVal = op.getDst(); + auto llvmElemTy = + getTypeConverter()->convertType(op.getDst().getType().getElementType()); + auto smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(), + adaptor.getSrc(), smemObj, getTypeConverter(), + rewriter, targetInfo); + rewriter.eraseOp(op); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateMemoryOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp new file mode 100644 index 000000000..32c7835c2 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp @@ -0,0 +1,243 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace { + +// The input print op contains: +// - a "prefix" (string) specified by the user, and +// - one or more "operands" (tensors). +// +// For each operand, we print all of the values contained in this GPU thread, +// one per line, along with the index of the value in its tensor. +struct PrintOpConversion : public ConvertOpToLLVMPattern { + explicit PrintOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + auto getPid = [&](int axis) { + return targetInfo.programId(rewriter, loc, + op->getParentOfType(), axis); + }; + std::array pid = {getPid(0), getPid(1), getPid(2)}; + + // Simple printf of a string without any tensors. + if (op.getNumOperands() == 0) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(pid[0]) << ", " + << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" + << op.getPrefix(); + llPrintf(formatStr, {pid[0], pid[1], pid[2]}, rewriter); + rewriter.eraseOp(op); + return success(); + } + + for (size_t i = 0; i < op.getNumOperands(); i++) { + // Elements of the tensor that are resident in this GPU thread. + auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); + + // Get the indices of `elems` within the tensor. Note that if `elems` + // has an "interesting" layout, then these will not be in any + // particularly nice order. + + // Extract the shape of the tensor being printed and use it to figure + // out how many digits we need for each of the dimensions. + SmallVector dimWidths; + SmallVector> indices; + if (auto rankedTy = + dyn_cast(op.getOperand(i).getType())) { + indices = emitIndices(loc, rewriter, targetInfo, rankedTy.getEncoding(), + rankedTy, true); + for (int64_t dim : rankedTy.getShape()) { + if (dim > 0) { + dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); + } else { + dimWidths.push_back(0); + } + } + } else { + // We're printing a scalar. + assert(elems.size() == 1); + indices.push_back({}); + } + + if (!elems.empty()) { + printTensor(op.getPrefix(), /*operand=*/i, + /*numOperands=*/op.getNumOperands(), elems, pid, indices, + dimWidths, op.getHex(), rewriter); + } + } + rewriter.eraseOp(op); + return success(); + } + + void printTensor(StringRef prefixStr, size_t operand, size_t numOperands, + ArrayRef elems, std::array pid, + ArrayRef> indices, + ArrayRef dimWidths, bool hex, + ConversionPatternRewriter &rewriter) const { + assert(!elems.empty()); + assert(elems.size() == indices.size()); + assert(dimWidths.size() == indices.front().size()); + + size_t rank = dimWidths.size(); + + // Format is: + // pid (, , ) idx (, , ...) (operand ) + // where we leave off "(operand )" if there's only one operand. + // + // The Python wrapper munges `prefix` so that it prints nicely (e.g. starts + // with " " and ends with ": "). + + Value formatStrValue; + int formatStrByteCount = 0; + for (int i = 0; i < elems.size(); i++) { + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + + // nvptx printf can only accept 32 args; if we pass more than that, it + // will print garbage for the trailing args. + constexpr int kMaxPrintfOperands = 32; + SmallVector printfOperands; + + // TODO(jlebar): We really should pad the pid, but because the max pid is + // not known at compile-time, this would require nontrivial device-side + // work. + os << "pid ("; + for (int j = 0; j < pid.size(); j++) { + if (j != 0) { + os << ", "; + } + os << getFormatSubstr(pid[j]); + printfOperands.push_back(pid[j]); + } + os << ") "; + + // If `rank` is large enough, we could end up exceeding + // kMaxPrintfOperands. In that case, just truncate the index. + // (Subtract 2 because we're going to add two operands after the index.) + int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2; + + os << "idx ("; + const auto &index = indices[i]; + for (size_t dim = 0; dim < index.size(); dim++) { + if (dim != 0) { + os << ", "; + } + if (dim == maxAllowedRank) { + os << "... (truncated)"; + break; + } + os << getFormatSubstr(index[dim], /*hex=*/false, + /*width=*/dimWidths[dim]); + printfOperands.push_back(index[dim]); + } + os << ")" << prefixStr; + + if (numOperands > 1) { + os << "(operand " << operand << ") "; + } + + auto elem = elems[i]; + os << getFormatSubstr(elem, hex); + printfOperands.push_back(elem); + + // It's the same format string each iteration, but it's a lot easier if we + // construct the format string at the same time as we populate + // printfOperands. But we don't want to create BLOCK_SIZE duplicate + // strings, so we cache the Value. + if (i == 0) { + formatStrValue = + llPrintf(formatStr, printfOperands, rewriter, &formatStrByteCount); + } else { + targetInfo.printf(rewriter, formatStrValue, formatStrByteCount, + printfOperands); + } + } + } + + std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt) const { + Type type = value.getType(); + if (isa(type)) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = + "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } else if (hex) { + prefix += "0"; + prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isSignedInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "lli"; + else + return prefix + "i"; + } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "llu"; + else + return prefix + "u"; + } + assert(false && "not supported type"); + return ""; + } + + // Returns a Value for the format string, which you can reuse. Writes the byte + // count for the string to |formatStrByteCount| if not null. + Value llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter, + int *formatStrByteCount = nullptr) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); + if (formatStrByteCount) + *formatStrByteCount = msgNewline.size_in_bytes(); + return msgValue; + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populatePrintOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp new file mode 100644 index 000000000..4d036c21a --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -0,0 +1,436 @@ +#include "ReduceScanCommon.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +namespace { +struct ReduceOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + ReduceOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ReduceOpHelper helper(op); + assert(helper.isSupportedLayout() && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + // Then reduce across threads within a warp. + reduceWithinWarps(helper, accs, rewriter); + + if (helper.isWarpSynchronous()) { + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + + // Compute a shared memory base per operand. + auto smemShape = helper.getScratchConfig(); + + SmallVector smemBases = + getSmemBases(op, product(smemShape), rewriter); + + storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); + + sync(rewriter, loc, op); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + accumulatePartialReductions(helper, smemBases, rewriter); + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier in case the layouts are accepted. + sync(rewriter, loc, op); + + // set output values + loadReductionAndPackResult(helper, smemShape, smemBases, rewriter); + + return success(); + } + +private: + const TargetInfoBase &targetInfo; + + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + SmallVector &acc, ValueRange cur, bool isFirst) const { + if (isFirst) { + acc = SmallVector(cur.begin(), cur.end()); + return; + } + + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(combineOp, &parent.front()); + auto &newReduce = parent.front(); + auto returnOp = dyn_cast(newReduce.getTerminator()); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), + combineArgs); + + auto results = returnOp.getResult(); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; + } + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + } + + SmallVector> + unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; + } + + void sync(ConversionPatternRewriter &rewriter, Location loc, + triton::ReduceOp op) const { + barrier(); + } + + // Reduce along op axis for elements that are in the same thread. The + // accumulated value is stored in accs. + void reduceWithinThreads( + ReduceOpHelper &helper, SmallVector> &srcValues, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + RankedTensorType operandType = op.getInputTypes()[0]; + // Assumes offsets don't actually depend on type + SmallVector> offsets = + emitOffsetForLayout(helper.getSrcLayout(), operandType); + + // Thread X might hold the same input value in two registers. Get the + // indices in `offsets` that hold unique values, and only accumualte over + // those. + llvm::MapVector, int> uniqueOffsets; + for (int i = 0; i < offsets.size(); ++i) { + uniqueOffsets.insert({offsets[i], i}); + } + + unsigned srcElems = getTotalElemsPerThread(operandType); + auto *combineOp = &op.getCombineOp(); + auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo, + helper.getSrcLayout(), operandType, true); + // reduce within threads + for (const auto &[_, i] : uniqueOffsets) { + SmallVector key = offsets[i]; + key[op.getAxis()] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); + if (isFirst) + indices[key] = srcIndices[i]; + } + } + + // Apply warp reduction across the given number of contiguous lanes using op + // region and the accumulator values as source. + void warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce, unsigned interleave) const { + auto success = + targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce); + if (success) + return; + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { + SmallVector shfl(acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); + } + accumulate(rewriter, op.getCombineOp(), acc, shfl, false); + } + } + + // Reduce across threads within each warp. + void + reduceWithinWarps(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + unsigned threadOffsetOnReductionAxis = + helper.getThreadOffsetOnReductionAxis(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = accs[key]; + warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps, + threadOffsetOnReductionAxis); + } + } + + // Pack the accumulator values and replace the reduce op with the result. + void packResults(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + unsigned axis = op.getAxis(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + SmallVector> resultOffset = + emitOffsetForLayout(resultLayout, resultTy); + SmallVector resultVals; + for (int j = 0; j < resultElems; j++) { + auto key = resultOffset[j]; + key.insert(key.begin() + axis, 0); + resultVals.push_back(accs[key][i]); + } + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else + results[i] = accs.begin()->second[i]; + } + rewriter.replaceOp(op, results); + } + + SmallVector + getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc, + ConversionPatternRewriter &rewriter) const { + auto srcLayout = helper.getSrcLayout(); + auto srcShape = helper.getSrcShape(); + auto order = triton::gpu::getWarpOrder(srcLayout); + SmallVector multiDimWarpId; + + // 2x2 warps with slice dim = 0, warpId = 2 ends up writing at the same + // address as warpId = 0 since the warpsPerCTA is [1, 2], need to figure out + // a way to properly delinearize warpId in the slice case + if (auto sliceLayout = mlir::dyn_cast(srcLayout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout); + auto parentOrder = triton::gpu::getWarpOrder(parentLayout); + multiDimWarpId = + delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder); + multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim()); + } else { + SmallVector warpsPerCTA = + triton::gpu::getWarpsPerCTA(srcLayout); + warpsPerCTA[helper.getAxis()] = triton::gpu::getWarpsPerCTAWithUniqueData( + srcLayout, srcShape)[helper.getAxis()]; + multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); + } + return multiDimWarpId; + } + + void storeWarpReduceToSharedMemory( + ReduceOpHelper &helper, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + Value threadId = getThreadId(rewriter, loc); + auto srcLayout = helper.getSrcLayout(); + Value warpSize = i32_val(triton::gpu::getWarpSize(srcLayout)); + Value warpId = udiv(threadId, warpSize); + Value laneId = urem(threadId, warpSize); + auto srcShape = helper.getSrcShape(); + unsigned axis = op.getAxis(); + auto smemShape = helper.getScratchConfig(); + + auto threadsPerWarp = + triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); + auto order = getOrder(srcLayout); + SmallVector multiDimLaneId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + Value laneIdAxis = multiDimLaneId[axis]; + Value zero = i32_val(0); + Value laneZero = icmp_eq(laneIdAxis, zero); + + SmallVector multiDimWarpId = + getMultiDimWarpId(helper, warpId, loc, rewriter); + Value warpIdAxis = multiDimWarpId[axis]; + + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = it.second; + + SmallVector writeIdx = indices[key]; + writeIdx[axis] = warpIdAxis; + Value writeOffset = + linearize(rewriter, loc, writeIdx, smemShape, smemOrder); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], writeOffset); + targetInfo.storeShared(rewriter, loc, writePtr, acc[i], laneZero); + } + } + } + + // Load the reduction of each warp and accumulate them to a final value and + // store back to shared memory. + void accumulatePartialReductions(ReduceOpHelper &helper, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + auto srcLayout = helper.getSrcLayout(); + auto smemShape = helper.getScratchConfig(); + unsigned elems = product(smemShape); + unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); + Location loc = op.getLoc(); + + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(triton::gpu::getWarpSize(srcLayout)); + Value laneId = urem(threadId, warpSize); + Value zero = i32_val(0); + + auto mod = op.getOperation()->getParentOfType(); + unsigned numThreads = + product(triton::gpu::getWarpsPerCTA(srcLayout)) * + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); + Value readOffset = threadId; + for (unsigned round = 0; round < elemsPerThread; ++round) { + SmallVector acc(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], readOffset); + acc[i] = targetInfo.loadShared(rewriter, loc, getTypeConverter(), + readPtr, elemTy, threadIsNeeded); + } + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); + // only the first thread in each sizeInterWarps is writing + Value writeOffset = readOffset; + SmallVector writePtrs(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + writePtrs[i] = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], writeOffset); + } + + Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); + Value laneIdModSizeInterWarpsIsZero = + icmp_eq(laneIdModSizeInterWarps, zero); + Value pred = and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero); + + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + targetInfo.storeShared(rewriter, loc, writePtrs[i], acc[i], pred); + } + + if (round != elemsPerThread - 1) { + readOffset = add(readOffset, i32_val(numThreads)); + } + } + } + + // Load the final reduction from shared memory and replace the reduce result + // with it. + void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector smemShape, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op = helper.getOperation(); + Location loc = op.getLoc(); + auto srcLayout = helper.getSrcLayout(); + auto axis = op.getAxis(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto elemTy = getElementType(op, i); + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + // nd-tensor where n >= 1 + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + auto resultIndices = emitIndices(loc, rewriter, targetInfo, + resultLayout, resultTy, true); + auto resultShape = resultTy.getShape(); + auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); + for (size_t resultIdx = 0, resultDim = resultShape.size(); + resultIdx < resultDim; ++resultIdx) { + auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1; + if (resultCTATile[resultIdx] > smemShape[smemIdx] || + resultShape[resultIdx] > smemShape[smemIdx]) { + // When srcShape smaller then src sizePerThread, only srcShape + // elements is accumulated in smem. Modulo smemShape effectively + // replicates srcShape elements to src sizePerThread. + readIdx[smemIdx] = + urem(readIdx[smemIdx], i32_val(smemShape[smemIdx])); + } + } + Value readOffset = + linearize(rewriter, loc, readIdx, smemShape, smemOrder); + Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, + smemBases[i], readOffset); + resultVals[j] = load(elemTy, readPtr); + } + + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load(elemTy, smemBases[i]); + } + } + rewriter.replaceOp(op, results); + } +}; +} // namespace + +void mlir::triton::populateReduceOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h new file mode 100644 index 000000000..3130001cc --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -0,0 +1,85 @@ +#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H +#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H + +// TODO: refactor so that it doesn't fail if Allocation.h +// is included after utility.h (due to conflict in `store` macro +// and +#include "triton/Analysis/Allocation.h" + +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +// +#include "mlir/IR/TypeUtilities.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include +#include + +#define DEBUG_TYPE "ttgpu_to_llvm" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::SharedMemoryObject; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::CTALayoutAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; +namespace ttng = ::mlir::triton::nvidia_gpu; + +namespace mlir::triton { +class ReduceOp; +class ScanOp; +} // namespace mlir::triton + +template +class ConvertTritonGPUReduceScanToLLVMPattern + : public ConvertOpToLLVMPattern { +public: + // Make sure the class is only instantiated with Reduce and Scan + static_assert(std::is_same_v || + std::is_same_v); + + using ConvertOpToLLVMPattern::getTypeConverter; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + // Return the pointee type of the shared memory pointer for operand i. + Type getElementType(SourceOp op, int i) const { + auto ty = op.getInputTypes()[i].getElementType(); + return getTypeConverter()->convertType(ty); + } + + // Helper to compute the smem bases in both reductions and scans + SmallVector getSmemBases(SourceOp op, unsigned elems, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + // indices will store the index of the op operands in descending order + // of their bitwidths + std::vector indices(op.getNumOperands()); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) { + return op.getElementTypes()[i].getIntOrFloatBitWidth() > + op.getElementTypes()[j].getIntOrFloatBitWidth(); + }); + // Assign base index to each operand in their order in indices + std::map indexToBase; + indexToBase[indices[0]] = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + indexToBase[indices[i]] = gep( + ptr_ty(rewriter.getContext(), 3), getElementType(op, indices[i - 1]), + indexToBase[indices[i - 1]], i32_val(elems)); + } + // smemBases[k] is the base pointer for the k-th operand + SmallVector smemBases(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemBases[i] = indexToBase[i]; + } + return smemBases; + } +}; + +#endif diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000..972fc5592 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,38 @@ +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct GetProgramIdOpConversion + : public ConvertOpToLLVMPattern { + explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value programId = targetInfo.programId(rewriter, op->getLoc(), + op->getParentOfType(), + op.getAxisAsInt()); + rewriter.replaceOp(op, programId); + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp new file mode 100644 index 000000000..675bf5a34 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -0,0 +1,589 @@ +#include + +#include "ReduceScanCommon.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::LLVM::delinearize; +using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +// apply combine region to acc and cur and accumulate it into acc +// TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce. +// Deduplicate +static SmallVector accumulate(ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur) { + // Allows for passing an unitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(combineOp, &parent.front()); + auto &newScan = parent.front(); + auto returnOp = dyn_cast(newScan.getTerminator()); + + SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(), + combineArgs); + SmallVector results; + llvm::transform(returnOp.getResult(), std::back_inserter(results), + [&](Value res) { return rewriter.getRemappedValue(res); }); + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +// Scan a contiguous elements within a thread and update `srcValues` in place. +static void +scanThreadContiguousElements(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper) { + // Depending on layout contiguous elements along axis dim may not be + // contiguous in srcValues. Keep track of what elements belong to the same + // chunk of contiguous elements. + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned numChunks = srcValues.size() / scanElementsPerThreads; + unsigned stride = helper.getAxisElementStride(); + SmallVector> accs(numChunks); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + // Change this into emitOffsetForLayout? + unsigned accIndex = (srcIndex % stride) + + ((srcIndex / stride) / scanElementsPerThreads) * stride; + + accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex], + srcValues[srcIndex]); + srcValues[srcIndex] = accs[accIndex]; + } +} + +// Apply a scan across threads of the warp for the last element of each +// contiguous group of elements. +static void warpScan(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value laneIdAxis) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Reduce within warps. + SmallVector acc = srcValues[srcIndex]; + for (unsigned i = 1; i <= scanDim / 2; i <<= 1) { + SmallVector shfl(acc.size()); + for (unsigned j = 0; j < acc.size(); ++j) { + shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); + } + SmallVector tempAcc = + accumulate(rewriter, helper.getCombineOp(), shfl, acc); + Value mask = icmp_slt(laneIdAxis, i32_val(i)); + for (unsigned j = 0; j < acc.size(); ++j) { + acc[j] = select(mask, acc[j], tempAcc[j]); + } + } + srcValues[srcIndex] = acc; + } +} + +// For each set of contiguous elements within a thread we store the partial +// reduction into shared memory. Each parallel scan and each warp will store its +// own partial reductions. The shared memory is organized as follow: +// ----------------------------------------------------------------- +// chunk 0: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +// chunk 1: | acc[0] warp 0 | acc[1] warp 0 | acc[0] warp 1 | acc[1] warp 1 | +static void storeWarpAccumulator(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId, SmallVector smemBases, + SmallVector smemTypes, + Value parallelLaneId, + const TargetInfoBase &targetInfo) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned chunkId = 0; + unsigned elementStride = helper.getAxisElementStride(); + + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + auto lastElement = srcValues[srcIndex]; + Value mask = icmp_eq(laneId, i32_val(scanDim - 1)); + Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane))); + index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps)); + for (unsigned i = 0; i < lastElement.size(); ++i) { + Value writePtr = gep(ptr_ty(rewriter.getContext(), 3), smemTypes[i], + smemBases[i], index); + targetInfo.storeShared(rewriter, loc, writePtr, lastElement[i], mask); + } + chunkId++; + } +} + +// Read the partial reductions from shared memory from each chunk of contiguous +// elements for each warp and parallel scan. Then combine the partial reduction +// with the right elements. Within a given contiguous element chunk we update +// all the elements by accumulating the value from the last element of the +// reduced value from the previous lane. +static void AddPartialReduce(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, + SmallVector smemBases, + SmallVector smemTypes, Value warpId, + Value laneIdAxis, Value parallelLaneId) { + Location loc = helper.getLoc(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); + Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); + Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + struct Accumulator { + SmallVector acc; + SmallVector maskedAcc; + }; + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + // Accumulate the partial reduction from shared memory. Decide which + // accumulator to combine based on whether the elements belong to the same + // dimension along axis. + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + Accumulator &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + for (unsigned i = 0; i < axisNumWarps; ++i) { + Value index = add(parallelLaneId, i32_val(numParallelLane * + (i + chunkId * axisNumWarps))); + SmallVector partialReduce(helper.getNumOperands()); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + auto elemTy = smemTypes[j]; + Value ptr = + gep(ptr_ty(rewriter.getContext(), 3), elemTy, smemBases[j], index); + partialReduce[j] = load(elemTy, ptr); + } + + if (accumulator.acc.size() == 0) { + accumulator.acc = partialReduce; + accumulator.maskedAcc = partialReduce; + continue; + } + accumulator.acc = accumulate(rewriter, helper.getCombineOp(), + accumulator.acc, partialReduce); + Value mask = icmp_slt(warpId, i32_val(i + 1)); + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + accumulator.maskedAcc[j] = + select(mask, accumulator.maskedAcc[j], accumulator.acc[j]); + } + } + auto temp = accumulate(rewriter, helper.getCombineOp(), + accumulator.maskedAcc, srcValues[srcIndex]); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + auto val = srcValues[srcIndex]; + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + temp[i] = select(maskFirstWarp, val[i], temp[i]); + } + } + srcValues[srcIndex] = temp; + // Update the rest of the contiguous elements. + SmallVector lastElement(helper.getNumOperands()); + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); + lastElement[i] = select(maskFirstLane, accumulator.maskedAcc[i], elem); + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + if (axisBlockId == 0) { + // For the first warp and first chunk we don't have anything to + // accumulate. + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + laneValue[j] = + select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], laneValue[j]); + } + } + srcValues[srcIndex - i * elementStride] = laneValue; + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + accumulator.maskedAcc = accumulator.acc; + chunkId++; + } +} + +static void AddPartialReduceOneWarp(SmallVector> &srcValues, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + ScanLoweringHelper &helper, Value warpId, + Value laneIdAxis, Value laneIdLast) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); + Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); + Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector> accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + auto &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + if (axisBlockId == 0) // First chunk and first block + accumulator = srcValues[srcIndex]; + else + srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(), + accumulator, srcValues[srcIndex]); + // Update the rest of the contiguous elements. + auto lastElement = srcValues[srcIndex]; + if (scanDim > 1) { + for (unsigned i = 0; i < helper.getNumOperands(); ++i) { + lastElement[i] = targetInfo.shuffleUp( + rewriter, loc, srcValues[srcIndex][i], threadStride); + lastElement[i] = select(maskFirstLane, accumulator[i], lastElement[i]); + if (numScanBlocks > 1) + // Update accumulator with the value from the last lane. + accumulator[i] = targetInfo.shuffleIdx( + rewriter, loc, srcValues[srcIndex][i], laneIdLast); + } + } else if (numScanBlocks > 1) { + accumulator = srcValues[srcIndex]; + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + auto laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + if (axisBlockId == 0) { + for (unsigned j = 0; j < helper.getNumOperands(); ++j) { + // For the first warp and first chunk we don't have anything to + // accumulate. + laneValue[j] = + select(maskFirstThread, + srcValues[srcIndex - i * elementStride][j], laneValue[j]); + } + } + srcValues[srcIndex - i * elementStride] = laneValue; + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + chunkId++; + } +} + +namespace { +struct ScanOpConversion + : public ConvertTritonGPUReduceScanToLLVMPattern { +public: + using ConvertTritonGPUReduceScanToLLVMPattern< + triton::ScanOp>::ConvertTritonGPUReduceScanToLLVMPattern; + explicit ScanOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertTritonGPUReduceScanToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (succeeded(emitFastScan(op, adaptor, rewriter))) + return success(); + return failure(); + } + +private: + const TargetInfoBase &targetInfo; + SmallVector getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const; + SmallVector getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const; + std::tuple + getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const; + LogicalResult emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; +}; + +SmallVector +ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + return delinearize(rewriter, loc, laneId, threadsPerWarp, order); +} + +SmallVector +ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); + return delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); +} + +// Break up the threadId into lane and warp id along the scan dimension and +// compute a flat id for the parallel dimensions. +std::tuple +ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value laneId, + Value warpId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); + SmallVector multiDimLaneId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); + + Value laneIdAxis = multiDimLaneId[axis]; + Value warpIdAxis = multiDimWarpId[axis]; + + multiDimLaneId[axis] = i32_val(0); + threadsPerWarp[axis] = 1; + Value laneIdParallel = + linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order); + multiDimWarpId[axis] = i32_val(0); + warpsPerCTA[axis] = 1; + Value warpIdParallel = + linearize(rewriter, loc, multiDimWarpId, warpsPerCTA, warpOrder); + Value flatIdParallel = + add(laneIdParallel, + mul(warpIdParallel, i32_val(helper.getNonAxisNumThreadsPerWarp()))); + return std::make_tuple(laneIdAxis, warpIdAxis, flatIdParallel); +} + +SmallVector> +unpackInputs(Location loc, triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter) { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + return srcValues; +} + +// Flip the srcValues. Both reverses the chunks and reverses the lanes. +// Lane reversal is done with a butterfly shuffle flip (divide and flip). +SmallVector> +flipSrcValues(Location loc, triton::ScanOp op, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + SmallVector> srcValues, int iWarpSize) { + SmallVector> values(srcValues.size()); + for (int i = 0; i < srcValues.size(); ++i) { + int revIndex = srcValues.size() - i - 1; + for (unsigned j = 0; j < op.getNumOperands(); ++j) { + for (unsigned k = iWarpSize / 2; k >= 1; k = k / 2) { + srcValues[revIndex][j] = + targetInfo.shuffleXor(rewriter, loc, srcValues[revIndex][j], k); + } + values[i].push_back(srcValues[revIndex][j]); + } + } + return values; +} + +// Lowering using warp shuffle operations to do warp level scan. +LogicalResult +ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ScanLoweringHelper helper(op); + auto loc = helper.getLoc(); + if (!helper.isSupported()) + return failure(); + + Value threadId = getThreadId(rewriter, loc); + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + Value warpSize = i32_val(iWarpSize); + Value warpId = udiv(threadId, warpSize); + Value laneId = urem(threadId, warpSize); + + auto [laneIdAxis, warpIdAxis, flatIdParallel] = + getDelinearizedIds(rewriter, helper, laneId, warpId); + auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps)); + auto srcValues = + unpackInputs(loc, op, adaptor, rewriter, *getTypeConverter()); + + // For the reverse option we apply flip(scan(flip()) in + // order to avoid having a separate code path in the reverse direction. + // We do this by 1) reversing chunks, 2) reversing lanes, 3) reversing + // warp ids and then undoing this below. + // (Note: Tried pretty hard to get shflDownSync to work but I ended up + // having to add a lot of the complex cross warp code (if rev switch + // first/last etc). Reverse first seems more maintainable.) + if (op.getReverse()) { + warpIdAxis = sub(i32_val(axisNumWarps - 1), warpIdAxis); + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + // Scan contiguous elements in a thread and update `srcValues`. + scanThreadContiguousElements(srcValues, rewriter, helper); + // Apply warp level scan to the last element of each chunk of contiguous + // elements. + warpScan(srcValues, rewriter, targetInfo, helper, laneIdAxis); + + if (axisNumWarps > 1) { + // Slow path for the case where there are multiple warps with unique data on + // the axis. + auto elems = helper.getScratchSizeInElems(); + SmallVector smemBases = getSmemBases(op, elems, rewriter); + SmallVector smemTypes(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + smemTypes[i] = getElementType(op, i); + } + + // Store the partial reducing for each warp into shared memory. + storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, + smemBases, smemTypes, flatIdParallel, targetInfo); + barrier(); + // Read back the partial reduction of each warp and accumulate them based on + // warpId. Then update each chunk of contiguous elements by adding the + // accumulated value from the previous lane. + AddPartialReduce(srcValues, rewriter, targetInfo, helper, smemBases, + smemTypes, warpIdAxis, laneIdAxis, flatIdParallel); + } else if (srcValues.size() > 1) { + // Fast path for the case where there is only one warp with unique data on + // the axis. + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId); + multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding()); + auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, + triton::gpu::getOrder(helper.getEncoding())); + AddPartialReduceOneWarp(srcValues, rewriter, targetInfo, helper, warpIdAxis, + laneIdAxis, laneIdLast); + } // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do. + + auto transpose = [](const SmallVector> &v) { + assert(v.size() > 0 && v[0].size() > 0); + auto ret = SmallVector>(v[0].size(), + SmallVector(v.size())); + for (int i = 0; i < v.size(); ++i) { + for (int j = 0; j < v[0].size(); ++j) { + ret[j][i] = v[i][j]; + } + } + return ret; + }; + + SmallVector results(op.getNumOperands()); + if (op.getReverse()) { + srcValues = + flipSrcValues(loc, op, rewriter, targetInfo, srcValues, iWarpSize); + } + + auto valuesTransposed = transpose(srcValues); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + auto resultTy = dyn_cast(op.getResult()[i].getType()); + results[i] = packLLElements(loc, getTypeConverter(), valuesTransposed[i], + rewriter, resultTy); + } + rewriter.replaceOp(op, results); + return success(); +} +} // namespace + +void mlir::triton::populateScanOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000..908aa1e2b --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -0,0 +1,137 @@ +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getTotalElemsPerThread; +using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; +using ::mlir::triton::gpu::SliceEncodingAttr; + +TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + addConversion([&](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type); + }); + addConversion([&](MemDescType type) -> std::optional { + return convertMemDescType(type); + }); + addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional { + return convertAsyncToken(type); + }); + addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); + addConversion([&](mlir::Float8E5M2Type type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); + addConversion([&](mlir::Float8E5M2FNUZType type) -> std::optional { + return IntegerType::get(type.getContext(), 8); + }); + // Internally store bfloat16 as int16 + addConversion([&](BFloat16Type type) -> std::optional { + return IntegerType::get(type.getContext(), 16); + }); +} + +Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (isa(pointeeType)) { + auto rankedTensorType = cast(pointeeType); + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto eleType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + SmallVector types; + // offsets + for (size_t i = 0; i < shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 32)); + // shapes, strides + for (size_t i = 0; i < 2 * shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 64)); + + types.push_back(LLVM::LLVMPointerType::get(ctx, type.getAddressSpace())); + + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); +} + +Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( + TensorOrMemDesc type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + Type elemTy = convertType(type.getElementType()); + auto dotOpLayout = mlir::dyn_cast(layout); + if (!dotOpLayout) + return elemTy; + auto mmaParent = + mlir::dyn_cast(dotOpLayout.getParent()); + if (!mmaParent || mmaParent.isHopper()) + return elemTy; + int bitwidth = elemTy.getIntOrFloatBitWidth(); + assert(bitwidth <= 32); + return IntegerType::get(ctx, 32); +} + +Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + SmallVector shape(type.getShape().begin(), type.getShape().end()); + Type eltType = getElementTypeForStruct(cast(type)); + + if (auto shared_layout = mlir::dyn_cast(layout)) { + SmallVector types; + // base ptr + auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); + types.push_back(ptrType); + // shape dims + auto rank = type.getRank(); + // offsets + strides + for (auto i = 0; i < rank * 2; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + + unsigned numElementsPerThread = getTotalElemsPerThread(type); + SmallVector types(numElementsPerThread, eltType); + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertMemDescType(MemDescType type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + SmallVector shape(type.getShape().begin(), type.getShape().end()); + SmallVector types; + // base ptr + auto ptrType = LLVM::LLVMPointerType::get(ctx, 3); + types.push_back(ptrType); + // shape dims + auto rank = type.getShape().size(); + // offsets + strides + for (auto i = 0; i < rank * 2; i++) { + types.push_back(IntegerType::get(ctx, 32)); + } + return LLVM::LLVMStructType::getLiteral(ctx, types); +} + +Type TritonGPUToLLVMTypeConverter::convertAsyncToken( + triton::gpu::AsyncTokenType type) { + return IntegerType::get(type.getContext(), 32); +} diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/Utility.cpp new file mode 100644 index 000000000..a80158a46 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -0,0 +1,619 @@ +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "llvm/ADT/STLExtras.h" + +namespace SharedToDotOperandMMAv1 { +using CoordTy = SmallVector; +using ValueTable = std::map, std::pair>; + +static SmallVector +getMNCoords(Value thread, Location loc, ConversionPatternRewriter &rewriter, + ArrayRef wpt, const NvidiaMmaEncodingAttr &mmaLayout, + ArrayRef shape, bool isARow, bool isBRow, bool isAVec4, + bool isBVec4) { + static constexpr std::array fpw{{2, 2, 1}}; + + auto *ctx = thread.getContext(); + Value _1 = i32_val(1); + Value _2 = i32_val(2); + Value _4 = i32_val(4); + Value _16 = i32_val(16); + Value _32 = i32_val(32); + Value _fpw0 = i32_val(fpw[0]); + Value _fpw1 = i32_val(fpw[1]); + + // A info + auto aRep = mmaLayout.getMMAv1Rep(0); + auto aSpw = mmaLayout.getMMAv1ShapePerWarp(0); + // B info + auto bSpw = mmaLayout.getMMAv1ShapePerWarp(1); + auto bRep = mmaLayout.getMMAv1Rep(1); + + SmallVector rep({aRep[0], bRep[1]}); + SmallVector spw({aSpw[0], bSpw[1]}); + SmallVector shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]}); + + Value lane = urem(thread, _32); + Value warp = udiv(thread, _32); + + Value warp0 = urem(warp, i32_val(wpt[0])); + Value warp12 = udiv(warp, i32_val(wpt[0])); + Value warp1 = urem(warp12, i32_val(wpt[1])); + + // warp offset + Value offWarpM = mul(warp0, i32_val(spw[0])); + Value offWarpN = mul(warp1, i32_val(spw[1])); + // quad offset + Value offQuadM = mul(udiv(and_(lane, _16), _4), _fpw0); + Value offQuadN = mul(udiv(and_(lane, _16), _4), _fpw1); + // pair offset + Value offPairM = udiv(urem(lane, _16), _4); + offPairM = urem(offPairM, _fpw0); + offPairM = mul(offPairM, _4); + Value offPairN = udiv(urem(lane, _16), _4); + offPairN = udiv(offPairN, _fpw0); + offPairN = urem(offPairN, _fpw1); + offPairN = mul(offPairN, _4); + + // sclare + offPairM = mul(offPairM, i32_val(rep[0] / 2)); + offQuadM = mul(offQuadM, i32_val(rep[0] / 2)); + offPairN = mul(offPairN, i32_val(rep[1] / 2)); + offQuadN = mul(offQuadN, i32_val(rep[1] / 2)); + + // quad pair offset + Value offLaneM = add(offPairM, offQuadM); + Value offLaneN = add(offPairN, offQuadN); + // a, b offset + Value offsetAM = add(offWarpM, offLaneM); + Value offsetBN = add(offWarpN, offLaneN); + // m indices + Value offsetCM = add(and_(lane, _1), offsetAM); + SmallVector idxM; + for (unsigned m = 0; m < shape[0]; m += shapePerCTA[0]) + for (unsigned mm = 0; mm < rep[0]; ++mm) + idxM.push_back(add(offsetCM, i32_val(m + mm * 2))); + + // n indices + Value offsetCN = add((and_(lane, _2)), (add(offWarpN, offPairN))); + SmallVector idxN; + for (int n = 0; n < shape[1]; n += shapePerCTA[1]) { + for (int nn = 0; nn < rep[1]; ++nn) { + idxN.push_back(add( + offsetCN, i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1]))); + idxN.push_back( + add(offsetCN, + i32_val(n + nn / 2 * 4 + (nn % 2) * 2 * fpw[1] * rep[1] + 1))); + } + } + + SmallVector> axes({idxM, idxN}); + + // product the axis M and axis N to get coords, ported from + // generator::init_idx method from triton2.0 + + // TODO[Superjomn]: check the order. + SmallVector coords; + for (Value x1 : axes[1]) { // N + for (Value x0 : axes[0]) { // M + SmallVector idx(2); + idx[0] = x0; // M + idx[1] = x1; // N + coords.push_back(std::move(idx)); + } + } + + return coords; // {M,N} in row-major +} +} // namespace SharedToDotOperandMMAv1 +namespace mlir { + +namespace triton::gpu { +Type getFunctionType(Type resultType, ValueRange operands) { + SmallVector operandTypes(operands.getTypes()); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); +} + +LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, + Operation *op, StringRef funcName, + Type funcType, + StringRef libname /*= ""*/, + StringRef libpath /*= ""*/) { + using LLVM::LLVMFuncOp; + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (funcOp) + return cast(*funcOp); + + Operation *parent = op; + if (!isa(op)) + parent = op->getParentOfType(); + OpBuilder b(parent); + auto ret = b.create(op->getLoc(), funcName, funcType); + ret.getOperation()->setAttr("libname", + StringAttr::get(op->getContext(), libname)); + ret.getOperation()->setAttr("libpath", + StringAttr::get(op->getContext(), libpath)); + return ret; +} +} // namespace triton::gpu + +SmallVector> +applyLinearLayout(Location loc, RewriterBase &rewriter, + const LinearLayout &layout, + ArrayRef> indices) { + assert(layout.getNumInDims() == indices.size()); + for (auto [inDimName, idx] : indices) { + assert(layout.hasInDim(inDimName) && "Invalid inDimName"); + } + + // This function can emit a lot of MLIR code, which ultimately makes + // compilation slow. (We think this shouldn't be the case -- it's not *that* + // much code -- but we're not clear on how to fix the slowness, which happens + // in the bowels of MLIR.) + // + // As a result we go through some contortions to avoid emitting code where + // possible. + + // Manually constant-fold the layout where possible. + SmallVector> constantIns; + for (auto [inDimName, idx] : indices) { + if (auto constant = dyn_cast(idx.getDefiningOp())) { + constantIns.push_back( + {inDimName, cast(constant.getValue()).getInt()}); + } else { + constantIns.push_back({inDimName, 0}); + } + } + SmallVector constantComponent = + llvm::to_vector(llvm::make_second_range(layout.apply(constantIns))); + + Value zero = i32_val(0); + SmallVector> outIndices; + for (auto [i, outDimName] : llvm::enumerate(layout.getOutDimNames())) { + if (constantComponent[i] == 0) + outIndices.push_back({outDimName, zero}); + else + outIndices.push_back({outDimName, i32_val(constantComponent[i])}); + } + + for (auto [inDimName, idx] : indices) { + if (isa(idx.getDefiningOp())) { + continue; + } + + int nBits = layout.getInDimSizeLog2(inDimName); + for (int i = 0; i < nBits; i++) { + Value bit = and_(idx, i32_val(1 << i)); + Value bit_is_zero = icmp_eq(bit, zero); + for (auto &[outDimName, outIdx] : outIndices) { + int32_t basis = layout.getBasis(inDimName, i, outDimName); + if (basis == 0) + continue; + outIdx = xor_(outIdx, select(bit_is_zero, zero, i32_val(basis))); + } + } + } + + return outIndices; +} + +std::optional>> +emitIndicesUsingLinearLayouts(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { + MLIRContext *ctx = rewriter.getContext(); + auto shape = type.getShape(); + + std::optional ll = triton::gpu::toLinearLayout(shape, layout); + if (!ll.has_value()) { + return std::nullopt; + } + + // TODO(jlebar): We could add strong typing if we wanted; for now this is + // "stringly typed". + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(ll->getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + Value blockId = + withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0); + unsigned rank = shape.size(); + SmallVector> ret; + for (unsigned reg = 0; reg < ll->getInDimSize(str_attr("register")); reg++) { + auto idxs = applyLinearLayout(loc, rewriter, *ll, + {{kRegister, i32_val(reg)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, blockId}}); + assert(idxs.size() == rank); + for (unsigned k = 0; k < rank; ++k) { + assert(idxs[k].first == str_attr("dim" + std::to_string(k))); + } + ret.push_back(llvm::to_vector(llvm::make_second_range(idxs))); + } + + return ret; +} + +namespace LLVM { +using namespace mlir::triton; +using mlir::triton::gpu::getOrder; +using mlir::triton::gpu::getSizePerThread; + +Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) { + auto i32ty = rewriter.getIntegerType(32); + return rewriter.create(loc, i32ty, + IntegerAttr::get(i32ty, v)); +} + +Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v) { + auto i64ty = rewriter.getIntegerType(64); + return rewriter.create(loc, i64ty, + IntegerAttr::get(i64ty, v)); +} + +Value createConstantF16(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f16Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF16FloatAttr(v)); +} + +Value createConstantF32(Location loc, OpBuilder &rewriter, float v) { + auto type = type::f32Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF32FloatAttr(v)); +} + +Value createConstantF64(Location loc, OpBuilder &rewriter, double v) { + auto type = type::f64Ty(rewriter.getContext()); + return rewriter.create(loc, type, + rewriter.getF64FloatAttr(v)); +} + +Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type) { + if (!isa(type)) { + llvm::report_fatal_error("Creating NaN constant for non-float type!"); + } + return rewriter.create( + loc, type, APFloat::getNaN(cast(type).getFloatSemantics())); +} + +// Create an index type constant. +Value createIndexConstant(OpBuilder &builder, Location loc, + const TypeConverter *converter, int64_t value) { + Type ty = converter->convertType(builder.getIndexType()); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +// Create an integer constant of \param width bits. +Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, + int64_t value) { + Type ty = builder.getIntegerType(width); + return builder.create(loc, ty, + builder.getIntegerAttr(ty, value)); +} + +SharedMemoryObject +getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, + ConversionPatternRewriter &rewriter) { + ArrayRef types = + cast(llvmStruct.getType()).getBody(); + SmallVector elems(types.size()); + for (unsigned i = 0; i < types.size(); ++i) { + Type type = types[i]; + elems[i] = extract_val(type, llvmStruct, i); + } + + auto rank = (elems.size() - 1) / 2; + return {/*base=*/elems[0], + /*baseElemType=*/elemTy, + /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, + /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; +} + +SmallVector getStridesFromShapeAndOrder(ArrayRef shape, + ArrayRef order, + Location loc, + RewriterBase &rewriter) { + auto rank = shape.size(); + SmallVector strides(rank); + int64_t stride = 1; + for (auto idx : order) { + strides[idx] = i32_val(stride); + stride *= shape[idx]; + } + return strides; +} + +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = applyPermutation(shape, order); + SmallVector reorderedMultiDim(rank); + if (auto constantOp = linear.getDefiningOp()) { + unsigned intVal = mlir::cast(constantOp.getValue()) + .getValue() + .getSExtValue(); + reorderedMultiDim = delinearize(rewriter, loc, intVal, reordered); + } else { + reorderedMultiDim = delinearize(rewriter, loc, linear, reordered); + } + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + unsigned linear, ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + unsigned remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + unsigned dimSize = en.value(); + multiDim[en.index()] = i32_val(remained % dimSize); + remained = remained / dimSize; + } + return multiDim; +} + +SmallVector delinearize(RewriterBase &rewriter, Location loc, + Value linear, ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + Value remained = linear; + for (auto &&en : llvm::enumerate(shape)) { + Value dimSize = i32_val(en.value()); + multiDim[en.index()] = urem(remained, dimSize); + remained = udiv(remained, dimSize); + } + return multiDim; +} + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape, + ArrayRef order) { + return linearize(rewriter, loc, applyPermutation(multiDim, order), + applyPermutation(shape, order)); +} + +Value linearize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDim, ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = i32_val(0); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = i32_val(dimShape); + linear = add(mul(linear, dimSize), dim); + } + } + return linear; +} + +Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, + StringRef key, StringRef content) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto ctx = moduleOp.getContext(); + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (key + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<64> contentStr(content); + size_t contentSize = contentStr.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize); + + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + UnknownLoc::get(ctx), globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(contentStr)); + } + + Value zero = i32_val(0); + Type globalPtrType = LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()); + Value globalPtr = rewriter.create( + UnknownLoc::get(ctx), globalPtrType, global.getSymName()); + Value stringStart = + gep(ptr_ty(ctx), i8_ty, globalPtr, SmallVector({zero})); + return stringStart; +} + +SmallVector getMultiDimOffset(Attribute layout, Location loc, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, + unsigned elemId, RankedTensorType type, + ArrayRef multiDimCTAInRepId, + ArrayRef shapePerCTATile) { + auto shape = type.getShape(); + unsigned rank = shape.size(); + if (auto blockedLayout = dyn_cast(layout)) { + auto multiDimOffsetFirstElem = emitBaseIndexForLayout( + loc, rewriter, targetInfo, blockedLayout, type, false); + SmallVector multiDimOffset(rank); + SmallVector multiDimElemId = getMultiDimIndex( + elemId, getSizePerThread(layout), getOrder(layout)); + for (unsigned d = 0; d < rank; ++d) { + multiDimOffset[d] = + add(multiDimOffsetFirstElem[d], + i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] + + multiDimElemId[d])); + } + return multiDimOffset; + } + if (auto sliceLayout = mlir::dyn_cast(layout)) { + unsigned dim = sliceLayout.getDim(); + auto parentEncoding = sliceLayout.getParent(); + auto parentSizePerThread = getSizePerThread(parentEncoding); + auto parentShape = sliceLayout.paddedShape(shape); + auto parentTy = RankedTensorType::get(parentShape, type.getElementType(), + parentEncoding); + auto offsets = emitOffsetForLayout(layout, type); + auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy); + SmallVector idxs; + for (SmallVector off : offsets) { + off.insert(off.begin() + dim, 0); + auto it = std::find(parentOffset.begin(), parentOffset.end(), off); + idxs.push_back(std::distance(parentOffset.begin(), it)); + } + auto multiDimOffsetParent = getMultiDimOffset( + parentEncoding, loc, rewriter, targetInfo, idxs[elemId], parentTy, + sliceLayout.paddedShape(multiDimCTAInRepId), + sliceLayout.paddedShape(shapePerCTATile)); + SmallVector multiDimOffset(rank); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d == dim) + continue; + unsigned slicedD = d < dim ? d : (d - 1); + multiDimOffset[slicedD] = multiDimOffsetParent[d]; + } + return multiDimOffset; + } + if (auto mmaLayout = mlir::dyn_cast(layout)) { + assert(rank == 2 || + (rank == 3 && mmaLayout.isAmpere()) && "Unexpected rank"); + auto shapePerCTA = getShapePerCTA(mmaLayout, shape); + auto instrShape = mmaLayout.getInstrShape(); + SmallVector mmaColIdx(2); + SmallVector mmaRowIdx(2); + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(32); + Value laneId = urem(threadId, warpSize); + Value warpId = udiv(threadId, warpSize); + // TODO: fix the bug in MMAEncodingAttr document + SmallVector multiDimWarpId(2); + auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); + auto warpOrder = triton::gpu::getWarpOrder(mmaLayout); + multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); + Value _1 = i32_val(1); + Value _2 = i32_val(2); + Value _4 = i32_val(4); + Value _8 = i32_val(8); + Value _16 = i32_val(16); + if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { + multiDimWarpId[rank - 1] = urem( + multiDimWarpId[rank - 1], + i32_val(ceil(shapePerCTA[rank - 1], instrShape[rank - 1]))); + multiDimWarpId[rank - 2] = urem( + multiDimWarpId[rank - 2], + i32_val(ceil(shapePerCTA[rank - 2], instrShape[rank - 2]))); + + Value mmaGrpId = udiv(laneId, _4); + Value mmaGrpIdP8 = add(mmaGrpId, _8); + Value mmaThreadIdInGrp = urem(laneId, _4); + Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2); + Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1); + Value rowWarpOffset = + mul(multiDimWarpId[rank - 2], i32_val(instrShape[rank - 2])); + mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset); + mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset); + Value colWarpOffset = + mul(multiDimWarpId[rank - 1], i32_val(instrShape[rank - 1])); + mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset); + mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset); + } else if (mmaLayout.isVolta()) { + // Volta doesn't follow the pattern here. + } else { + llvm_unreachable("Unexpected MMALayout version"); + } + + SmallVector multiDimOffset(rank); + if (mmaLayout.isHopper()) { + unsigned elemIdRem4 = elemId % 4; + unsigned nGrpId = elemId / 4; + multiDimOffset[0] = elemIdRem4 < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[1] = elemIdRem4 % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[1] = add(multiDimOffset[1], i32_val(8 * nGrpId)); + multiDimOffset[0] = add(multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * + shapePerCTATile[0])); + multiDimOffset[1] = add(multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * + shapePerCTATile[1])); + } else if (mmaLayout.isAmpere()) { + if (rank == 3) + multiDimOffset[0] = + add(multiDimWarpId[0], + i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0])); + multiDimOffset[rank - 2] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1]; + multiDimOffset[rank - 1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1]; + multiDimOffset[rank - 2] = + add(multiDimOffset[rank - 2], i32_val(multiDimCTAInRepId[rank - 2] * + shapePerCTATile[rank - 2])); + multiDimOffset[rank - 1] = + add(multiDimOffset[rank - 1], i32_val(multiDimCTAInRepId[rank - 1] * + shapePerCTATile[rank - 1])); + } else if (mmaLayout.isVolta()) { + auto [isARow, isBRow, isAVec4, isBVec4, _] = + mmaLayout.decodeVoltaLayoutStates(); + auto coords = SharedToDotOperandMMAv1::getMNCoords( + threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape, + isARow, isBRow, isAVec4, isBVec4); + return coords[elemId]; + } else { + llvm_unreachable("Unexpected MMALayout version"); + } + return multiDimOffset; + } + if (isa(layout)) { + auto multiDimBase = + emitBaseIndexForLayout(loc, rewriter, targetInfo, layout, type, false); + SmallVector> offsets; + assert(rank == 2); + SmallVector multiDimOffset(rank); + if (auto mfmaLayout = dyn_cast(layout)) { + emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0], + multiDimCTAInRepId[1]); + } else if (auto wmmaLayout = dyn_cast(layout)) { + emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0], + multiDimCTAInRepId[1]); + } + multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0])); + multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1])); + return multiDimOffset; + } + llvm_unreachable("unexpected layout in getMultiDimOffset"); +} + +SmallVector getWrappedMultiDimOffset( + ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDimOffset, ArrayRef shape, + SmallVector shapePerCTATile, SmallVector shapePerCTA) { + unsigned rank = shape.size(); + SmallVector multiDimOffsetWrapped(rank); + for (unsigned d = 0; d < rank; ++d) { + if (shapePerCTATile[d] > shapePerCTA[d]) + multiDimOffsetWrapped[d] = urem(multiDimOffset[d], i32_val(shape[d])); + else + multiDimOffsetWrapped[d] = multiDimOffset[d]; + } + return multiDimOffsetWrapped; +} + +} // namespace LLVM +} // namespace mlir diff --git a/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp new file mode 100644 index 000000000..e0f6e9377 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -0,0 +1,398 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +namespace { +struct SplatOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a + // LLVM::StructType value. + // + // @elemType: the element type in operand. + // @resType: the return type of the Splat-like op. + // @constVal: a LLVM::ConstantOp or other scalar value. + static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + Location loc) { + auto tensorTy = cast(resType); + // Check the converted type for the tensor as depending on the encoding the + // converter may pick different element types. + auto srcType = typeConverter->convertType(tensorTy); + if (auto structTy = dyn_cast(srcType)) + srcType = structTy.getBody()[0]; + // If the type sizes don't match we need to pack constants. + if (srcType.isIntOrFloat() && constVal.getType().getIntOrFloatBitWidth() != + srcType.getIntOrFloatBitWidth()) { + unsigned cstBitWidth = constVal.getType().getIntOrFloatBitWidth(); + unsigned srcBitWidth = srcType.getIntOrFloatBitWidth(); + assert(cstBitWidth <= srcBitWidth && srcBitWidth % cstBitWidth == 0); + unsigned ratio = srcBitWidth / cstBitWidth; + Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth); + VectorType vecType = VectorType::get(ratio, intTy); + Value intCst = bitcast(constVal, intTy); + Value vec = undef(vecType); + for (unsigned i = 0; i < ratio; ++i) + vec = insert_element(vecType, vec, intCst, int_val(32, i)); + constVal = vec; + } + auto llSrc = bitcast(constVal, srcType); + size_t elemsPerThread = getTotalElemsPerThread(tensorTy); + llvm::SmallVector elems(elemsPerThread, llSrc); + return packLLElements(loc, typeConverter, elems, rewriter, resType); + } + LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto src = adaptor.getSrc(); + auto typeConverter = getTypeConverter(); + auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src, + typeConverter, rewriter, loc); + rewriter.replaceOp(op, {llStruct}); + return success(); + } +}; +// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), +// the logic is the same as triton::SplatOp, so the underlying implementation +// is reused. +struct ArithConstantSplatOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto value = op.getValue(); + if (!mlir::dyn_cast(value)) + return failure(); + auto loc = op->getLoc(); + LLVM::ConstantOp arithConstantOp; + auto values = mlir::dyn_cast(op.getValue()); + auto elemType = values.getElementType(); + Attribute val; + if (type::isFloat(elemType)) { + val = values.getValues()[0]; + } else if (type::isInt(elemType)) { + val = values.getValues()[0]; + } else { + llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: " + << value.getType() << "\n"; + return failure(); + } + auto constOp = rewriter.create(loc, elemType, val); + auto typeConverter = getTypeConverter(); + auto llStruct = SplatOpConversion::convertSplatLikeOp( + elemType, op.getType(), constOp, typeConverter, rewriter, loc); + rewriter.replaceOp(op, llStruct); + return success(); + } +}; +struct CatOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename CatOp::Adaptor; + explicit CatOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + unsigned elems = getTotalElemsPerThread(resultTy); + auto typeConverter = getTypeConverter(); + Type elemTy = typeConverter->convertType(resultTy.getElementType()); + SmallVector types(elems, elemTy); + // unpack input values + auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter); + auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter); + // concatenate (and potentially reorder) values + SmallVector retVals; + for (Value v : lhsVals) + retVals.push_back(v); + for (Value v : rhsVals) + retVals.push_back(v); + // pack and replace + Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct JoinOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename JoinOp::Adaptor; + explicit JoinOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The op has a blocked encoding. + // - The last dimension (the one we're joining) is also the most minor + // dimension. + // - The input and output encodings are the same, except the output has + // 2 elements per thread in the last dim. + // + // With these invariants, join is trivial: We just return the i'th element + // from lhs, followed by the i'th elem from rhs. + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto typeConverter = getTypeConverter(); + SmallVector lhsVals = + unpackLLElements(loc, adaptor.getLhs(), rewriter); + SmallVector rhsVals = + unpackLLElements(loc, adaptor.getRhs(), rewriter); + assert(lhsVals.size() == rhsVals.size()); + SmallVector joinedVals; + for (int i = 0; i < lhsVals.size(); i++) { + joinedVals.push_back(lhsVals[i]); + joinedVals.push_back(rhsVals[i]); + } + Value ret = + packLLElements(loc, typeConverter, joinedVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct SplitOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename SplitOp::Adaptor; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We rely on the following invariants of this op (which are checked by its + // verifier): + // + // - The op has a blocked encoding. + // - The last dimension (the one we're spliting) is also the most minor + // dimension, and has sizePerThread=2. + // + // With these invariants, split is trivial: Every other value goes into + // return value 0, and every other goes into return value 1. + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + SmallVector srcVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(srcVals.size() % 2 == 0); + SmallVector outLhsVals; + SmallVector outRhsVals; + for (int i = 0; i < srcVals.size(); i += 2) { + outLhsVals.push_back(srcVals[i]); + outRhsVals.push_back(srcVals[i + 1]); + } + auto resultTy = cast(op.getResult(0).getType()); + Value retLhs = + packLLElements(loc, typeConverter, outLhsVals, rewriter, resultTy); + Value retRhs = + packLLElements(loc, typeConverter, outRhsVals, rewriter, resultTy); + rewriter.replaceOp(op, {retLhs, retRhs}); + return success(); + } +}; +struct ReshapeOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ReshapeOp::Adaptor; + explicit ReshapeOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + if (triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType())) { + return emitOptionalError(loc, + "expensive view not supported on reshape op"); + } + auto resultTy = cast(op.getType()); + auto srcTy = cast(op.getSrc().getType()); + auto typeConverter = getTypeConverter(); + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, typeConverter, vals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern { + using OpAdaptor = typename ExpandDimsOp::Adaptor; + explicit ExpandDimsOpConversion( + LLVMTypeConverter &typeConverter, + PatternBenefit benefit = patternBenefitDefault) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + LogicalResult + matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(op.getType()); + auto srcLayout = dyn_cast(srcTy.getEncoding()); + if (!srcLayout) { + return emitOptionalError( + loc, "ExpandDimsOp only supports SliceEncodingAttr as its input"); + } + auto resultLayout = resultTy.getEncoding(); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + offset.erase(offset.begin() + srcLayout.getDim()); + resultVals.push_back(srcValues.at(offset)); + } + Value ret = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; +struct TransOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + if (auto enc = dyn_cast(resultTy.getEncoding())) { + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.base, srcSmemObj.baseElemType, + /*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()), + /*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder())); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } else if (auto enc = mlir::dyn_cast( + resultTy.getEncoding())) { + // If the dst encoding is blocked, then TransOp::inferReturnTypes + // ensures that: + // - the src encoding is also blocked, and + // - the translation from src to dst is just a "renaming" of the + // registers, i.e. each thread has exactly the same values. + // Thus the transpose op simply returns the same values it got. + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, this->getTypeConverter(), vals, rewriter, + resultTy); + rewriter.replaceOp(op, ret); + return success(); + } + return emitOptionalError(loc, "unsupported encoding for TransOp"); + } +}; +struct BroadcastOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Following the order of indices in the legacy code, a broadcast of: + // [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)] + // => + // [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)] + // + // logically maps to a broadcast within a thread's scope: + // [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1), + // 1,spt(k+1)..spt(n-1)] + // => + // [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)] + // + // regardless of the order of the layout + // + Location loc = op->getLoc(); + Value src = adaptor.getSrc(); + Value result = op.getResult(); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(result.getType()); + auto srcLayout = srcTy.getEncoding(); + auto resultLayout = resultTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultShape = resultTy.getShape(); + unsigned rank = srcTy.getRank(); + auto typeConverter = getTypeConverter(); + assert(rank == resultTy.getRank()); + auto order = triton::gpu::getOrder(srcLayout); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + SmallVector srcVals = unpackLLElements(loc, src, rewriter); + std::map, Value> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + for (size_t j = 0; j < srcShape.size(); j++) + if (srcShape[j] == 1) + offset[j] = 0; + resultVals.push_back(srcValues.at(offset)); + } + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct MemDescSubviewOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // %dst = extract_slice %src[%offsets] + Location loc = op->getLoc(); + auto srcTy = op.getSrc().getType(); + auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); + + // newBase = base + offset + auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + SmallVector opOffsetVals = op.getOffsets(); + size_t destRank = op.getResult().getType().getRank(); + SmallVector offsetVals; + SmallVector strides; + int rankReduced = srcTy.getRank() - destRank; + for (int i = rankReduced; i < opOffsetVals.size(); i++) { + strides.push_back(smemObj.strides[i]); + offsetVals.push_back(opOffsetVals[i]); + } + // Compute the offset based on the original strides of the shared memory + // object + auto offset = dot(rewriter, loc, opOffsetVals, smemObj.strides); + auto elemPtrTy = ptr_ty(rewriter.getContext(), 3); + smemObj = + SharedMemoryObject(gep(elemPtrTy, llvmElemTy, smemObj.base, offset), + llvmElemTy, strides, offsetVals); + auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; +} // namespace + +void mlir::triton::populateViewOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/xpu/lib/Conversion/TritonToTritonGPU/CMakeLists.txt new file mode 100644 index 000000000..1b629ba16 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonToTritonGPU + TritonGPUConversion.cpp + TritonToTritonGPUPass.cpp + + DEPENDS + TritonConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + TritonIR + TritonGPUIR + TritonGPUTransforms +) diff --git a/third_party/xpu/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/third_party/xpu/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp new file mode 100644 index 000000000..34fb89954 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -0,0 +1,123 @@ +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +#include +#include + +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton::gpu; + +// +// TypeConverter +// +TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, + int numWarps, int threadsPerWarp, + int numCTAs) + : context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp), + numCTAs(numCTAs) { + addConversion([](Type type) { return type; }); + + // Add encoding for tensor + addConversion([this](RankedTensorType tensorType) -> RankedTensorType { + // types with encoding are already in the right format + // TODO: check for layout encodings more specifically + if (tensorType.getEncoding()) + return tensorType; + ArrayRef shape = tensorType.getShape(); + triton::gpu::BlockedEncodingAttr encoding = + getDefaultBlockedEncoding(this->context, shape, this->numWarps, + this->threadsPerWarp, this->numCTAs); + return RankedTensorType::get(shape, tensorType.getElementType(), encoding); + }); + + // Add encoding for tensor pointer + addConversion([this](triton::PointerType ptrType) -> triton::PointerType { + // Check whether tensor pointer `tt.ptr>` + auto pointeeTensorType = + dyn_cast(ptrType.getPointeeType()); + if (pointeeTensorType == nullptr) + return ptrType; + + // Add layout into the tensor + auto convertedTensorType = convertType(pointeeTensorType); + return triton::PointerType::get(convertedTensorType, + ptrType.getAddressSpace()); + }); + + // + // Materializations + // + // This will be called when (newArgType != origArgType) + // This will create newArg, and map(origArg, newArg) + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); + return std::nullopt; + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonGPU Conversion"); + return std::nullopt; + }); + + // This will be called when (desiredType != newOperandType) + // where, desiredType = typeConverter->convertType(origType) + // NOTE: only for remapped values. + addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + auto cast = + builder.create(loc, tensorType, inputs); + return std::optional(cast.getResult()); + }); +} + +// +// TritonGPUConversion +// +TritonGPUConversionTarget::TritonGPUConversionTarget( + MLIRContext &context, TritonGPUTypeConverter &typeConverter) + : ConversionTarget(context) { + // TODO: we should also verify ops of TritonGPUDialect + addLegalDialect(); + + // Some ops from SCF are illegal + addIllegalOp(); + + addDynamicallyLegalDialect([&](Operation *op) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; + }); + + // We have requirements for the data layouts + addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { + Attribute aEncoding = + cast(dotOp.getA().getType()).getEncoding(); + Attribute bEncoding = + cast(dotOp.getB().getType()).getEncoding(); + if (aEncoding && isa(aEncoding) && + bEncoding && isa(bEncoding)) + return true; + return false; + }); +} diff --git a/third_party/xpu/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/third_party/xpu/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp new file mode 100644 index 000000000..4aa2712ec --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -0,0 +1,821 @@ +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "llvm/ADT/APSInt.h" +#include + +#define GEN_PASS_CLASSES +#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// pass named attrs (e.g., tt.contiguity) from Triton to Triton +static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { + for (const NamedAttribute attr : dictAttrs.getValue()) + if (!op->hasAttr(attr.getName())) + op->setAttr(attr.getName(), attr.getValue()); +} + +template struct GenericOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + + return success(); + } +}; + +class ArithConstantPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = getTypeConverter()->convertType(op.getType()); + auto retShapedType = cast(retType); + auto value = dyn_cast(adaptor.getValue()); + if (dyn_cast(retShapedType)) { + assert(value); + if (value.getElementType().isInteger(1) && value.isSplat()) + // Workaround until https://reviews.llvm.org/D133743 is included. + value = + DenseElementsAttr::get(retShapedType, value.getSplatValue()); + else + // This is a hack. We just want to add encoding + value = value.reshape(retShapedType); + } + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retShapedType, value), + adaptor.getAttributes()); + return success(); + } +}; + +void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + // -------------- + // Add legality and rewrite pattern rules for operations + // from the Arith dialect. The basic premise is that + // Arith operations require both inputs to have the same + // non-null encoding + // -------------- + MLIRContext *context = patterns.getContext(); + // TODO: there's probably a better way to avoid adding all ops one-by-one + patterns.add< + ArithConstantPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, // NegFOp + // Floating point + GenericOpPattern, GenericOpPattern, + // MaxMin + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + // Floating point + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + // Cmp + GenericOpPattern, GenericOpPattern, + // Select + GenericOpPattern, + // Cast Ops + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern>(typeConverter, context); +} + +void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonGPUConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern>( + typeConverter, context); +} + +// +// Triton patterns +// +struct TritonExpandDimsPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Type retType = op.getType()); + RankedTensorType argType = + cast(adaptor.getSrc().getType()); + Attribute _argEncoding = argType.getEncoding(); + if (!_argEncoding) + return failure(); + auto argEncoding = cast(_argEncoding); + // return shape + auto retShape = argType.getShape().vec(); + retShape.insert(retShape.begin() + op.getAxis(), 1); + // return encoding + auto retSizePerThread = argEncoding.getSizePerThread(); + retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1); + auto retThreadsPerWarp = argEncoding.getThreadsPerWarp(); + retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1); + auto retWarpsPerCTA = argEncoding.getWarpsPerCTA(); + retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1); + SmallVector retOrder(retShape.size()); + std::iota(retOrder.begin(), retOrder.end(), 0); + + auto argCTALayout = argEncoding.getCTALayout(); + auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis()); + auto retCTASplitNum = + insertOne(argCTALayout.getCTASplitNum(), op.getAxis()); + auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis()); + auto retCTALayout = triton::gpu::CTALayoutAttr::get( + getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder); + + triton::gpu::BlockedEncodingAttr retEncoding = + triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, + retThreadsPerWarp, retWarpsPerCTA, + retOrder, retCTALayout); + // convert operand to slice of return type + Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( + getContext(), op.getAxis(), retEncoding); + RankedTensorType newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), newArgEncoding); + // construct new op + auto newSrc = rewriter.create( + op.getLoc(), newArgType, adaptor.getSrc()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newSrc, adaptor.getAxis()), + adaptor.getAttributes()); + return success(); + } + +private: + template + SmallVector insertOne(ArrayRef vec, unsigned axis) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + axis, 1); + return res; + } + + // Example: order = [ 0, 2, 1, 3], dim = 2 + // resOrder = [2, 0, 3, 1, 4] + SmallVector insertOrder(ArrayRef order, + unsigned axis) const { + SmallVector resOrder(order.begin(), order.end()); + for (unsigned i = 0; i < resOrder.size(); ++i) + if (resOrder[i] >= axis) + ++resOrder[i]; + resOrder.insert(resOrder.begin(), axis); + return resOrder; + } +}; + +struct TritonDotPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType origType = op.getType(); + auto origShape = origType.getShape(); + auto typeConverter = getTypeConverter(); + int numWarps = typeConverter->getNumWarps(); + int threadsPerWarp = typeConverter->getThreadsPerWarp(); + int numCTAs = typeConverter->getNumCTAs(); + auto rank = origShape.size(); + SmallVector retSizePerThread(rank, 1); + auto numElements = product(origShape); + if (numElements / (numWarps * threadsPerWarp) >= 4) { + retSizePerThread[rank - 1] = 2; + retSizePerThread[rank - 2] = 2; + } + if (numElements / (numWarps * threadsPerWarp) >= 16) { + retSizePerThread[rank - 1] = 4; + retSizePerThread[rank - 2] = 4; + } + SmallVector retOrder(rank); + for (unsigned i = 0; i < rank; ++i) + retOrder[i] = rank - 1 - i; + Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get( + getContext(), origShape, retSizePerThread, retOrder, numWarps, + threadsPerWarp, numCTAs); + RankedTensorType retType = + RankedTensorType::get(origShape, origType.getElementType(), dEncoding); + // a & b must be of smem layout + auto aType = cast(adaptor.getA().getType()); + auto bType = cast(adaptor.getB().getType()); + Type aEltType = aType.getElementType(); + Type bEltType = bType.getElementType(); + Attribute aEncoding = aType.getEncoding(); + Attribute bEncoding = bType.getEncoding(); + if (!aEncoding || !bEncoding) + return failure(); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + Value c = adaptor.getC(); + if (!mlir::isa(aEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 0, dEncoding, aEltType); + auto dstType = + RankedTensorType::get(aType.getShape(), aEltType, encoding); + a = rewriter.create(a.getLoc(), dstType, a); + } + if (!mlir::isa(bEncoding)) { + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 1, dEncoding, bEltType); + auto dstType = + RankedTensorType::get(bType.getShape(), bEltType, encoding); + b = rewriter.create(b.getLoc(), dstType, b); + } + c = rewriter.create(c.getLoc(), retType, c); + + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, a, b, c, adaptor.getInputPrecision(), + adaptor.getMaxNumImpreciseAcc()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The cat op satisfy two conditions: + // 1. output.numel = lhs.numel + rhs.numel + // 2. output.total_elems_per_thread = + // next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread) + // For now, this behaves like generic, but this + // will evolve when we add support for `can_reorder=False`. + auto retType = cast( + this->getTypeConverter()->convertType(op.getType())); + auto retEncoding = + cast(retType.getEncoding()); + auto lhsType = adaptor.getLhs().getType(); + auto rhsType = adaptor.getRhs().getType(); + auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType); + auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType); + auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType); + auto retShape = retType.getShape(); + auto retOrder = retEncoding.getOrder(); + auto retSizePerThread = retEncoding.getSizePerThread(); + auto retThreadsPerWarp = retEncoding.getThreadsPerWarp(); + auto retWarpsPerCTA = retEncoding.getWarpsPerCTA(); + // Get new retSizePerThread if ret elems per thread is not enough. + // We have to round it up to the next power of 2 due to triton's tensor size + // constraint. + auto newRetTotalElemsPerThread = + nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread); + auto newRetSizePerThread = retSizePerThread; + newRetSizePerThread[retOrder[0]] *= + newRetTotalElemsPerThread / retTotalElemsPerThread; + triton::gpu::BlockedEncodingAttr newRetEncoding = + triton::gpu::BlockedEncodingAttr::get( + getContext(), newRetSizePerThread, retThreadsPerWarp, + retWarpsPerCTA, retOrder, retEncoding.getCTALayout()); + auto newRetType = RankedTensorType::get(retShape, retType.getElementType(), + newRetEncoding); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newRetType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonJoinOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Simply rely on type inference for this op. (Notably, GenericOpPattern + // does not do this, instead it assigns the default layout to the ins and + // outs.) + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, adaptor.getLhs(), adaptor.getRhs()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonSplitOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = dyn_cast(srcTy.getEncoding()); + int rank = srcEnc.getOrder().size(); + auto typeConverter = getTypeConverter(); + + // The operand to split must have: + // - a blocked layout, with + // - sizePerThread = 2 in the last dimension, + // - threadsPerWarp, warpsPerCTA, and CTAsPerCGA = 1 in the last dim, and + // - the last dimension minor. + // If that's not the case, add a convert before the split. + if (!srcEnc || srcEnc.getSizePerThread().back() != 2 || + srcEnc.getOrder().front() != rank - 1) { + // If we take the default encoding for the op's result (i.e. post-split) + // and add 1 to the end of each dim, that gives us what we want. Other + // than making a legal src encoding, our choice of layout doesn't matter; + // it'll get fixed by RemoveLayoutConversions. + auto defaultEnc = getDefaultBlockedEncoding( + getContext(), + cast(op.getResult(0).getType()).getShape(), + typeConverter->getNumWarps(), typeConverter->getThreadsPerWarp(), + typeConverter->getNumCTAs()); + + auto append = [&](ArrayRef vals, unsigned val) { + SmallVector res(vals); + res.push_back(val); + return res; + }; + auto prepend = [&](ArrayRef vals, unsigned val) { + SmallVector res; + res.push_back(val); + res.append(vals.begin(), vals.end()); + return res; + }; + + srcEnc = BlockedEncodingAttr::get( + getContext(), append(defaultEnc.getSizePerThread(), 2), + append(defaultEnc.getThreadsPerWarp(), 1), + append(defaultEnc.getWarpsPerCTA(), 1), + prepend(defaultEnc.getOrder(), rank - 1), + CTALayoutAttr::get(getContext(), + append(defaultEnc.getCTAsPerCGA(), 1), + append(defaultEnc.getCTASplitNum(), 1), + prepend(defaultEnc.getCTAOrder(), rank - 1))); + srcTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), + srcEnc); + src = rewriter.create(op.getLoc(), srcTy, src); + } + + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonTransPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = srcTy.getEncoding(); + if (!srcEnc) + return failure(); + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src, op.getOrder()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonBroadcastPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // This creates a tensor with the new shape but the argument's layout + LogicalResult + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(adaptor.getSrc().getType()); + auto srcEncoding = srcType.getEncoding(); + if (!srcEncoding) + return failure(); + Type retType = RankedTensorType::get( + op.getType().getShape(), op.getType().getElementType(), srcEncoding); + // Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReduce = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); + addNamedAttrs(newReduce, adaptor.getAttributes()); + + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newReduce.getResult()); + return success(); + } +}; + +struct TritonScanPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newScan = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis(), op.getReverse()); + addNamedAttrs(newScan, adaptor.getAttributes()); + + auto &newCombineOp = newScan.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newScan.getResult()); + return success(); + } +}; + +class TritonFuncOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getName(), op.getFunctionType()); + addNamedAttrs(newOp, adaptor.getAttributes()); + rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), + newOp.getBody().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter))) + return failure(); + + return success(); + } +}; + +class TritonCallOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getCallee(), op.getResultTypes(), adaptor.getOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + return success(); + } +}; + +class TritonReturnOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns, unsigned numCTAs) { + MLIRContext *context = patterns.getContext(); + patterns.insert< // TODO: view should have custom pattern that views the + // layout + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + TritonBroadcastPattern, GenericOpPattern, + TritonCatPattern, TritonJoinOpPattern, TritonSplitOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, TritonReducePattern, + GenericOpPattern, TritonScanPattern, + GenericOpPattern, + GenericOpPattern, TritonExpandDimsPattern, + TritonTransPattern, TritonDotPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, TritonFuncOpPattern>(typeConverter, + context); +} + +// +// SCF patterns +// +// This is borrowed from ConvertForOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +struct SCFForPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + // Ref: ConvertForOpTypes + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Now, update all the types. + + // Convert the types of block arguments within the given region. This + // replaces each block with a new block containing the updated signature. + // The entry block may have a special conversion if `entryConversion` is + // provided. On success, the new entry block to the region is returned for + // convenience. Otherwise, failure is returned. + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *getTypeConverter()))) { + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + // Change the clone to use the updated operands. We could have cloned with + // a IRMapping, but this seems a bit more direct. + newOp->setOperands(adaptor.getOperands()); + // Update the result types to the new converted types. + SmallVector newResultTypes; + for (Type type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + + rewriter.replaceOp(op, newOp.getResults()); + + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFWhilePattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + auto newOp = rewriter.create(op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +class SCFConditionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add, SCFForPattern, SCFIfPattern, + SCFWhilePattern, SCFConditionPattern>(typeConverter, context); +} + +// CF + +class CFBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getSuccessor(), adaptor.getOperands()); + if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +class CFCondBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + + if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(), + *converter))) + return failure(); + if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +void populateCFPatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(typeConverter, context); +} +// + +class ConvertTritonToTritonGPU + : public ConvertTritonToTritonGPUBase { +public: + ConvertTritonToTritonGPU() = default; + // constructor with some parameters set explicitly. + ConvertTritonToTritonGPU(const std::string &target, int numWarps, + int threadsPerWarp, int numCTAs) { + this->numWarps = numWarps; + this->threadsPerWarp = threadsPerWarp; + this->numCTAs = numCTAs; + this->target = target; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + // type converter + TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp, + numCTAs); + TritonGPUConversionTarget target(*context, typeConverter); + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + populateArithPatternsAndLegality(typeConverter, patterns, target); + populateMathPatternsAndLegality(typeConverter, patterns, target); + populateTritonPatterns(typeConverter, patterns, numCTAs); + // TODO: can we use + // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? + populateSCFPatterns(typeConverter, patterns); + populateCFPatterns(typeConverter, patterns); + + auto inti = llvm::APSInt(32, false); + auto i32_ty = IntegerType::get(mod->getContext(), 32); + + mod->setAttr( + AttrNumWarpsName, + IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue()))); + mod->setAttr( + AttrNumThreadsPerWarp, + IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue()))); + + mod->setAttr(AttrNumCTAsName, + IntegerAttr::get(i32_ty, llvm::APInt(32, numCTAs.getValue()))); + + if (this->target.getValue().empty()) { + mod.emitError("expected target specification to attach to the module op"); + return signalPassFailure(); + } + mod->setAttr(AttrTargetName, + StringAttr::get(context, this->target.getValue())); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + + // update layouts + // broadcast src => multicast, dst => broadcasted + // if (failed(target.refineLayouts(mod, numWarps))) + // return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass(const std::string &target, + int numWarps, + int threadsPerWarp, + int numCTAs) { + return std::make_unique<::ConvertTritonToTritonGPU>(target, numWarps, + threadsPerWarp, numCTAs); +} + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonGPUPass() { + return std::make_unique<::ConvertTritonToTritonGPU>(); +} diff --git a/third_party/xpu/lib/Conversion/TritonToTritonXPU/CMakeLists.txt b/third_party/xpu/lib/Conversion/TritonToTritonXPU/CMakeLists.txt new file mode 100644 index 000000000..c58083d48 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonToTritonXPU/CMakeLists.txt @@ -0,0 +1,7 @@ +add_triton_library(TritonToTritonXPU + TritonXPUConversion.cpp + TritonToTritonXPUPass.cpp + + DEPENDS + TT2TTXConversionPassIncGen +) diff --git a/third_party/xpu/lib/Conversion/TritonToTritonXPU/TritonToTritonXPUPass.cpp b/third_party/xpu/lib/Conversion/TritonToTritonXPU/TritonToTritonXPUPass.cpp new file mode 100644 index 000000000..4e689e343 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonToTritonXPU/TritonToTritonXPUPass.cpp @@ -0,0 +1,643 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +// clang-format off +#include +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "triton/Dialect/Triton/IR/Dialect.h" // mlir::triton::op + +#include "triton/Conversion/TritonToTritonXPU/Passes.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" + +#include "triton/Dialect/TritonXPU/Transforms/TritonXPUConversion.h" // TritonXPUTypeConverter + TritonXPUConversionTarget +#include "llvm/Support/ErrorHandling.h" // TODO[dyq]: Check All Pattern And Remove It + +#define GEN_PASS_CLASSES +#include "triton/Conversion/TritonToTritonXPU/Passes.h.inc" +// clang-format on + +namespace { +using namespace mlir; +using namespace mlir::triton; + +// pass named attrs (e.g., tt.contiguity) from Triton to Triton +static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { + for (const NamedAttribute attr : dictAttrs.getValue()) + if (!op->hasAttr(attr.getName())) + op->setAttr(attr.getName(), attr.getValue()); +} + +template struct GenericOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + + return success(); + } +}; + +class ArithConstantPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = getTypeConverter()->convertType(op.getType()); + auto retShapedType = cast(retType); + auto value = dyn_cast(adaptor.getValue()); + if (dyn_cast(retShapedType)) { + assert(value); + if (value.getElementType().isInteger(1) && value.isSplat()) + // Workaround until https://reviews.llvm.org/D133743 is included. + value = + DenseElementsAttr::get(retShapedType, value.getSplatValue()); + else + // This is a hack. We just want to add encoding + value = value.reshape(retShapedType); + } + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retShapedType, value), + adaptor.getAttributes()); + return success(); + } +}; + +void populateArithPatternsAndLegality(TritonXPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonXPUConversionTarget &target) { + // -------------- + // Add legality and rewrite pattern rules for operations + // from the Arith dialect. The basic premise is that + // Arith operations require both inputs to have the same + // non-null encoding + // -------------- + MLIRContext *context = patterns.getContext(); + // TODO: there's probably a better way to avoid adding all ops one-by-one + patterns.add< + ArithConstantPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, // NegFOp + // Floating point + GenericOpPattern, GenericOpPattern, + // MaxMin + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + // Floating point + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + // Cmp + GenericOpPattern, GenericOpPattern, + // Select + GenericOpPattern, + // Cast Ops + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern>(typeConverter, context); +} + +void populateMathPatternsAndLegality(TritonXPUTypeConverter &typeConverter, + RewritePatternSet &patterns, + TritonXPUConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + // Rewrite rule + patterns.add, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern>( + typeConverter, context); +} + +// +// Triton patterns +// +struct TritonExpandDimsPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Type retType = op.getType()); + RankedTensorType argType = + cast(adaptor.getSrc().getType()); + Attribute _argEncoding = argType.getEncoding(); + if (!_argEncoding) + return failure(); + auto argEncoding = cast(_argEncoding); + // return shape + auto retShape = argType.getShape().vec(); + retShape.insert(retShape.begin() + op.getAxis(), 1); + // return encoding + auto retSizePerCore = argEncoding.getSizePerCore().vec(); + retSizePerCore.insert(retSizePerCore.begin() + op.getAxis(), 1); + auto retCoresPerGroup = argEncoding.getCoresPerGroup().vec(); + retCoresPerGroup.insert(retCoresPerGroup.begin() + op.getAxis(), 1); + auto retGroupsPerCluster = argEncoding.getGroupsPerCluster().vec(); + retGroupsPerCluster.insert(retGroupsPerCluster.begin() + op.getAxis(), 1); + SmallVector retOrder(retShape.size()); + std::iota(retOrder.begin(), retOrder.end(), 0); + + bool isReduceOpt = argEncoding.getIsReduceOpt(); + + triton::xpu::ClusterLayoutAttr retEncoding = + triton::xpu::ClusterLayoutAttr::get( + getContext(), retSizePerCore, retCoresPerGroup, retGroupsPerCluster, + retOrder, isReduceOpt); + + // convert operand to slice of return type + Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get( + getContext(), op.getAxis(), retEncoding); + RankedTensorType newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), newArgEncoding); + // construct new op + auto newSrc = rewriter.create( + op.getLoc(), newArgType, adaptor.getSrc()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, newSrc, adaptor.getAxis()), + adaptor.getAttributes()); + return success(); + } + +private: + template + SmallVector insertOne(ArrayRef vec, unsigned axis) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + axis, 1); + return res; + } + + // Example: order = [ 0, 2, 1, 3], dim = 2 + // resOrder = [2, 0, 3, 1, 4] + SmallVector insertOrder(ArrayRef order, + unsigned axis) const { + SmallVector resOrder(order.begin(), order.end()); + for (unsigned i = 0; i < resOrder.size(); ++i) + if (resOrder[i] >= axis) + ++resOrder[i]; + resOrder.insert(resOrder.begin(), axis); + return resOrder; + } +}; + +struct TritonDotPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm_unreachable("TODO[dyq]: XPUSDNN-CHECK Add " + "triton::xpu::GlobalEncodingAttr Calculation Logic"); + return failure(); + } +}; + +struct TritonCatPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm_unreachable( + "TODO[dyq]: Add triton::xpu::GlobalEncodingAttr Calculation Logic"); + return failure(); + } +}; + +struct TritonJoinOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable("TODO[dyq]: Check Logic"); + // Simply rely on type inference for this op. (Notably, GenericOpPattern + // does not do this, instead it assigns the default layout to the ins and + // outs.) + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, adaptor.getLhs(), adaptor.getRhs()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonSplitOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable( + "TODO[dyq]: Add triton::xpu::GlobalEncodingAttr Calculation Logic"); + return failure(); + } +}; + +struct TritonTransPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm_unreachable("TODO[dyq]: Check Logic"); + Value src = adaptor.getSrc(); + auto srcTy = cast(src.getType()); + auto srcEnc = srcTy.getEncoding(); + if (!srcEnc) + return failure(); + addNamedAttrs(rewriter.replaceOpWithNewOp(op, src, op.getOrder()), + adaptor.getAttributes()); + return success(); + } +}; + +struct TritonBroadcastPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // This creates a tensor with the new shape but the argument's layout + LogicalResult + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(adaptor.getSrc().getType()); + auto srcEncoding = srcType.getEncoding(); + if (!srcEncoding) + return failure(); + Type retType = RankedTensorType::get( + op.getType().getShape(), op.getType().getElementType(), srcEncoding); + // Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands()), + adaptor.getAttributes()); + return success(); + } +}; +struct TritonReducePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newReduce = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis()); + addNamedAttrs(newReduce, adaptor.getAttributes()); + + auto &newCombineOp = newReduce.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newReduce.getResult()); + return success(); + } +}; + +struct TritonScanPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm_unreachable("TODO[dyq]: Check Logic"); + auto newScan = rewriter.create( + op.getLoc(), adaptor.getOperands(), adaptor.getAxis(), op.getReverse()); + addNamedAttrs(newScan, adaptor.getAttributes()); + + auto &newCombineOp = newScan.getCombineOp(); + rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp, + newCombineOp.end()); + rewriter.replaceOp(op, newScan.getResult()); + return success(); + } +}; + +class TritonFuncOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm_unreachable("TODO[dyq]: Check Logic"); + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getName(), op.getFunctionType()); + addNamedAttrs(newOp, adaptor.getAttributes()); + rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), + newOp.getBody().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter))) + return failure(); + + return success(); + } +}; + +void populateTritonPatterns(TritonXPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert< // TODO: view should have custom pattern that views the + // layout + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + TritonBroadcastPattern, GenericOpPattern, + TritonCatPattern, TritonJoinOpPattern, TritonSplitOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, TritonReducePattern, + GenericOpPattern, TritonScanPattern, + GenericOpPattern, + GenericOpPattern, TritonExpandDimsPattern, + TritonTransPattern, TritonDotPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, TritonFuncOpPattern>(typeConverter, + context); +} + +// +// SCF patterns +// +// This is borrowed from ConvertForOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +struct SCFForPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + // Ref: ConvertForOpTypes + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Now, update all the types. + + // Convert the types of block arguments within the given region. This + // replaces each block with a new block containing the updated signature. + // The entry block may have a special conversion if `entryConversion` is + // provided. On success, the new entry block to the region is returned for + // convenience. Otherwise, failure is returned. + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *getTypeConverter()))) { + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + // Change the clone to use the updated operands. We could have cloned with + // a IRMapping, but this seems a bit more direct. + newOp->setOperands(adaptor.getOperands()); + // Update the result types to the new converted types. + SmallVector newResultTypes; + for (Type type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + + rewriter.replaceOp(op, newOp.getResults()); + + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +class SCFWhilePattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + auto newOp = rewriter.create(op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +class SCFConditionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +void populateSCFPatterns(TritonXPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add, SCFForPattern, SCFIfPattern, + SCFWhilePattern, SCFConditionPattern>(typeConverter, context); +} + +// CF + +class CFBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm_unreachable("TODO[dyq]: Check Logic"); + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getSuccessor(), adaptor.getOperands()); + if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +class CFCondBranchPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm_unreachable("TODO[dyq]: Check Logic"); + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + + if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(), + *converter))) + return failure(); + if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(), + *converter))) + return failure(); + return success(); + } +}; + +void populateCFPatterns(TritonXPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(typeConverter, context); +} +// + +class ConvertTritonToTritonXPU + : public ConvertTritonToTritonXPUBase { +public: + ConvertTritonToTritonXPU() = default; + // constructor with some parameters set explicitly. + ConvertTritonToTritonXPU(uint32_t xpu_arch, uint32_t buffer_size, + uint32_t core_num) { + this->xpu_arch = xpu_arch; + this->buffer_size = buffer_size; + this->core_num = core_num; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + // type converter. the reason that we cant use TT2TTGPass directly + TritonXPUTypeConverter typeConverter(context, buffer_size, core_num); + TritonXPUConversionTarget target(*context, typeConverter); + // rewrite patterns + RewritePatternSet patterns(context); + // add rules + populateArithPatternsAndLegality(typeConverter, patterns, target); + populateMathPatternsAndLegality(typeConverter, patterns, target); + populateTritonPatterns(typeConverter, patterns); + // TODO: can we use + // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? + populateSCFPatterns(typeConverter, patterns); + populateCFPatterns(typeConverter, patterns); + + auto inti = llvm::APSInt(32, false); + auto i32_ty = IntegerType::get(mod->getContext(), 32); + + if (!this->xpu_arch.getValue()) { + mod.emitError("expected target specification to attach to the module op"); + return signalPassFailure(); + } + mod->setAttr( + AttrXPUTargetName, + StringAttr::get(context, + "xpu:" + std::to_string(this->xpu_arch.getValue()))); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + + // update layouts + // broadcast src => multicast, dst => broadcasted + // if (failed(target.refineLayouts(mod, numWarps))) + // return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonXPUPass(uint32_t xpu_arch, + uint32_t buffer_size, + uint32_t core_num) { + return std::make_unique<::ConvertTritonToTritonXPU>(xpu_arch, buffer_size, + core_num); +} + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonXPUPass() { + return std::make_unique<::ConvertTritonToTritonXPU>(); +} diff --git a/third_party/xpu/lib/Conversion/TritonToTritonXPU/TritonXPUConversion.cpp b/third_party/xpu/lib/Conversion/TritonToTritonXPU/TritonXPUConversion.cpp new file mode 100644 index 000000000..0bf832019 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonToTritonXPU/TritonXPUConversion.cpp @@ -0,0 +1,101 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonXPU/Transforms/TritonXPUConversion.h" + +#include "mlir/IR/MLIRContext.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include + +using namespace mlir; + +// +// TypeConverter +// +TritonXPUTypeConverter::TritonXPUTypeConverter(MLIRContext *context, + uint32_t buffer_size, + uint32_t core_num) + : context(context), buffer_size(buffer_size), core_num(core_num) { + + addConversion([](Type type) { return type; }); + + addConversion([this](RankedTensorType tensorType) -> RankedTensorType { + if (tensorType.getEncoding()) + return tensorType; + + ArrayRef shape = tensorType.getShape(); + triton::xpu::ClusterLayoutAttr encoding = + triton::xpu::getDefaultClusterEncoding( + this->context, shape, this->buffer_size, this->core_num); + return RankedTensorType::get(shape, tensorType.getElementType(), encoding); + }); + + // TODO[dyq]: check addConversion for triton::PointerType + + // + // Materializations + // + // This will be called when (newArgType != origArgType) + // This should create newArg, and map(origArg, newArg) + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonXPU conversion"); + return std::nullopt; + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonXPU Conversion"); + return std::nullopt; + }); + + // This will be called when (desiredType != newOperandType) + // where, desiredType = typeConverter->convertType(origType) + // NOTE: only for remapped values. + addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + auto cast = + builder.create(loc, tensorType, inputs); + return std::optional(cast.getResult()); + }); +} + +// +// TritonXPUConversion +// +TritonXPUConversionTarget::TritonXPUConversionTarget( + MLIRContext &context, TritonXPUTypeConverter &typeConverter) + : ConversionTarget(context) { + + addLegalDialect(); + + // Some ops from SCF are illegal + // TODO[dyq]: addIllegalOp necessary? + // addIllegalOp(); + + addDynamicallyLegalDialect([&](Operation *op) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; + }); + + // TODO[dyq]: XPUSDNN-CHECK check addDynamicallyLegalDialect for triton::DotOp +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/CMakeLists.txt b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000..62770ec92 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/CMakeLists.txt @@ -0,0 +1,24 @@ +add_triton_library(TritonXPUToLLVM + ConvertLayoutOpToLLVM.cpp + ElementwiseOpToLLVM.cpp + FuncOpToLLVM.cpp + GPUOpToLLVMXPU.cpp + LoadStoreOpToLLVM.cpp + MakeRangeOpToLLVM.cpp + ReduceOpToLLVM.cpp + SPMDOpToLLVM.cpp + TritonXPUToLLVM.cpp + VectorizedOpToLLVM.cpp + ViewOpToLLVM.cpp + XPUUtilityOpToLLVM.cpp + + TargetInfo.cpp + TypeConverter.cpp + Utility.cpp + + DEPENDS + TTX2LLVMConversionPassIncGen + + LINK_LIBS PUBLIC + TritonGPUToLLVM +) diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ConvertLayoutOpToLLVM.cpp new file mode 100644 index 000000000..b1db9f093 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct XPUConvertLayoutOpConversion + : public ConvertOpToLLVMPattern { + XPUConvertLayoutOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit) {} + + bool isaXPUValidLayout(const Attribute &layout) const { + return mlir::isa(layout) || + mlir::isa( + mlir::cast(layout).getParent()); + } + + LogicalResult + matchAndRewrite(triton::xpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = op.getSrc(); + Value dst = op.getResult(); + auto srcTy = cast(src.getType()); + auto dstTy = cast(dst.getType()); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + if (isaXPUValidLayout(srcLayout) && isaXPUValidLayout(dstLayout)) { + return lowerOperand(op, adaptor, rewriter); + } + return failure(); + }; + + LogicalResult lowerOperand(triton::xpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto typeConverter = getTypeConverter(); + RankedTensorType srcTy = op.getSrc().getType(); + RankedTensorType dstTy = op.getType(); + + auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + Value ret = packLLElements(loc, typeConverter, vals, rewriter, dstTy); + + rewriter.replaceOp(op, ret); + return success(); + } +}; + +} // namespace + +void mlir::triton::xpu::populateConvertLayoutOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, + benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ElementwiseOpToLLVM.cpp new file mode 100644 index 000000000..ab2a7a5bb --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ElementwiseOpToLLVM.cpp @@ -0,0 +1,99 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +namespace { + +template +struct ElementwiseOpConversion + : public mlir::triton::gpu::ElementwiseOpConversionBase< + SourceOp, ElementwiseOpConversion> { + + using Base = mlir::triton::gpu::ElementwiseOpConversionBase< + SourceOp, ElementwiseOpConversion>; + using Base::Base; + using OpAdaptor = typename Base::OpAdaptor; + + // An interface to support variant DestOp builder. + SmallVector + createDestOps(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Type elemTy, + mlir::triton::gpu::MultipleOperandsRange operands, + Location loc) const { + return {rewriter.create(loc, elemTy, operands[0], + adaptor.getAttributes().getValue())}; + } +}; + +template +struct OpToExternCallConversion + : public triton::gpu::ElementwiseOpConversionBase< + TritonOp, OpToExternCallConversion> { + using Base = triton::gpu::ElementwiseOpConversionBase< + TritonOp, OpToExternCallConversion>; + using Base::Base; + using Adaptor = typename Base::OpAdaptor; + + explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + StringRef externFuncName, + PatternBenefit benefit) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + funcName(externFuncName) {} + + SmallVector createDestOps(TritonOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, + triton::gpu::MultipleOperandsRange operands, + Location loc) const { + Type funcType = triton::gpu::getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + triton::gpu::appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + return { + rewriter.create(loc, funcOp, operands[0]).getResult()}; + } + +private: + StringRef funcName; +}; + +} // namespace + +void mlir::triton::xpu::populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfo &targetInfo, + PatternBenefit benefit) { + + patterns.add>( + typeConverter, axisInfoAnalysis, "_ZN3xpu10__fsqrt_rnEf", benefit); + +#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + POPULATE_UNARY_OP(arith::NegFOp, LLVM::FNegOp) + POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp) + POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp) + POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp) + POPULATE_UNARY_OP(arith::FPToSIOp, LLVM::FPToSIOp) + POPULATE_UNARY_OP(math::ExpOp, LLVM::Exp2Op) + POPULATE_UNARY_OP(math::LogOp, LLVM::Log2Op) +#undef POPULATE_UNARY_OP + +#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \ + patterns.add>( \ + typeConverter, axisInfoAnalysis, benefit); + POPULATE_BINARY_OP(arith::AddFOp, LLVM::FAddOp) // addf + POPULATE_BINARY_OP(arith::SubFOp, LLVM::FSubOp) // subf + POPULATE_BINARY_OP(arith::MulFOp, LLVM::FMulOp) // mulf + POPULATE_BINARY_OP(arith::DivFOp, LLVM::FDivOp) // divf + POPULATE_BINARY_OP(arith::MaximumFOp, LLVM::MaximumOp) // maximum + POPULATE_BINARY_OP(arith::MinimumFOp, LLVM::MinimumOp) // minimum + POPULATE_BINARY_OP(triton::PreciseDivFOp, LLVM::FDivOp) +#undef POPULATE_BINARY_OP +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/FuncOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000..d0af561dd --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,84 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +/// FuncOp legalization pattern that converts MemRef arguments to pointers to +/// MemRef descriptors (LLVM struct data types) containing all the MemRef type +/// information. +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = funcOp; + if (!LLVM::isKernel(funcOp)) + llvm_unreachable("Unsuppprted funcOp is not Kernel"); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + amendedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) { + return failure(); + } + + auto ctx = funcOp->getContext(); + + if (LLVM::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr("xpu.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + newFuncOp.setLinkage(LLVM::Linkage::External); + } else { + // The noinline attribute will be used by the LLVM codegen to prevent + // inlining. + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + newFuncOp.setPassthroughAttr( + ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + rewriter.eraseOp(amendedFuncOp); + newFuncOp.setLinkage(LLVM::Linkage::Internal); + } + rewriter.eraseOp(funcOp); + return success(); + } +}; + +} // namespace + +void mlir::triton::xpu::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/GPUOpToLLVMXPU.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/GPUOpToLLVMXPU.cpp new file mode 100644 index 000000000..92dfe447e --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/GPUOpToLLVMXPU.cpp @@ -0,0 +1,100 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +namespace mlir { + +struct XPUIntrinsicOpConversionBase { + explicit XPUIntrinsicOpConversionBase(LLVMTypeConverter &typeConverter, + const xpu::TargetInfo &targetInfo) + : targetInfo(targetInfo), + indexBitwidth(typeConverter.getIndexTypeBitwidth()) {} + +protected: + const xpu::TargetInfo &targetInfo; + unsigned indexBitwidth; +}; + +// for physical id +template +struct XPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern, + public XPUIntrinsicOpConversionBase { + +public: + explicit XPUIndexIntrinsicOpLowering(LLVMTypeConverter &typeConverter, + const xpu::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + XPUIntrinsicOpConversionBase(typeConverter, targetInfo) {} + + // Convert the kernel arguments to an LLVM type, preserve the rest. + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *context = rewriter.getContext(); + Value newOp = rewriter.create(loc, type::i32Ty(context)); + + if (indexBitwidth > 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } else if (indexBitwidth < 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } + + rewriter.replaceOp(op, {newOp}); + return success(); + } +}; + +// for logical id: refer xtdk_sys.h +template +struct XPULoadParamIntrinsicOpLowering : public ConvertOpToLLVMPattern, + public XPUIntrinsicOpConversionBase { + +public: + explicit XPULoadParamIntrinsicOpLowering(LLVMTypeConverter &typeConverter, + const xpu::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), + XPUIntrinsicOpConversionBase(typeConverter, targetInfo) {} + + // Convert the kernel arguments to an LLVM type, preserve the rest. + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *context = rewriter.getContext(); + Value newOp = rewriter.create( + loc, type::i32Ty(context), i32_val(Num)); + + if (indexBitwidth > 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } else if (indexBitwidth < 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } + + rewriter.replaceOp(op, {newOp}); + return success(); + } +}; + +} // namespace mlir + +void mlir::triton::xpu::populateGPUToXPUConversionPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add, + XPULoadParamIntrinsicOpLowering, + XPULoadParamIntrinsicOpLowering, + XPULoadParamIntrinsicOpLowering>( + typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/LoadStoreOpToLLVM.cpp new file mode 100644 index 000000000..19bddd016 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/LoadStoreOpToLLVM.cpp @@ -0,0 +1,1979 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +using namespace mlir; +using namespace mlir::triton; + +using ::mlir::triton::gpu::getTotalElemsPerThread; + +namespace { + +struct LoadStoreConversionBase { + explicit LoadStoreConversionBase(const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass) + : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) { + isBf16RoundToMid = ::triton::tools::getBoolEnv("TRITONXPU_BF16_ROUND_MID"); + } + + unsigned getContiguity(Value ptr) const { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + return axisAnalysisPass.getPtrContiguity(ptr); + } + + unsigned getVectorSize(Value ptr) const { + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + return 1; + auto contiguity = getContiguity(ptr); + auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy); + LDBG("getVectorSize contiguity = " << contiguity << " pointeeBitWidth = " + << pointeeBitWidth); + // The maximum vector size is 512 bits on XPUs. + return std::min(512 / pointeeBitWidth, contiguity); + } + + unsigned getMaskAlignment(Value mask) const { + return axisAnalysisPass.getMaskAlignment(mask); + } + + void getVectorInfo(Type tensorType, unsigned &vecSize, + unsigned &elemNbits) const { + vecSize = 1u; + elemNbits = 32u; + if (auto vecType = mlir::dyn_cast(tensorType)) { + unsigned numElems = vecType.getNumElements(); + Type elemTy = vecType.getElementType(); + elemNbits = isa(elemTy) + ? 64u + : elemTy.getIntOrFloatBitWidth(); + // The maximum vector size is 512 bits on XPU2. + vecSize = std::min(512 / elemNbits, numElems); + } + } + + void setHaddr(ConversionPatternRewriter &rewriter, mlir::Location &loc, + Value ptr) const { + Value ptrInt = ptrtoint(i64_ty, ptr); + Value ptrIntH32 = lshr(ptrInt, int_val(64, 32)); + Value ptrIntS32 = trunc(i32_ty, ptrIntH32); + rewriter.create(loc, ptrIntS32); + } + + void createGM2LMOp(ConversionPatternRewriter &rewriter, + mlir::MLIRContext *ctx, mlir::Location &loc, Value src, + Value dst, Value offset, Value size) const { + switch (static_cast(targetInfo.getXPUArch())) { + case XPUArch::XPU2: { + setHaddr(rewriter, loc, src); + Value srcAs0 = addrspace_cast(ptr_ty(ctx, 0), src); + rewriter.create(loc, srcAs0, dst, offset, size); + break; + } + case XPUArch::XPU3: { + rewriter.create(loc, src, dst, offset, size); + break; + } + default: + llvm_unreachable( + "Failed to create GM2LMOp with unsupported xpu architecture."); + } + } + + void createLM2GMOp(ConversionPatternRewriter &rewriter, + mlir::MLIRContext *ctx, mlir::Location &loc, Value src, + Value dst, Value offset, Value size) const { + switch (static_cast(targetInfo.getXPUArch())) { + case XPUArch::XPU2: { + setHaddr(rewriter, loc, dst); + Value dstAs0 = addrspace_cast(ptr_ty(ctx, 0), dst); + rewriter.create(loc, src, dstAs0, offset, size); + break; + } + case XPUArch::XPU3: { + rewriter.create(loc, src, dst, offset, size); + break; + } + default: + llvm_unreachable( + "Failed to create LM2GMOp with unsupported xpu architecture."); + } + } + + void createSM2GMOp(ConversionPatternRewriter &rewriter, + mlir::MLIRContext *ctx, mlir::Location &loc, Value src, + Value dst, Value offset, Value size) const { + switch (static_cast(targetInfo.getXPUArch())) { + case XPUArch::XPU2: { + setHaddr(rewriter, loc, dst); + Value dstAs0 = addrspace_cast(ptr_ty(ctx, 0), dst); + rewriter.create(loc, src, dstAs0, offset, size); + break; + } + case XPUArch::XPU3: { + rewriter.create(loc, src, dst, offset, size); + break; + } + default: + llvm_unreachable( + "Failed to create LM2GMOp with unsupported xpu architecture."); + } + } + + void createMemOp(ConversionPatternRewriter &rewriter, mlir::MLIRContext *ctx, + mlir::Location &loc, Value bufPtr, Value gmPtr, Value offset, + Value size, MemCpyType memCpyType) const { + switch (static_cast(memCpyType)) { + case MemCpyType::GM2LM: + createGM2LMOp(rewriter, ctx, loc, bufPtr, gmPtr, offset, size); + break; + case MemCpyType::LM2GM: + createLM2GMOp(rewriter, ctx, loc, gmPtr, bufPtr, offset, size); + break; + case MemCpyType::SM2GM: + createSM2GMOp(rewriter, ctx, loc, bufPtr, gmPtr, offset, size); + break; + default: + llvm_unreachable("Memory Op only includes GM2LM, LM2GM, SM2GM"); + } + } + + void createMfenceOp(ConversionPatternRewriter &rewriter, + mlir::Location &loc) const { + // The magic number 5(101) of MfenceOp means mfencing on LM and GM + rewriter.create(loc, i32_val(5)); + } + + Value getStartPtr(ConversionPatternRewriter &rewriter, mlir::MLIRContext *ctx, + mlir::Location &loc, Value gmPtr, Value zeroPtr, + Value rowLen, Value elemBytes) const { + Value gmPtrInt = ptrtoint(i64_ty, gmPtr); + Value zeroPtrInt = ptrtoint(i64_ty, zeroPtr); + Value offset = sdiv(sub(gmPtrInt, zeroPtrInt), elemBytes); + Value startOffsetBytes = mul(mul(sdiv(offset, rowLen), rowLen), elemBytes); + Value startPtr = gep(ptr_ty(ctx, 0), i8_ty, zeroPtr, + startOffsetBytes); // convert ptr first, then move + return startPtr; + } + + void lowerLocallyContinuousUnfixedStride( + Operation *op, Location loc, ConversionPatternRewriter &rewriter, + int64_t _rowLen, int64_t _bufLen, int64_t _elemBytes, Value llGMPtr, + Value llLMPtr, Value llLen, Value offsetBytes, MemCpyType memCpyType, + Block *oldBlock, Block *newBlock) const { + // clang-format off + /* ***************************************************************************** + def getStartPtr(gmPtr, zeroPtr, rowLen, elemBytes): + offset = (gmPtr - zeroPtr) / elemBytes + startOffsetBytes = (offset / rowLen) * rowLen * elemBytes + return zeroPtr + startOffsetBytes + + _rowMaxTail = _bufLen % _rowLen + _rowNum = _bufLen / _rowLen + rowBytes = rowLen * elemBytes + tailLen = min(rowLen - (gmPtr.front() - zeroPtr) / elemBytes % rowLen, bufLen) + if _rowMaxTail == 0: + for i in range(_rowNum): + gmStartPtr = llGMPtrs[i * _rowLen] + lmOffsetBytes = (i * _rowLen) * elemBytes + lmStartPtr = lmPtr + lmOffsetBytes; + gm2lm(gmStartPtr, lmStartPtr, remainBytes) + else: + if 0 < tailLen < rowMaxTail: + gm2lm(gmPtr.front(), lmPtr, tailBytes) + for i in range(_rowNum): + gmStartPtr = getStartPtr(gmPtr[_rowMaxTail+i*_rowLen], zeroPtr, rowLen, elemBytes) + lmOffsetBytes = (tailLen + i * rowLen) * elemBytes + lmStartPtr = lmPtr + lmOffsetBytes + gm2lm(gmStartPtr, lmStartPtr, rowBytes) + gmStartPtr = getStartPtr(gmPtr.back(), zeroPtr, rowLen, elemBytes) + offset = tailLen + rowNum * rowLen + lmOffsetBytes = offset * elemBytes + lmStartPtr = lmPtr + lmOffsetBytes + remainBytes = (bufLen - offset) * elemBytes + gm2lm(gmStartPtr, lmStartPtr, remainBytes) + else: + gm2lm(gmPtr.front(), lmPtr, tailBytes) + if _rowNum >= 1: + for i in range(_rowNum-1): + gmPtr1 = gmPtr[_rowMaxTail+i*_rowLen] + gmPtr2 = gmPtr[_rowMaxTail+(i+1)*_rowLen] + gmPtr = select(tailLen == rowMaxTail, gmPtr1, gmPtr2) + gmStartPtr = getStartPtr(gmPtr[_rowMaxTail+(i+1)*_rowLen], zeroPtr, rowLen, elemBytes) + lmOffsetBytes = (tailLen + i * rowLen) * elemBytes + lmStartPtr = lmPtr + lmOffsetBytes + gm2lm(gmStartPtr, lmStartPtr, rowBytes) + gmStartPtr = getStartPtr(gmPtr.back(), zeroPtr, rowLen, elemBytes) + offset = tailLen + (rowNum - 1) * rowLen + lmOffsetBytes = offset * elemBytes + lmStartPtr = lmPtr + lmOffsetBytes + remainBytes = (bufLen - offset) * elemBytes + gm2lm(gmStartPtr, lmStartPtr, remainBytes) + ********************************************************************************/ + // clang-format on + MLIRContext *ctx = rewriter.getContext(); + + auto llGMPtrs = unpackLLElements(loc, llGMPtr, rewriter); + auto llLMPtrs = unpackLLElements(loc, llLMPtr, rewriter); + Value gmFrontPtr = llGMPtrs.front(); + Value gmBackPtr = llGMPtrs.back(); + Value lmPtr = llLMPtrs.front(); + + auto zeroOp = findDefOpBwd(gmFrontPtr); + Value zeroPtr = cast(zeroOp).getBase(); + Value zeroPtrInt = ptrtoint(i64_ty, zeroPtr); + Value gmFrontPtrInt = ptrtoint(i64_ty, gmFrontPtr); + + int64_t _rowMaxTail = _bufLen % _rowLen; + int64_t _rowNum = _bufLen / _rowLen; + Value rowMaxTail = i64_val(_rowMaxTail); + Value rowNum = i64_val(_rowNum); + Value rowLen = i64_val(_rowLen); + Value bufLen = i64_val(_bufLen); + Value elemBytes = i64_val(_elemBytes); + Value rowBytes = trunc(i32_ty, mul(rowLen, elemBytes)); + + if (_rowMaxTail == 0) { + // GM2LM/LM2GM Row Data + for (int64_t i = 0; i < _rowNum; ++i) { + Value gmStartPtr = llGMPtrs[i * _rowLen]; + Value lmOffsetBytes = mul(i64_val(i * _rowLen), elemBytes); + Value lmStartPtr = gep(ptr_ty(ctx, 0), i8_ty, lmPtr, lmOffsetBytes); + createMemOp(rewriter, ctx, loc, gmStartPtr, lmStartPtr, offsetBytes, + rowBytes, memCpyType); + } + } else { + Value gmFrontOffset = sdiv(sub(gmFrontPtrInt, zeroPtrInt), elemBytes); + Value tailLen = smin(sub(rowLen, srem(gmFrontOffset, rowLen)), bufLen); + + Block *thenBB = rewriter.createBlock(newBlock); + Block *elseBB = rewriter.createBlock(newBlock); + Block *mfenceBB = rewriter.createBlock(newBlock); + rewriter.setInsertionPointToEnd(oldBlock); + + Value condTailSgt = icmp_sgt(tailLen, i64_val(0)); + Value condTailSlt = icmp_slt(tailLen, rowMaxTail); + Value condTailDiff = and_(condTailSgt, condTailSlt); + rewriter.create(loc, condTailDiff, thenBB, elseBB); + // 1. ThenBB + rewriter.setInsertionPointToEnd(thenBB); + { + // 1.1 GM2LM/LM2GM Tail Data + Value tailBytes = trunc(i32_ty, mul(tailLen, elemBytes)); + createMemOp(rewriter, ctx, loc, gmFrontPtr, lmPtr, offsetBytes, + tailBytes, memCpyType); + // 1.2 GM2LM/LM2GM Row Data + for (int64_t i = 0; i < _rowNum; ++i) { + Value gmPtr = llGMPtrs[_rowMaxTail + i * _rowLen]; + Value gmStartPtr = getStartPtr(rewriter, ctx, loc, gmPtr, zeroPtr, + rowLen, elemBytes); + Value lmOffsetBytes = + mul(add(tailLen, i64_val(i * _rowLen)), elemBytes); + Value lmStartPtr = gep(ptr_ty(ctx, 0), i8_ty, lmPtr, lmOffsetBytes); + createMemOp(rewriter, ctx, loc, gmStartPtr, lmStartPtr, offsetBytes, + rowBytes, memCpyType); + } + // 1.3 GM2LM/LM2GM Remain Data + Value gmPtr = llGMPtrs.back(); + Value gmStartPtr = + getStartPtr(rewriter, ctx, loc, gmPtr, zeroPtr, rowLen, elemBytes); + Value offset = add(tailLen, i64_val(_rowNum * _rowLen)); + Value lmOffsetBytes = mul(offset, elemBytes); + Value lmStartPtr = gep(ptr_ty(ctx, 0), i8_ty, lmPtr, lmOffsetBytes); + Value remainBytes = trunc(i32_ty, mul(sub(bufLen, offset), elemBytes)); + createMemOp(rewriter, ctx, loc, gmStartPtr, lmStartPtr, offsetBytes, + remainBytes, memCpyType); + } + rewriter.create(loc, ValueRange{}, + mfenceBB); // Jump to mfenceBB + + // 2. elseBB + rewriter.setInsertionPointToEnd(elseBB); + { + // 1.1 GM2LM/LM2GM Tail Data + Value tailBytes = trunc(i32_ty, mul(tailLen, elemBytes)); + createMemOp(rewriter, ctx, loc, gmFrontPtr, lmPtr, offsetBytes, + tailBytes, memCpyType); + if (_rowNum >= 1) { + // 1.2 GM2LM/LM2GM Row Data + Value gmCond = icmp_eq(tailLen, rowMaxTail); + for (int64_t i = 0; i < _rowNum - 1; ++i) { + Value gmPtr1 = llGMPtrs[_rowMaxTail + i * _rowLen]; + Value gmPtr2 = llGMPtrs[_rowMaxTail + (i + 1) * _rowLen]; + Value gmPtr = select(gmCond, gmPtr1, gmPtr2); + Value gmStartPtr = getStartPtr(rewriter, ctx, loc, gmPtr, zeroPtr, + rowLen, elemBytes); + Value lmOffsetBytes = + mul(add(tailLen, i64_val(i * _rowLen)), elemBytes); + Value lmStartPtr = gep(ptr_ty(ctx, 0), i8_ty, lmPtr, lmOffsetBytes); + createMemOp(rewriter, ctx, loc, gmStartPtr, lmStartPtr, offsetBytes, + rowBytes, memCpyType); + } + // 1.3 GM2LM/LM2GM Remain Data + Value gmPtr = llGMPtrs.back(); + Value gmStartPtr = getStartPtr(rewriter, ctx, loc, gmPtr, zeroPtr, + rowLen, elemBytes); + Value offset = add(tailLen, i64_val((_rowNum - 1) * _rowLen)); + Value lmOffsetBytes = mul(offset, elemBytes); + Value lmStartPtr = gep(ptr_ty(ctx, 0), i8_ty, lmPtr, lmOffsetBytes); + Value remainBytes = + trunc(i32_ty, mul(sub(bufLen, offset), elemBytes)); + createMemOp(rewriter, ctx, loc, gmStartPtr, lmStartPtr, offsetBytes, + remainBytes, memCpyType); + } + } + rewriter.create(loc, ValueRange{}, + mfenceBB); // Jump to mfenceBB + + // 3. mefenceBB + rewriter.setInsertionPointToEnd(mfenceBB); + } + } + + void lowerLocallyContinuousLargeRow(Operation *op, Location loc, + ConversionPatternRewriter &rewriter, + size_t rowSize, size_t rowStride, + Value llGMPtr, Value llLMPtr, Value llLen, + Value bufLen, Value elemBytes, + Value offsetBytes, MemCpyType memCpyType, + Block *oldBlock, Block *newBlock) const { + + /* ************************************************* + gapLen = strideLen - rowLen + bankOffset = (bankPtrInt - zeroPtrInt) / elemBytes + rowOffset = bankOffset / strideLen * strideLen + blockOffset = ((bankOffset - rowOffset) / rowLen) * rowLen + realTailLen = rowLen - (bankOffset - (blockOffset + rowOffset)) + + if 0 < realTailLen < bufLen: + gm2lm(bankPtr, lmPtr, realTailLen * elemBytes) + gm2lm(bankPtr + (realTailLen + gapLen) * elemBytes, lmPtr + realTailLen, + elemBytes,(bufLen - realTailLen)* elemBytes) + + else : + gm2lm(bankPtr, lmPtr, bufLen * elemBytes) + * ************************************************/ + + MLIRContext *ctx = rewriter.getContext(); + + auto llGMPtrs = unpackLLElements(loc, llGMPtr, rewriter); + auto llLMPtrs = unpackLLElements(loc, llLMPtr, rewriter); + auto bankPtr = llGMPtrs[0]; + auto lmBuf = llLMPtrs[0]; + if (bufLen.getType().isInteger(64)) { + bufLen = trunc(i32_ty, bufLen); + } + + auto zeroOp = findDefOpBwd(bankPtr); + auto zeroPtr = cast(zeroOp).getBase(); + Value zeroPtrInt = ptrtoint(i64_ty, zeroPtr); + Value bankPtrInt = ptrtoint(i64_ty, bankPtr); + + size_t gapSize = rowStride - rowSize; + Value rowLen = i32_val(rowSize); + Value strideLen = i32_val(rowStride); + Value gapLen = i32_val(gapSize); + Value gapBytes = mul(gapLen, elemBytes); + Value bankOffset = + sdiv(trunc(i32_ty, sub(bankPtrInt, zeroPtrInt)), elemBytes); + Value rowOffset = rowStride == 0 + ? i32_val(0) + : mul(sdiv(bankOffset, strideLen), strideLen); + Value blockOffset = + rowStride == 0 ? i32_val(0) + : mul(sdiv(sub(bankOffset, rowOffset), rowLen), rowLen); + Value realTailLen = + sub(rowLen, sub(bankOffset, add(blockOffset, rowOffset))); + Value realTailBytes = mul(realTailLen, elemBytes); + + zeroPtr = gep(ptr_ty(ctx, 1), i8_ty, zeroPtr, i32_val(0)); + bankPtr = gep(ptr_ty(ctx, 1), i8_ty, bankPtr, i32_val(0)); + Value lmPtr = gep(ptr_ty(ctx, 0), i8_ty, lmBuf, i32_val(0)); + + Block *thenBB = rewriter.createBlock(newBlock); + Block *elseBB = rewriter.createBlock(newBlock); + Block *mfenceBB = rewriter.createBlock(newBlock); + rewriter.setInsertionPointToEnd(oldBlock); + + Value condRemSgt = icmp_sgt(realTailLen, i32_val(0)); + Value condRemSlt = icmp_slt(realTailLen, bufLen); + Value condRemDiff = and_(condRemSgt, condRemSlt); + rewriter.create(loc, condRemDiff, thenBB, elseBB); + rewriter.setInsertionPointToEnd(thenBB); + // 1. ThenBB + // 1.1 GM2LM Tail Data + Value tailLen = realTailLen; + if (llLen) { + auto llLens = unpackLLElements(loc, llLen, rewriter); + if (llLens[0].getType().isInteger(64)) { + Value limitedLen = + smin(smax(llLens[0], i64_val(0)), sext(i64_ty, bufLen)); + tailLen = smin(realTailLen, trunc(i32_ty, limitedLen)); + } else if (llLens[0].getType().isInteger(1)) { + Value limitedLen = bufLen; + tailLen = smin(realTailLen, limitedLen); + } else { + Value limitedLen = smin(smax(llLens[0], i32_val(0)), bufLen); + tailLen = smin(realTailLen, limitedLen); + } + } + Value tailBytes = mul(tailLen, elemBytes); + createMemOp(rewriter, ctx, loc, bankPtr, lmPtr, offsetBytes, tailBytes, + memCpyType); + + // 1.2 GM2LM Remain Data + Value startCond; + if (llLen) { + auto llLens = unpackLLElements(loc, llLen, rewriter); + if (llLens[0].getType().isInteger(64)) { + startCond = icmp_sge(sext(i64_ty, realTailLen), llLens[0]); + } else if (llLens[0].getType().isInteger(1)) { + startCond = icmp_sge(realTailLen, bufLen); + } else { + startCond = icmp_sge(realTailLen, llLens[0]); + } + } + Value startPtrInt = + add(bankPtrInt, zext(i64_ty, add(realTailBytes, gapBytes))); + Value startPtr = + rowStride == 0 ? zeroPtr : inttoptr(ptr_ty(ctx, 1), startPtrInt); + startPtr = startCond ? select(startCond, zeroPtr, startPtr) : startPtr; + Value dstStartPtr = gep(ptr_ty(ctx, 0), i8_ty, lmPtr, + realTailBytes); // convert ptr first, then move + + Value remainLen = sub(bufLen, realTailLen); + Value remainBytes = mul(remainLen, elemBytes); + createMemOp(rewriter, ctx, loc, startPtr, dstStartPtr, offsetBytes, + remainBytes, memCpyType); + rewriter.create(loc, ValueRange{}, + mfenceBB); // Jump to mfenceBB + + // 2. elseBB + rewriter.setInsertionPointToEnd(elseBB); + // GM2LM the whole bufLen + Value readBytes = mul(bufLen, elemBytes); + createMemOp(rewriter, ctx, loc, bankPtr, lmPtr, offsetBytes, readBytes, + memCpyType); + rewriter.create(loc, ValueRange{}, + mfenceBB); // Jump to mfenceBB + + // 3. mefenceBB + rewriter.setInsertionPointToEnd(mfenceBB); + } + + void lowerLocallyContinuousSmallRow(Operation *op, Location loc, + ConversionPatternRewriter &rewriter, + size_t rowSize, size_t rowStride, + Value llGMPtr, Value llLMPtr, Value llLen, + Value bufLen, Value elemBytes, + Value offsetBytes, MemCpyType memCpyType, + Block *oldBlock, Block *newBlock) const { + + /* ************************************************* + bankOffset = (bankPtrInt - zeroPtrInt) / elemBytes + rowOffset = bankOffset / strideLen * strideLen + blockOffset = ((bankOffset - rowOffset) / rowLen) + rowHeadLen = bankOffset - (blockOffset + rowOffset) + realTailLen = rowLen - rowHeadLen + rowNum = (bufLen - realTailLen - 1) / rowLen + + gm2lm(bankPtr, lmPtr, realTailLen * elemBytes) + + for(i = 0; i < rowNum; i++) { + gm2lm(bankPtr + ((i + 1) * strideLen - rowHeadLen) * elemBytes, lmPtr + + (realTailLen + i * rowLen) * elemBytes, rowLen * elemBytes) + } + + remLen = bufLen - realTailLen - rowNum * rowLen + gm2lm(bankPtr + ((rowNum + 1) * strideLen - rowHeadLen) * elemBytes, lmPtr + + (realTailLen + rowNum * rowLen) * elemBytes, (remLen * elemBytes) + *************************************************/ + + MLIRContext *ctx = rewriter.getContext(); + + auto llGMPtrs = unpackLLElements(loc, llGMPtr, rewriter); + auto llLMPtrs = unpackLLElements(loc, llLMPtr, rewriter); + auto bankPtr = llGMPtrs[0]; + auto lmBuf = llLMPtrs[0]; + if (bufLen.getType().isInteger(64)) { + bufLen = trunc(i32_ty, bufLen); + } + auto zeroOp = findDefOpBwd(bankPtr); + auto zeroPtr = cast(zeroOp).getBase(); + Value zeroPtrInt = ptrtoint(i64_ty, zeroPtr); + Value bankPtrInt = ptrtoint(i64_ty, bankPtr); + + Value rowLen = i32_val(rowSize); + Value strideLen = i32_val(rowStride); + Value bankOffset = + sdiv(trunc(i32_ty, sub(bankPtrInt, zeroPtrInt)), elemBytes); + Value rowOffset = rowStride == 0 + ? i32_val(0) + : mul(sdiv(bankOffset, strideLen), strideLen); + Value blockOffset = + rowStride == 0 ? i32_val(0) + : mul(sdiv(sub(bankOffset, rowOffset), rowLen), rowLen); + Value realTailLen = + sub(rowLen, sub(bankOffset, add(blockOffset, rowOffset))); + Value realTailBytes = mul(realTailLen, elemBytes); + Value rowBytes = mul(rowLen, elemBytes); + Value rowHeadLen = sub(rowLen, realTailLen); + Value rowHeadBytes = sub(rowBytes, realTailBytes); + Value realRemainLen = sub(sub(bufLen, realTailLen), i32_val(1)); + Value rowNum = sdiv(realRemainLen, rowLen); + + zeroPtr = gep(ptr_ty(ctx, 1), i8_ty, zeroPtr, i32_val(0)); + bankPtr = gep(ptr_ty(ctx, 1), i8_ty, bankPtr, i32_val(0)); + Value lmPtr = gep(ptr_ty(ctx, 0), i8_ty, lmBuf, i32_val(0)); + + Block *judgeBB = rewriter.createBlock(newBlock, TypeRange{i32_ty}, {loc}); + Block *gm2lmRowBB = rewriter.createBlock(newBlock); + Block *stepBB = rewriter.createBlock(newBlock); + Block *gm2lmRemBB = rewriter.createBlock(newBlock); + + // 1. GM2LM Tail Data + rewriter.setInsertionPointToEnd(oldBlock); + Value tailLen = realTailLen; + if (llLen) { + auto llLens = unpackLLElements(loc, llLen, rewriter); + if (llLens[0].getType().isInteger(64)) { + tailLen = smin(realTailLen, trunc(i32_ty, llLens[0])); + } else { + tailLen = smin(realTailLen, llLens[0]); + } + } + Value tailBytes = mul(tailLen, elemBytes); + createMemOp(rewriter, ctx, loc, bankPtr, lmPtr, offsetBytes, tailBytes, + memCpyType); + + Value _init = i32_val(0); + Value _step = i32_val(1); + rewriter.create(loc, ValueRange{_init}, + judgeBB); // Jump to judgeBB + Value iter = judgeBB->getArgument(0); + + // 2. GM2LM Row Data + rewriter.setInsertionPointToEnd(judgeBB); + Value condSlt = icmp_slt(iter, rowNum); + rewriter.create(loc, condSlt, gm2lmRowBB, gm2lmRemBB); + + rewriter.setInsertionPointToEnd(gm2lmRowBB); + Value skipStride = mul(add(iter, i32_val(1)), strideLen); + Value skipStrideBytes = mul(skipStride, elemBytes); + Value skipRowLen = mul(iter, rowLen); + Value startPtrInt = + add(bankPtrInt, zext(i64_ty, sub(skipStrideBytes, rowHeadBytes))); + Value startPtr = + rowStride == 0 ? zeroPtr : inttoptr(ptr_ty(ctx, 1), startPtrInt); + startPtr = gep(ptr_ty(ctx, 1), i8_ty, startPtr, i32_val(0)); + Value startCond; + if (llLen) { + auto llLens = unpackLLElements(loc, llLen, rewriter); + if (llLens[0].getType().isInteger(64)) { + startCond = + icmp_sgt(sext(i64_ty, add(skipRowLen, realTailLen)), llLens[0]); + } else { + startCond = icmp_sgt(add(skipRowLen, realTailLen), llLens[0]); + } + } + startPtr = startCond ? select(startCond, zeroPtr, startPtr) : startPtr; + Value dstOffset = add(realTailLen, skipRowLen); + Value dstOffsetBytes = mul(dstOffset, elemBytes); + Value dstStartPtr = gep(ptr_ty(ctx, 0), i8_ty, lmPtr, + dstOffsetBytes); // convert ptr first, then move + createMemOp(rewriter, ctx, loc, startPtr, dstStartPtr, offsetBytes, + rowBytes, memCpyType); + rewriter.create(loc, ValueRange{}, stepBB); // Jump to stepBB + + rewriter.setInsertionPointToEnd(stepBB); + Value _index = add(iter, _step); + rewriter.create(loc, ValueRange{_index}, + judgeBB); // Jump back to judgeBB + + // 3 GM2LM Remain Data + rewriter.setInsertionPointToEnd(gm2lmRemBB); + { + Value skipStride = mul(add(rowNum, i32_val(1)), strideLen); + Value skipStrideBytes = mul(skipStride, elemBytes); + Value skipRowLen = mul(rowNum, rowLen); + Value remainBytes = + mul(sub(bufLen, add(realTailLen, skipRowLen)), elemBytes); + Value startPtrInt = + add(bankPtrInt, zext(i64_ty, sub(skipStrideBytes, rowHeadBytes))); + Value startPtr = + rowStride == 0 ? zeroPtr : inttoptr(ptr_ty(ctx, 1), startPtrInt); + startPtr = gep(ptr_ty(ctx, 1), i8_ty, startPtr, i32_val(0)); + Value startCond; + if (llLen) { + auto llLens = unpackLLElements(loc, llLen, rewriter); + if (llLens[0].getType().isInteger(64)) { + startCond = + icmp_sgt(sext(i64_ty, add(skipRowLen, realTailLen)), llLens[0]); + } else { + startCond = icmp_sgt(add(skipRowLen, realTailLen), llLens[0]); + } + } + startPtr = startCond ? select(startCond, zeroPtr, startPtr) : startPtr; + Value dstOffset = add(realTailLen, skipRowLen); + Value dstOffsetBytes = mul(dstOffset, elemBytes); + Value dstStartPtr = gep(ptr_ty(ctx, 0), i8_ty, lmPtr, + dstOffsetBytes); // convert ptr first, then move + createMemOp(rewriter, ctx, loc, startPtr, dstStartPtr, offsetBytes, + remainBytes, memCpyType); + } + } + +protected: + const xpu::TargetInfo &targetInfo; + ModuleAxisInfoAnalysis &axisAnalysisPass; + bool isBf16RoundToMid = false; +}; + +struct XPULoadOpConversion : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + XPULoadOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + void VecBF16ToFP32Unordered(triton::xpu::LoadOp op, mlir::MLIRContext *ctx, + Location &loc, + ConversionPatternRewriter &rewriter, + Type &resElemTy, int numElems, int resVecSize, + int ptrDataVecSize, Value &lmBasePtr, + SmallVector &loadedVals) const { + VectorType vecBf16Ty = VectorType::get(ptrDataVecSize, bf16_ty); + VectorType veci16Ty = VectorType::get(ptrDataVecSize, i16_ty); + VectorType veci32Ty = VectorType::get(resVecSize, i32_ty); + VectorType vec1Ty = VectorType::get(ptrDataVecSize, i1_ty); + VectorType halfVecBf16Ty = VectorType::get(resVecSize, bf16_ty); + VectorType VecFp16Ty = VectorType::get(ptrDataVecSize, f16_ty); + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + int mask = 0xaaaaaaaa; + Value maskVal = i32_val(mask); + maskVal = bitcast(maskVal, vec1Ty); + Value maskNegVal = i32_val(~mask); + maskNegVal = bitcast(maskNegVal, vec1Ty); + int16_t pad = 0x8000; + Value padVec = rewriter.create(loc, veci16Ty); + for (size_t elemIdx = 0; elemIdx < ptrDataVecSize; ++elemIdx) { + padVec = insert_element(veci16Ty, padVec, i16_val(pad), i16_val(elemIdx)); + } + for (int i = 0; i < numElems / 2; ++i) { + Value elemPtr = gep(ptr_ty(ctx, 0), vecBf16Ty, lmBasePtr, i32_val(i)); + Value veven; + if (isBf16RoundToMid) { + veven = rewriter.create( + loc, veci16Ty, elemPtr, padVec, maskVal); + } else { + veven = rewriter.create(loc, veci16Ty, + elemPtr, maskVal); + } + veven = bitcast(veven, resElemTy); + loadedVals.emplace_back(veven); + Value vodd; + if (isBf16RoundToMid) { + vodd = rewriter.create( + loc, veci16Ty, elemPtr, padVec, maskNegVal); + } else { + vodd = rewriter.create( + loc, veci16Ty, elemPtr, maskNegVal); + } + vodd = bitcast(vodd, VecFp16Ty); + Value voddSl = + rewriter.create(loc, VecFp16Ty, vodd); + voddSl = bitcast(voddSl, resElemTy); + loadedVals.emplace_back(voddSl); + } + if (numElems % 2 == 1) { + int remainedIdx = numElems - 1; + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + Value elemPtr = + gep(ptr_ty(ctx, 0), halfVecBf16Ty, lmBasePtr, i32_val(remainedIdx)); + Value loaded = load(halfVecBf16Ty, elemPtr); + loaded = rewriter.create(loc, resElemTy, loaded); + loadedVals.emplace_back(loaded); + } + return; + } + + void VecBF16ToFP32(triton::xpu::LoadOp op, mlir::MLIRContext *ctx, + Location &loc, ConversionPatternRewriter &rewriter, + Type &resElemTy, int numElems, int resVecSize, + int ptrDataVecSize, SmallVector &loadedVals) const { + VectorType vecFp16Ty = VectorType::get(ptrDataVecSize, f16_ty); + Value padVec = rewriter.create(loc, vecFp16Ty); + int16_t pad = isBf16RoundToMid ? 0x8000 : 0; + for (size_t elemIdx = 0; elemIdx < ptrDataVecSize; ++elemIdx) { + padVec = + insert_element(vecFp16Ty, padVec, f16_val(pad), i16_val(elemIdx)); + } + SmallVector newLoadedVals; + for (int i = 0; i < numElems / 2; ++i) { + Value val = bitcast(loadedVals[i], vecFp16Ty); + Value vl = rewriter.create(loc, vecFp16Ty, + padVec, val); + vl = bitcast(vl, resElemTy); + newLoadedVals.emplace_back(vl); + Value vh = rewriter.create(loc, vecFp16Ty, + padVec, val); + vh = bitcast(vh, resElemTy); + newLoadedVals.emplace_back(vh); + } + if (numElems % 2 == 1) { + int remainedIdx = numElems - 1; + Value ext = rewriter.create(loc, resElemTy, + loadedVals[remainedIdx]); + newLoadedVals.emplace_back(ext); + } + loadedVals = newLoadedVals; + return; + } + + LogicalResult + matchAndRewrite(triton::xpu::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + + // original values + Value res = op.getResult(); + Value ptr = op.getPtr(); + Value mask = op.getMask(); + Value other = op.getOther(); + Value index = op.getIndex(); + + int32_t stride = op.getStride(); + int32_t tensor_col_size = op.getTensorColSize(); + bool coreDealMultiRows = tensor_col_size != -1; + bool isDiscreteSame = (stride == 0); + bool isUnknown = stride != 0 && stride != 1 && !op.getIsDiscrete(); + bool bf16Tofp32Unordered = op.getBf16Tofp32Unordered(); + + LDBG("Lower LoadOp for " << ptr); + + // adaptor values + assert(!isTensorPointerType(ptr.getType()) && + "Cannot convert load with a tensor pointer into LLVM; " + "this case should be transformed to normal load before lowering"); + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llOther = adaptor.getOther(); + Value llIndex = adaptor.getIndex(); + + // Determine Type + Type ptrTy = ptr.getType(); + Type resTy = res.getType(); + + Type ptrElemTy = typeConverter->convertType(getElementTypeOrSelf(ptrTy)); + Type resElemTy = typeConverter->convertType(getElementTypeOrSelf(resTy)); + Type ptrDataVecTy = resElemTy; + + unsigned ptrNumElems = getTotalElemsPerThread(ptrTy); + unsigned resNumElems = getTotalElemsPerThread(resTy); + + Type ptrElemScalarTy; + if (auto ptrTensorTy = mlir::dyn_cast(ptrTy)) { + // Tensor + ptrElemScalarTy = + mlir::cast(ptrTensorTy.getElementType()) + .getPointeeType(); + } else { + // Scalar + ptrElemScalarTy = mlir::cast(ptrTy).getPointeeType(); + } + + Type resElemScalarTy = getElementTypeOrSelf(resElemTy); + + // Get the LLVM values + auto llPtrs = unpackLLElements(loc, llPtr, rewriter); + stride = + (stride >= 0 && ptrNumElems * stride <= targetInfo.getXPUBufferSize()) + ? stride + : 1; + + assert(llPtrs.size() == ptrNumElems); + bool isVectorized = false; + unsigned vecSize = 1u; + unsigned elemNbits = + isa(resElemScalarTy) + ? 64u + : resElemScalarTy.getIntOrFloatBitWidth(); + if (mlir::isa(resElemTy)) { + isVectorized = true; + getVectorInfo(resElemTy, vecSize, elemNbits); + } + + // fp16Tofp32 + if (resElemScalarTy.isF32() && ptrElemScalarTy.isF16()) { + Value fp16LM = bitcast(llPtrs[0], ptr_ty(ctx, 0)); + Value fp32LM = bitcast(llPtrs[0], ptr_ty(ctx, 0)); + ValueRange singleOperandRange( + {fp16LM, fp32LM, i32_val(ptrNumElems * stride)}); + mlir::LLVM::XPU::createDeviceCall("_ZN3xpu10fp16tofp32EPKNS_7float16EPfi", + rewriter, op, singleOperandRange, loc); + createMfenceOp(rewriter, loc); + ptrElemScalarTy = resElemScalarTy; + } + // bf16Tofp32 + bool bf16Tofp32 = false; + if (resElemScalarTy.isF32() && ptrElemScalarTy.isBF16()) { + int ptrVecSize = std::min(ptrNumElems, vecSize * 2); + ptrDataVecTy = isVectorized ? VectorType::get(ptrVecSize, ptrElemScalarTy) + : ptrElemScalarTy; + bf16Tofp32 = true; + } + + unsigned ptrDataVecSize = 1u; + unsigned ptrDataNbits = + isa(ptrElemScalarTy) + ? 64u + : ptrElemScalarTy.getIntOrFloatBitWidth(); + if (mlir::isa(ptrDataVecTy)) { + getVectorInfo(ptrDataVecTy, ptrDataVecSize, ptrDataNbits); + } + + SmallVector loadedVals; + Value lmBasePtr = bitcast(llPtrs[0], ptr_ty(ctx, 0)); + if (index) { + ptrNumElems = resNumElems; + unsigned _stride = (bf16Tofp32 && isVectorized) ? ptrNumElems * stride / 2 + : ptrNumElems * stride; + Value idx = mul(llIndex, i32_val(_stride)); + lmBasePtr = gep(ptr_ty(ctx, 0), ptrDataVecTy, lmBasePtr, idx); + } + + if (op.getSVOpt()) { + Value elemPtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + Value loaded = load(ptrElemScalarTy, elemPtr); + if (bf16Tofp32) { + loaded = rewriter.create(loc, resElemScalarTy, loaded); + } + loadedVals.push_back(loaded); + } else if (isDiscreteSame) { + Value elemPtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + Value loaded = load(ptrElemScalarTy, elemPtr); + if (bf16Tofp32) { + loaded = rewriter.create(loc, resElemScalarTy, loaded); + } + for (size_t elemIdx = 0; elemIdx < resNumElems; elemIdx++) { + if (isVectorized) { + Value newVector = rewriter.create(loc, resElemTy); + for (size_t idx = 0; idx < vecSize; ++idx) { + newVector = + insert_element(resElemTy, newVector, loaded, i32_val(idx)); + } + loadedVals.push_back(newVector); + } else { + loadedVals.push_back(loaded); + } + } + } else if (op.getIsDiscrete()) { + if (isVectorized) { + for (size_t vecIdx = 0; vecIdx < resNumElems; ++vecIdx) { + Value newVector = rewriter.create(loc, resElemTy); + for (size_t elemIdx = 0; elemIdx < vecSize; ++elemIdx) { + auto idx = vecIdx * vecSize + elemIdx; + Value elemPtr = bitcast(llPtrs[idx], ptr_ty(ctx, 0)); + Value loaded = load(ptrElemScalarTy, elemPtr); + if (bf16Tofp32) { + loaded = + rewriter.create(loc, resElemScalarTy, loaded); + } + // insert val to newVector + newVector = + insert_element(resElemTy, newVector, loaded, i32_val(elemIdx)); + } + loadedVals.push_back(newVector); + } + } else { + for (size_t elemIdx = 0; elemIdx < resNumElems; elemIdx++) { + Value elemPtr = bitcast(llPtrs[elemIdx], ptr_ty(ctx, 0)); + Value loaded = load(ptrElemScalarTy, elemPtr); + if (bf16Tofp32) { + loaded = + rewriter.create(loc, resElemScalarTy, loaded); + } + loadedVals.push_back(loaded); + } + } + } else if (stride > 1 && isVectorized) { + // Vgather + VectorType offsetTy = + VectorType::get(ptrDataVecSize, int_ty(ptrDataNbits)); + Value offsetVec = rewriter.create(loc, offsetTy); + for (size_t elemIdx = 0; elemIdx < ptrDataVecSize; ++elemIdx) { + Value offsetVal = + int_val(ptrDataNbits, (ptrDataNbits / 8u) * stride * elemIdx); + offsetVec = insert_element(offsetTy, offsetVec, offsetVal, + int_val(ptrDataNbits, elemIdx)); + } + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + for (size_t vecIdx = 0; + vecIdx < (index ? ptrNumElems : (ptrNumElems / ptrDataVecSize)); + ++vecIdx) { + Value vecPtr = gep(ptr_ty(ctx, 0), ptrDataVecTy, lmBasePtr, + int_val(ptrDataNbits, vecIdx * stride)); + Value tmpPtr = bitcast(vecPtr, ptr_ty(ctx, 0)); + Value vgather; + if (ptrElemScalarTy.isF32()) { + vgather = rewriter.create( + loc, offsetTy, tmpPtr, offsetVec); + } else if (ptrElemScalarTy.isF16() || ptrElemScalarTy.isBF16()) { + vgather = rewriter.create( + loc, offsetTy, tmpPtr, offsetVec); + } else { + llvm_unreachable("Only support FP16/BF16/FP32 in VGather!"); + } + Value loaded = bitcast(vgather, resElemTy); + loadedVals.push_back(loaded); + } + if (bf16Tofp32) { + VecBF16ToFP32(op, ctx, loc, rewriter, resElemTy, resNumElems, vecSize, + ptrDataVecSize, loadedVals); + } + } else { // Continuous || Unknown(No VGather) + if (coreDealMultiRows && !isUnknown) { + // Unknown Manipulation in GM2LMOp Conversion + /* Small Col Size Opt GM2LM (14 legal data) + + Before Opt: + 1 1 1 1 1 1 1 1 + 1 1 1 1 1 1 0 0 + + After Opt: + 1 1 1 1 1 1 1 0 + 1 1 1 1 1 1 1 0 + */ + Value bufPtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + auto mem_col_size = + mlir::cast(resTy).getShape()[1] * vecSize; + auto tensor_row_size = + mlir::cast(resTy).getShape()[0] / 64; + unsigned rowRemainElem = mem_col_size - tensor_col_size; + + for (size_t row_idx = 0; row_idx < tensor_row_size; ++row_idx) { + for (size_t col_idx = 0; col_idx < tensor_col_size; ++col_idx) { + auto buf_global_idx = row_idx * tensor_col_size + col_idx; + Value elemPtr = gep(ptr_ty(ctx, 0), ptrElemScalarTy, bufPtr, + i32_val(buf_global_idx)); + Value loaded = load(ptrElemScalarTy, elemPtr); + if (bf16Tofp32) { + loaded = + rewriter.create(loc, resElemScalarTy, loaded); + } + loadedVals.push_back(loaded); + } + + for (size_t remainElem = rowRemainElem; remainElem > 0; + --remainElem) { + Value loaded = int_val(ptrDataNbits, 0); + loaded = bitcast(loaded, ptrElemScalarTy); + if (bf16Tofp32) { + loaded = + rewriter.create(loc, resElemScalarTy, loaded); + } + loadedVals.push_back(loaded); + } + } + + if (isVectorized) { + SmallVector loadedVecVals; + for (size_t vecStart = 0; vecStart < resNumElems; ++vecStart) { + Value newVector = rewriter.create(loc, resElemTy); + for (size_t elemStart = 0; elemStart < vecSize; ++elemStart) { + // insert val to newVector + newVector = + insert_element(resElemTy, newVector, + loadedVals[vecStart * vecSize + elemStart], + i32_val(elemStart)); + } + loadedVecVals.push_back(newVector); + } + loadedVals = loadedVecVals; + } + } else { + if (isVectorized) { + if (bf16Tofp32) { + if (bf16Tofp32Unordered) { + VecBF16ToFP32Unordered(op, ctx, loc, rewriter, resElemTy, + resNumElems, vecSize, ptrDataVecSize, + lmBasePtr, loadedVals); + } else { + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + for (size_t elemIdx = 0; elemIdx < resNumElems / 2; elemIdx++) { + Value elemPtr = gep(ptr_ty(ctx, 0), ptrDataVecTy, lmBasePtr, + i32_val(elemIdx * stride)); + Value loaded = load(ptrDataVecTy, elemPtr); + loadedVals.push_back(loaded); + } + int remainedIdx = 2 * (resNumElems / 2); + if (resNumElems - remainedIdx) { + VectorType halfVecBf16Ty = VectorType::get(vecSize, bf16_ty); + Value elemPtr = gep(ptr_ty(ctx, 0), halfVecBf16Ty, lmBasePtr, + i32_val(remainedIdx * stride)); + Value loaded = load(halfVecBf16Ty, elemPtr); + loadedVals.push_back(loaded); + } + VecBF16ToFP32(op, ctx, loc, rewriter, resElemTy, resNumElems, + vecSize, ptrDataVecSize, loadedVals); + } + } else { + for (size_t elemIdx = 0; elemIdx < resNumElems; elemIdx++) { + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + Value elemPtr = gep(ptr_ty(ctx, 0), ptrDataVecTy, lmBasePtr, + i32_val(elemIdx * stride)); + Value loaded = load(ptrDataVecTy, elemPtr); + loadedVals.push_back(loaded); + } + } + } else { + for (size_t elemIdx = 0; elemIdx < ptrNumElems; elemIdx++) { + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + Value elemPtr = gep(ptr_ty(ctx, 0), ptrElemScalarTy, lmBasePtr, + i32_val(elemIdx * stride)); + Value loaded = load(ptrElemScalarTy, elemPtr); + if (bf16Tofp32) { + loaded = + rewriter.create(loc, resElemScalarTy, loaded); + } + loadedVals.push_back(loaded); + } + } + } + } + + Type llvmResultStructTy = typeConverter->convertType(op.getType()); + Value resultStruct = packLLElements(loc, typeConverter, loadedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct XPUStoreOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + XPUStoreOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + void VecFP32ToBF16Unordered(triton::xpu::StoreOp op, mlir::MLIRContext *ctx, + Location &loc, + ConversionPatternRewriter &rewriter, int numElems, + int valueVecSize, int ptrDataVecSize, + SmallVector &valueElems, + Value &lmBasePtr) const { + VectorType vecBf16Ty = VectorType::get(ptrDataVecSize, bf16_ty); + VectorType veci16Ty = VectorType::get(ptrDataVecSize, i16_ty); + VectorType veci32Ty = VectorType::get(valueVecSize, i32_ty); + VectorType vec1Ty = VectorType::get(ptrDataVecSize, i1_ty); + VectorType halfVecBf16Ty = VectorType::get(valueVecSize, bf16_ty); + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); // vecBf16Ty + int mask = 0xaaaaaaaa; + Value maskVal = i32_val(mask); + maskVal = bitcast(maskVal, vec1Ty); + Value maskNegVal = i32_val(~mask); + maskNegVal = bitcast(maskNegVal, vec1Ty); + Value poseVal = i32_val(16); + uint32_t one = 0x0001; + uint32_t magic = 0x7fff; + for (int i = 0; i < numElems / 2; ++i) { + Value veven = bitcast(valueElems[2 * i], veci32Ty); + if (!isBf16RoundToMid) { + SmallVector vevenAndOperands({i32_val(one), veven}); + auto vevenAnd = rewriter.create( + loc, veci32Ty, vevenAndOperands, "vand.u.mz $0{mr1}, $1, $2", + "=&v,r,v", + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + SmallVector evenOperands({i32_val(magic), vevenAnd.getRes()}); + auto vevenSvAdd = rewriter.create( + loc, veci32Ty, evenOperands, "vadd.u.mz $0{mr1}, $1, $2", "=&v,r,v", + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + veven = add(veven, vevenSvAdd.getRes()); + } + veven = bitcast(veven, veci16Ty); + Value elemPtr = gep(ptr_ty(ctx, 0), vecBf16Ty, lmBasePtr, i32_val(i)); + rewriter.create(loc, veven, elemPtr, + maskVal); + Value vodd = bitcast(valueElems[2 * i + 1], veci32Ty); + if (!isBf16RoundToMid) { + SmallVector oddAndOperands({i32_val(one), vodd}); + auto voddAnd = rewriter.create( + loc, veci32Ty, oddAndOperands, "vand.u.mz $0{mr1}, $1, $2", + "=&v,r,v", + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + SmallVector oddOperands({i32_val(magic), voddAnd.getRes()}); + auto voddSvAdd = rewriter.create( + loc, veci32Ty, oddOperands, "vadd.u.mz $0{mr1}, $1, $2", "=&v,r,v", + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + vodd = add(vodd, voddSvAdd.getRes()); + } + Value voddSr = rewriter.create(loc, veci32Ty, + poseVal, vodd); + voddSr = bitcast(voddSr, veci16Ty); + rewriter.create(loc, voddSr, elemPtr, + maskNegVal); + } + if (numElems % 2 == 1) { + int remainedIdx = numElems - 1; + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + Value elemPtr = + gep(ptr_ty(ctx, 0), halfVecBf16Ty, lmBasePtr, i32_val(remainedIdx)); + Value elem = valueElems[remainedIdx]; + Value trunc = rewriter.create(loc, halfVecBf16Ty, elem); + store(trunc, elemPtr); + } + return; + } + + void VecFP32ToBF16(triton::xpu::StoreOp op, mlir::MLIRContext *ctx, + Location &loc, ConversionPatternRewriter &rewriter, + int numElems, int valueVecSize, int ptrDataVecSize, + SmallVector &valueElems, Value &lmBasePtr) const { + VectorType vecBf16Ty = VectorType::get(ptrDataVecSize, bf16_ty); + VectorType vecI16Ty = VectorType::get(ptrDataVecSize, i16_ty); + VectorType vec1Ty = VectorType::get(ptrDataVecSize, i1_ty); + VectorType halfVecBf16Ty = VectorType::get(valueVecSize, bf16_ty); + VectorType veci32Ty = VectorType::get(valueVecSize, i32_ty); + constexpr int mask = 0xaaaaaaaa; // 0b10101010101010101010101010101010 + Value maskVal = i32_val(mask); + maskVal = bitcast(maskVal, vec1Ty); + SmallVector offset_v = {0, 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, + 10, 0, 12, 0, 14, 0, 16, 0, 18, 0, 20, + 0, 22, 0, 24, 0, 26, 0, 28, 0, 30}; + Value offsetVec = rewriter.create(loc, vecI16Ty); + for (size_t elemIdx = 0; elemIdx < ptrDataVecSize; ++elemIdx) { + offsetVec = insert_element(vecI16Ty, offsetVec, + i16_val(offset_v[elemIdx]), i16_val(elemIdx)); + } + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); // halfVecBf16Ty + uint32_t one = 0x0001; + uint32_t magic = 0x7fff; + for (int i = 0; i < numElems / 2; ++i) { + Value dstPtr1 = + gep(ptr_ty(ctx, 0), halfVecBf16Ty, lmBasePtr, i16_val(2 * i)); + Value vl = bitcast(valueElems[2 * i], veci32Ty); + if (!isBf16RoundToMid) { + SmallVector vlAndOperands({i32_val(one), vl}); + auto vlAnd = rewriter.create( + loc, veci32Ty, vlAndOperands, "vand.u.mz $0{mr1}, $1, $2", + "=&v,r,v", + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + SmallVector vlOperands({i32_val(magic), vlAnd.getRes()}); + auto vlSvAdd = rewriter.create( + loc, veci32Ty, vlOperands, "vadd.u.mz $0{mr1}, $1, $2", "=&v,r,v", + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + vl = add(vl, vlSvAdd.getRes()); + } + vl = bitcast(vl, vecI16Ty); + rewriter.create(loc, vl, maskVal, dstPtr1, + offsetVec); + Value vh = bitcast(valueElems[2 * i + 1], veci32Ty); + if (!isBf16RoundToMid) { + SmallVector vhAndOperands({i32_val(one), vh}); + auto vhAnd = rewriter.create( + loc, veci32Ty, vhAndOperands, "vand.u.mz $0{mr1}, $1, $2", + "=&v,r,v", + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + SmallVector vhOperands({i32_val(magic), vhAnd.getRes()}); + auto vhSvAdd = rewriter.create( + loc, veci32Ty, vhOperands, "vadd.u.mz $0{mr1}, $1, $2", "=&v,r,v", + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + vh = add(vh, vhSvAdd.getRes()); + } + vh = bitcast(vh, vecI16Ty); + Value dstPtr2 = + gep(ptr_ty(ctx, 0), halfVecBf16Ty, lmBasePtr, i16_val(2 * i + 1)); + rewriter.create(loc, vh, maskVal, dstPtr2, + offsetVec); + } + if (numElems % 2 == 1) { + int remainedIdx = numElems - 1; + Value elemPtr = + gep(ptr_ty(ctx, 0), halfVecBf16Ty, lmBasePtr, i32_val(remainedIdx)); + Value elem = valueElems[remainedIdx]; + Value trunc = rewriter.create(loc, halfVecBf16Ty, elem); + store(trunc, elemPtr); + } + return; + } + + LogicalResult + matchAndRewrite(triton::xpu::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + // original values + Value ptr = op.getPtr(); + Value value = op.getValue(); + Value index = op.getIndex(); + + int32_t tensor_col_size = op.getTensorColSize(); + bool coreDealMultiRows = tensor_col_size != -1; + bool bf16Tofp32Unordered = op.getBf16Tofp32Unordered(); + + // adaptor values + Value llPtr = adaptor.getPtr(); + Value llMask = adaptor.getMask(); + Value llValue = adaptor.getValue(); + Value llIndex = adaptor.getIndex(); + + // Determine Type + Type ptrTy = ptr.getType(); + auto valueTy = value.getType(); + + Type ptrElemTy = typeConverter->convertType(getElementTypeOrSelf(ptrTy)); + Type valueElemTy = + typeConverter->convertType(getElementTypeOrSelf(valueTy)); + Type valueElemScalarTy = getElementTypeOrSelf(valueElemTy); + Type ptrElemScalarTy; + if (auto ptrTensorTy = mlir::dyn_cast(ptrTy)) { + // Tensor + ptrElemScalarTy = + mlir::cast(ptrTensorTy.getElementType()) + .getPointeeType(); + } else { + // Scalar + ptrElemScalarTy = mlir::cast(ptrTy).getPointeeType(); + } + Type ptrDataVecTy = valueElemTy; + + unsigned valueNumElems = getTotalElemsPerThread(valueTy); + unsigned ptrNumElems = getTotalElemsPerThread(ptrTy); + + // Get the LLVM values + auto llPtrs = unpackLLElements(loc, llPtr, rewriter); + auto llVals = unpackLLElements(loc, llValue, rewriter); + // Determine the vectorization size + bool isVectorized = mlir::isa(valueElemTy); + unsigned valueVecSize = 1u; + unsigned valueScalarNbits = 32u; + if (mlir::isa(valueElemTy)) { + isVectorized = true; + getVectorInfo(valueElemTy, valueVecSize, valueScalarNbits); + } + + // fp32 to bf16 + bool fp32Tobf16 = false; + if (valueElemScalarTy.isF32() && ptrElemScalarTy.isBF16()) { + int ptrVecSize = std::min(ptrNumElems, valueVecSize * 2); + ptrDataVecTy = isVectorized ? VectorType::get(ptrVecSize, ptrElemScalarTy) + : ptrElemScalarTy; + fp32Tobf16 = true; + } + + bool fp32Tofp16 = false; + if (valueElemScalarTy.isF32() && ptrElemScalarTy.isF16()) { + ptrElemScalarTy = valueElemScalarTy; + fp32Tofp16 = true; + } + + unsigned ptrDataVecSize = 1u; + unsigned ptrDataNbits = + isa(ptrElemScalarTy) + ? 64u + : ptrElemScalarTy.getIntOrFloatBitWidth(); + if (mlir::isa(ptrDataVecTy)) { + getVectorInfo(ptrDataVecTy, ptrDataVecSize, ptrDataNbits); + } + + Value lmBasePtr = bitcast(llPtrs[0], ptr_ty(ctx, 0)); + if (index) { + ptrNumElems = valueNumElems; + unsigned _stride = + (fp32Tobf16 && isVectorized) ? ptrNumElems / 2 : ptrNumElems; + Value idx = mul(llIndex, i32_val(_stride)); + lmBasePtr = gep(ptr_ty(ctx, 0), ptrDataVecTy, lmBasePtr, idx); + } + + if (coreDealMultiRows) { + /* Small Col Size Opt LM2GM (14 legal data) + + Before Opt: + 1 1 1 1 1 1 1 0 + 1 1 1 1 1 1 1 0 + + After Opt: + 1 1 1 1 1 1 1 1 + 1 1 1 1 1 1 0 0 + */ + if (isVectorized) { + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + SmallVector valueElemsScalar; + for (size_t vecStart = 0; vecStart < valueNumElems; ++vecStart) { + for (size_t elemStart = 0; elemStart < valueVecSize; ++elemStart) { + // extract val to ptrElemScalarTy + Value ext_val = extract_element(ptrElemScalarTy, llVals[vecStart], + i32_val(elemStart)); + valueElemsScalar.push_back(ext_val); + } + } + llVals = valueElemsScalar; + } + + auto mem_col_size = mlir::cast(valueTy).getShape()[1] * + valueVecSize; // 16 + auto tensor_row_size = + mlir::cast(valueTy).getShape()[0] / + 64; // 128 / 64 = 2 + + for (size_t row_idx = 0; row_idx < tensor_row_size; ++row_idx) { + for (size_t col_idx = 0; col_idx < tensor_col_size; ++col_idx) { + auto mem_global_idx = row_idx * mem_col_size + col_idx; + auto buf_global_idx = row_idx * tensor_col_size + col_idx; + Value elem = llVals[mem_global_idx]; + if (fp32Tobf16) { + elem = rewriter.create(loc, ptrElemScalarTy, elem); + } + Value elemPtr = gep(ptr_ty(ctx, 0), ptrElemScalarTy, lmBasePtr, + i32_val(buf_global_idx)); + store(elem, elemPtr); + } + } + } else { + if (isVectorized) { + if (valueElemScalarTy.isF32() && ptrElemScalarTy.isBF16()) { + if (bf16Tofp32Unordered) { + VecFP32ToBF16Unordered(op, ctx, loc, rewriter, valueNumElems, + valueVecSize, ptrDataVecSize, llVals, + lmBasePtr); + } else { + VecFP32ToBF16(op, ctx, loc, rewriter, valueNumElems, valueVecSize, + ptrDataVecSize, llVals, lmBasePtr); + } + } else { + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + for (size_t elemIdx = 0; elemIdx < valueNumElems; elemIdx++) { + Value elem = llVals[elemIdx]; + Value elemPtr = + gep(ptr_ty(ctx, 0), ptrDataVecTy, lmBasePtr, i32_val(elemIdx)); + store(elem, elemPtr); + } + } + } else { + lmBasePtr = bitcast(lmBasePtr, ptr_ty(ctx, 0)); + for (size_t elemIdx = 0; elemIdx < ptrNumElems; elemIdx++) { + Value elem = llVals[elemIdx]; + if (fp32Tobf16) { + elem = rewriter.create(loc, ptrElemScalarTy, elem); + } + Value elemPtr = + gep(ptr_ty(ctx, 0), ptrElemScalarTy, lmBasePtr, i32_val(elemIdx)); + store(elem, elemPtr); + } + } + } + createMfenceOp(rewriter, loc); + + // fp32 to fp16 + if (fp32Tofp16) { + Value fp16LM = bitcast(llPtrs[0], ptr_ty(ctx, 0)); + Value fp32LM = bitcast(llPtrs[0], ptr_ty(ctx, 0)); + ValueRange singleOperandRange({fp32LM, fp16LM, i32_val(ptrNumElems)}); + mlir::LLVM::XPU::createDeviceCall("_ZN3xpu10fp32tofp16EPKfPNS_7float16Ei", + rewriter, op, singleOperandRange, loc); + createMfenceOp(rewriter, loc); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct XPUAllocaOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + XPUAllocaOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::xpu::AllocaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + + auto resTy = op.getType(); + Type valueElemTy; + if (auto resTensorTy = mlir::dyn_cast(resTy)) { + // Tensor + valueElemTy = + mlir::cast(resTensorTy.getElementType()) + .getPointeeType(); + } else { + // Scalar + valueElemTy = mlir::cast(resTy).getPointeeType(); + } + + unsigned numElems = getTotalElemsPerThread(resTy); + + auto allocNumElems = numElems; + if (static_cast(targetInfo.getXPUArch()) == XPUArch::XPU2 && + valueElemTy.isF16()) { + // algin to 32, cause fp16tofp32 use vector<32*fp16> instruction + allocNumElems = (allocNumElems + 31) / 32 * 32; + // double space to accommodate 32*fp32 + allocNumElems *= 2; + } + for (auto user : op->getUsers()) { + if (auto gm2lmOp = dyn_cast(user)) { + auto fixedStride = gm2lmOp.getFixedStride(); + if (fixedStride > 0 && + fixedStride * numElems <= targetInfo.getXPUBufferSize()) { + allocNumElems *= fixedStride; + } + } + } + + allocNumElems = + align(allocNumElems, valueElemTy, 64); // 64 bytes aligned for LM + auto lmPtrTy = LLVM::LLVMPointerType::get(ctx, 0); + auto lmBuf = allocate(lmPtrTy, valueElemTy, i32_val(allocNumElems)); + + SmallVector lmPtrs; + for (int i = 0; i < numElems; i++) { + lmPtrs.push_back(gep(lmPtrTy, valueElemTy, lmBuf, i32_val(i))); + } + Type llvmResultStructTy = typeConverter->convertType(resTy); + Value resultStruct = packLLElements(loc, typeConverter, lmPtrs, rewriter, + llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct XPUGM2LMOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + XPUGM2LMOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::xpu::GM2LMOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + + // original values + Value ptr = op.getPtr(); + Value len = op.getLen(); + int32_t tensor_col_size = op.getTensorColSize(); + bool coreDealMultiRows = tensor_col_size != -1; + bool async = op.getAsync(); + + // adaptor values + Value llLen = adaptor.getLen(); + Value llGMPtr = adaptor.getPtr(); + Value llLMPtr = adaptor.getBufPtr(); + Value resultStruct = llLMPtr; + + Type ptrTy = ptr.getType(); + Type elemTy; + if (auto ptrTensorTy = mlir::dyn_cast(ptrTy)) { + // Tensor + elemTy = mlir::cast(ptrTensorTy.getElementType()) + .getPointeeType(); + } else { + // Scalar + elemTy = mlir::cast(ptrTy).getPointeeType(); + } + unsigned elemNbits = isa(elemTy) + ? 64u + : elemTy.getIntOrFloatBitWidth(); + unsigned numElems = getTotalElemsPerThread(ptrTy); + + assert(llLMPtr && "llBufPtr should not be null."); + auto llGMPtrs = unpackLLElements(loc, llGMPtr, rewriter); + auto llLMPtrs = unpackLLElements(loc, llLMPtr, rewriter); + + Value elemBytes = i32_val(elemNbits / 8u); + Value offsetBytes = i32_val(0); + + bool mask = false; + unsigned lenElemBit = 32; + llvm::SmallVector llLens; + if (op.getLen()) { + auto lenElemTy = getElementTypeOrSelf(op.getLen().getType()); + lenElemBit = lenElemTy.getIntOrFloatBitWidth(); + mask = llGMPtrs.size() > 1 ? mlir::isa(lenElemTy) && + (lenElemBit == 32 || lenElemBit == 64) + : false; + } + + Value bufLen = i32_val(numElems); + Value readLen = bufLen; + if (mask) { + llLens = unpackLLElements(loc, llLen, rewriter); + bufLen = int_val(lenElemBit, numElems); + readLen = smin(smax(llLens[0], int_val(lenElemBit, 0)), bufLen); + if (lenElemBit == 64) { + readLen = trunc(i32_ty, readLen); + } + } + Value readBytes = mul(readLen, elemBytes); + + OffsetState offsetState = static_cast(op.getOffsetState()); + int32_t fixedStride = op.getFixedStride(); + if (offsetState == OffsetState::Unknown) { + /* Small Col Size Opt Mask(14 < 16) + + Before Opt: + T T T T T T T T + T T T T T T F F + + After Opt: + T T T T T T T F + T T T T T T T F + */ + SmallVector maskLists; + if (coreDealMultiRows) { + auto mem_col_size = + mlir::cast(ptrTy).getShape()[1]; // 16 + auto tensor_row_size = + mlir::cast(ptrTy).getShape()[0] / + 64; // 128 / 64 = 2 + unsigned rowRemainElem = mem_col_size - tensor_col_size; // 16 - 15 = 1 + + for (size_t row_idx = 0; row_idx < tensor_row_size; ++row_idx) { + for (size_t col_idx = 0; col_idx < tensor_col_size; ++col_idx) { + maskLists.push_back(true); + } + + for (size_t remainElem = rowRemainElem; remainElem > 0; + --remainElem) { + maskLists.push_back(false); + } + } + } + + if (fixedStride > 0 && + numElems * fixedStride <= targetInfo.getXPUBufferSize()) { + // Unknown FixedStride Vgather + readBytes = mul(i32_val(fixedStride), readBytes); + } else { + // Unknown + readBytes = elemBytes; + for (size_t i = 0; i < llGMPtrs.size(); ++i) { + // Protect Ptr Boundary Condition + Value base; + if (coreDealMultiRows) { + base = mask ? select(int_val(1, maskLists[i]), llGMPtrs[i], + llGMPtrs[0]) + : llGMPtrs[i]; + } else { + if (mask) { // Has Mask + if (llLens[0].getType().isInteger(32)) { + base = select(icmp_slt(i32_val(i), llLens[0]), llGMPtrs[i], + llGMPtrs[0]); + } else if (llLens[0].getType().isInteger(64)) { + base = select(icmp_slt(i64_val(i), llLens[0]), llGMPtrs[i], + llGMPtrs[0]); + } else { + llvm_unreachable("Unsupported Mask Int Type"); + } + } else { + base = llGMPtrs[i]; + } + } + Value dstPtr = bitcast(llLMPtrs[i], ptr_ty(ctx, 0)); + Value srcPtr = bitcast(base, ptr_ty(ctx, 1)); + createGM2LMOp(rewriter, ctx, loc, srcPtr, dstPtr, offsetBytes, + readBytes); + if (!async) + createMfenceOp(rewriter, loc); + } + + rewriter.replaceOp(op, {resultStruct}); + return success(); + } + } else if (offsetState == OffsetState::Discrete) { + // Reorder the local buffer ptrs. + SmallVector newLmBufPtrs(llGMPtrs.size()); + Value basePtrInt = ptrtoint(i64_ty, llGMPtrs[0]); + for (size_t idx = 0; idx < llGMPtrs.size(); ++idx) { + Value elemPtrInt = ptrtoint(i64_ty, llGMPtrs[idx]); // convert to int + Value offsetBytes = + sub(elemPtrInt, basePtrInt); // get the offset(Bytes) + Value elemPtr = gep(ptr_ty(ctx, 0), i8_ty, llLMPtrs[0], offsetBytes); + newLmBufPtrs[idx] = elemPtr; + } + resultStruct = packLLElements(loc, typeConverter, newLmBufPtrs, rewriter, + llLMPtr.getType()); + } else if (offsetState == OffsetState::DiscreteSame) { + readBytes = elemBytes; + SmallVector newLmBufPtrs(llLMPtrs.size(), llLMPtrs[0]); + resultStruct = packLLElements(loc, typeConverter, newLmBufPtrs, rewriter, + llLMPtr.getType()); + } else if (offsetState == OffsetState::LocallyContinuous) { + int64_t _rowLen = op.getRowLen(); + int64_t _rowStride = op.getRowStride(); + if (_rowLen % numElems == 0) { + offsetState = OffsetState::Continuous; + LLVM_DEBUG( + llvm::dbgs() + << "[OffsetState]: GM2LM Update LocallyContinuous to Continuous\n"); + } else { + auto oldBlock = op->getBlock(); + auto newBlock = oldBlock->splitBlock(op->getNextNode()); + int64_t _elemBytes = elemNbits / 8u; + int64_t _bufLen = static_cast(numElems); + LLVM_DEBUG(llvm::dbgs() << "[GM2LM LocallyContinuous]: rowLen is " + << _rowLen << ", rowStride is " << _rowStride + << ", bufLen is " << _bufLen << "\n"); + if (_rowStride == -1) { + lowerLocallyContinuousUnfixedStride( + op, loc, rewriter, _rowLen, _bufLen, _elemBytes, llGMPtr, llLMPtr, + llLen, offsetBytes, MemCpyType::GM2LM, oldBlock, newBlock); + } else { + if (_rowLen > _bufLen) { + lowerLocallyContinuousLargeRow( + op, loc, rewriter, _rowLen, _rowStride, llGMPtr, llLMPtr, llLen, + bufLen, elemBytes, offsetBytes, MemCpyType::GM2LM, oldBlock, + newBlock); + } else { + lowerLocallyContinuousSmallRow( + op, loc, rewriter, _rowLen, _rowStride, llGMPtr, llLMPtr, llLen, + bufLen, elemBytes, offsetBytes, MemCpyType::GM2LM, oldBlock, + newBlock); + } + } + + if (!async) + createMfenceOp(rewriter, loc); + + resultStruct = packLLElements(loc, typeConverter, llLMPtrs, rewriter, + llLMPtr.getType()); + rewriter.replaceOp(op, {resultStruct}); + rewriter.create(loc, ValueRange{}, newBlock); + return success(); + } + } + + Value dstPtr = bitcast(llLMPtrs[0], ptr_ty(ctx, 0)); + Value srcPtr = bitcast(llGMPtrs[0], ptr_ty(ctx, 1)); + createGM2LMOp(rewriter, ctx, loc, srcPtr, dstPtr, offsetBytes, readBytes); + if (!async) + createMfenceOp(rewriter, loc); + + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +struct XPULM2GMOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + + XPULM2GMOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::xpu::LM2GMOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + + // original values + Value ptr = op.getPtr(); + Value value = op.getValue(); + Value len = op.getLen(); + int32_t offsetStateInt = op.getOffsetState(); + OffsetState offsetState = static_cast(offsetStateInt); + auto tensor_col_size = op.getTensorColSize(); + + // adaptor values + Value llPtr = adaptor.getPtr(); + Value llLen = adaptor.getLen(); + Value llBufPtr = adaptor.getBufPtr(); + assert(llBufPtr && "llBufPtr should not be null."); + + // Get elemTy and numElems + Type ptrTy = ptr.getType(); + Type ptrElemTy = typeConverter->convertType(getElementTypeOrSelf(ptrTy)); + Type elemTy; + if (auto ptrTensorTy = mlir::dyn_cast(ptrTy)) { + // Tensor + elemTy = mlir::cast(ptrTensorTy.getElementType()) + .getPointeeType(); + } else { + // Scalar + elemTy = mlir::cast(ptrTy).getPointeeType(); + } + unsigned elemNbits = isa(elemTy) + ? 64u + : elemTy.getIntOrFloatBitWidth(); + Value elemBytes = i32_val(elemNbits / 8u); + unsigned numElems = getTotalElemsPerThread(ptrTy); + + // Get base, readBytes and offsetBytes + auto llPtrs = unpackLLElements(loc, llPtr, rewriter); + + Value base = llPtrs[0]; + Value offsetBytes = i32_val(0); + + llvm::SmallVector llLens; + bool mask = false; + unsigned lenElemBit = 32; + if (op.getLen()) { + auto lenElemTy = getElementTypeOrSelf(op.getLen().getType()); + lenElemBit = lenElemTy.getIntOrFloatBitWidth(); + mask = llPtrs.size() > 1 ? mlir::isa(lenElemTy) && + (lenElemBit == 32 || lenElemBit == 64) + : false; + } + Value bufLen = i32_val(numElems); + Value readLen = bufLen; + if (mask) { + llLens = unpackLLElements(loc, llLen, rewriter); + bufLen = int_val(lenElemBit, numElems); + readLen = smin(smax(llLens[0], int_val(lenElemBit, 0)), bufLen); + if (lenElemBit == 64) { + readLen = trunc(i32_ty, readLen); + } + } + Value readBytes = mul(readLen, elemBytes); + auto lmBufPtrs = unpackLLElements(loc, llBufPtr, rewriter); + Value lmBuf = lmBufPtrs[0]; + + // Create LM2GM and mfence + switch (offsetState) { + case OffsetState::Continuous: { + Value srcPtr = bitcast(lmBuf, ptr_ty(ctx, 0)); + Value basePtr = bitcast(base, ptr_ty(ctx, 1)); + createLM2GMOp(rewriter, ctx, loc, srcPtr, basePtr, offsetBytes, + readBytes); + break; + } + case OffsetState::LocallyContinuous: { + int64_t _rowLen = op.getRowLen(); + int64_t _rowStride = op.getRowStride(); + if (_rowLen % numElems == 0) { + offsetState = OffsetState::Continuous; + LLVM_DEBUG( + llvm::dbgs() + << "[OffsetState]: LM2GM Update LocallyContinuous to Continuous\n"); + } else { + auto oldBlock = op->getBlock(); + auto newBlock = oldBlock->splitBlock(op->getNextNode()); + int64_t _elemBytes = elemNbits / 8u; + int64_t _bufLen = static_cast(numElems); + LLVM_DEBUG(llvm::dbgs() << "[LM2GM LocallyContinuous]: rowLen is " + << _rowLen << ", rowStride is " << _rowStride + << ", bufLen is " << _bufLen << "\n"); + + if (_rowStride == -1) { + lowerLocallyContinuousUnfixedStride( + op, loc, rewriter, _rowLen, _bufLen, _elemBytes, llPtr, llBufPtr, + llLen, offsetBytes, MemCpyType::LM2GM, oldBlock, newBlock); + } else { + if (_rowLen > _bufLen) { + lowerLocallyContinuousLargeRow( + op, loc, rewriter, _rowLen, _rowStride, llPtr, llBufPtr, llLen, + bufLen, elemBytes, offsetBytes, MemCpyType::LM2GM, oldBlock, + newBlock); + } else { + lowerLocallyContinuousSmallRow( + op, loc, rewriter, _rowLen, _rowStride, llPtr, llBufPtr, llLen, + bufLen, elemBytes, offsetBytes, MemCpyType::LM2GM, oldBlock, + newBlock); + } + } + createMfenceOp(rewriter, loc); + rewriter.eraseOp(op); + rewriter.create(loc, ValueRange{}, newBlock); + return success(); + } + break; + } + case OffsetState::Unknown: { + for (size_t llPtrIdx = 0; llPtrIdx < llPtrs.size(); ++llPtrIdx) { + Value maskedIdx; + if (mask) { + auto llLenTy = llLens[0].getType(); + if (llLenTy.isInteger(32)) { + maskedIdx = select(icmp_slt(i32_val(llPtrIdx), llLens[0]), + i32_val(llPtrIdx), i32_val(0)); + } else if (llLenTy.isInteger(64)) { + maskedIdx = select(icmp_slt(i64_val(llPtrIdx), llLens[0]), + i64_val(llPtrIdx), i64_val(0)); + } else { + llvm_unreachable("Unsupported Mask Int Type"); + } + } else { + maskedIdx = i32_val(llPtrIdx); + } + + lmBuf = bitcast(lmBuf, ptr_ty(ctx, 0)); + Value elemPtr = gep(ptr_ty(ctx, 0), elemTy, lmBuf, maskedIdx); + Value srcPtr = bitcast(elemPtr, ptr_ty(ctx, 0)); + // Protect Ptr Boundary Condition + Value dstPtr; + if (mask) { + auto llLenTy = llLens[0].getType(); + if (llLenTy.isInteger(32)) { + dstPtr = select(icmp_slt(i32_val(llPtrIdx), llLens[0]), + llPtrs[llPtrIdx], llPtrs[0]); + } else if (llLenTy.isInteger(64)) { + dstPtr = select(icmp_slt(i64_val(llPtrIdx), llLens[0]), + llPtrs[llPtrIdx], llPtrs[0]); + } else { + llvm_unreachable("Unsupported Mask Int Type"); + } + } else { + dstPtr = llPtrs[llPtrIdx]; + } + createLM2GMOp(rewriter, ctx, loc, srcPtr, dstPtr, offsetBytes, + elemBytes); + } + break; + } + default: + llvm_unreachable("Unknown offset state"); + break; + } + createMfenceOp(rewriter, loc); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct XPUAtomicRMWOpConversion + : public ConvertOpToLLVMPattern, + public LoadStoreConversionBase { + + XPUAtomicRMWOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + + Value ptr = op.getPtr(); + Value val = op.getVal(); + Value mask = op.getMask(); + auto atomicRmwAttr = op.getAtomicRmwOp(); + + Value llPtr = adaptor.getPtr(); + Value llValue = adaptor.getVal(); + Value llMask = adaptor.getMask(); + + auto llPtrs = unpackLLElements(loc, llPtr, rewriter); + auto llValues = unpackLLElements(loc, llValue, rewriter); + + auto resTy = op.getType(); + Type valueElemTy = getElementTypeOrSelf(getElementTypeOrSelf(resTy)); + unsigned numElems = getTotalElemsPerThread(resTy); + + std::string funcName; + if (valueElemTy.isF16()) { + switch (atomicRmwAttr) { + case RMWOp::ADD: + funcName = "_ZN3xpu9atomicAddEPU3AS1DF16_DF16_"; + break; + case RMWOp::FADD: + funcName = "_ZN3xpu9atomicAddEPU3AS1DF16_DF16_"; + break; + default: + return failure(); + } + } else { + switch (atomicRmwAttr) { + case RMWOp::ADD: + funcName = "_ZN3xpu9atomicAddEPU3AS1ff"; + break; + case RMWOp::FADD: + funcName = "_ZN3xpu9atomicAddEPU3AS1ff"; + break; + default: + return failure(); + } + } + + SmallVector resultVals(numElems); + for (unsigned i = 0; i < numElems; ++i) { + ValueRange operandRange({llPtrs[i], llValues[i]}); + Value devCall = mlir::LLVM::XPU::createDeviceCall( + funcName, rewriter, op, valueElemTy, operandRange, loc); + resultVals[i] = devCall; + } + + Type structTy = this->getTypeConverter()->convertType(resTy); + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, structTy); + rewriter.replaceOp(op, resultStruct); + + return success(); + } +}; + +} // namespace + +void mlir::triton::xpu::populateLoadStoreOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/MakeRangeOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/MakeRangeOpToLLVM.cpp new file mode 100644 index 000000000..7d5c0bff3 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/MakeRangeOpToLLVM.cpp @@ -0,0 +1,447 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +template +struct RangeOpConversionBase : public ConvertOpToLLVMPattern { + explicit RangeOpConversionBase(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + converter(converter), targetInfo(targetInfo), benefit(benefit) {} + + using ConvertOpToLLVMPattern::getTypeConverter; + +protected: + LLVMTypeConverter &converter; + const TargetInfoBase &targetInfo; + PatternBenefit benefit; +}; + +struct XPUMakeRangeOpConversion + : public RangeOpConversionBase { + XPUMakeRangeOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : RangeOpConversionBase(converter, targetInfo, + benefit) {} + + // Emit indices calculation within each ConversionPattern, and returns a + // [elemsPerThread X rank] index matrix. + inline SmallVector> + emitIndices(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type) const { + + auto clusterLayout = mlir::cast(layout); + auto shape = type.getShape(); + unsigned rank = shape.size(); + unsigned elemsPerCore = clusterLayout.getTotalElemsPerThread(shape, type); + SmallVector> indices(elemsPerCore, + SmallVector(rank)); + + // offset = idInsideGroup * elemsPerCore + n + Value coreId = getThreadId(rewriter, loc); + unsigned groupSize = product(clusterLayout.getCoresPerGroup()); + Value idInsideGroup = srem(coreId, i32_val(groupSize)); + Value base = mul(idInsideGroup, i32_val(elemsPerCore)); + + for (unsigned n = 0; n < elemsPerCore; ++n) { + for (unsigned k = 0; k < rank; ++k) { + indices[n][k] = add(base, idx_val(n)); + } + } + return indices; + } + + inline SmallVector> + emitIndices(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, const Value &loopIdx) const { + + auto clusterLayout = mlir::cast(layout); + auto shape = type.getShape(); + unsigned rank = shape.size(); + unsigned elemsPerCore = clusterLayout.getTotalElemsPerThread(shape, type); + SmallVector> indices(elemsPerCore, + SmallVector(rank)); + + // const int nthreads = core_num() * cluster_num(); + // const int tid = cluster_id() * core_num() + core_id(); + // for (int i = 0; i < iterCount; ++i) { + // const int idx = tid + nthreads * i; + // const int indice = idx * buf_len; + Value coreNum = mlir::LLVM::XPU::getBlockDim(rewriter, loc); + auto coresPerGroup = clusterLayout.getCoresPerGroup(); + auto groupsPerCluster = clusterLayout.getGroupsPerCluster(); + bool atomicSim = (llvm::find_if(coresPerGroup, + [](unsigned int num) { + return num != 1; + }) == coresPerGroup.end()) && + (llvm::find_if(groupsPerCluster, [](unsigned int num) { + return num != 1; + }) == groupsPerCluster.end()); + + Value coreId = getThreadId(rewriter, loc); + Value bufLen = i32_val(elemsPerCore); + Value base; + if (atomicSim) { + base = mul(loopIdx, bufLen); + } else { + base = mul(add(coreId, mul(loopIdx, coreNum)), bufLen); + } + + for (unsigned n = 0; n < elemsPerCore; ++n) { + for (unsigned k = 0; k < rank; ++k) { + indices[n][k] = add(base, idx_val(n)); + } + } + return indices; + } + + inline SmallVector> + emitIndices(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, const Value &loopIdx, + const Value &unrollIdx, uint32_t range) const { + + auto clusterLayout = mlir::cast(layout); + auto shape = type.getShape(); + unsigned rank = shape.size(); + + unsigned _unrollNum = clusterLayout.getTotalElemsPerThread(shape, type); + auto coresPerGroup = clusterLayout.getCoresPerGroup().back(); + unsigned _elemsPerCore = range / coresPerGroup; + SmallVector> indices(_unrollNum, + SmallVector(rank)); + + // (idInsideGroup + loopIdx * groupSize) * elemsPerCore + unrollIdx * + // unrollNum + (0, unrollNum) + Value coreId = getThreadId(rewriter, loc); + Value elemsPerCore = i32_val(_elemsPerCore); + Value unrollNum = i32_val(_unrollNum); + Value _loopIdx = loopIdx ? loopIdx : i32_val(0); + Value groupSize = i32_val(coresPerGroup); + Value idInsideGroup = srem(coreId, groupSize); + Value base = + mul(add(idInsideGroup, mul(_loopIdx, groupSize)), elemsPerCore); + base = add(base, mul(unrollIdx, unrollNum)); + + for (unsigned n = 0; n < _unrollNum; ++n) { + for (unsigned k = 0; k < rank; ++k) { + indices[n][k] = add(base, idx_val(n)); + } + } + return indices; + } + + LogicalResult + matchAndRewrite(triton::xpu::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + RankedTensorType ty = op.getType(); + auto shape = ty.getShape(); + auto layout = ty.getEncoding(); + auto elemTy = ty.getElementType(); + assert(elemTy.isInteger(32)); + uint32_t _start = op.getStart(); + uint32_t _end = op.getEnd(); + Value start = createIndexAttrConstant(rewriter, loc, elemTy, _start); + uint32_t _range = _end - _start; + + auto loopIndex = + adaptor.getLoopIndex(); // TODO[dyq]: check loopIndex Lowering Logic + auto unrollIndex = adaptor.getUnrollIndex(); + SmallVector> idxs; + if (unrollIndex) { + idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, loopIndex, + unrollIndex, _range); + } else if (loopIndex) { + idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, loopIndex); + } else { + idxs = emitIndices(loc, rewriter, targetInfo, layout, ty); + } + + unsigned elems = idxs.size(); + SmallVector retVals(elems); + // TODO: slice layout has more elements than expected. + // Unexpected behavior for make range, but generally OK when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. + for (const auto &multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = add(multiDim.value()[0], start); + } + auto typeConverter = getTypeConverter(); + Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct XPUOutRangeOpConversion + : public RangeOpConversionBase { + + SmallVector> + emitIndices(Location loc, RewriterBase &rewriter, const Attribute &layout, + RankedTensorType type, int groupsize, int rowspercore, + const Value &index) const { + + auto clusterLayout = mlir::cast(layout); + auto shape = type.getShape(); + unsigned rank = shape.size(); + unsigned elemsPerCore = clusterLayout.getTotalElemsPerThread(shape, type); + SmallVector> indices(elemsPerCore, + SmallVector(rank)); + + // offset = (idx * group_num + group_id) * rowspercore + (0 ... + // rowspercore-1) + unsigned ngroup = product(clusterLayout.getGroupsPerCluster()); + + Value coreId = getThreadId(rewriter, loc); + Value groupId = sdiv(coreId, i32_val(groupsize)); + Value base = + mul(add(mul(index, i32_val(ngroup)), groupId), i32_val(rowspercore)); + + for (unsigned n = 0; n < elemsPerCore; ++n) { + for (unsigned k = 0; k < rank; ++k) { + indices[n][k] = add(base, idx_val(n)); + } + } + return indices; + } + + XPUOutRangeOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : RangeOpConversionBase(converter, targetInfo, + benefit) {} + + LogicalResult + matchAndRewrite(triton::xpu::OutRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + RankedTensorType ty = op.getType(); + auto shape = ty.getShape(); + auto layout = ty.getEncoding(); + auto elemTy = ty.getElementType(); + assert(elemTy.isInteger(32)); + + auto groupsize = adaptor.getGroupsize(); + auto rowspercore = adaptor.getRowspercore(); + auto index = adaptor.getIndex(); + + auto idxs = + emitIndices(loc, rewriter, layout, ty, groupsize, rowspercore, index); + + unsigned elems = idxs.size(); + SmallVector retVals(elems); + // TODO: slice layout has more elements than expected. + // Unexpected behavior for make range, but generally OK when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. + for (const auto &multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = multiDim.value()[0]; + } + + auto typeConverter = getTypeConverter(); + Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct XPUInterleaveOpConversion + : public RangeOpConversionBase { + + XPUInterleaveOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : RangeOpConversionBase(converter, targetInfo, + benefit) {} + + inline SmallVector> + emitIndices(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, const Value &loopIdx) const { + + auto clusterLayout = mlir::cast(layout); + auto shape = type.getShape(); + auto elemBit = type.getElementType().getIntOrFloatBitWidth(); + unsigned rank = shape.size(); + unsigned elemsPerCore = clusterLayout.getTotalElemsPerThread(shape, type); + + // const int nthreads = core_num() * cluster_num(); + // const int tid = core_id() * cluster_num() + cluster_id(); + // for (int i = 0; i < iterCount; ++i) { + // const int idx = tid + nthreads * i; + // const int offset = idx * buf_len; + Value coreNum = mlir::LLVM::XPU::getBlockDim(rewriter, loc); + auto coresPerGroup = clusterLayout.getCoresPerGroup(); + auto groupsPerCluster = clusterLayout.getGroupsPerCluster(); + + bool atomicSim = (llvm::find_if(coresPerGroup, + [](unsigned int num) { + return num != 1; + }) == coresPerGroup.end()) && + (llvm::find_if(groupsPerCluster, [](unsigned int num) { + return num != 1; + }) == groupsPerCluster.end()); + + Value base; + if (atomicSim) { + base = mul(loopIdx, i32_val(elemsPerCore)); + } else { + Value clusterNum = mlir::LLVM::XPU::getGridDim(rewriter, loc); + Value coreNum = mlir::LLVM::XPU::getBlockDim(rewriter, loc); + Value clusterId = mlir::LLVM::XPU::getBlockId(rewriter, loc); + Value coreId = getThreadId(rewriter, loc); + Value bufLen = i32_val(elemsPerCore); + Value _loopIdx = loopIdx; + if (elemBit == 64) { + bufLen = i64_val(elemsPerCore); + clusterNum = rewriter.create(loc, i64_ty, clusterNum); + coreNum = rewriter.create(loc, i64_ty, coreNum); + clusterId = rewriter.create(loc, i64_ty, clusterId); + coreId = rewriter.create(loc, i64_ty, coreId); + _loopIdx = rewriter.create(loc, i64_ty, loopIdx); + } + Value nThread = mul(clusterNum, coreNum); + Value tid = add(mul(coreId, clusterNum), clusterId); + Value idx = add(tid, mul(nThread, _loopIdx)); + base = mul(idx, bufLen); + } + + SmallVector> indices(elemsPerCore, + SmallVector(rank)); + for (unsigned n = 0; n < elemsPerCore; ++n) { + for (unsigned k = 0; k < rank; ++k) { + indices[n][k] = add(base, int_val(elemBit, n)); + } + } + return indices; + } + + inline SmallVector> + emitIndices(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, const Value &loopIdx, + const Value &unrollIdx, uint32_t range) const { + + auto clusterLayout = mlir::cast(layout); + auto shape = type.getShape(); + auto elemBit = type.getElementType().getIntOrFloatBitWidth(); + unsigned rank = shape.size(); + + unsigned _unrollNum = clusterLayout.getTotalElemsPerThread(shape, type); + auto unrollNum = i32_val(_unrollNum); + auto _coresPerGroup = product(clusterLayout.getCoresPerGroup()); + auto _groupsPerCluster = product(clusterLayout.getGroupsPerCluster()); + auto _coreNum = _coresPerGroup * _groupsPerCluster; + unsigned _elemsPerCore = ceil(range, _coreNum); + + // const int nthreads = core_num() * cluster_num(); + // const int tid = core_id() * cluster_num() + cluster_id(); + // for (int loopIdx = 0; loopIdx < iterCount; ++loopIdx) { + // for (int unrollIdx = 0; unrollIdx < unrollNum; ++unrollIdx) { + // const int idx = tid + nthreads * loopIdx; + // const int offset = idx * bufLen + unrollIdx * unrollNum; + Value coreNum = mlir::LLVM::XPU::getBlockDim(rewriter, loc); + auto coresPerGroup = clusterLayout.getCoresPerGroup(); + auto groupsPerCluster = clusterLayout.getGroupsPerCluster(); + + bool atomicSim = (llvm::find_if(coresPerGroup, + [](unsigned int num) { + return num != 1; + }) == coresPerGroup.end()) && + (llvm::find_if(groupsPerCluster, [](unsigned int num) { + return num != 1; + }) == groupsPerCluster.end()); + + Value base; + Value newUnrollIdx = unrollIdx; + Value bufLen = int_val(elemBit, _elemsPerCore); + if (atomicSim) { + base = mul(loopIdx, bufLen); + } else { + Value clusterNum = mlir::LLVM::XPU::getGridDim(rewriter, loc); + Value coreNum = mlir::LLVM::XPU::getBlockDim(rewriter, loc); + Value clusterId = mlir::LLVM::XPU::getBlockId(rewriter, loc); + Value coreId = getThreadId(rewriter, loc); + Value _loopIdx = loopIdx; + if (elemBit == 64) { + clusterNum = rewriter.create(loc, i64_ty, clusterNum); + coreNum = rewriter.create(loc, i64_ty, coreNum); + clusterId = rewriter.create(loc, i64_ty, clusterId); + coreId = rewriter.create(loc, i64_ty, coreId); + _loopIdx = rewriter.create(loc, i64_ty, loopIdx); + newUnrollIdx = rewriter.create(loc, i64_ty, unrollIdx); + unrollNum = rewriter.create(loc, i64_ty, unrollNum); + } + Value nThread = mul(clusterNum, coreNum); + Value tid = add(mul(coreId, clusterNum), clusterId); + Value idx = add(tid, mul(nThread, _loopIdx)); + base = add(mul(idx, bufLen), mul(newUnrollIdx, unrollNum)); + } + + SmallVector> indices(_unrollNum, + SmallVector(rank)); + for (unsigned n = 0; n < _unrollNum; ++n) { + for (unsigned k = 0; k < rank; ++k) { + indices[n][k] = add(base, int_val(elemBit, n)); + } + } + return indices; + } + + LogicalResult + matchAndRewrite(triton::xpu::InterleaveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + RankedTensorType ty = op.getType(); + auto shape = ty.getShape(); + auto layout = ty.getEncoding(); + auto elemTy = ty.getElementType(); + assert(elemTy.isInteger(32) || elemTy.isInteger(64)); + uint32_t _start = op.getStart(); + uint32_t _end = op.getEnd(); + Value start = createIndexAttrConstant(rewriter, loc, elemTy, _start); + uint32_t _range = _end - _start; + + auto loopIndex = + adaptor.getLoopIndex(); // TODO[dyq]: check loopIndex Lowering Logic + auto unrollIndex = adaptor.getUnrollIndex(); + SmallVector> idxs; + if (unrollIndex) { + idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, loopIndex, + unrollIndex, _range); + } else { + idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, loopIndex); + } + + unsigned elems = idxs.size(); + SmallVector retVals(elems); + // TODO: slice layout has more elements than expected. + // Unexpected behavior for make range, but generally OK when followed by + // expand dims + broadcast. very weird behavior otherwise potentially. + for (const auto &multiDim : llvm::enumerate(idxs)) { + assert(multiDim.value().size() == 1); + retVals[multiDim.index()] = add(multiDim.value()[0], start); + } + auto typeConverter = getTypeConverter(); + Value result = packLLElements(loc, typeConverter, retVals, rewriter, ty); + rewriter.replaceOp(op, result); + return success(); + } +}; + +void mlir::triton::xpu::populateMakeRangeOpToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h new file mode 100644 index 000000000..4c37421d2 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h @@ -0,0 +1,135 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITON_CONVERSION_TRITONXPU_TO_LLVM_PATTERNS_TRITON_XPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONXPU_TO_LLVM_PATTERNS_TRITON_XPU_OP_TO_LLVM_H + +// clang-format off +// Dialect +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "triton/Dialect/LLVMXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" + +#include "triton/Analysis/Membar.h" // ModuleMembarAnalysis +#include "triton/Analysis/AxisInfo.h" // ModuleAxisInfoAnalysis +#include "triton/Analysis/Allocation.h" // ModuleAllocation + + +#include "triton/Analysis/Utility.h" +#include "xpu/lib/Conversion/TritonXPUToLLVM/Utility.h" + +#include "xpu/lib/Conversion/TritonXPUToLLVM/TargetInfo.h" // TargetInfo +#include "triton/Conversion/TritonXPUToLLVM/TypeConverter.h" // TritonXPUToLLVMTypeConverter + +#include "llvm/Support/ErrorHandling.h" +// clang-format on + +namespace mlir { +namespace triton { +namespace xpu { + +//===----------------------------------------------------------------------===// +// triton::xpu::LoadOp, triton::xpu::StoreOp, triton::xpu::AllocaOp, +// triton::xpu::GM2LMOp, triton::xpu::LM2GMOp +//===----------------------------------------------------------------------===// +void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// mlir::gpu::ThreadIdOp -> mlir::LLVM::XPU::CoreIdOp +// mlir::gpu::BlockIdOp -> mlir::LLVM::XPU::LoadParamOp[0] +// mlir::gpu::GridDimOp -> mlir::LLVM::XPU::LoadParamOp[1] +// mlir::gpu::BlockDimOp -> mlir::LLVM::XPU::LoadParamOp[2] +// +// Collect a set of patterns to convert from the mlir::gpu dialect to +// mlir::LLVM::XPU. +//===----------------------------------------------------------------------===// +void populateGPUToXPUConversionPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// triton::xpu::MakeRangeOp +//===----------------------------------------------------------------------===// +void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// triton::xpu::ExtractOp +//===----------------------------------------------------------------------===// +void populateTTXPUUtilityOpToLLVMConversionPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// triton::xpu::ExtractOp +//===----------------------------------------------------------------------===// +void populateTTXPUVectorizedOpToLLVMConversionPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// arith::ExtFOp -> LLVM::FPExtOp arith::TruncFOp -> LLVM::FPTruncOp +// arith::SIToFPOp -> LLVM::SIToFPOp arith::FPToSIOp -> LLVM::FPToSIOp +// triton::PreciseSqrtOp -> LLVM::SqrtOp +// arith::AddFOp -> LLVM::FAddOp arith::SubFOp, LLVM::FSubOp +// arith::MulFOp -> LLVM::FMulOp arith::DivFOp, LLVM::FDivOp +// triton::PreciseDivFOp -> LLVM::FDivOp +//===----------------------------------------------------------------------===// +void populateElementwiseOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfo &targetInfo, + PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// triton::xpu::ConvertLayoutOp -> <> +//===----------------------------------------------------------------------===// +void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// triton::ExpandDimsOp -> <> +//===----------------------------------------------------------------------===// +void populateViewOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// triton::xpu::ReduceOp -> calculation logic +//===----------------------------------------------------------------------===// +void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// triton::FuncOp -> LLVM::FuncOp +//===----------------------------------------------------------------------===// +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +//===----------------------------------------------------------------------===// +// triton::GetNumProgramsOp -> LLVM::LoadParamOp +//===----------------------------------------------------------------------===// +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +} // namespace xpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITONXPU_TO_LLVM_PATTERNS_TRITON_XPU_OP_TO_LLVM_H diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ReduceOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ReduceOpToLLVM.cpp new file mode 100644 index 000000000..ae0452bbd --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ReduceOpToLLVM.cpp @@ -0,0 +1,1002 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" +#include "llvm/ADT/TypeSwitch.h" +#include +namespace { + +using ::mlir::triton::gpu::getTotalElemsPerThread; + +struct XPUReduceOpConversion + : public ConvertOpToLLVMPattern { + + XPUReduceOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + inline SmallVector> + emitIndices(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, const Value &loopIdx) const { + SmallVector shape(type.getShape()); + + // for sliceEncoding, we need to set layout as its parent + if (auto slice = dyn_cast(layout)) { + layout = slice.getParent(); + shape.insert(shape.begin() + slice.getDim(), 1); + } + + auto clusterLayout = mlir::cast(layout); + unsigned rank = shape.size(); + unsigned elemsPerCore = clusterLayout.getTotalElemsPerThread(shape, type); + SmallVector> indices(elemsPerCore, + SmallVector(rank)); + + // const int nthreads = core_num() * cluster_num(); + // const int tid = cluster_id() * core_num() + core_id(); + // for (int i = 0; i < iterCount; ++i) { + // const int idx = tid + nthreads * i; + // const int indice = idx * buf_len; + Value coreNum = mlir::LLVM::XPU::getBlockDim(rewriter, loc); + Value coreId = getThreadId(rewriter, loc); + Value bufLen = i32_val(elemsPerCore); + Value base = mul(add(coreId, mul(loopIdx, coreNum)), bufLen); + + for (unsigned n = 0; n < elemsPerCore; ++n) { + for (unsigned k = 0; k < rank; ++k) { + indices[n][k] = add(base, idx_val(n)); + } + } + return indices; + } + + // Return the pointee type of the shared memory pointer for operand i. + Type getElementType(triton::xpu::ReduceOp op, int i) const { + auto ty = getElementTypeOrSelf(op.getInputTypes()[i].getElementType()); + return getTypeConverter()->convertType(ty); + } + + // Helper to compute the smem bases in both reductions and scans + std::pair, SmallVector> + getSmemBases(triton::xpu::ReduceOp op, unsigned elems, + ConversionPatternRewriter &rewriter) const { + ReduceOpHelper helper(op); + SmallVector offsets; + // auto curIdx = helper.getReduceId(); + auto prevSMOffset = + helper.getReduceId() == 0 + ? 0 + : helper.getSMOffsets(helper.getReduceId() - 1)->endOffset; + // op->dump(); + // LLVM_DEBUG(llvm::dbgs() << "\nprevSMOffset = " << prevSMOffset << "\n"); + + auto loc = op.getLoc(); + // indices will store the index of the op operands in descending order + // of their bitwidths + std::vector indices(op.getNumOperands() - 1); // skip loopIndex + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) { + if (i == op.getNumOperands() - 1 || + j == op.getNumOperands() - 1) { // skip loopIndex + return false; + } + return op.getElementTypes()[i].getIntOrFloatBitWidth() > + op.getElementTypes()[j].getIntOrFloatBitWidth(); + }); + + // Assign base index to each operand in their order in indices + std::map indexToBase; + indexToBase[indices[0]] = + LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); + // add prev reduceOp used sm bytes offset + indexToBase[indices[0]] = + gep(ptr_ty(rewriter.getContext(), 2), getElementType(op, indices[0]), + indexToBase[indices[0]], i32_val(prevSMOffset)); + + offsets.push_back( + (getElementType(op, indices[0]).getIntOrFloatBitWidth() * elems) / 8); + for (unsigned i = 1; i < (op.getNumOperands() - 1); ++i) { // skip loopIndex + indexToBase[indices[i]] = gep( + ptr_ty(rewriter.getContext(), 2), getElementType(op, indices[i - 1]), + indexToBase[indices[i - 1]], i32_val(elems)); + offsets.push_back( + (getElementType(op, indices[i - 1]).getIntOrFloatBitWidth() * elems) / + 8); + } + + // smemBases[k] is the base pointer for the k-th operand + SmallVector smemBases(op.getNumOperands() - 1); // skip loopIndex + for (unsigned i = 0; i < (op.getNumOperands() - 1); ++i) { // skip loopIndex + smemBases[i] = indexToBase[i]; + } + + // loopResCacheSmemBases[k] is the base pointer for the k-th operand which + // is the prev loop reduce result + SmallVector loopResCacheSmemBases(op.getNumOperands() - + 1); // skip loopIndex + std::map indexToBaseForLoopCache; + indexToBaseForLoopCache[indices[0]] = gep( + ptr_ty(rewriter.getContext(), 2), getElementType(op, indices.back()), + smemBases.back(), i32_val(elems)); + offsets.push_back( + (getElementType(op, indices[0]).getIntOrFloatBitWidth() * 1) / 8); + for (unsigned i = 1; i < (op.getNumOperands() - 1); ++i) { // skip loopIndex + indexToBaseForLoopCache[indices[i]] = gep( + ptr_ty(rewriter.getContext(), 2), getElementType(op, indices[i - 1]), + indexToBaseForLoopCache[indices[i - 1]], i32_val(1)); + offsets.push_back( + (getElementType(op, indices[0]).getIntOrFloatBitWidth() * 1) / 8); + } + + for (unsigned i = 0; i < (op.getNumOperands() - 1); ++i) { // skip loopIndex + loopResCacheSmemBases[i] = indexToBaseForLoopCache[i]; + } + + helper.setSMOffsets(helper.getReduceId(), offsets); + return {smemBases, loopResCacheSmemBases}; + } + + LogicalResult + matchAndRewrite(triton::xpu::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ReduceOpHelper helper(op); + assert(cast(helper.getSrcLayout()) && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + + // Init shared memory for mask. + initSharedMemory(helper, rewriter, loc); + + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + if (helper.isCoreSynchronous()) { + LLVM_DEBUG(llvm::dbgs() << "\nisCoreSynchronous=True\n"); + + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + LLVM_DEBUG(llvm::dbgs() << "\nisCoreSynchronous=False\n"); + + // Then reduce across threads within a group. + reduceWithinGroups(helper, accs, rewriter); + // LLVM_DEBUG(llvm::dbgs() << "\n After matchAndRewrite xpu::ReduceOp:" + // << op->getParentOfType() << "\n"); + // helper.dumpSMOffsets(); + return success(); + } + +private: + const TargetInfoBase &targetInfo; + + inline bool isNeedLoopCacheResult(triton::xpu::ReduceOp op) const { + if (op.getAxis() == 1) { + LLVM_DEBUG(llvm::dbgs() << "\nDont Need SM Loop Result Cache"); + return false; + } + + if (auto resultTy = + dyn_cast(op.getResult()[0].getType())) { + auto resultShape = resultTy.getShape(); + auto rank = resultShape.size(); + + if (rank == 1) { // TODO[dyq]: add [op.getAxis() == None] logic + LLVM_DEBUG(llvm::dbgs() << "\nNeed SM Loop Result Cache"); + return true; + } + + LLVM_DEBUG(llvm::dbgs() << "\nDont Need SM Loop Result Cache"); + return false; + } + + // scalar + LLVM_DEBUG(llvm::dbgs() << "\nNeed SM Loop Result Cache"); + return true; + } + + SmallVector initVals(triton::xpu::ReduceOp &op, + ConversionPatternRewriter &rewriter, + Location loc) const { + + auto naiveInit = [&](Type elemTy, int value) { + Value val; + if (elemTy.isInteger(1)) { + val = int_val(1, value); + } else if (elemTy.isInteger(8)) { + val = int_val(8, value); + } else if (elemTy.isInteger(16)) { + val = int_val(16, value); + } else if (elemTy.isInteger(32)) { + val = i32_val(value); + } else if (elemTy.isInteger(64)) { + val = int_val(64, value); + } else if (elemTy.isF16()) { + val = f16_val(value); + } else if (elemTy.isF32()) { + val = f32_val(value); + } else if (elemTy.isF64()) { + val = f64_val(value); + } else { + LLVM_DEBUG(elemTy.dump()); + llvm_unreachable("[Reduce Init]: Unsupported ElemTy in Naive Init"); + } + return val; + }; + + auto maxInit = [&](Type elemTy) { + Value val; + if (elemTy.isInteger(1)) { + val = int_val(1, 1); + } else if (elemTy.isInteger(8)) { + val = int_val(8, INT8_MAX); + } else if (elemTy.isInteger(16)) { + val = int_val(16, INT16_MAX); + } else if (elemTy.isInteger(32)) { + val = i32_val(INT32_MAX); + } else if (elemTy.isInteger(64)) { + val = int_val(64, INT64_MAX); + } else if (elemTy.isF16()) { + val = f16_val(65504); + } else if (elemTy.isF32()) { + val = f32_val(FLT_MAX); + } else if (elemTy.isF64()) { + val = f64_val(DBL_MAX); + } else { + LLVM_DEBUG(elemTy.dump()); + llvm_unreachable("[Reduce Init]: Unsupported ElemTy in Max Init"); + } + return val; + }; + + auto minInit = [&](Type elemTy) { + Value val; + if (elemTy.isInteger(1)) { + val = int_val(1, 0); + } else if (elemTy.isInteger(8)) { + val = int_val(8, -INT8_MAX); + } else if (elemTy.isInteger(16)) { + val = int_val(16, -INT16_MAX); + } else if (elemTy.isInteger(32)) { + val = i32_val(-INT32_MAX); + } else if (elemTy.isInteger(64)) { + val = int_val(64, -INT64_MAX); + } else if (elemTy.isF16()) { + val = f16_val(-65504); + } else if (elemTy.isF32()) { + val = f32_val(-FLT_MAX); + } else if (elemTy.isF64()) { + val = f64_val(-DBL_MAX); + } else { + LLVM_DEBUG(elemTy.dump()); + llvm_unreachable("[Reduce Init]: Unsupported ElemTy in Min Init"); + } + return val; + }; + + auto &combineBlock = op.getCombineOp().getBlocks().front(); + SmallVector blockArgDefOps; + for (int i = 0; i < combineBlock.getArguments().size() / 2; ++i) { + auto arg = combineBlock.getArgument(i); + bool isBreak = false; + for (auto user : arg.getUsers()) { + TypeSwitch(user) + .Case([&](auto andIOp) { + blockArgDefOps.emplace_back(andIOp); + isBreak = true; + }) + .Case([&](auto orIOp) { + blockArgDefOps.emplace_back(orIOp); + isBreak = true; + }) + .Case([&](auto addFOp) { + blockArgDefOps.emplace_back(addFOp); + isBreak = true; + }) + .Case([&](auto addIOp) { + blockArgDefOps.emplace_back(addIOp); + isBreak = true; + }) + .Case([&](auto subFOp) { + blockArgDefOps.emplace_back(subFOp); + isBreak = true; + }) + .Case([&](auto subIOp) { + blockArgDefOps.emplace_back(subIOp); + isBreak = true; + }) + .Case([&](auto mulFOp) { + blockArgDefOps.emplace_back(mulFOp); + isBreak = true; + }) + .Case([&](auto mulIOp) { + blockArgDefOps.emplace_back(mulIOp); + isBreak = true; + }) + .Case([&](auto divFOp) { + blockArgDefOps.emplace_back(divFOp); + isBreak = true; + }) + .Case([&](auto divSIOp) { + blockArgDefOps.emplace_back(divSIOp); + isBreak = true; + }) + .Case([&](auto divUIOp) { + blockArgDefOps.emplace_back(divUIOp); + isBreak = true; + }) + .Case([&](auto maxNumFOp) { + blockArgDefOps.emplace_back(maxNumFOp); + isBreak = true; + }) + .Case([&](auto maxSIOp) { + blockArgDefOps.emplace_back(maxSIOp); + isBreak = true; + }) + .Case([&](auto maxUIOp) { + blockArgDefOps.emplace_back(maxUIOp); + isBreak = true; + }) + .Case([&](auto minNumFOp) { + blockArgDefOps.emplace_back(minNumFOp); + isBreak = true; + }) + .Case([&](auto minSIOp) { + blockArgDefOps.emplace_back(minSIOp); + isBreak = true; + }) + .Case([&](auto minUIOp) { + blockArgDefOps.emplace_back(minUIOp); + isBreak = true; + }) + .Case([&](auto maximumFOp) { + blockArgDefOps.emplace_back(maximumFOp); + isBreak = true; + }) + .Case([&](auto minimumFOp) { + blockArgDefOps.emplace_back(minimumFOp); + isBreak = true; + }) + .Case([&](auto cmpFOp) { + if (cmpFOp.getPredicate() == arith::CmpFPredicate::OGT || + cmpFOp.getPredicate() == arith::CmpFPredicate::OGE || + cmpFOp.getPredicate() == arith::CmpFPredicate::OLT || + cmpFOp.getPredicate() == arith::CmpFPredicate::OLE || + cmpFOp.getPredicate() == arith::CmpFPredicate::UGT || + cmpFOp.getPredicate() == arith::CmpFPredicate::UGE || + cmpFOp.getPredicate() == arith::CmpFPredicate::ULT || + cmpFOp.getPredicate() == arith::CmpFPredicate::ULE) { + blockArgDefOps.emplace_back(cmpFOp); + isBreak = true; + } + }) + .Case([&](auto cmpIOp) { + if (cmpIOp.getPredicate() == arith::CmpIPredicate::slt || + cmpIOp.getPredicate() == arith::CmpIPredicate::sle || + cmpIOp.getPredicate() == arith::CmpIPredicate::sgt || + cmpIOp.getPredicate() == arith::CmpIPredicate::sge || + cmpIOp.getPredicate() == arith::CmpIPredicate::ult || + cmpIOp.getPredicate() == arith::CmpIPredicate::ule || + cmpIOp.getPredicate() == arith::CmpIPredicate::ugt || + cmpIOp.getPredicate() == arith::CmpIPredicate::uge) { + blockArgDefOps.emplace_back(cmpIOp); + isBreak = true; + } + }); + if (isBreak) { + break; + } + } + } + + auto types = op.getInputTypes(); + assert(blockArgDefOps.size() == types.size() && + "[Reduce Init]: BlockArgDefOps Size() != Types Size"); + SmallVector vals; + for (int i = 0; i < blockArgDefOps.size(); ++i) { + auto elemTy = getElementTypeOrSelf(types[i]); + if (auto vecTy = dyn_cast(elemTy)) { + elemTy = vecTy.getElementType(); + } + Value val = naiveInit(elemTy, 0); + auto blockArgDefOp = blockArgDefOps[i]; + TypeSwitch(blockArgDefOp) + .Case([&](auto andIOp) { val = naiveInit(elemTy, 1); }) + .Case([&](auto orIOp) { val = naiveInit(elemTy, 0); }) + .Case([&](auto addFOp) { val = naiveInit(elemTy, 0); }) + .Case([&](auto addIOp) { val = naiveInit(elemTy, 0); }) + .Case([&](auto subFOp) { val = naiveInit(elemTy, 0); }) + .Case([&](auto subIOp) { val = naiveInit(elemTy, 0); }) + .Case([&](auto mulFOp) { val = naiveInit(elemTy, 1); }) + .Case([&](auto mulIOp) { val = naiveInit(elemTy, 1); }) + .Case([&](auto divFOp) { val = naiveInit(elemTy, 1); }) + .Case( + [&](auto divSIOp) { val = naiveInit(elemTy, 1); }) + .Case( + [&](auto divUIOp) { val = naiveInit(elemTy, 1); }) + .Case( + [&](auto maxNumFOp) { val = minInit(elemTy); }) + .Case([&](auto maxSIOp) { val = minInit(elemTy); }) + .Case( + [&](auto maxUIOp) { val = naiveInit(elemTy, 0); }) + .Case( + [&](auto minNumFOp) { val = maxInit(elemTy); }) + .Case([&](auto minSIOp) { val = maxInit(elemTy); }) + .Case([&](auto minUIOp) { val = maxInit(elemTy); }) + .Case( + [&](auto maximumFOp) { val = minInit(elemTy); }) + .Case( + [&](auto minimumFOp) { val = maxInit(elemTy); }) + .Case([&](auto cmpFOp) { + if (cmpFOp.getPredicate() == arith::CmpFPredicate::OGT || + cmpFOp.getPredicate() == arith::CmpFPredicate::OGE || + cmpFOp.getPredicate() == arith::CmpFPredicate::UGT || + cmpFOp.getPredicate() == arith::CmpFPredicate::UGE) { + val = minInit(elemTy); + } else if (cmpFOp.getPredicate() == arith::CmpFPredicate::OLT || + cmpFOp.getPredicate() == arith::CmpFPredicate::OLE || + cmpFOp.getPredicate() == arith::CmpFPredicate::ULT || + cmpFOp.getPredicate() == arith::CmpFPredicate::ULE) { + val = maxInit(elemTy); + } else { + llvm_unreachable( + "[Reduce Init]: Unsupported CmpFPredicate in CmpFOp"); + } + }) + .Case([&](auto cmpIOp) { + if (cmpIOp.getPredicate() == arith::CmpIPredicate::sgt || + cmpIOp.getPredicate() == arith::CmpIPredicate::sge || + cmpIOp.getPredicate() == arith::CmpIPredicate::ugt || + cmpIOp.getPredicate() == arith::CmpIPredicate::uge) { + val = minInit(elemTy); + } else if (cmpIOp.getPredicate() == arith::CmpIPredicate::slt || + cmpIOp.getPredicate() == arith::CmpIPredicate::sle || + cmpIOp.getPredicate() == arith::CmpIPredicate::ult || + cmpIOp.getPredicate() == arith::CmpIPredicate::ule) { + val = maxInit(elemTy); + } else { + llvm_unreachable( + "[Reduce Init]: Unsupported CmpFPredicate in CmpIOp"); + } + }) + .Default([&](auto defaultOp) { + LLVM_DEBUG(defaultOp->dump()); + llvm_unreachable( + "[Reduce Init]: Unsupported Operation in Reduce Output"); + }); + vals.emplace_back(val); + } + + return vals; + } + + void initSharedMemory(ReduceOpHelper &helper, + ConversionPatternRewriter &rewriter, + Location loc) const { + // Init shared memory + ConversionPatternRewriter::InsertionGuard guard( + rewriter); // save reduceOpPtr to restore + auto reduceOpPtr = rewriter.saveInsertionPoint(); + triton::xpu::ReduceOp op = helper.getXPUOperation(); + auto func = op->template getParentOfType(); + rewriter.setInsertionPointToStart(&(func.front())); + // Compute a shared memory base per operand. + Value coreId = getThreadId(rewriter, loc); + SmallVector vals = initVals(op, rewriter, loc); + + auto smemShape = helper.getXPUScratchConfig(); + auto [smemBases, loopResCacheSmemBases] = + getSmemBases(op, product(smemShape), rewriter); + for (unsigned i = 0; i < (op.getNumOperands() - 1); ++i) { // skip loopIndex + auto elemTy = getElementType(op, i); + Value initPtr = + gep(ptr_ty(rewriter.getContext(), 2), elemTy, smemBases[i], coreId); + store_sm(vals[i], initPtr); + } + xpu_barrier(); + rewriter.restoreInsertionPoint(reduceOpPtr); // restore reduceOpPtr + } + + void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + SmallVector &acc, ValueRange cur, bool isFirst) const { + if (isFirst) { + acc = SmallVector(cur.begin(), cur.end()); + return; + } + + // Create a new copy of the reduce block, and inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(combineOp, &parent.front()); + auto &newReduce = parent.front(); + auto returnOp = + dyn_cast(newReduce.getTerminator()); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), + combineArgs); + + auto results = returnOp.getResult(); + for (unsigned i = 0; i < acc.size(); ++i) { + acc[i] = results[i]; + } + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + } + + void calculate(ConversionPatternRewriter &rewriter, const Location &loc, + Operation *op, Value &acc, const Value &cur) const { + TypeSwitch(op) + .Case([&](auto addfOp) { acc = fadd(acc, cur); }) + .Case([&](auto mulfOp) { acc = fmul(acc, cur); }) + .Case([&](auto maxfOp) { acc = fmax(acc, cur); }) + .Case([&](auto minfOp) { acc = fmin(acc, cur); }) + .Case([&](auto oriOp) { acc = or_(acc, cur); }) + .Case([&](auto xoriOp) { acc = xor_(acc, cur); }) + .Case([&](auto andOp) { acc = and_(acc, cur); }) + .Default([&](auto defaultOp) { + LLVM_DEBUG(defaultOp->dump()); + llvm_unreachable("[Vectorization]: Unsupported Operation Type " + "To VecType in Reduce"); + }); + } + + void accmulateWithinVector(ConversionPatternRewriter &rewriter, + const Location &loc, Operation *op, + Value &accVec) const { + auto accTy = cast(accVec.getType()); + size_t vecSize = accTy.getNumElements(); + Type elemTy = getElementTypeOrSelf(accTy); + Value acc = extract_element(elemTy, accVec, i32_val(0)); + for (size_t i = 1; i < vecSize; ++i) { + auto cur = extract_element(elemTy, accVec, i32_val(i)); + calculate(rewriter, loc, op, acc, cur); + } + accVec = acc; + } + + void accmulateNaive(ConversionPatternRewriter &rewriter, const Location &loc, + SmallVector &ops, SmallVector &accs, + SmallVector &curs, bool isFirst) const { + for (unsigned i = 0; i < accs.size(); ++i) { + if (isFirst) { + accs[i] = curs[i]; + } else { + calculate(rewriter, loc, ops[i], accs[i], curs[i]); + } + } + } + + SmallVector> + unpackInputs(Location loc, triton::xpu::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto types = op.getInputTypes(); + auto operands = adaptor.getOperands(); + unsigned srcElems = getTotalElemsPerThread(types[0]); + SmallVector> srcValues(srcElems); + for (unsigned i = 0; i < (op.getNumOperands() - 1); ++i) { // skip loopIndex + auto values = unpackLLElements(loc, operands[i], rewriter); + + assert(values.size() == srcValues.size()); + for (unsigned j = 0; j < srcValues.size(); ++j) { + srcValues[j].push_back(values[j]); + } + } + + return srcValues; + } + + // Reduce along op axis for elements that are in the same thread. The + // accumulated value is stored in accs. + void reduceWithinThreads( + ReduceOpHelper &helper, SmallVector> &srcValues, + std::map, SmallVector> &accs, + std::map, SmallVector> &indices, + ConversionPatternRewriter &rewriter) const { + triton::xpu::ReduceOp op = helper.getXPUOperation(); + RankedTensorType operandType = op.getInputTypes()[0]; + // Assumes offsets don't actually depend on type + SmallVector> offsets = + emitOffsetForLayout(helper.getSrcLayout(), operandType); + + // Thread X might hold the same input value in two registers. Get the + // indices in `offsets` that hold unique values, and only accumualte over + // those. + llvm::MapVector, int> uniqueOffsets; + for (int i = 0; i < offsets.size(); ++i) { + uniqueOffsets.insert({offsets[i], i}); + } + + unsigned srcElems = getTotalElemsPerThread(operandType); + auto *combineOp = &op.getCombineOp(); + auto srcIndices = + emitIndices(op.getLoc(), rewriter, targetInfo, helper.getSrcLayout(), + operandType, op.getLoopIndex()); + + // reduce within threads + for (const auto &[_, i] : uniqueOffsets) { + SmallVector key = offsets[i]; + key[op.getAxis()] = 0; + bool isFirst = accs.find(key) == accs.end(); + accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); + if (isFirst) + indices[key] = srcIndices[i]; + } + + // Accumulate within Vector + if (helper.isVectorized()) { + SmallVector returnDefOps = helper.getReturnDefOps(); + for (auto &it : accs) { + SmallVector &accVecs = it.second; + assert(accVecs.size() == returnDefOps.size() && + "accVecs.size() !=returnDefOps.size()"); + for (unsigned i = 0; i < returnDefOps.size(); ++i) { + accmulateWithinVector(rewriter, op.getLoc(), returnDefOps[i], + accVecs[i]); + } + } + } + } + + void storeCoreReduceToSharedMemory( + ReduceOpHelper &helper, + std::map, SmallVector> &accs, + SmallVector &smemBases, + ConversionPatternRewriter &rewriter) const { + triton::xpu::ReduceOp op = helper.getXPUOperation(); + Location loc = op.getLoc(); + Value coreId = getThreadId(rewriter, loc); + + for (auto it : accs) { + const SmallVector &key = it.first; + SmallVector &acc = it.second; + + for (unsigned i = 0; i < (op.getNumOperands() - 1); + ++i) { // skip loopIndex + auto elemTy = getElementType(op, i); + Value writePtr = + gep(ptr_ty(rewriter.getContext(), 2), elemTy, smemBases[i], coreId); + store_sm(acc[i], writePtr); + } + } + } + + mlir::RewriterBase::InsertPoint + getPreviousInsertionPoint(PatternRewriter &rewriter) const { + auto currentInsertionPoint = rewriter.getInsertionPoint(); + auto oldInsertionPoint = currentInsertionPoint; + Block *currentBlock_0 = rewriter.getInsertionBlock(); + // Move the iterator one step backward + if (currentInsertionPoint != currentBlock_0->begin()) { + --currentInsertionPoint; + } + // Set the insertion point to the adjusted position and save it + rewriter.setInsertionPoint(currentBlock_0, currentInsertionPoint); + auto startInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(currentBlock_0, oldInsertionPoint); + return startInsertionPoint; + } + + void createCondBr(PatternRewriter &rewriter, Location loc, Value condition, + Block *&trueDest, Block *&falseDest) const { + Block *currentBlock = rewriter.getInsertionBlock(); + falseDest = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + trueDest = rewriter.createBlock(falseDest); + + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, condition, trueDest, falseDest); + rewriter.setInsertionPointToStart(trueDest); + + rewriter.create(loc, falseDest); + rewriter.setInsertionPointToStart(falseDest); + } + + void moveOpsBetweenInsertionPoints(ReduceOpHelper &helper, + PatternRewriter &rewriter, Block *trueDest, + Block *falseDest, Block::iterator start, + Block::iterator end) const { + triton::xpu::ReduceOp op = helper.getXPUOperation(); + + auto terminator = &*(trueDest->begin()); + + // LLVM_DEBUG(llvm::dbgs() << "\n Before moveBefore:" << + // op->getParentOfType() + // << "\n"); + + // LLVM_DEBUG(llvm::dbgs() << "\n movedOp: " << "\n"); + start++; // skip the previous op + while (start != end) { + Operation *op = &*start++; + // op->dump(); + op->moveBefore(trueDest, trueDest->end()); + } + + // deal the previous op + Operation *endOp = &*end; + endOp->moveBefore(trueDest, trueDest->end()); + terminator->moveBefore(trueDest, trueDest->end()); + } + + void accumulatePartialReductions(ReduceOpHelper &helper, + SmallVector &smemBases, + SmallVector &loopResCacheSmemBases, + ConversionPatternRewriter &rewriter) const { + triton::xpu::ReduceOp op = helper.getXPUOperation(); + auto srcLayout = helper.getSrcLayout(); + Location loc = op.getLoc(); + unsigned groupSizeInt = helper.getIntraGroupSizeWithUniqueData(); + unsigned operandNum = op.getInputTypes().size(); + + Value coreId = getThreadId(rewriter, loc); + Value groupSize = i32_val(groupSizeInt); + Value coreIdInGroup = urem(coreId, groupSize); + Value groupId = udiv(coreId, groupSize); + Value groupSkip = mul(groupId, groupSize); + // Value laneId = add(mul(groupId, groupSize), coreIdInGroup); + Value zero = i32_val(0); + Value coreIdInGroupZero = icmp_eq(coreIdInGroup, zero); + + auto startInsertionPoint = getPreviousInsertionPoint(rewriter); + + SmallVector acc(operandNum); // skip loopIndex + SmallVector> readValues(groupSizeInt); + for (unsigned i = 0; i < operandNum; ++i) { // skip loopIndex + auto elemTy = getElementType(op, i); + + for (unsigned readOffset = 0; readOffset < groupSizeInt; ++readOffset) { + Value laneId = add(groupSkip, i32_val(readOffset)); + Value readPtr = + gep(ptr_ty(rewriter.getContext(), 2), elemTy, smemBases[i], laneId); + readValues[readOffset].push_back(load_sm(elemTy, readPtr)); + } + } + + SmallVector returnDefOps = helper.getReturnDefOps(); + for (auto [i, v] : llvm::enumerate(readValues)) { + if (helper.isVectorized()) { + accmulateNaive(rewriter, loc, returnDefOps, acc, v, i == 0); + } else { + accumulate(rewriter, op.getCombineOp(), acc, v, i == 0); + } + } + + SmallVector writePtrs(operandNum); // skip loopIndex + for (unsigned i = 0; i < operandNum; ++i) { // skip loopIndex + auto elemTy = getElementType(op, i); + // TODO[dyq]: check writeOffset == i32_val(0) + Value writeOffset = groupSkip; + writePtrs[i] = gep(ptr_ty(rewriter.getContext(), 2), elemTy, smemBases[i], + /*writeOffset*/ writeOffset); + } + + for (unsigned i = 0; i < operandNum; ++i) { // skip loopIndex + store_sm(acc[i], writePtrs[i]); + } + + // reduce calcution(cur loop) finsh, now move it to the trueDest + auto endInsertionPoint = getPreviousInsertionPoint(rewriter); + + Block *trueDest = nullptr; + Block *falseDest = nullptr; + createCondBr(rewriter, op->getLoc(), coreIdInGroupZero, trueDest, + falseDest); + + moveOpsBetweenInsertionPoints(helper, rewriter, trueDest, falseDest, + startInsertionPoint.getPoint(), + endInsertionPoint.getPoint()); + + if (isNeedLoopCacheResult(op)) { + // reduce with prev loopResult + Value laneId = coreId; + Value zero = i32_val(0); + Value laneZero = icmp_eq(laneId, zero); + Value loopIndex = op.getLoopIndex(); + Value loopNonZero = icmp_ne(loopIndex, zero); + Value cond = and_(laneZero, loopNonZero); + + Block *loopReduceTrueDest = nullptr; + Block *loopReduceFalseDest = nullptr; + createCondBr(rewriter, op->getLoc(), cond, loopReduceTrueDest, + loopReduceFalseDest); + + auto curInsertionPoint = rewriter.getInsertionPoint(); + Block *curBlock = rewriter.getInsertionBlock(); + + rewriter.setInsertionPointToStart(loopReduceTrueDest); + SmallVector curResSmemValues(operandNum); + SmallVector loopResCacheSmemValues(operandNum); + for (unsigned i = 0; i < (operandNum); ++i) { // skip loopIndex + auto elemTy = getElementType(op, i); + curResSmemValues[i] = load_sm(elemTy, smemBases[i]); + loopResCacheSmemValues[i] = load_sm(elemTy, loopResCacheSmemBases[i]); + } + + SmallVector returnDefOps = helper.getReturnDefOps(); + if (helper.isVectorized()) { + accmulateNaive(rewriter, loc, returnDefOps, loopResCacheSmemValues, + loopResCacheSmemValues, false); + } else { + accumulate(rewriter, op.getCombineOp(), curResSmemValues, + loopResCacheSmemValues, false); + } + + // store the final result + for (unsigned i = 0; i < operandNum; ++i) { // skip loopIndex + store_sm(curResSmemValues[i], smemBases[i]); + } + + rewriter.setInsertionPoint(curBlock, curInsertionPoint); + } + } + + // Load the final reduction from shared memory and replace the reduce result + // with it. + void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector smemShape, + SmallVector &smemBases, + SmallVector &loopResCacheSmemBases, + ConversionPatternRewriter &rewriter) const { + triton::xpu::ReduceOp op = helper.getXPUOperation(); + Location loc = op.getLoc(); + auto srcLayout = helper.getSrcLayout(); + auto axis = op.getAxis(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); + unsigned groupSizeInt = helper.getIntraGroupSizeWithUniqueData(); + + Value coreId = getThreadId(rewriter, loc); + Value groupSize = i32_val(groupSizeInt); + Value groupId = udiv(coreId, groupSize); + Value groupSkip = mul(groupId, groupSize); + SmallVector results(op.getNumOperands() - 1); // skip loopIndex + for (unsigned i = 0; i < (op.getNumOperands() - 1); ++i) { // skip loopIndex + auto elemTy = getElementType(op, i); + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + // nd-tensor where n >= 1 + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + auto resultIndices = + emitIndices(loc, rewriter, targetInfo, resultLayout, resultTy, + op.getLoopIndex()); + auto resultShape = resultTy.getShape(); + auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); + assert(resultIndices.size() == resultElems); + + SmallVector resultVals(resultElems); + for (size_t j = 0; j < resultElems; ++j) { + SmallVector readIdx = resultIndices[j]; + // readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); + // for (size_t resultIdx = 0, resultDim = resultShape.size(); + // resultIdx < resultDim; ++resultIdx) { + // auto smemIdx = resultIdx < op.getAxis() ? resultIdx : + // resultIdx + // + 1; if (resultCTATile[resultIdx] > smemShape[smemIdx] || + // resultShape[resultIdx] > smemShape[smemIdx]) { + // // When srcShape smaller then src sizePerThread, only + // srcShape + // // elements is accumulated in smem. Modulo smemShape + // effectively + // // replicates srcShape elements to src sizePerThread. + // readIdx[smemIdx] = + // urem(readIdx[smemIdx], i32_val(smemShape[smemIdx])); + // } + // } + + Value readOffset = groupSkip; + // Value readOffset = linearize(rewriter, loc, readIdx, smemShape, + // smemOrder); + Value readPtr = gep(ptr_ty(rewriter.getContext(), 2), elemTy, + smemBases[i], readOffset); + resultVals[j] = load_sm(elemTy, readPtr); + } + + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else { + // 0d-tensor -> scalar + results[i] = load_sm(elemTy, smemBases[i]); + // save reduce result in cur loop + if (isNeedLoopCacheResult(op)) + store_sm(results[i], loopResCacheSmemBases[i]); + } + } + + rewriter.replaceOp(op, results); + } + + // Reduce across threads within each group. + void + reduceWithinGroups(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::xpu::ReduceOp op = helper.getXPUOperation(); + Location loc = op.getLoc(); + unsigned axis = op.getAxis(); + + // unsigned sizeIntraGroups = helper.getIntraGroupSizeWithUniqueData(); + // unsigned threadOffsetOnReductionAxis = + // helper.getThreadOffsetOnReductionAxis(); LLVM_DEBUG(llvm::dbgs() << + // "\nsizeIntraGroups=" << sizeIntraGroups << "\n"); LLVM_DEBUG(llvm::dbgs() + // << + // "\nthreadOffsetOnReductionAxis=" + // << threadOffsetOnReductionAxis << "\n"); + + // Compute a shared memory base per operand. + auto smemShape = helper.getXPUScratchConfig(); + + auto [smemBases, loopResCacheSmemBases] = + getSmemBases(op, product(smemShape), rewriter); + + storeCoreReduceToSharedMemory(helper, accs, smemBases, rewriter); + // LLVM_DEBUG(llvm::dbgs() << "\n After storeCoreReduceToSharedMemory:" + // << op->getParentOfType() << "\n"); + + xpu_barrier(); + + accumulatePartialReductions(helper, smemBases, loopResCacheSmemBases, + rewriter); + // LLVM_DEBUG(llvm::dbgs() << "\n After accumulatePartialReductions:" + // << op->getParentOfType() << "\n"); + + xpu_barrier(); + + loadReductionAndPackResult(helper, smemShape, smemBases, + loopResCacheSmemBases, rewriter); + + // llvm_unreachable("Not Supported"); + } + + // Pack the accumulator values and replace the reduce op with the result. + void packResults(ReduceOpHelper &helper, + std::map, SmallVector> &accs, + ConversionPatternRewriter &rewriter) const { + triton::xpu::ReduceOp op = helper.getXPUOperation(); + Location loc = op.getLoc(); + unsigned axis = op.getAxis(); + SmallVector results((op.getNumOperands() - 1)); // skip loopIndex + for (unsigned i = 0; i < (op.getNumOperands() - 1); ++i) { // skip loopIndex + if (auto resultTy = + dyn_cast(op.getResult()[i].getType())) { + auto resultLayout = cast(resultTy.getEncoding()); + unsigned resultElems = getTotalElemsPerThread(resultTy); + SmallVector> resultOffset = + emitOffsetForLayout(resultLayout, resultTy); + SmallVector resultVals; + for (int j = 0; j < resultElems; j++) { + auto key = resultOffset[j]; + key.insert(key.begin() + axis, 0); + resultVals.push_back(accs[key][i]); + } + results[i] = packLLElements(loc, getTypeConverter(), resultVals, + rewriter, resultTy); + } else + results[i] = accs.begin()->second[i]; + } + rewriter.replaceOp(op, results); + } +}; + +} // namespace + +void mlir::triton::xpu::populateReduceOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/SPMDOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000..237645b51 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct XPUGetNumProgramsOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::GetNumProgramsOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + Value retVal = rewriter.create( + loc, type::i32Ty(ctx), i32_val(1)); + + rewriter.replaceOp(op, {retVal}); + return success(); + } +}; + +} // namespace + +void mlir::triton::xpu::populateSPMDOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TargetInfo.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TargetInfo.cpp new file mode 100644 index 000000000..26a6d460e --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TargetInfo.cpp @@ -0,0 +1,117 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +// clang-format off +#include "xpu/lib/Conversion/TritonXPUToLLVM/TargetInfo.h" // TargetInfo + +#include "triton/Analysis/UtilityXPU.h" +#include "xpu/lib/Conversion/TritonXPUToLLVM/Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +// clang-format on + +using namespace mlir; + +namespace mlir { +namespace triton { +namespace xpu { + +bool TargetInfo::supportMaximumMinimum() const { + llvm_unreachable("not impl"); + return false; +} + +Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { + llvm_unreachable("not impl"); + return Value(); +} + +Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, + Type type, Value cmp) const { + llvm_unreachable("not impl"); + return Value(); +} + +void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, + Value ptr, Value val, Value pred) const { + llvm_unreachable("not impl"); +} + +Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *converter, Value ptr, + Type elemTy, Value pred) const { + llvm_unreachable("not impl"); + return Value(); +} + +Value TargetInfo::shuffleXor(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const { + llvm_unreachable("not impl"); + return Value(); +} + +Value TargetInfo::shuffleUp(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const { + llvm_unreachable("not impl"); + return Value(); +} + +Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, int i) const { + llvm_unreachable("not impl"); + return Value(); +} + +Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, + Value val, Value i) const { + llvm_unreachable("not impl"); + return Value(); +} + +Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, + ModuleOp moduleOp, int axis) const { + return LLVM::XPU::llGetPid(loc, rewriter, moduleOp, axis); +} + +bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce) const { + llvm_unreachable("not impl"); + return false; +} + +bool TargetInfo::processReplicaUsingStMatrix( + ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + SmallVector &vals, RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, ArrayRef origRepShape, + ArrayRef outOrd, unsigned accumNumReplicates, + int swizzlingByteWidth) const { + llvm_unreachable("not impl"); + return false; +} + +std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { + std::string funcName = + resultElementTy.isInteger(32) ? "_ZN3xpu6umulhiEjj" : "Unsupported"; + return funcName; +} + +void TargetInfo::printf(ConversionPatternRewriter &rewriter, + Value formatStrStart, int /*formatStrByteCount*/, + ValueRange args) const { + llvm_unreachable("not impl"); +} + +void TargetInfo::assertFail(ConversionPatternRewriter &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const { + llvm_unreachable("not impl"); +} + +uint32_t TargetInfo::getXPUArch() const { return this->xpu_arch; } +uint32_t TargetInfo::getXPUBufferSize() const { return this->buffer_size; } + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TargetInfo.h b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TargetInfo.h new file mode 100644 index 000000000..4953159c2 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TargetInfo.h @@ -0,0 +1,76 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITON_CONVERSION_TRITONXPU_TO_LLVM_TARGETINFOXPU_H +#define TRITON_CONVERSION_TRITONXPU_TO_LLVM_TARGETINFOXPU_H + +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" + +namespace mlir { +namespace triton { +namespace xpu { + +class TargetInfo : public mlir::triton::TargetInfoBase { +public: + TargetInfo(uint32_t xpu_arch, uint32_t buffer_size) + : xpu_arch(xpu_arch), buffer_size(buffer_size) {} + + bool supportMaximumMinimum() const override; + + Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; + + Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, + Value cmp) const override; + + void storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value val, Value pred) const override; + + Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + const TypeConverter *converter, Value ptr, Type elemTy, + Value pred) const override; + + Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, Value val, + int i) const override; + Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + int i) const override; + Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value i) const override; + + Value programId(ConversionPatternRewriter &rewriter, Location loc, + ModuleOp moduleOp, int axis) const override; + + bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, + SmallVector &acc, triton::ReduceOp op, + unsigned numLaneToReduce) const override; + + bool processReplicaUsingStMatrix( + ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + SmallVector &vals, RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, ArrayRef origRepShape, + ArrayRef outOrd, unsigned accumNumReplicates, + int swizzleByteWidth) const override; + + std::string getMulhiFuncName(Type resultElementTy) const override; + + void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const override; + void assertFail(ConversionPatternRewriter &rewriter, Location loc, + StringRef message, StringRef file, StringRef func, + int line) const override; + + uint32_t getXPUArch() const; + uint32_t getXPUBufferSize() const; + +private: + uint32_t xpu_arch; + uint32_t buffer_size; +}; +} // namespace xpu +} // namespace triton +} // namespace mlir +#endif // TRITON_CONVERSION_TRITONXPU_TO_LLVM_TARGETINFOXPU_H diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TritonXPUToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TritonXPUToLLVM.cpp new file mode 100644 index 000000000..bdebd6c87 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TritonXPUToLLVM.cpp @@ -0,0 +1,229 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Conversion/TritonXPUToLLVM/Passes.h" +// clang-format off +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +// #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" + +// clang-format on + +namespace mlir { +namespace triton { + +#define GEN_PASS_DEF_CONVERTTRITONXPUTOLLVM +#include "triton/Conversion/TritonXPUToLLVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +using namespace mlir; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); + } +}; + +class TritonLLVMFunctionConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + // addLegalDialect(); // TODO[dyq]: necessary? + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + } +}; + +struct ConvertTritonXPUToLLVM + : public triton::impl::ConvertTritonXPUToLLVMBase { + using ConvertTritonXPUToLLVMBase::ConvertTritonXPUToLLVMBase; + + ConvertTritonXPUToLLVM(uint32_t xpu_arch, uint32_t buffer_size) + : ConvertTritonXPUToLLVMBase({xpu_arch, buffer_size}) {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + option.overrideIndexBitwidth(32); + TritonXPUToLLVMTypeConverter typeConverter(context, + option); // we can reuse it + TritonLLVMConversionTarget convTarget(*context); + + // Allocate shared memory and set barrier + // TODO[dyq]: necessary to open? + // ModuleAllocation allocation(mod); + // ModuleMembarAnalysis membarPass(&allocation); + // membarPass.run(); + + // Lower functions + { + mlir::LowerToLLVMOptions option(context); + option.overrideIndexBitwidth(32); + TritonXPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + // TODO[dyq]: add [nvvm.maxntid, nvvm.kernel] attr + mlir::triton::xpu::populateFuncOpConversionPattern( + typeConverter, funcPatterns, patternBenefitDefault); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); + } + + // LLVM_DEBUG(llvm::dbgs() << "\nAfter Lower Functions:\n" << mod << "\n"); + + // initSharedMemory is run before the conversion of call and ret ops, + // because the call op has to know the shared memory base address of + // each function + initSharedMemory(typeConverter); + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + OpBuilder::InsertPoint indexInsertPoint; + + RewritePatternSet patterns(context); + triton::xpu::TargetInfo targetInfo(xpu_arch, buffer_size); + int benefit = patternBenefitPrioritizeOverLLVMConversions; + // Make benefit for XPU specific patterns higher so they apply before common + // patterns + int xpuBenefit = benefit + 1; + + // TODO[dyq]: Open allToLLVMPatterns + mlir::triton::xpu::populateConvertLayoutOpToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit); + + // TODO[dyq]: XPUSDNN-CHECK add DotOp Lowering Pattern + // mlir::triton::xpu::populateDotOpToLLVMPatterns(typeConverter, patterns, + // benefit); + mlir::triton::xpu::populateElementwiseOpToLLVMPatterns( + typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); + + mlir::triton::populateElementwiseOpToLLVMPatterns( + typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); + + mlir::triton::xpu::populateTTXPUVectorizedOpToLLVMConversionPatterns( + typeConverter, targetInfo, patterns, benefit); + + // TODO[dyq]: + mlir::triton::xpu::populateTTXPUUtilityOpToLLVMConversionPatterns( + typeConverter, targetInfo, patterns, axisInfoAnalysis, benefit); + + mlir::triton::xpu::populateLoadStoreOpToLLVMPatterns( + typeConverter, targetInfo, patterns, axisInfoAnalysis, benefit); + + mlir::triton::xpu::populateGPUToXPUConversionPatterns( + typeConverter, patterns, targetInfo, benefit); + + mlir::triton::xpu::populateReduceOpToLLVMPatterns(typeConverter, patterns, + targetInfo, benefit); + // mlir::triton::populateScanOpToLLVMPatterns(typeConverter, patterns, + // targetInfo, benefit); + + // mlir::triton::populateHistogramOpToLLVMPatterns(typeConverter, patterns, + // targetInfo, benefit); + // mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, + // targetInfo, benefit); + mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::xpu::populateSPMDOpToLLVMPattern(typeConverter, patterns, + benefit); + mlir::triton::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + + mlir::triton::xpu::populateViewOpToLLVMPatterns(typeConverter, patterns, + xpuBenefit); + mlir::triton::populateViewOpToLLVMPatterns(typeConverter, patterns, + benefit); + // mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, + // targetInfo, benefit); + // mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, + // patterns, benefit); + mlir::triton::xpu::populateMakeRangeOpToLLVMPattern( + typeConverter, targetInfo, patterns, benefit); + + // TODO(thomas): this should probably be done in a separate step to not + // interfere with our own lowering of arith ops. Add arith/math's patterns + // to help convert scalar expression to LLVM. + mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + +private: + Value smem; + + void initSharedMemory(LLVMTypeConverter &typeConverter) { + ModuleOp mod = getOperation(); + OpBuilder b(mod.getBodyRegion()); + auto ctx = mod.getContext(); + auto loc = mod.getLoc(); + auto elemTy = typeConverter.convertType(b.getIntegerType(8)); + // Set array size 0 and external linkage indicates that we use dynamic + // shared allocation to allow a larger shared memory size for each kernel. + + // XPU: the shared space used by Triton will be put at the end of + // SHARED_MEMORY section since only pointer is used without the real + // allocation + auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); + auto global = b.create( + loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::Internal, + "global_smem", /*value=*/Attribute(), /*alignment=*/8, 2); + + // TODO[dyq]: llvm-18 don't need to set pointer type, will this logic be + // changed? + SmallVector funcs; + mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); }); + assert(funcs.size() == 1 && + "Inliner pass is expected before TritonXPUToLLVM"); + b.setInsertionPointToStart(&funcs[0].getBody().front()); + smem = b.create(loc, global); + + // TODO[dyq]: llvm-18 don't need to set pointer type, this type maybe cause + // error + // auto ptrTy = + // LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 2); + auto ptrTy = LLVM::LLVMPointerType::get(mod.getContext(), 2); + smem = b.create(loc, ptrTy, smem); + } +}; + +} // namespace + +namespace mlir { +namespace triton { + +std::unique_ptr> createConvertTritonXPUToLLVMPass() { + return std::make_unique(); +} + +std::unique_ptr> +createConvertTritonXPUToLLVMPass(uint32_t xpu_arch, uint32_t buffer_size) { + return std::make_unique(xpu_arch, buffer_size); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TypeConverter.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000..5870adc56 --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/TypeConverter.cpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Conversion/TritonXPUToLLVM/TypeConverter.h" // TritonXPUToLLVMTypeConverter + +using namespace mlir; +using namespace mlir::triton; +using ::mlir::triton::xpu::getTotalElemsPerThread; + +TritonXPUToLLVMTypeConverter::TritonXPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + // deal RankedTensorType to calculate elemNum + addConversion([&](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type); + }); +} + +Type TritonXPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (isa(pointeeType)) { + auto rankedTensorType = cast(pointeeType); + // struct { offset0, offset1, shape0, shape1, stride0, + // stride1, base_ptr}; + auto eleType = rankedTensorType.getElementType(); + auto shape = rankedTensorType.getShape(); + SmallVector types; + // offsets + for (size_t i = 0; i < shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 32)); + // shapes, strides + for (size_t i = 0; i < 2 * shape.size(); ++i) + types.push_back(IntegerType::get(ctx, 64)); + + types.push_back(LLVM::LLVMPointerType::get(ctx, type.getAddressSpace())); + + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); +} + +Type TritonXPUToLLVMTypeConverter::getElementTypeForStruct( + TensorOrMemDesc type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + Type elemTy = convertType(type.getElementType()); + return elemTy; +} + +Type TritonXPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type) { + auto ctx = type.getContext(); + Attribute layout = type.getEncoding(); + SmallVector shape(type.getShape().begin(), type.getShape().end()); + Type eltType = getElementTypeForStruct(cast(type)); + + unsigned numElementsPerThread = getTotalElemsPerThread(type); + SmallVector types(numElementsPerThread, eltType); + return LLVM::LLVMStructType::getLiteral(ctx, types); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/Utility.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/Utility.cpp new file mode 100644 index 000000000..6eb3fbfce --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/Utility.cpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "xpu/lib/Conversion/TritonXPUToLLVM/Utility.h" + +namespace mlir::LLVM::XPU { + +Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis) { + assert(axis >= 0); + assert(axis < 3); + assert(moduleOp); + static constexpr mlir::gpu::Dimension dims[] = {mlir::gpu::Dimension::x, + mlir::gpu::Dimension::y, + mlir::gpu::Dimension::z}; + + // TODO[dyq]: add Dimension:y & Dimension:z mapping + Value blockId; + switch (axis) { + case 0: { + blockId = rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[axis]); + break; + } + case 1: + case 2: { + blockId = i32_val(0); + break; + } + default: { + llvm_unreachable("ProgramIdOp Get Invalid Axis"); + } + } + + return rewriter.create(loc, i32_ty, blockId); +} + +Type getFunctionType(mlir::OpBuilder &builder, ValueRange operands) { + SmallVector operandTypes(operands.getTypes()); + mlir::MLIRContext *ctx = builder.getContext(); + auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx); + return LLVM::LLVMFunctionType::get(voidTy, operandTypes); +} + +Value createDeviceCall(StringRef funcName, ConversionPatternRewriter &rewriter, + Operation *op, Type &elemTy, ValueRange &operands, + Location &loc) { + Type funcType = mlir::triton::gpu::getFunctionType(elemTy, operands); + LLVM::LLVMFuncOp funcOp = mlir::triton::gpu::appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, "", ""); + return rewriter.create(loc, funcOp, operands).getResult(); +} + +void createDeviceCall(StringRef funcName, ConversionPatternRewriter &rewriter, + Operation *op, ValueRange &operands, Location &loc) { + OpBuilder builder(op); + Type funcType = getFunctionType(builder, operands); + LLVM::LLVMFuncOp funcOp = mlir::triton::gpu::appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, "", ""); + rewriter.create(loc, funcOp, operands); + return; +} + +} // namespace mlir::LLVM::XPU diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/Utility.h b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/Utility.h new file mode 100644 index 000000000..fee5bad5a --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/Utility.h @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#ifndef TRITON_CONVERSION_TRITONXPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONXPU_TO_LLVM_UTILITY_H + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" + +#define addrspace_cast(...) \ + rewriter.create(loc, __VA_ARGS__) +#define allocate(...) rewriter.create(loc, __VA_ARGS__) +#define idx_val(...) \ + LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \ + __VA_ARGS__) +#define sdiv(...) rewriter.create(loc, __VA_ARGS__) +#define srem(...) rewriter.create(loc, __VA_ARGS__) +#define load_sm(...) rewriter.create(loc, __VA_ARGS__) +#define store_sm(...) rewriter.create(loc, __VA_ARGS__) +#define xpu_barrier() rewriter.create(loc) +#define i16_val(...) \ + LLVM::createLLVMIntegerConstant(rewriter, loc, 16, __VA_ARGS__) + +namespace mlir::triton { +inline size_t align(size_t elemNum, Type elemTy, size_t target) { + size_t elemBit = isa(elemTy) + ? 64u + : elemTy.getIntOrFloatBitWidth(); + size_t elemBytes = (elemBit / 8u) ? (elemBit / 8u) : 1; + size_t aligned = (elemNum * elemBytes + target - 1) / target * target; + return aligned / elemBytes; +} +} // namespace mlir::triton + +namespace mlir::LLVM::XPU { + +Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, + ModuleOp moduleOp, int axis); + +Value createDeviceCall(StringRef funcName, ConversionPatternRewriter &rewriter, + Operation *op, Type &elemTy, ValueRange &operands, + Location &loc); + +void createDeviceCall(StringRef funcName, ConversionPatternRewriter &rewriter, + Operation *op, ValueRange &operands, Location &loc); + +SmallVector> +emitOffsetForClusterLayout(const triton::xpu::ClusterLayoutAttr &clusterLayout, + RankedTensorType type); + +inline Value getGridDim(RewriterBase &rewriter, Location loc) { + Value gridDim = + rewriter.create<::mlir::gpu::GridDimOp>(loc, ::mlir::gpu::Dimension::x); + return rewriter.create(loc, i32_ty, gridDim); +} + +inline Value getBlockDim(RewriterBase &rewriter, Location loc) { + Value blockDim = + rewriter.create<::mlir::gpu::BlockDimOp>(loc, ::mlir::gpu::Dimension::x); + return rewriter.create(loc, i32_ty, blockDim); +} + +inline Value getBlockId(RewriterBase &rewriter, Location loc) { + Value blockId = + rewriter.create<::mlir::gpu::BlockIdOp>(loc, ::mlir::gpu::Dimension::x); + return rewriter.create(loc, i32_ty, blockId); +} + +} // namespace mlir::LLVM::XPU + +#endif // TRITON_CONVERSION_TRITONXPU_TO_LLVM_UTILITY_H diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/VectorizedOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/VectorizedOpToLLVM.cpp new file mode 100644 index 000000000..edd81416f --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/VectorizedOpToLLVM.cpp @@ -0,0 +1,1012 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +// clang-format off +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" +// clang-format on + +namespace { +// TODO[dyq]: add to head file +enum class ElemState { + SS = 0, /*00*/ + SV = 1, /*01*/ + VS = 2, /*10*/ + VV = 3 /*11*/ +}; + +template struct SVOp2Str; + +#define SVOp2ASMStr(SrcType, ASM_STR) \ + template <> struct SVOp2Str { \ + static const llvm::StringRef value; \ + }; \ + const llvm::StringRef SVOp2Str::value = ASM_STR; + +SVOp2ASMStr(triton::xpu::SvaddFOp, "vadd.f.mz.rn $0{mr1}, $1, $2"); +SVOp2ASMStr(triton::xpu::SvmulFOp, "vmul.f.mz.rn $0{mr1}, $1, $2"); +SVOp2ASMStr(triton::xpu::SvsubFOp, "vsub.f.mz.rn $0{mr1}, $1, $2"); +SVOp2ASMStr(triton::xpu::SvmaxFOp, "vmax.f.mz $0{mr1}, $1, $2"); + +template struct SVOp2StrFP16; + +#define SVOp2ASMStrFP16(SrcType, ASM_STR) \ + template <> struct SVOp2StrFP16 { \ + static const llvm::StringRef value; \ + }; \ + const llvm::StringRef SVOp2StrFP16::value = ASM_STR; + +SVOp2ASMStrFP16(triton::xpu::SvaddFOp, "vadd.hf.mz.rn $0{mr1}, $1, $2"); +SVOp2ASMStrFP16(triton::xpu::SvmulFOp, "vmul.hf.mz.rn $0{mr1}, $1, $2"); +SVOp2ASMStrFP16(triton::xpu::SvsubFOp, "vsub.hf.mz.rn $0{mr1}, $1, $2"); +SVOp2ASMStrFP16(triton::xpu::SvmaxFOp, "vmax.hf.mz $0{mr1}, $1, $2"); + +template struct VLibOp; + +#define VLibOp2DevCall(SrcType, ASM_STR) \ + template <> struct VLibOp { \ + static const llvm::StringRef value; \ + }; \ + const llvm::StringRef VLibOp::value = ASM_STR; + +VLibOp2DevCall(triton::xpu::VSinFOp, "_ZN3xpu5vsinfEDv16_f"); +VLibOp2DevCall(triton::xpu::VCosFOp, "_ZN3xpu5vcosfEDv16_f"); + +template struct VLibOpFP16; + +#define VLibOpFP162DevCall(SrcType, ARCH, ASM_STR) \ + template <> struct VLibOpFP16 { \ + static const llvm::StringRef value; \ + }; \ + const llvm::StringRef VLibOpFP16::value = ASM_STR; + +VLibOpFP162DevCall(triton::xpu::VSinFOp, 2, "_ZN3xpu5vsinfEDv32_t"); +VLibOpFP162DevCall(triton::xpu::VCosFOp, 2, "_ZN3xpu5vcosfEDv32_t"); +VLibOpFP162DevCall(triton::xpu::VSinFOp, 3, "_ZN3xpu5vsinfEDv32_DF16_"); +VLibOpFP162DevCall(triton::xpu::VCosFOp, 3, "_ZN3xpu5vcosfEDv32_DF16_"); + +} // namespace + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +struct XPUVectorizedOpsConversionBase { + + explicit XPUVectorizedOpsConversionBase( + const triton::xpu::TargetInfo &targetInfo) { + switch (static_cast(targetInfo.getXPUArch())) { + case XPUArch::XPU2: { + xpuArch = 2; + break; + } + case XPUArch::XPU3: { + xpuArch = 3; + break; + } + default: + llvm_unreachable( + "Failed to create GM2LMOp with unsupported xpu architecture."); + } + } + + unsigned getVectorSize(Type type) const { + auto vectorTy = mlir::dyn_cast(type); + if (!vectorTy) + return 1; + auto elemTy = vectorTy.getElementType(); + auto width = elemTy.getIntOrFloatBitWidth(); + + auto shape = vectorTy.getShape(); + if (shape[0] != 16) { // return vecSize = numElems for vector + return shape[0]; + } + + return 512 / width; + } + + Type convertVectorType(Type type) const { + auto vectorType = mlir::cast(type); + auto ctx = vectorType.getContext(); + auto elemTy = vectorType.getElementType(); + if (elemTy.isF16()) + return LLVM::getFixedVectorType(LLVM::type::f16Ty(ctx), + getVectorSize(type)); + else if (elemTy.isF32()) + return LLVM::getFixedVectorType(LLVM::type::f32Ty(ctx), + getVectorSize(type)); + else if (elemTy.isInteger(16)) + return LLVM::getFixedVectorType(LLVM::type::i16Ty(ctx), + getVectorSize(type)); + else if (elemTy.isInteger(32)) + return LLVM::getFixedVectorType(LLVM::type::i32Ty(ctx), + getVectorSize(type)); + else if (elemTy.isBF16()) + return LLVM::getFixedVectorType(LLVM::type::bf16Ty(ctx), + getVectorSize(type)); + + llvm_unreachable("Not implemented."); + } + +protected: + int xpuArch = 3; +}; + +template +struct VVBinOpsConversion : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename SrcOp::Adaptor; + + VVBinOpsConversion(LLVMTypeConverter &converter, PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + LogicalResult + matchAndRewrite(SrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + Value lllhs = adaptor.getLhs(); + Value llrhs = adaptor.getRhs(); + + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto valueTy = lhs.getType(); + + Type valueElemTy = + getTypeConverter()->convertType(getElementTypeOrSelf(valueTy)); + unsigned numElems = getTotalElemsPerThread(valueTy); + + auto lhsElems = unpackLLElements(loc, lllhs, rewriter); + auto rhsElems = unpackLLElements(loc, llrhs, rewriter); + assert(lhsElems.size() == rhsElems.size()); + + SmallVector calculatedVals; + for (size_t vecStart = 0; vecStart < numElems; vecStart += 1) { + Value vaddOp = + rewriter.create(loc, convertVectorType(valueElemTy), + lhsElems[vecStart], rhsElems[vecStart]); + calculatedVals.push_back(vaddOp); + } + + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), calculatedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + + return success(); + } +}; + +template +struct SVBinOpsConversion : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename SrcOp::Adaptor; + + SVBinOpsConversion(LLVMTypeConverter &converter, PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + LogicalResult + matchAndRewrite(SrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + int32_t elemStateInt = op.getElemState(); + ElemState elemState = static_cast(elemStateInt); + + Value lllhs = adaptor.getLhs(); + Value llrhs = adaptor.getRhs(); + + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + Type valueTy; + + if (elemState == ElemState::SV) { + valueTy = rhs.getType(); + } else if (elemState == ElemState::VS) { + valueTy = lhs.getType(); + } + + Type valueElemTy = + getTypeConverter()->convertType(getElementTypeOrSelf(valueTy)); + unsigned numElems = getTotalElemsPerThread(valueTy); + + // Get data from a struct + auto lhsElems = unpackLLElements(loc, lllhs, rewriter); + auto rhsElems = unpackLLElements(loc, llrhs, rewriter); + + // Create LLVM Op + SmallVector calculatedVals; + Type vecTy = getElementTypeOrSelf(valueTy); + Type elemTy = getElementTypeOrSelf(vecTy); + StringRef asm_string; + if (elemTy.isF32()) { + asm_string = SVOp2Str::value; + } else if (elemTy.isF16()) { + asm_string = SVOp2StrFP16::value; + } else { + llvm_unreachable("Only FP16 and FP32 are supported in SVBinary!"); + } + StringRef constraints = "=v,r,v"; + for (size_t vecStart = 0; vecStart < numElems; vecStart += 1) { + if (elemState == ElemState::SV) { + SmallVector operands({lhsElems[0], rhsElems[vecStart]}); + auto asmOp = rewriter.create( + loc, valueElemTy, operands, asm_string, constraints, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + calculatedVals.push_back(asmOp.getRes()); + } else if (elemState == ElemState::VS) { + SmallVector operands({rhsElems[0], lhsElems[vecStart]}); + auto asmOp = rewriter.create( + loc, valueElemTy, operands, asm_string, constraints, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + calculatedVals.push_back(asmOp.getRes()); + } + } + + // Wrap data into a struct + Type llvmResultStructTy = getTypeConverter()->convertType(valueTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), calculatedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + + return success(); + } +}; + +template +struct UnaryOpConversion : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename SrcOp::Adaptor; + + UnaryOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + LogicalResult + matchAndRewrite(SrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Value value = op.getValue(); + Value result = op.getResult(); + + Value llvalue = adaptor.getValue(); + + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto resultTy = result.getType(); + Type resultElemTy = + getTypeConverter()->convertType(getElementTypeOrSelf(resultTy)); + unsigned numElems = getTotalElemsPerThread(value.getType()); + + auto valueElems = unpackLLElements(loc, llvalue, rewriter); + + SmallVector calculatedVals; + for (size_t vecStart = 0; vecStart < numElems; vecStart += 1) { + Value vexpOp = rewriter.create( + loc, convertVectorType(resultElemTy), valueElems[vecStart]); + calculatedVals.push_back(vexpOp); + } + + Type llvmResultStructTy = getTypeConverter()->convertType(resultTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), calculatedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + + return success(); + } +}; + +template +struct VOpConversionLibCall : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename SrcOp::Adaptor; + + VOpConversionLibCall(LLVMTypeConverter &converter, PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + LogicalResult + matchAndRewrite(SrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTy = op.getType(); + Location loc = op->getLoc(); + unsigned vecElems = getTotalElemsPerThread(resultTy); + auto resultVecTy = getElementTypeOrSelf(resultTy); + Type vecTy = this->getTypeConverter()->convertType(resultVecTy); + auto elemTy = getElementTypeOrSelf(vecTy); + SmallVector types(vecElems, vecTy); + Type structTy = this->getTypeConverter()->convertType(resultTy); + + auto operands = getOperands(rewriter, adaptor, vecElems, loc); + SmallVector resultVals(vecElems); + for (unsigned i = 0; i < vecElems; ++i) { + ValueRange singleOperandRange(operands[i]); + if (elemTy.isF32()) { + Value devCall = mlir::LLVM::XPU::createDeviceCall( + VLibOp::value, rewriter, op, vecTy, singleOperandRange, loc); + resultVals[i] = devCall; + } else if (elemTy.isF16()) { + Value devCall; + switch (xpuArch) { + case 2: + devCall = mlir::LLVM::XPU::createDeviceCall( + VLibOpFP16::value, rewriter, op, vecTy, + singleOperandRange, loc); + break; + case 3: + devCall = mlir::LLVM::XPU::createDeviceCall( + VLibOpFP16::value, rewriter, op, vecTy, + singleOperandRange, loc); + break; + default: + llvm_unreachable("Failed to create device call with unsupported xpu " + "architecture."); + } + resultVals[i] = devCall; + } else { + llvm_unreachable("Only FP16 and FP32 are supported in LibDevice!"); + } + if (!bool(resultVals[i])) + return failure(); + } + Value view = + packLLElements(loc, getTypeConverter(), resultVals, rewriter, structTy); + rewriter.replaceOp(op, view); + + return success(); + } + +private: + SmallVector> + getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor, + const unsigned elems, Location loc) const { + SmallVector> operands(elems); + for (auto operand : adaptor.getOperands()) { + auto sub_operands = unpackLLElements(loc, operand, rewriter); + for (size_t i = 0; i < elems; ++i) { + operands[i].push_back(sub_operands[i]); + } + } + return operands; + } +}; + +template +struct VConstOpConversion : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename SrcOp::Adaptor; + + VConstOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + LogicalResult + matchAndRewrite(SrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Value res = op.getResult(); + Attribute attr = adaptor.getValueAttr(); + + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto resTy = res.getType(); + Type resElemTy = + getTypeConverter()->convertType(getElementTypeOrSelf(resTy)); + unsigned numElems = getTotalElemsPerThread(res.getType()); + + SmallVector calculatedVals; + for (size_t vecStart = 0; vecStart < numElems; vecStart += 1) { + Value vconstOp = + rewriter.create(loc, convertVectorType(resElemTy), attr); + calculatedVals.push_back(vconstOp); + } + + Type llvmResultStructTy = getTypeConverter()->convertType(resTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), calculatedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + + return success(); + } +}; + +template +struct VSplatOpConversion : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename SrcOp::Adaptor; + + VSplatOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + Value convertSplatLikeOp(Type resTy, Value llsrc, Type llvmResultStructTy, + ConversionPatternRewriter &rewriter, + Location loc) const { + auto resElemTy = + this->getTypeConverter()->convertType(getElementTypeOrSelf(resTy)); + size_t elemsPerThread = getTotalElemsPerThread(resTy); + + auto valueElems = unpackLLElements(loc, llsrc, rewriter); + + Value vector_1xTy = rewriter.create(loc, resElemTy); + vector_1xTy = + insert_element(resElemTy, vector_1xTy, valueElems[0], i32_val(0)); + + int32_t vecSize = cast(resElemTy).getNumElements(); + SmallVector zeroValues(vecSize, 0); + // TODO[dyq]: check getI32ArrayAttr -> getDenseI32ArrayAttr + auto zeroAttrs = rewriter.getDenseI32ArrayAttr(zeroValues); + Value shuffleVectorOp = rewriter.create( + loc, resElemTy, vector_1xTy, vector_1xTy, zeroAttrs); + + llvm::SmallVector elems(elemsPerThread, shuffleVectorOp); + + return packLLElements(loc, this->getTypeConverter(), elems, rewriter, + llvmResultStructTy); + } + + LogicalResult matchAndRewrite(SrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op->getLoc(); + auto llsrc = adaptor.getSrc(); + auto llvmResultStructTy = + this->getTypeConverter()->convertType(op.getType()); + auto llStruct = convertSplatLikeOp(op.getType(), llsrc, llvmResultStructTy, + rewriter, loc); + + rewriter.replaceOp(op, {llStruct}); + return success(); + } +}; + +template +struct VSelectOpConversion : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename SrcOp::Adaptor; + + VSelectOpConversion(LLVMTypeConverter &converter, + + PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + LogicalResult matchAndRewrite(SrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // original values + Value condition = op.getCondition(); + Value true_value = op.getTrueValue(); + Value false_value = op.getFalseValue(); + + // adaptor values + Value llCondition = adaptor.getCondition(); + Value llTrue_value = adaptor.getTrueValue(); + Value llFalse_value = adaptor.getFalseValue(); + + MLIRContext *ctx = rewriter.getContext(); + auto loc = op->getLoc(); + auto resTy = op.getType(); + + Type resElemTy = getTypeConverter()->convertType( + getElementTypeOrSelf(resTy)); // vector<16xf32> + unsigned numElems = getTotalElemsPerThread(resTy); + + // Get data from a struct + auto conditionElems = unpackLLElements(loc, llCondition, rewriter); + auto trueValElems = unpackLLElements(loc, llTrue_value, rewriter); + auto falseValElems = unpackLLElements(loc, llFalse_value, rewriter); + + // Create LLVM Op + Type elemTy = getElementTypeOrSelf(resElemTy); + unsigned elemBits = elemTy.getIntOrFloatBitWidth(); + unsigned vecSize = mlir::cast(resElemTy).getNumElements(); + SmallVector resVals; + + for (size_t elemIter = 0; elemIter < numElems; ++elemIter) { + // Step 1. Convert Condition To v32i1/v16i1 Mask + Value orV = i32_val(0); + for (size_t conditionIter = 0; conditionIter < vecSize; ++conditionIter) { + Value boolVal = + isa(conditionElems[0].getType()) + ? extract_element(i1_ty, conditionElems[elemIter], + i32_val(conditionIter)) + : conditionElems[elemIter * vecSize + conditionIter]; + Value extV = zext(i32_ty, boolVal); + Value shlV = shl(extV, i32_val(conditionIter)); + orV = or_(orV, shlV); + } + VectorType maskTy = VectorType::get(32, i1_ty); + Value maskV = bitcast(orV, maskTy); + + if (elemTy.isF32()) { + // Step 2. vset_zero() + StringRef xor_asm_string = "vxor.s.mz $0{mr1}, $0, $0"; + StringRef xor_constraints = "=v"; + SmallVector xor_operands({}); + auto zerosIAsmOp = rewriter.create( + loc, resElemTy, xor_operands, xor_asm_string, xor_constraints, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + Value zerosFAsmOp = bitcast(zerosIAsmOp.getRes(), resElemTy); + // Step 3. vvor_float32x16_mh(mask, zero, a, b) + Value vvorFOp = rewriter.create( + loc, resElemTy, maskV, zerosFAsmOp, trueValElems[elemIter], + falseValElems[elemIter]); + resVals.push_back(vvorFOp); + } else if (elemTy.isInteger(32)) { + // Step 2. vset_zero() + StringRef xor_asm_string = "vxor.s.mz $0{mr1}, $0, $0"; + StringRef xor_constraints = "=v"; + SmallVector xor_operands({}); + auto zerosIAsmOp = rewriter.create( + loc, resElemTy, xor_operands, xor_asm_string, xor_constraints, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + Value zerosFAsmOp = bitcast(zerosIAsmOp.getRes(), resElemTy); + // Step 3. vvor_int32x16_mh(mask, zero, a, b) + Value vvorFOp = rewriter.create( + loc, resElemTy, maskV, zerosFAsmOp, trueValElems[elemIter], + falseValElems[elemIter]); + resVals.push_back(vvorFOp); + } else if (elemTy.isF16()) { + // Step 2. vset_zero() + StringRef xor_asm_string = "vxor.hf.mz $0{mr1}, $0, $0"; + StringRef xor_constraints = "=v"; + SmallVector xor_operands({}); + auto zerosFAsmOp = rewriter.create( + loc, resElemTy, xor_operands, xor_asm_string, xor_constraints, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + // Step 3. vvor_float16x32_mh(mask, zero, a, b) + Value vvorFOp = rewriter.create( + loc, resElemTy, maskV, zerosFAsmOp.getRes(), trueValElems[elemIter], + falseValElems[elemIter]); + resVals.push_back(vvorFOp); + } else { + llvm_unreachable("Only FP16 and FP32 are supported in VSelect!"); + } + } + + // Wrap data into a struct + auto llvmResultStructTy = getTypeConverter()->convertType(resTy); + auto llStruct = packLLElements(loc, getTypeConverter(), resVals, rewriter, + llvmResultStructTy); + rewriter.replaceOp(op, {llStruct}); + + return success(); + } +}; + +template +struct VMacFOpConversion : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename SrcOp::Adaptor; + + VMacFOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + LogicalResult matchAndRewrite(SrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // original values + Value value = op.getValue(); + Value mulData = op.getMulData(); + Value addData = op.getAddData(); + + // adaptor values + Value llValue = adaptor.getValue(); + Value llMulData = adaptor.getMulData(); + Value llAddData = adaptor.getAddData(); + auto attrs = adaptor.getAttributes(); + + MLIRContext *ctx = rewriter.getContext(); + auto loc = op->getLoc(); + auto resTy = op.getType(); + + auto resElemTy = getTypeConverter()->convertType( + getElementTypeOrSelf(resTy)); // vector<16xf32> + unsigned numElems = getTotalElemsPerThread(resTy); + + // Get data from a struct + auto valueElems = unpackLLElements(loc, llValue, rewriter); + auto mulElems = unpackLLElements(loc, llMulData, rewriter); + auto addElems = unpackLLElements(loc, llAddData, rewriter); + + // Create LLVM Op + SmallVector calculatedVals; + auto elemTy = getElementTypeOrSelf(resElemTy); + StringRef asm_string; + if (elemTy.isF32()) { + asm_string = "vmac.f.mz.rn $0{mr1}, $1, $2"; + } else if (elemTy.isF16()) { + asm_string = "vmac.hf.mz.rn $0{mr1}, $1, $2"; + } else { + llvm_unreachable("Only FP16 and FP32 are supported in VMac!"); + } + StringRef constraints = "=v,v,v,0"; + for (size_t vecStart = 0; vecStart < numElems; vecStart += 1) { + SmallVector operands( + {valueElems[vecStart], mulElems[vecStart], addElems[vecStart]}); + auto asmOp = rewriter.create( + loc, resElemTy, operands, asm_string, constraints, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr()); + calculatedVals.push_back(asmOp.getRes()); + } + + // Wrap data into a struct + auto llvmResultStructTy = getTypeConverter()->convertType(resTy); + auto llStruct = packLLElements(loc, getTypeConverter(), calculatedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {llStruct}); + + return success(); + } +}; + +struct VExtFOpConversion : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename triton::xpu::VExtFOp::Adaptor; + + VExtFOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + Value convertFp16ToFp32(Location loc, ConversionPatternRewriter &rewriter, + Value val, Value res, Type valElemTy, Type resElemTy, + Value llVal, Type llvmResultStructTy) const { + auto ctx = rewriter.getContext(); + auto llVals = unpackLLElements(loc, llVal, rewriter); + unsigned numElems = getTotalElemsPerThread(val.getType()); + + SmallVector fp32x16Vecs; + for (int i = 0; i < numElems; ++i) { + auto asml = rewriter.create( + loc, resElemTy, ValueRange{llVals[i]}, // operands + "vfp162float_l.rn $0, $1", // asm_string + "=&v,v", // constraints + false, // has_size_effects + false, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr::get(ctx, {})); + fp32x16Vecs.emplace_back(asml.getRes()); + auto asmh = rewriter.create( + loc, resElemTy, ValueRange{llVals[i]}, // operands + "vfp162float_h.rn $0, $1", // asm_string + "=&v,v", // constraints + false, // has_size_effects + false, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr::get(ctx, {})); + fp32x16Vecs.emplace_back(asmh.getRes()); + } + + Value resultStruct = packLLElements(loc, getTypeConverter(), fp32x16Vecs, + rewriter, llvmResultStructTy); + return resultStruct; + } + + Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter, + Value val, Value res, Type valElemTy, Type resElemTy, + Value llVal, Type llvmResultStructTy) const { + auto ctx = rewriter.getContext(); + auto llVals = unpackLLElements(loc, llVal, rewriter); + unsigned numElems = getTotalElemsPerThread(val.getType()); + + VectorType vecFp16Ty = VectorType::get(32, f16_ty); + Value padVec = rewriter.create(loc, vecFp16Ty); + for (size_t elemIdx = 0; elemIdx < 32; ++elemIdx) { + padVec = insert_element(vecFp16Ty, padVec, f16_val(0), i16_val(elemIdx)); + } + + SmallVector fp32x16Vecs; + for (int i = 0; i < numElems; ++i) { + Value val = bitcast(llVals[i], vecFp16Ty); + Value vl = rewriter.create(loc, vecFp16Ty, + padVec, val); + vl = bitcast(vl, resElemTy); + fp32x16Vecs.emplace_back(vl); + Value vh = rewriter.create(loc, vecFp16Ty, + padVec, val); + vh = bitcast(vh, resElemTy); + fp32x16Vecs.emplace_back(vh); + } + + Value resultStruct = packLLElements(loc, getTypeConverter(), fp32x16Vecs, + rewriter, llvmResultStructTy); + return resultStruct; + } + + LogicalResult + matchAndRewrite(triton::xpu::VExtFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op->getLoc(); + auto val = op.getValue(); + auto res = op.getResult(); + auto llVal = adaptor.getValue(); + + Type valTy = val.getType(); + Type resTy = res.getType(); + auto valElemTy = getElementTypeOrSelf(valTy); + auto _valElemTy = getElementTypeOrSelf(valElemTy); + auto resElemTy = getElementTypeOrSelf(resTy); + auto _resElemTy = getElementTypeOrSelf(resElemTy); + auto llValElemTy = typeConverter->convertType(valElemTy); + auto llResElemTy = typeConverter->convertType(resElemTy); + Type llvmResultStructTy = getTypeConverter()->convertType(resTy); + assert(_resElemTy.isF32() && "Only support F32 as target dtype inVExtF!"); + if (_valElemTy.isF16()) { + auto result = convertFp16ToFp32(loc, rewriter, val, res, valElemTy, + resElemTy, llVal, llvmResultStructTy); + rewriter.replaceOp(op, {result}); + } else if (_valElemTy.isBF16()) { + auto result = convertBf16ToFp32(loc, rewriter, val, res, valElemTy, + resElemTy, llVal, llvmResultStructTy); + rewriter.replaceOp(op, {result}); + } else { + assert(0 && "Only support FP16 as source dtype in VExtF!"); + } + return success(); + } +}; + +struct VTruncFOpConversion + : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename triton::xpu::VTruncFOp::Adaptor; + + VTruncFOpConversion(LLVMTypeConverter &converter, + + PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + Value convertFp32ToFp16(Location loc, ConversionPatternRewriter &rewriter, + Value val, Value res, Type valElemTy, Type resElemTy, + Value llVal, Type llvmResultStructTy) const { + auto ctx = rewriter.getContext(); + auto llVals = unpackLLElements(loc, llVal, rewriter); + unsigned numElems = getTotalElemsPerThread(val.getType()); + + SmallVector fp16x32Vecs; + for (int i = 0; i < numElems; i += 2) { + auto asmlh = rewriter.create( + loc, resElemTy, ValueRange{llVals[i], llVals[i + 1]}, // operands + "vfloat2fp16_l.rn $0, $1\nvfloat2fp16_h.rn $0, $2", // asm_string + "=&v,v,v", // constraints + false, // has_size_effects + false, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, LLVM::AsmDialect::AD_ATT), + ArrayAttr::get(ctx, {})); + fp16x32Vecs.push_back(asmlh.getRes()); + } + + Value resultStruct = packLLElements(loc, getTypeConverter(), fp16x32Vecs, + rewriter, llvmResultStructTy); + return resultStruct; + } + + LogicalResult + matchAndRewrite(triton::xpu::VTruncFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op->getLoc(); + auto val = op.getValue(); + auto res = op.getResult(); + auto llVal = adaptor.getValue(); + + Type valTy = val.getType(); + Type resTy = res.getType(); + auto valElemTy = getElementTypeOrSelf(valTy); + auto _valElemTy = getElementTypeOrSelf(valElemTy); + auto resElemTy = getElementTypeOrSelf(resTy); + auto _resElemTy = getElementTypeOrSelf(resElemTy); + auto llValElemTy = typeConverter->convertType(valElemTy); + auto llResElemTy = typeConverter->convertType(resElemTy); + Type llvmResultStructTy = getTypeConverter()->convertType(resTy); + assert(_valElemTy.isF32() && + "Only support F32 as source dtype in VTruncF!"); + if (_resElemTy.isF16()) { + auto result = convertFp32ToFp16(loc, rewriter, val, res, valElemTy, + resElemTy, llVal, llvmResultStructTy); + rewriter.replaceOp(op, {result}); + } else { + assert(0 && "Only support FP16 as target dtype in VTruncF!"); + } + return success(); + } +}; + +struct VCmpFOpConversion : public ConvertOpToLLVMPattern, + public XPUVectorizedOpsConversionBase { + + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::getTypeConverter; + using OpAdaptor = typename triton::xpu::VCmpFOp::Adaptor; + + VCmpFOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit, + const triton::xpu::TargetInfo &targetInfo) + : ConvertOpToLLVMPattern(converter, benefit), + XPUVectorizedOpsConversionBase(targetInfo) {} + + LogicalResult matchAndRewrite(triton::xpu::VCmpFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + Value lllhs = adaptor.getLhs(); + Value llrhs = adaptor.getRhs(); + + auto loc = op->getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + auto valueTy = lhs.getType(); + + Type valueElemTy = + getTypeConverter()->convertType(getElementTypeOrSelf(valueTy)); + unsigned numElems = getTotalElemsPerThread(valueTy); + + auto lhsElems = unpackLLElements(loc, lllhs, rewriter); + auto rhsElems = unpackLLElements(loc, llrhs, rewriter); + assert(lhsElems.size() == rhsElems.size()); + + auto resTy = op.getResult().getType(); + Type resElemTy = + getTypeConverter()->convertType(getElementTypeOrSelf(resTy)); + + SmallVector calculatedVals; + for (size_t vecStart = 0; vecStart < numElems; vecStart += 1) { + Value vcmpfOp = rewriter.create( + loc, resElemTy, ArithCmpFPredicateToLLVM(op.getPredicate()), + lhsElems[vecStart], rhsElems[vecStart]); + calculatedVals.push_back(vcmpfOp); + } + + Type llvmResultStructTy = getTypeConverter()->convertType(resTy); + Value resultStruct = packLLElements(loc, getTypeConverter(), calculatedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + + return success(); + } + + static LLVM::FCmpPredicate + ArithCmpFPredicateToLLVM(arith::CmpFPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__, item1__) \ + case arith::CmpFPredicate::item__: \ + return LLVM::FCmpPredicate::item1__ + + __PRED_ENUM(OEQ, oeq); + __PRED_ENUM(ONE, one); + __PRED_ENUM(OGT, ogt); + __PRED_ENUM(OGE, oge); + __PRED_ENUM(OLT, olt); + __PRED_ENUM(OLE, ole); + __PRED_ENUM(ORD, ord); + __PRED_ENUM(UEQ, ueq); + __PRED_ENUM(UGT, ugt); + __PRED_ENUM(UGE, uge); + __PRED_ENUM(ULT, ult); + __PRED_ENUM(ULE, ule); + __PRED_ENUM(UNE, une); + __PRED_ENUM(UNO, uno); + __PRED_ENUM(AlwaysTrue, _true); + __PRED_ENUM(AlwaysFalse, _false); + +#undef __PRED_ENUM + } + llvm_unreachable("Unknown arith::CmpFPredicate"); + } +}; + +} // namespace + +void mlir::triton::xpu::populateTTXPUVectorizedOpToLLVMConversionPatterns( + LLVMTypeConverter &typeConverter, const triton::xpu::TargetInfo &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion, + VVBinOpsConversion>( + typeConverter, benefit, targetInfo); + patterns.add, + SVBinOpsConversion, + SVBinOpsConversion, + SVBinOpsConversion>(typeConverter, + benefit, targetInfo); + patterns.add, + UnaryOpConversion, + UnaryOpConversion, + UnaryOpConversion>( + typeConverter, benefit, targetInfo); + patterns.add, + VOpConversionLibCall>(typeConverter, + benefit, targetInfo); + patterns.add>( + typeConverter, benefit, targetInfo); + patterns.add>(typeConverter, + benefit, targetInfo); + patterns.add>( + typeConverter, benefit, targetInfo); + patterns.add>(typeConverter, benefit, + targetInfo); + patterns.add(typeConverter, benefit, targetInfo); + patterns.add(typeConverter, benefit, targetInfo); + patterns.add(typeConverter, benefit, targetInfo); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ViewOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ViewOpToLLVM.cpp new file mode 100644 index 000000000..727e6383a --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/ViewOpToLLVM.cpp @@ -0,0 +1,139 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +namespace { +struct XPUExpandDimsOpConversion + : public ConvertOpToLLVMPattern { + + XPUExpandDimsOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto typeConverter = getTypeConverter(); + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(op.getType()); + auto srcLayout = dyn_cast(srcTy.getEncoding()); + if (!srcLayout) { + return emitOptionalError( + loc, "ExpandDimsOp only supports SliceEncodingAttr as its input"); + } + auto resultLayout = resultTy.getEncoding(); + + Value ret = packLLElements(loc, typeConverter, srcVals, rewriter, resultTy); + rewriter.replaceOp(op, ret); + return success(); + } +}; + +struct XPUBroadcastOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::xpu::BroadcastOp>::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::xpu::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Following the order of indices in the legacy code, a broadcast of: + // [s(0), s(1) ... s(k-1), 1, s(k+1), s(k+2) ... s(n-1)] + // => + // [s(0), s(1) ... s(k-1), s(k), s(k+1), s(k+2) ... s(n-1)] + // + // logically maps to a broadcast within a thread's scope: + // [cta(0)..cta(k-1), 1,cta(k+1)..cta(n-1),spt(0)..spt(k-1), + // 1,spt(k+1)..spt(n-1)] + // => + // [cta(0)..cta(k-1),cta(k),cta(k+1)..cta(n-1),spt(0)..spt(k-1),spt(k),spt(k+1)..spt(n-1)] + // + // regardless of the order of the layout + // + Location loc = op->getLoc(); + Value src = adaptor.getSrc(); + Value result = op.getResult(); + auto srcTy = cast(op.getSrc().getType()); + auto resultTy = cast(result.getType()); + Type resElemTy = getElementTypeOrSelf(resultTy); + bool isVectorized = isa(resElemTy); + + auto srcLayout = srcTy.getEncoding(); + auto resultLayout = resultTy.getEncoding(); + auto srcShape = srcTy.getShape(); + auto resultShape = resultTy.getShape(); + unsigned rank = srcTy.getRank(); + auto typeConverter = getTypeConverter(); + assert(rank == resultTy.getRank()); + auto order = triton::gpu::getOrder(srcLayout); + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + SmallVector srcVals = unpackLLElements(loc, src, rewriter); + std::map, Value> srcValues; + if (isVectorized) { + // TODO: Enhance VBroadcastOp + assert(rank == 2 && "BroadcastOp is Vectorized, But Rank != 2"); + auto rowsPerCore = cast(resultLayout) + .getSizePerCore()[0]; + Type elemTy = getElementTypeOrSelf(resElemTy); + unsigned vecSize = 512 / elemTy.getIntOrFloatBitWidth(); + if (srcShape[1] == 1) { + for (size_t i = 0; i < srcOffsets.size(); i++) { + Value srcVals_0_vector = + rewriter.create(loc, resElemTy); + for (size_t elemStart = 0; elemStart < vecSize; ++elemStart) { + srcVals_0_vector = insert_element(resElemTy, srcVals_0_vector, + srcVals[i], i32_val(elemStart)); + } + srcValues[srcOffsets[i]] = srcVals_0_vector; + } + } else if (srcShape[0] == 1) { + SmallVector srcVectorVals; + for (size_t i = 0; i < resultOffsets.size() / rowsPerCore; i++) { + Value srcVals_vector = rewriter.create(loc, resElemTy); + for (size_t elemStart = 0; elemStart < vecSize; ++elemStart) { + srcVals_vector = insert_element(resElemTy, srcVals_vector, + srcVals[i * vecSize + elemStart], + i32_val(elemStart)); + } + srcVectorVals.push_back(srcVals_vector); + } + for (size_t i = 0; i < resultOffsets.size() / rowsPerCore; i++) { + srcValues[resultOffsets[i]] = srcVectorVals[i]; + } + } else { + llvm_unreachable("Only Support Vectorized BroadcastOp: [Mx1xTy] -> " + "[MxNxTy] or [1xNxTy] -> [MxNxTy]"); + } + } else { + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + } + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + for (size_t j = 0; j < srcShape.size(); j++) + if (srcShape[j] == 1) + offset[j] = 0; + resultVals.push_back(srcValues.at(offset)); + } + Value resultStruct = + packLLElements(loc, typeConverter, resultVals, rewriter, resultTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } +}; + +} // namespace + +void mlir::triton::xpu::populateViewOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); +} diff --git a/third_party/xpu/lib/Conversion/TritonXPUToLLVM/XPUUtilityOpToLLVM.cpp b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/XPUUtilityOpToLLVM.cpp new file mode 100644 index 000000000..e7e59548f --- /dev/null +++ b/third_party/xpu/lib/Conversion/TritonXPUToLLVM/XPUUtilityOpToLLVM.cpp @@ -0,0 +1,227 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "xpu/lib/Conversion/TritonXPUToLLVM/PatternTritonXPUOpToLLVM.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using ::mlir::triton::gpu::getTotalElemsPerThread; + +struct XPUExtractOpConversion + : public ConvertOpToLLVMPattern { + + XPUExtractOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::xpu::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + + // original values + auto index = op.getIndex(); + + // adaptor values + auto llTensor = adaptor.getTensor(); + + // Get the LLVM values + auto llTensors = unpackLLElements(loc, llTensor, rewriter); + + // TODO[dyq]: necessary? + // assert(index >= 0 && index < llTensors.size() && + // "Get Invalid Index For triton::xpu::ExtractOp"); + + // Modifition Logic + rewriter.replaceOp(op, {llTensors[index]}); + return success(); + }; +}; + +struct XPUExtractSliceOpConversion + : public ConvertOpToLLVMPattern { + + XPUExtractSliceOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit) {} + + LogicalResult + matchAndRewrite(triton::xpu::ExtractSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + auto llTensor = adaptor.getTensor(); + auto llTensors = unpackLLElements(loc, llTensor, rewriter); + + auto resType = op.getResult().getType(); + auto rankedTy = mlir::dyn_cast(resType); + unsigned elems = getTotalElemsPerThread(resType); + SmallVector retVals(elems); + for (unsigned i = 0; i < elems; ++i) { + retVals[i] = llTensors[i]; + } + Type llvmResultStructTy = getTypeConverter()->convertType(resType); + Value resultStruct = packLLElements(loc, getTypeConverter(), retVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + }; +}; + +struct XPUGetThreadIdOpConversion + : public ConvertOpToLLVMPattern { + + XPUGetThreadIdOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) { + } + + LogicalResult + matchAndRewrite(triton::xpu::GetThreadIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + + auto threadType = op.getThreadType(); + auto resType = op.getResult().getType(); + auto rankedTy = mlir::dyn_cast(resType); + unsigned elems = getTotalElemsPerThread(resType); + + SmallVector retVals(elems); + Value clusterNum = mlir::LLVM::XPU::getGridDim(rewriter, loc); + Value coreNum = mlir::LLVM::XPU::getBlockDim(rewriter, loc); + Value clusterId = mlir::LLVM::XPU::getBlockId(rewriter, loc); + Value coreId = getThreadId(rewriter, loc); + Value threadId; + switch (threadType) { + case 0: { + // tid = core_id() * cluster_num() + cluster_id() + threadId = add(mul(coreId, clusterNum), clusterId); + break; + } + case 1: { + // tid = core_num() * cluster_id() + core_id() + threadId = add(mul(coreNum, clusterId), coreId); + break; + } + default: + llvm_unreachable("Unknown threadId Type"); + } + + for (unsigned i = 0; i < elems; ++i) { + retVals[i] = threadId; + } + + Type llvmResultStructTy = getTypeConverter()->convertType(resType); + Value resultStruct = packLLElements(loc, getTypeConverter(), retVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + }; +}; + +struct XPUGetClusterIdOpConversion + : public ConvertOpToLLVMPattern { + + XPUGetClusterIdOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit) {} + + LogicalResult + matchAndRewrite(triton::xpu::GetClusterIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + Value retVal = rewriter.create( + loc, type::i32Ty(ctx), i32_val(0)); + + rewriter.replaceOp(op, {retVal}); + return success(); + }; +}; + +struct XPUGetCoreIdOpConversion + : public ConvertOpToLLVMPattern { + + XPUGetCoreIdOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::xpu::GetCoreIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + Value retVal = + rewriter.create(loc, type::i32Ty(ctx)); + + rewriter.replaceOp(op, {retVal}); + return success(); + }; +}; + +struct XPUGetNumClusterOpConversion + : public ConvertOpToLLVMPattern { + + XPUGetNumClusterOpConversion(LLVMTypeConverter &converter, + const xpu::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, + benefit) {} + + LogicalResult + matchAndRewrite(triton::xpu::GetNumClusterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + Value retVal = rewriter.create( + loc, type::i32Ty(ctx), i32_val(1)); + + rewriter.replaceOp(op, {retVal}); + return success(); + }; +}; + +} // namespace + +void mlir::triton::xpu::populateTTXPUUtilityOpToLLVMConversionPatterns( + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, + axisInfoAnalysis, benefit); +} diff --git a/third_party/xpu/lib/Dialect/CMakeLists.txt b/third_party/xpu/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..b8ffa31f9 --- /dev/null +++ b/third_party/xpu/lib/Dialect/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(Triton) +add_subdirectory(TritonGPU) +add_subdirectory(TritonNvidiaGPU) +add_subdirectory(TritonXPU) +add_subdirectory(LLVMXPU) diff --git a/third_party/xpu/lib/Dialect/LLVMXPU/CMakeLists.txt b/third_party/xpu/lib/Dialect/LLVMXPU/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/third_party/xpu/lib/Dialect/LLVMXPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/xpu/lib/Dialect/LLVMXPU/IR/CMakeLists.txt b/third_party/xpu/lib/Dialect/LLVMXPU/IR/CMakeLists.txt new file mode 100644 index 000000000..0d52d028e --- /dev/null +++ b/third_party/xpu/lib/Dialect/LLVMXPU/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(MLIRLLVMXPUDialect + Dialect.cpp + +# ADDITIONAL_HEADER_DIRS +# ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR +# ${PROJECT_SOURCE_DIR}/third_party/triton/Dialect/LLVMXPU/IR/ + + DEPENDS + MLIRXPUConversionsIncGen + +# LINK_LIBS PUBLIC +# MLIRIR +# MLIRLLVMDialect +# MLIRSideEffectInterfaces +) diff --git a/third_party/xpu/lib/Dialect/LLVMXPU/IR/Dialect.cpp b/third_party/xpu/lib/Dialect/LLVMXPU/IR/Dialect.cpp new file mode 100644 index 000000000..ac863009f --- /dev/null +++ b/third_party/xpu/lib/Dialect/LLVMXPU/IR/Dialect.cpp @@ -0,0 +1,36 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/LLVMXPU/IR/Dialect.h" // before cpp.inc + +#include "triton/Dialect/LLVMXPU/IR/Dialect.cpp.inc" + +//===----------------------------------------------------------------------===// +// Dialect Initialization +//===----------------------------------------------------------------------===// + +void ::mlir::LLVM::XPU::LLVMXPUDialect::initialize() { + addOperations< +#define GET_OP_LIST // declare +#include "triton/Dialect/LLVMXPU/IR/Ops.cpp.inc" + >(); +} + +#define GET_OP_CLASSES // define +#include "triton/Dialect/LLVMXPU/IR/Ops.cpp.inc" + +mlir::LogicalResult +mlir::LLVM::XPU::LLVMXPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // Kernel function attribute should be attached to functions. + if (attr.getName() == LLVMXPUDialect::getKernelFuncAttrName()) { + if (!isa(op)) { + return op->emitError() << "'" << LLVMXPUDialect::getKernelFuncAttrName() + << "' attribute attached to unexpected op"; + } + } + return success(); +} diff --git a/third_party/xpu/lib/Dialect/Triton/CMakeLists.txt b/third_party/xpu/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/xpu/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/xpu/lib/Dialect/Triton/IR/CMakeLists.txt new file mode 100644 index 000000000..752daa7ff --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonIR + Dialect.cpp + Ops.cpp + Traits.cpp + Types.cpp + + DEPENDS + TritonTableGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRArithDialect + MLIRMathDialect + MLIRSCFDialect +) diff --git a/third_party/xpu/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/xpu/lib/Dialect/Triton/IR/Dialect.cpp new file mode 100644 index 000000000..8f46e8ca8 --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/IR/Dialect.cpp @@ -0,0 +1,138 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/DialectImplementation.h" + +#include "mlir/Transforms/InliningUtils.h" +#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" +#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" + +using namespace mlir; +using namespace mlir::triton; + +//===----------------------------------------------------------------------===// +// TritonDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct TritonInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + auto funcOp = dyn_cast(callable); + if (!funcOp) + return true; + if (funcOp->hasAttr("noinline")) + return !funcOp->getAttrOfType("noinline").getValue(); + return true; + } + + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + IRMapping &) const final { + return true; + } + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create(op->getLoc(), newDest, + returnOp.getOperands()); + op->erase(); + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { + // Only return needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } +}; + +struct TensorModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getRank(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementTypeBitWidth(); + } +}; + +struct MemDescModel + : public TensorOrMemDesc::ExternalModel { + Type getElementType(Type pointer) const { + return cast(pointer).getElementType(); + } + Attribute getEncoding(Type pointer) const { + return cast(pointer).getEncoding(); + } + ArrayRef getShape(Type pointer) const { + return cast(pointer).getShape(); + } + int64_t getRank(Type pointer) const { + return cast(pointer).getShape().size(); + } + int64_t getElementTypeBitWidth(Type pointer) const { + return cast(pointer).getElementType().getIntOrFloatBitWidth(); + } +}; + +} // namespace + +void TritonDialect::initialize() { + registerTypes(); + + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + >(); + + // We can also add interface here. + addInterfaces(); + + RankedTensorType::attachInterface(*getContext()); + MemDescType::attachInterface(*getContext()); +} + +Operation *TritonDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return arith::ConstantOp::materialize(builder, value, type, loc); +} diff --git a/third_party/xpu/lib/Dialect/Triton/IR/Ops.cpp b/third_party/xpu/lib/Dialect/Triton/IR/Ops.cpp new file mode 100644 index 000000000..4c4480c58 --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/IR/Ops.cpp @@ -0,0 +1,982 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +namespace mlir { +namespace triton { + +void LoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getPtr(), + triton::GlobalMemory::get()); + if (getIsVolatile()) + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace triton +} // namespace mlir + +#define GET_OP_CLASSES +#include "triton/Dialect/Triton/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { + +//-- LoadOp -- +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile) { + auto paddingAttr = + padding.has_value() + ? PaddingOptionAttr::get(builder.getContext(), padding.value()) + : PaddingOptionAttr(); + LoadOp::build(builder, state, ptr, mask, other, + builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, + evict, isVolatile); +} + +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { + CanonicalizeMaskedLoadPattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return success(); + } +}; + +void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- StoreOp -- +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + /*boundaryCheck=*/{}, cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, Value mask, CacheModifier cache, + EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, + cache, evict); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value value, ArrayRef boundaryCheck, + CacheModifier cache, EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, + builder.getDenseI32ArrayAttr(boundaryCheck), cache, + evict); +} + +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern : public OpRewritePattern { + CanonicalizeMaskedStorePattern(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return failure(); + + auto splatMask = mlir::dyn_cast(constantMask.getValue()); + if (!splatMask) + return failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return success(); + } +}; + +void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//-- TransOp -- +OpFoldResult TransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + return getSrc(); + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + return {}; +} + +LogicalResult TransOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input + auto argTy = cast(operands[0].getType()); + auto order = properties.as()->order.asArrayRef(); + SmallVector retShape = applyPermutation(argTy.getShape(), order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, order, retEncoding) + .failed()) { + return failure(); + } + } + if (isa(argTy)) { + inferredReturnTypes.push_back( + MemDescType::get(retShape, retEltTy, retEncoding)); + } else { + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +LogicalResult TransOp::verify() { + // Check that the op's `order` attribute is a permutation of the right length. + auto srcTy = getSrc().getType(); + + ArrayRef order = getOrder(); + if (order.size() != srcTy.getRank()) { + return emitError("order must have the same size as the rank of the " + "operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +//-- DotOp -- +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = dyn_cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return failure(); + } + return success(); +} + +LogicalResult DotOp::verify() { + auto aTy = getA().getType(); + auto bTy = getB().getType(); + if (aTy.getElementType().getIntOrFloatBitWidth() != + bTy.getElementType().getIntOrFloatBitWidth()) + return emitError( + "element types of operands A and B must have same bit width"); + auto aEncoding = aTy.getEncoding(); + auto bEncoding = bTy.getEncoding(); + if (!aEncoding && !bEncoding) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + Dialect &dialect = aEncoding.getDialect(); + auto interface = cast(&dialect); + return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, + bEncoding); +} + +// //-- MakeRangeOp -- +// OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { +// // make_range(start, start + 1) -> constant(start) +// if (adaptor.getStart() + 1 == adaptor.getEnd()) { +// auto shapedType = cast(getType()); +// return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); +// } +// return {}; +// } + +LogicalResult MakeRangeOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start > end) { + return this->emitOpError() << "start must be less than or equal to end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start != ty.getShape()[0]) { + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must match size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32)) { + return this->emitOpError() << "returned tensor must have i32 elements"; + } + return success(); +} + +//-- ReduceOp -- +static LogicalResult +inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy, + int axis, SmallVectorImpl &inferredReturnTypes) { + auto retShape = argTy.getShape().vec(); + retShape.erase(retShape.begin() + axis); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(retEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = + dyn_cast(&dialect); + if (inferLayoutInterface + ->inferReduceOpEncoding(argEncoding, axis, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return failure(); + } + } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +void ReduceOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis) { + SmallVector inferredReturnTypes; + for (unsigned i = 0; i < operands.size(); ++i) { + auto argTy = cast(operands[i].getType()); + auto retEltTy = argTy.getElementType(); + (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); + } + + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + +LogicalResult ReduceOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + for (auto arg : operands) { + auto argTy = cast(arg.getType()); + auto retEltTy = argTy.getElementType(); + if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +// Helpers for Reductions and Scans +template LogicalResult verifyReduceScan(Op &op) { + if (op.getOperands().empty()) { + return op.emitOpError() << "must have at least 1 operand"; + } + if (op.getNumOperands() != op.getNumResults()) { + return op.emitOpError() << "must have the same number of inputs as outputs"; + } + + auto getElementType = [](Type ty) { + if (auto tensorType = dyn_cast(ty)) { + return tensorType.getElementType(); + } + return ty; + }; + + for (auto [opElemTy, resTy] : + llvm::zip(op.getElementTypes(), op.getResultTypes())) { + if (opElemTy != getElementType(resTy)) { + return op.emitOpError() << "operand types and result types must agree"; + } + } + return success(); +} + +template +static LogicalResult verifyRegionsImpl(Op &op) { + auto argElementTypes = op.getElementTypes(); + const auto &operands = op.getOperands(); + const auto numArgs = 2 * operands.size(); + auto &block = *op.getBody(); + if (block.getNumArguments() != numArgs) { + return op.emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = argElementTypes[i % operands.size()]; + if (blockArgTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return op.emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != operands.size()) { + return op.emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return success(); +} + +static llvm::SmallVector +getInputTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcTys; + srcTys.reserve(operands.size()); + for (const auto &ty : operands.getTypes()) { + srcTys.push_back(cast(ty)); + } + return srcTys; +} + +static llvm::SmallVector +getElementTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(operands.size()); + for (const auto &op : operands) { + srcElemTys.push_back(cast(op.getType()).getElementType()); + } + return srcElemTys; +} + +LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ReduceOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ReduceOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + +//-- ScanOp -- +void ScanOp::build(OpBuilder &builder, OperationState &state, + ValueRange operands, int axis, bool reverse) { + SmallVector inferredReturnTypes; + state.addAttribute("reverse", builder.getBoolAttr(reverse)); + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); +} + +LogicalResult +ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + for (auto arg : operands) + inferredReturnTypes.push_back(arg.getType()); + return success(); +} + +LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ScanOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ScanOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ScanOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } + +//-- SplatOp -- +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getSrc(); + if (!value) + return {}; + auto shapedType = cast(getType()); + auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); + return ret; +} + +//-- ExpandDimsOp -- +LogicalResult ExpandDimsOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // infer shape + auto arg = operands[0]; + auto argTy = cast(arg.getType()); + auto retShape = argTy.getShape().vec(); + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + retShape.insert(retShape.begin() + axis, 1); + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc) + .failed()) + return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); + } + // create type + auto argEltTy = argTy.getElementType(); + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, argEltTy, retEncoding)); + return success(); +} + +LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + // expand_dims(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) + // + // On its own this doesn't do much, but consider + // broadcast(expand_dims(broadcast)) + // -> broadcast(broadcast(expand_dims)) + // -> broadcast(expand_dims) + if (auto broadcast = dyn_cast(definingOp)) { + auto src = broadcast.getSrc(); + auto srcTy = src.getType(); + SmallVector newExpandShape(srcTy.getShape()); + newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); + + // Infer the encoding of the new expand op, if encodings are present. + Attribute newExpandEnc; + if (auto srcEnc = srcTy.getEncoding()) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc, + op.getLoc()) + .failed()) { + return emitOptionalError(op.getLoc(), + "failed to infer layout for ExpandDimsOp"); + } + } + + auto newExpandTy = RankedTensorType::get( + newExpandShape, srcTy.getElementType(), newExpandEnc); + auto newExpand = rewriter.create(op.getLoc(), newExpandTy, + src, op.getAxis()); + auto newBroadcast = rewriter.create( + broadcast.getLoc(), op.getType(), newExpand.getResult()); + rewriter.replaceOp(op, {newBroadcast.getResult()}); + return success(); + } + + return failure(); +} + +template +static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { + if (!value) + return {}; + + auto shapedType = cast(op.getType()); + if (auto denseElemsAttr = dyn_cast(value)) { + if (denseElemsAttr.isSplat()) { + return denseElemsAttr.resizeSplat(shapedType); + } else { + return denseElemsAttr.reshape(shapedType); + } + } + return {}; +} + +OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +//-- ReshapeOp -- +template +LogicalResult canonicalizeViewOrBroadcast(OpType op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + + // view(view) -> view + if (auto parentView = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, TypeRange({op.getType()}), + parentView->getOperands(), + parentView->getAttrs()); + return success(); + } + + // view(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + + return failure(); +} + +LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { + if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + return failure(); + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + return foldViewLikeOp(*this, adaptor.getSrc()); +} + +LogicalResult ReshapeOp::verify() { + auto dstTy = getType(); + auto srcTy = getSrc().getType(); + if (getType().getNumElements() != srcTy.getNumElements()) { + return emitError( + "number of src and dst elements of reshape must be the same"); + } + + Attribute srcEnc = srcTy.getEncoding(); + Attribute dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("Op requires that either (a) src and dst both have " + "encodings, or (b) neither does."); + } + + if (srcEnc && !getAllowReorder()) { + Attribute inferredDstEnc; + if (cast(&srcEnc.getDialect()) + ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, + dstTy.getShape(), inferredDstEnc, + getLoc()) + .failed()) { + return emitError("This reshape is impossible without reordering, but " + "reordering is not allowed. Try choosing a different " + "encoding for the input tensor (or allow reordering)."); + } + if (inferredDstEnc != dstEnc) { + return emitError("Expected result encoding ") + << inferredDstEnc << " but was " << dstEnc; + } + } + + return success(); +} + +//-- FpToFpOp -- +LogicalResult FpToFpOp::verify() { + auto dstType = getType().getElementType(); + auto srcType = getSrc().getType().getElementType(); + if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && + (!getRounding().has_value())) { + return emitError("Rounding mode is required for FP downcast"); + } + return success(); +} + +//-- BroadcastOp -- +LogicalResult BroadcastOp::canonicalize(BroadcastOp op, + PatternRewriter &rewriter) { + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + auto value = adaptor.getSrc(); + if (!value) + return {}; + + if (auto denseElemsAttr = dyn_cast(value)) { + auto shapedType = cast(getType()); + return denseElemsAttr.resizeSplat(shapedType); + } + return {}; +} + +//-- MakeTensorPtrOp -- +void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ValueRange offsets, ArrayRef tensorShape, + ArrayRef order) { + // Get pointer type from `base` + auto pointerType = cast(base.getType()); + assert(pointerType != nullptr); + + // Build type `tt.ptr>` + auto tensorType = RankedTensorType::get( + SmallVector(tensorShape.begin(), tensorShape.end()), + pointerType.getPointeeType()); + auto result = PointerType::get(tensorType, 1); + + return build(builder, state, result, base, shape, strides, offsets, + builder.getDenseI32ArrayAttr(order)); +} + +// The following ops, including `call`, `func`, and `return` are copied and +// modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp +// We could revert it back once MLIR has a better inliner interface. +//-- FuncOp -- +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void FuncOp::print(OpAsmPrinter &printer) { + function_interface_impl::printFunctionOp( + printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +// -- CallOp -- +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this).getProperties().callee; + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +// -- ReturnOp -- +LogicalResult ReturnOp::verify() { + auto function = cast((*this)->getParentOp()); + + // The operand number and types must match the function signature. + const auto &results = function.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing function (@" + << function.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match function result type (" + << results[i] << ")" + << " in function @" << function.getName(); + + return success(); +} + +// -- JoinOp -- +LogicalResult +JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 2); + assert(operands[0].getType() == operands[1].getType()); + assert(isa(operands[0].getType())); + assert(isa(operands[1].getType())); + + Value lhs = operands[0]; + Value rhs = operands[1]; + auto srcTy = cast(lhs.getType()); + + SmallVector retShape(srcTy.getShape()); + retShape.push_back(2); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferJoinOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); + return success(); +} + +// -- SplitOp -- +LogicalResult SplitOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 1); + assert(isa(operands[0].getType())); + + Value src = operands[0]; + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + + if (srcShape.empty() || srcShape.back() != 2) { + return emitOptionalError(location, + "last dimension of input tensor must be 2"); + } + ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferSplitOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); + inferredReturnTypes.push_back(retTy); + inferredReturnTypes.push_back(retTy); + return success(); +} + +// -- ElementwiseInlineAsmOp -- +void ElementwiseInlineAsmOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +LogicalResult ElementwiseInlineAsmOp::verify() { + if (getNumOperands() >= 1) { + auto tensorType = dyn_cast(getOperand(0).getType()); + size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; + if (numInputElems % this->getPackedElement() != 0) { + return emitError("number of input elements ") + << numInputElems + << " must be a multiple of the op's packed_element attribute, " + << getPackedElement(); + } + } + return success(); +} + +// -- ExternElementwiseOp -- +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/Triton/IR/Traits.cpp b/third_party/xpu/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 000000000..1345718ee --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,247 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +#include + +#include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +static LogicalResult verifySameEncoding(Type typeA, Type typeB, + bool allowTensorPointerType) { + // TODO(Keren): the allowTensorPointerType argument is a hack to allow. + // The type checking code is kind of a mess with the current design. + auto getEncoding = [=](Type type) -> Attribute { + Attribute ret; + if (auto tensorType = dyn_cast(type)) { + ret = tensorType.getEncoding(); + } + if (!allowTensorPointerType) { + assert(!triton::isTensorPointerType(type)); + } + return ret; + }; + auto encodingA = getEncoding(typeA); + auto encodingB = getEncoding(typeB); + if (!encodingA || !encodingB) + return success(); + return encodingA == encodingB ? success() : failure(); +} + +LogicalResult +OpTrait::impl::verifySameOperandsEncoding(Operation *op, + bool allowTensorPointerType) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) + return op->emitOpError() << "requires the same encoding for all operands"; + + return success(); +} + +LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( + Operation *op, bool allowTensorPointerType) { + if (op->getNumOperands() == 0) + return success(); + + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto type = op->getOperand(0).getType(); + for (auto resultType : op->getResultTypes()) + if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) + return op->emitOpError() + << "requires the same encoding for all operands and results"; + + return verifySameOperandsEncoding(op, allowTensorPointerType); +} + +LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + //===-------------------- For Triton XPU -----------------------===// + // Triton XPU Don't need the power-of-two limitation + // if ((numElements & (numElements - 1)) != 0) + // return op->emitError("Number of elements must be power-of-two, but + // ") + // << *op << " doesn't follow the rule (" << numElements << ")" + // << " elements"; + //===-----------------------------------------------------------===// + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + //===-------------------- For Triton XPU -----------------------===// + // Triton XPU Don't need the power-of-two limitation + // if ((numElements & (numElements - 1)) != 0) + // return op->emitError("Number of elements must be power-of-two, but + // ") + // << *op << " doesn't follow the rule (" << numElements << ")" + // << " elements"; + //===-----------------------------------------------------------===// + } + } + return success(); +} + +// Check that the Triton layouts on op's operands and return types are valid. +// For example, we check that the number of warps per block in a Triton GPU +// blocked layout matches that of its module. +// +// It's a little weird to check these properties of a layout only when the +// layout is used in an op, since most of the properties don't actually depend +// on the op. They do depend on the *module*, though, and a layout is attached +// to a module only by virtue of being used in one of the module's ops. +LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { + auto module = op->getParentOfType(); + auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { + // Only ranked tensors can have layouts. + auto rankedTy = dyn_cast(val.getType()); + if (!rankedTy) + return success(); + + mlir::Attribute layout = rankedTy.getEncoding(); + if (!layout) + return success(); + + if (isa(layout)) + return makeErr() << "Shared layout is not allowed on tensor type."; + // TODO(jlebar): Currently this only checks blocked layouts, but other + // layouts also have invariants! + + // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. + if (auto blocked = dyn_cast(layout)) { + // A different verifier should have checked that the layout itself is + // valid, including that threads-per-warp has the same rank as + // warps-per-block etc. + auto layoutRank = blocked.getThreadsPerWarp().size(); + if (layoutRank != rankedTy.getRank()) { + return makeErr() << layout << ".\nLayout has rank " << layoutRank + << ", but the tensor it's attached to has rank " + << rankedTy.getRank() << "."; + } + + int moduleThreadsPerWarp = + ttg::TritonGPUDialect::getThreadsPerWarp(module); + int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); + if (layoutThreadsPerWarp != moduleThreadsPerWarp) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutThreadsPerWarp + << " threads per warp, but the module specifies " + << moduleThreadsPerWarp << " threads per warp."; + } + + int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module); + int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); + if (layoutWarpsPerCTA != moduleWarpsPerCTA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutWarpsPerCTA + << " warps per CTA, but the module specifies " + << moduleWarpsPerCTA << " warps per CTA."; + } + + if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { + int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module); + int64_t layoutCTAsPerCGA = + product(blocked.getCTALayout().getCTAsPerCGA()); + if (layoutCTAsPerCGA != moduleCTAsPerCGA) { + return makeErr() << layout << ".\nLayout has a total of " + << layoutCTAsPerCGA + << " CTAs per CGA, but the module specifies " + << moduleCTAsPerCGA << " CTAs per CGA."; + } + } + } + + return success(); + }; + + for (size_t i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + auto err = checkLayout(operand, [&]() { + // Stringify the operand using `printAsOperand`. This prints e.g. "%42" + // rather than the full definition. + std::string operandStr; + llvm::raw_string_ostream os(operandStr); + // If we don't assume verified, dump() will recursively call this + // function! + operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); + + return op->emitError("Operand ") + << i << " (" << operand << ") has an invalid layout: "; + }); + if (!err.succeeded()) + return err; + } + + for (size_t i = 0; i < op->getNumResults(); i++) { + auto result = op->getResult(i); + auto err = checkLayout(result, [&]() { + if (op->getNumResults() == 1) { + return op->emitError("Result has an invalid layout: "); + } else { + return op->emitError("Result ") << i << " has an invalid layout: "; + } + }); + if (!err.succeeded()) + return err; + } + + return success(); +} + +static ArrayRef getTypeShape(Type type) { + auto rankedType = dyn_cast(type); + if (auto ptrType = dyn_cast(type)) + rankedType = dyn_cast(ptrType.getPointeeType()); + return rankedType ? rankedType.getShape() : ArrayRef(); +} + +LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() << "requires the same shape for all operands"; + + return success(); +} + +LogicalResult +OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { + if (failed(verifyAtLeastNOperands(op, 1)) || + failed(verifyAtLeastNResults(op, 1))) + return failure(); + + auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); + for (auto type : op->getResultTypes()) + if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) + return op->emitOpError() + << "requires the same shape for all operands and results"; + + return verifySameLoadStoreOperandsShape(op); +} diff --git a/third_party/xpu/lib/Dialect/Triton/IR/Types.cpp b/third_party/xpu/lib/Dialect/Triton/IR/Types.cpp new file mode 100644 index 000000000..0e1df5b74 --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/IR/Types.cpp @@ -0,0 +1,171 @@ +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void TritonDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/Triton/IR/Types.cpp.inc" + >(); +} + +Type PointerType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + Type pointeeType; + if (parser.parseType(pointeeType)) + return Type(); + + int addressSpace = 1; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseInteger(addressSpace)) + return Type(); + } + + if (parser.parseGreater()) + return Type(); + + return PointerType::get(pointeeType, addressSpace); +} + +void PointerType::print(AsmPrinter &printer) const { + if (getAddressSpace() == 1) { + printer << "<" << getPointeeType() << ">"; + } else { + printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; + } +} + +static constexpr llvm::StringRef kMutableMemory = "mutable"; + +Type MemDescType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false)) + return Type(); + + // Parse the element type. + Type elementType; + if (parser.parseType(elementType)) + return Type(); + + Attribute encoding; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(encoding)) + return Type(); + } + bool mutableMemory = false; + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseOptionalKeyword(kMutableMemory)) + return Type(); + mutableMemory = true; + } + if (parser.parseGreater()) + return Type(); + + return MemDescType::get(parser.getContext(), dimensions, elementType, + encoding, mutableMemory); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + for (auto dim : getShape()) + printer << dim << "x"; + printer << getElementType(); + if (getEncoding()) + printer << ", " << getEncoding(); + if (getMutableMemory()) + printer << ", " << kMutableMemory; + printer << ">"; +} + +namespace mlir { + +namespace triton { + +unsigned getPointeeBitWidth(Type type) { + auto pointeeType = getPointeeType(type); + if (auto tensorTy = dyn_cast(pointeeType)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + return pointeeType.getIntOrFloatBitWidth(); +} + +Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i1Type, + tensorTy.getEncoding()); + return i1Type; +} + +Type getPointeeType(Type type) { + if (auto tensorTy = dyn_cast(type)) { + // Tensor of pointers + auto shape = tensorTy.getShape(); + auto ptrType = dyn_cast(tensorTy.getElementType()); + Type pointeeType = ptrType.getPointeeType(); + return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + // scalar pointer + Type pointeeType = ptrType.getPointeeType(); + return pointeeType; + } + return type; +} + +Type getI32SameShape(Type type) { + auto i32Type = IntegerType::get(type.getContext(), 32); + if (auto tensorTy = dyn_cast(type)) + return RankedTensorType::get(tensorTy.getShape(), i32Type, + tensorTy.getEncoding()); + return i32Type; +} + +Type getPointerTypeSameShape(Type type) { + if (auto tensorTy = dyn_cast(type)) { + Type elementType = tensorTy.getElementType(); + auto shape = tensorTy.getShape(); + PointerType ptrType = PointerType::get(elementType, 1); + return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding()); + } else { + return PointerType::get(type, 1); + } +} + +Type getPointerType(Type type) { return PointerType::get(type, 1); } + +bool isTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + return isa(ptrType.getPointeeType()); + return false; +} + +bool isTensorOrTensorPointerType(Type type) { + return isa(type) || isTensorPointerType(type); +} + +Type getElementTypeOfTensorPointerType(Type type) { + if (auto ptrType = dyn_cast(type)) + if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) + return tensorTy.getElementType(); + return {}; +} + +} // namespace triton + +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/Triton/Transforms/CMakeLists.txt b/third_party/xpu/lib/Dialect/Triton/Transforms/CMakeLists.txt new file mode 100644 index 000000000..298398750 --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -0,0 +1,18 @@ +set(LLVM_TARGET_DEFINITIONS Combine.td) +mlir_tablegen(TritonCombine.inc -gen-rewriters) +add_public_tablegen_target(TritonCombineIncGen) + +add_triton_library(TritonTransforms + Combine.cpp + ReorderBroadcast.cpp + RewriteTensorPointer.cpp + + DEPENDS + TritonTransformsIncGen + TritonCombineIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTransformUtils + TritonIR +) diff --git a/third_party/xpu/lib/Dialect/Triton/Transforms/Combine.cpp b/third_party/xpu/lib/Dialect/Triton/Transforms/Combine.cpp new file mode 100644 index 000000000..c5d638754 --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/Transforms/Combine.cpp @@ -0,0 +1,255 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace mlir::triton { +namespace { + +bool isZero(Value val) { + if (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())) + return true; + // broadcast(constant_0) + if (auto bc = val.getDefiningOp()) { + if (matchPattern(bc.getSrc(), m_Zero()) || + matchPattern(bc.getSrc(), m_AnyZeroFloat())) + return true; + } + return false; +} + +bool isBroadcastConstantCombinable(Attribute value) { + if (auto denseValue = dyn_cast(value)) { + return denseValue.isSplat(); + } + return isa(value); +} + +DenseElementsAttr getConstantValue(Builder &builder, Attribute value, + Value bcast_res) { + auto resType = cast(bcast_res.getType()); + DenseElementsAttr res; + if (auto denseValue = dyn_cast(value)) { + res = + DenseElementsAttr::get(resType, denseValue.getSplatValue()); + } else { + res = DenseElementsAttr::get(resType, value); + } + return res; +} + +bool isAddPtrOffsetCombinable(Value first, Value second) { + auto GetConstantIntValue = [](Value val) -> std::optional { + DenseElementsAttr constAttr; + auto defOp = val.getDefiningOp(); + if (defOp) { + if (auto splatOp = llvm::dyn_cast(defOp)) + val = splatOp.getSrc(); + else if (matchPattern(defOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto attr = constAttr.getSplatValue(); + // Check IntegerAttr + if (auto intAttr = dyn_cast_or_null(attr)) + return intAttr.getValue(); + } + } + + // Check constant value. + llvm::APInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal; + + return std::nullopt; + }; + + if (first.getType() == second.getType()) { + // Whether bitwidth of element type is equal to pointer + if (getElementTypeOrSelf(first.getType()).getIntOrFloatBitWidth() == 64) + return true; + + // first + second does not overflow + auto firstVal = GetConstantIntValue(first); + auto secondVal = GetConstantIntValue(second); + if (firstVal && secondVal) { + bool overflow = false; + auto resVal = firstVal->sadd_ov(*secondVal, overflow); + return !overflow; + } + } + return false; +} + +// TODO(csigg): remove after next LLVM integrate. +using FastMathFlags = arith::FastMathFlags; + +#include "TritonCombine.inc" + +// select(cond, load(ptrs, splat(cond), ???), other) +// => load(ptrs, splat(cond), other) +class CombineSelectMaskedLoadPattern : public RewritePattern { +public: + CombineSelectMaskedLoadPattern(MLIRContext *context) + : RewritePattern(arith::SelectOp::getOperationName(), 3, context, + {LoadOp::getOperationName()}) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto selectOp = llvm::dyn_cast(op); + if (!selectOp) + return failure(); + + Value trueValue = selectOp.getTrueValue(); + Value falseValue = selectOp.getFalseValue(); + Value condSelect = selectOp.getCondition(); + + auto *loadOpCandidate = trueValue.getDefiningOp(); + auto loadOp = llvm::dyn_cast_or_null(loadOpCandidate); + if (!loadOp) + return failure(); + + Value mask = loadOp.getMask(); + if (!mask) + return failure(); + + auto *splatOpCandidate = mask.getDefiningOp(); + auto splatOp = llvm::dyn_cast_or_null(splatOpCandidate); + if (!splatOp) + return failure(); + + auto splatCond = splatOp.getSrc(); + if (splatCond != condSelect) + return failure(); + + rewriter.replaceOpWithNewOp( + op, loadOp.getPtr(), loadOp.getMask(), /*other=*/falseValue, + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + return success(); + } +}; + +// sum(x[:, :, None] * y[None, :, :], 1) +// -> dot(x, y) +class CombineBroadcastMulReducePattern : public RewritePattern { +private: + static bool isAddF32(const Operation *op) { + if (auto addf = dyn_cast_or_null(op)) + return addf.getType().getIntOrFloatBitWidth() <= 32; + return false; + } + + static SmallVector getEqualIndices(ArrayRef x, + ArrayRef y) { + SmallVector res; + for (int i = 0; i < x.size(); ++i) + if (x[i] == y[i]) + res.push_back(i); + return res; + } + +public: + CombineBroadcastMulReducePattern(MLIRContext *context) + : RewritePattern(ReduceOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto reduceOp = llvm::dyn_cast(op); + if (!reduceOp) + return failure(); + // only support reduce with simple addition + Region &combineOp = reduceOp.getCombineOp(); + bool isReduceAdd = combineOp.hasOneBlock() && + combineOp.front().getOperations().size() == 2 && + isAddF32(&*combineOp.front().getOperations().begin()); + if (!isReduceAdd) + return failure(); + // operand of reduce has to be mul + auto mulOp = llvm::dyn_cast_or_null( + reduceOp.getOperand(0).getDefiningOp()); + if (!mulOp) + return failure(); + // mul operand has to be broadcast + auto broadcastLhsOp = llvm::dyn_cast_or_null( + mulOp.getOperand(0).getDefiningOp()); + if (!broadcastLhsOp) + return failure(); + auto broadcastRhsOp = llvm::dyn_cast_or_null( + mulOp.getOperand(1).getDefiningOp()); + if (!broadcastRhsOp) + return failure(); + // broadcast operand is expand dims + auto expandLhsOp = llvm::dyn_cast_or_null( + broadcastLhsOp.getSrc().getDefiningOp()); + if (!expandLhsOp) + return failure(); + auto expandRhsOp = llvm::dyn_cast_or_null( + broadcastRhsOp.getSrc().getDefiningOp()); + if (!expandRhsOp) + return failure(); + // get not-broadcast dimensions + int expandLhsAxis = expandLhsOp.getAxis(); + int expandRhsAxis = expandRhsOp.getAxis(); + if (expandLhsAxis != 2 || expandRhsAxis != 0) + return failure(); + auto broadcastLhsShape = + cast(broadcastLhsOp.getType()).getShape(); + auto broadcastRhsShape = + cast(broadcastLhsOp.getType()).getShape(); + if (broadcastLhsShape[2] < 16 || broadcastRhsShape[0] < 16) + return failure(); + Type newAccType = RankedTensorType::get( + {broadcastLhsShape[0], broadcastRhsShape[2]}, + cast(broadcastLhsOp.getSrc().getType()).getElementType()); + rewriter.setInsertionPoint(op); + auto newAcc = rewriter.create( + op->getLoc(), newAccType, + rewriter.create(op->getLoc(), + rewriter.getF32FloatAttr(0))); + rewriter.replaceOpWithNewOp(op, expandLhsOp.getSrc(), + expandRhsOp.getSrc(), newAcc, + InputPrecision::TF32, 0); + return success(); + } +}; + +class CombineOpsPass : public TritonCombineOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + // Dot Add %{ + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + // %} + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // anonymous namespace + +std::unique_ptr createCombineOpsPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/third_party/xpu/lib/Dialect/Triton/Transforms/Combine.td b/third_party/xpu/lib/Dialect/Triton/Transforms/Combine.td new file mode 100644 index 000000000..5a2fcecfa --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/Transforms/Combine.td @@ -0,0 +1,54 @@ +#ifndef TRITON_PATTERNS +#define TRITON_PATTERNS + +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "triton/Dialect/Triton/IR/TritonOps.td" +include "mlir/IR/PatternBase.td" + + +// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) + +// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +def CombineDotAddIPattern : Pat< + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFPattern : Pat< + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +def CombineDotAddIRevPattern : Pat< + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $overflow), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; +def CombineDotAddFRevPattern : Pat< + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), + (ConstrainthasOneUse()">, "dot result has a single use">)]>; + +// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1)) +// Note: leave (sub %c0, %c0) canceling to ArithDialect +// (ref: ArithCanonicalization.td) +defvar DefOverflow = ConstantEnumCase; +def CombineAddPtrPattern : Pat< + (TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1), + (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)), + [(Constraint> $idx0, $idx1)]>; + +// broadcast(cst) => cst +def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; +def CombineBroadcastConstantPattern : Pat< + (TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)), + (Arith_ConstantOp (getConstantValue $value, $bcast_res), (location $bcast_res)), + [(Constraint> $value)]>; + +#endif diff --git a/third_party/xpu/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/third_party/xpu/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp new file mode 100644 index 000000000..43479a3d9 --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -0,0 +1,247 @@ +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +// TODO(jlebar): Move this and all other generatede code into namespace +// mlir::triton. +#define GEN_PASS_DEF_TRITONREORDERBROADCAST +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace mlir::triton { +namespace { + +Operation *cloneWithNewArgsAndResultTypes(PatternRewriter &rewriter, + Operation *op, ValueRange newOperands, + TypeRange newTypes) { + OperationState newElementwiseState(op->getLoc(), op->getName()); + newElementwiseState.addOperands(newOperands); + newElementwiseState.addTypes(newTypes); + newElementwiseState.addAttributes(op->getAttrs()); + return rewriter.create(newElementwiseState); +} + +bool isSplat(Operation *op) { + if (auto splatOp = llvm::dyn_cast(op)) { + return true; + } + DenseElementsAttr constAttr; + return (matchPattern(op, m_Constant(&constAttr)) && constAttr.isSplat()); +} + +// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) +struct MoveSplatAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveSplatAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult match(Operation *op) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + for (auto operand : op->getOperands()) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + + if (!isSplat(definingOp)) { + return failure(); + } + } + return success(op->getNumOperands() > 0); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto operands = op->getOperands(); + + llvm::SmallVector scalarOperands(operands.size()); + for (unsigned iOp = 0; iOp < operands.size(); ++iOp) { + auto definingOp = operands[iOp].getDefiningOp(); + + DenseElementsAttr constAttr; + if (auto splatOp = llvm::dyn_cast(definingOp)) { + scalarOperands[iOp] = splatOp.getSrc(); + } else if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto value = constAttr.getSplatValue(); + scalarOperands[iOp] = arith::ConstantOp::materialize( + rewriter, value, constAttr.getElementType(), loc); + } else { + llvm_unreachable("Expected a splat"); + } + } + + auto resultTypes = op->getResultTypes(); + llvm::SmallVector scalarResultTys; + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + scalarResultTys.push_back(elemTy); + } + + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, scalarOperands, + scalarResultTys); + + for (unsigned iRes = 0; iRes < resultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + } +}; + +// elementwise(broadcast(a)) => broadcast(elementwise(a)) +// This also generalizes to multiple arguments when the rest are splat-like +// Not handled: multiple broadcasted arguments +struct MoveBroadcastAfterElementwisePattern + : public OpTraitRewritePattern { + + MoveBroadcastAfterElementwisePattern(MLIRContext *context) + : OpTraitRewritePattern(context) {} + + LogicalResult match(Operation *op) const override { + if (!isMemoryEffectFree(op)) { + return failure(); + } + + auto operands = op->getOperands(); + bool seenBroadcast = false; + ArrayRef srcShape; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (!definingOp) { + return failure(); + } + auto getSrcShape = [](BroadcastOp b) { + return b.getSrc().getType().getShape(); + }; + if (auto broadcastOp = llvm::dyn_cast(definingOp)) { + if (!seenBroadcast) { + seenBroadcast = true; + srcShape = getSrcShape(broadcastOp); + } else if (srcShape != getSrcShape(broadcastOp)) { + // If the broadcast have different types we cannot re-order. + return failure(); + } + } else if (!isSplat(definingOp)) { + // Not splat or broadcast + return failure(); + } + } + return success(seenBroadcast); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Find broadcast op + auto operands = op->getOperands(); + BroadcastOp broadcastOp; + for (auto operand : operands) { + broadcastOp = operand.getDefiningOp(); + if (broadcastOp) { + break; + } + } + + auto srcTy = broadcastOp.getSrc().getType(); + auto srcShape = srcTy.getShape(); + auto srcEncoding = srcTy.getEncoding(); + + // Reshape operands to match srcShape + llvm::SmallVector newOperands; + for (auto operand : operands) { + auto definingOp = operand.getDefiningOp(); + if (auto broadcastSrcOp = llvm::dyn_cast(definingOp)) { + newOperands.push_back(broadcastSrcOp.getSrc()); + continue; + } + auto elemTy = + dyn_cast(operand.getType()).getElementType(); + auto newTy = RankedTensorType::get(srcShape, elemTy, srcEncoding); + if (auto splatOp = llvm::dyn_cast(definingOp)) { + auto newSplat = rewriter.create(loc, newTy, splatOp.getSrc()); + newOperands.push_back(newSplat); + continue; + } + DenseElementsAttr constAttr; + if (matchPattern(definingOp, m_Constant(&constAttr)) && + constAttr.isSplat()) { + auto scalarValue = constAttr.getSplatValue(); + auto splatValue = SplatElementsAttr::get(newTy, scalarValue); + auto newConstant = + rewriter.create(loc, newTy, splatValue); + newOperands.push_back(newConstant); + continue; + } + llvm_unreachable("Expected broadcast or splat"); + } + + // Reshape results to match srcShape + llvm::SmallVector newResultTypes; + auto resultTypes = op->getResultTypes(); + for (auto resultTy : resultTypes) { + auto elemTy = dyn_cast(resultTy).getElementType(); + newResultTypes.push_back( + RankedTensorType::get(srcShape, elemTy, srcEncoding)); + } + + // Create new op and broadcast results + auto newOp = cloneWithNewArgsAndResultTypes(rewriter, op, newOperands, + newResultTypes); + for (unsigned iRes = 0; iRes < newResultTypes.size(); ++iRes) { + auto newResult = rewriter.create(loc, resultTypes[iRes], + newOp->getResult(iRes)); + rewriter.replaceAllUsesWith(op->getResult(iRes), newResult); + } + } +}; + +template +class CanonicalizePattern : public OpRewritePattern { +public: + explicit CanonicalizePattern(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpType op, + PatternRewriter &rewriter) const override { + return OpType::canonicalize(op, rewriter); + } +}; + +class ReorderBroadcastPass + : public ::impl::TritonReorderBroadcastBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + ModuleOp m = getOperation(); + + patterns.add>(context); + patterns.add>(context); + // elementwise(broadcast(a)) => broadcast(elementwise(a)) + patterns.add(context); + // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) + patterns.add(context); + + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr createReorderBroadcastPass() { + return std::make_unique(); +} + +} // namespace mlir::triton diff --git a/third_party/xpu/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/third_party/xpu/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp new file mode 100644 index 000000000..52f4ba0b3 --- /dev/null +++ b/third_party/xpu/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -0,0 +1,572 @@ +#include +#include + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/Triton/Transforms/Passes.h.inc" + +namespace { + +/// An additional struct to record the meta information of operations +/// with tensor pointers +struct RewritedInfo { +private: + Value base; + SmallVector shape; + SmallVector strides; + SmallVector offsets; + ArrayRef tensorShape; + + // A cache to avoid generating the same offset with range + DenseMap cachedOffsetWithRange; + +public: + RewritedInfo() = default; + + RewritedInfo(const RewritedInfo &other) = default; + + RewritedInfo(Value base, const SmallVector &shape, + const SmallVector &strides, + const SmallVector &offsets, + const ArrayRef &tensorShape) + : base(base), shape(shape), strides(strides), offsets(offsets), + tensorShape(tensorShape) { + assert(shape.size() == strides.size() && shape.size() == offsets.size() && + shape.size() == tensorShape.size()); + } + + unsigned int length() const { return shape.size(); } + + Value getOffset(unsigned i) { return offsets[i]; } + + SmallVector getOffsets() { return offsets; } + + void setOffset(unsigned i, Value newOffset) { + offsets[i] = newOffset; + cachedOffsetWithRange.clear(); + } + + void setOffsets(const SmallVector &newOffsets) { + offsets = newOffsets; + cachedOffsetWithRange.clear(); + } + + Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc, + unsigned i) { + if (cachedOffsetWithRange.count(i)) + return cachedOffsetWithRange[i]; + + // Add range + auto indexI32RowType = + RankedTensorType::get({tensorShape[i]}, builder.getI32Type()); + auto indexRowType = + RankedTensorType::get({tensorShape[i]}, builder.getI64Type()); + Value splatOffset = + builder.create(loc, indexRowType, offsets[i]); + Value range = builder.create(loc, indexI32RowType, 0, + tensorShape[i]); + Value i64Range = builder.create(loc, indexRowType, range); + + // Expand dimensions + Value expandedResult = + builder.create(loc, splatOffset, i64Range); + for (int j = 0; j < tensorShape.size(); ++j) { + if (j == i) + continue; + expandedResult = + builder.create(loc, expandedResult, j); + } + + return cachedOffsetWithRange[i] = expandedResult; + } + + Value generatePtr(OpBuilder &builder, const Location &loc) { + assert(tensorShape.size() == offsets.size() && + tensorShape.size() == strides.size()); + auto indexTensorType = + RankedTensorType::get(tensorShape, builder.getI64Type()); + auto ptrType = cast(base.getType()); + auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType); + + // Generate offsets per dimension + Value ptr = builder.create(loc, ptrTensorType, base); + for (unsigned i = 0; i < tensorShape.size(); ++i) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // We must splat strides into the expanded shape not a row for retaining + // the divisibility information given by strides + Value splatStride = builder.create( + loc, offsetWithRange.getType(), strides[i]); + Value offsetWithStride = + builder.create(loc, offsetWithRange, splatStride); + Value broadcasted = builder.create( + loc, indexTensorType, offsetWithStride); + + // Add to the pointer + ptr = builder.create(loc, ptrTensorType, ptr, + broadcasted); + } + + return ptr; + } + + Value generateMask(OpBuilder &builder, const Location &loc, + const std::optional> &boundaryCheck) { + if (!boundaryCheck.has_value()) + return {}; + + // Generate mask per dimension + auto maskTensorType = + RankedTensorType::get(tensorShape, builder.getI1Type()); + Value mask; + for (auto i : boundaryCheck.value()) { + auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); + + // Compare with lower bound + Value lowerBound = builder.create( + loc, 0, builder.getI64Type()); + Value splatLowerBound = builder.create( + loc, offsetWithRange.getType(), lowerBound); + Value cmpLower = builder.create( + loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound); + + // Compare with upper bound + Value splatUpperBound = builder.create( + loc, offsetWithRange.getType(), shape[i]); + Value cmpUpper = builder.create( + loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound); + + // And and broadcast + Value andResult = builder.create(loc, cmpLower, cmpUpper); + Value broadcasted = + builder.create(loc, maskTensorType, andResult); + + // And up all results + if (!mask) { + mask = broadcasted; + } else { + mask = builder.create(loc, mask, broadcasted); + } + } + + return mask; + } + + Value generateOther(OpBuilder &builder, const Location &loc, + const std::optional &padding) { + if (!padding.has_value()) + return Value(); + + // Create element attribute + auto elementType = + cast(base.getType()).getPointeeType(); + auto otherTensorType = RankedTensorType::get(tensorShape, elementType); + + // Set zero padding value + TypedAttr attr = + elementType.isIntOrIndex() + ? cast(builder.getIntegerAttr(elementType, 0)) + : cast(builder.getFloatAttr(elementType, 0)); + + // Float NaN padding case + if (padding.value() == triton::PaddingOption::PAD_NAN) { + assert(!elementType.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + cast(attr).getValue().getSemantics()); + attr = builder.getFloatAttr(elementType, apNaN); + } + + // Create tensor + Value constant = builder.create(loc, attr); + return builder.create(loc, otherTensorType, constant); + } +}; + +} // namespace + +// TODO: this pass relies on assumptions of how block pointers are created and +// on pattern matches that walks the SSA links to find the base/strides. This is +// very fragile and to solve we should expose convert Ptr of tensor to a +// structure containins all values and not only offsets. +class RewriteTensorPointerPass + : public TritonRewriteTensorPointerBase { +private: + DenseMap rewritedInfo; + +public: + static bool needRewrite(Operation *op) { + return std::any_of(op->getOperands().begin(), op->getOperands().end(), + [](Value operand) { + return triton::isTensorPointerType(operand.getType()); + }); + } + + static SmallVector + generateNewOperands(const SmallVector &oldOperands, unsigned index, + const SmallVector &newValues) { + assert(index < oldOperands.size()); + SmallVector newOperands; + for (int i = 0; i < index; ++i) + newOperands.push_back(oldOperands[i]); + for (auto value : newValues) + newOperands.push_back(value); + for (auto i = index + 1; i < oldOperands.size(); ++i) + newOperands.push_back(oldOperands[i]); + return newOperands; + } + + Operation *rewriteMakeTensorPtrOp(OpBuilder &builder, + triton::MakeTensorPtrOp op, + std::stack &eraser) { + // Save info for later use + auto ptrType = cast(op.getType()); + auto tensorType = cast(ptrType.getPointeeType()); + + // Cast I32 offsets into I64 + SmallVector i64Offsets; + for (auto offset : op.getOffsets()) { + auto i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), offset); + i64Offsets.push_back(i64Offset); + } + + // Save information + rewritedInfo[op.getResult()] = + RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets, + tensorType.getShape()); + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteAdvanceOp(OpBuilder &builder, triton::AdvanceOp op, + std::stack &eraser) { + // Get info from previous results + assert(rewritedInfo.count(op.getPtr())); + auto info = rewritedInfo[op.getPtr()]; + + // Calculate new offsets + assert(info.length() == op.getOffsets().size()); + SmallVector newOffsets; + for (int i = 0; i < info.length(); ++i) { + Value i64Offset = builder.create( + op.getLoc(), builder.getI64Type(), op.getOffsets()[i]); + Value newOffset = builder.create( + op.getLoc(), info.getOffset(i), i64Offset); + newOffsets.push_back(newOffset); + } + + // Save info for later use + info.setOffsets(newOffsets); + rewritedInfo[op.getResult()] = info; + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op, + std::stack &eraser) { + assert(isa(op) || isa(op)); + + // We only have to rewrite load/stores with tensor pointers + auto ptr = op->getOperand(0); + if (!triton::isTensorPointerType(ptr.getType())) + return nullptr; + + // Get info from previous results + assert(rewritedInfo.count(ptr)); + auto info = rewritedInfo[ptr]; + + // Load/store with tensor pointers implicitly will check the bound while + // accessing memory, so we should set `mask` and `other` (according to the + // padding). Also note that load with tensor pointers do not have `mask` and + // `other` while building IR from Python AST + std::optional> boundaryCheck; + if (auto loadOp = dyn_cast(op)) { + assert(!loadOp.getMask() && !loadOp.getOther()); + boundaryCheck = loadOp.getBoundaryCheck(); + } else if (auto storeOp = dyn_cast(op)) { + assert(!storeOp.getMask()); + boundaryCheck = storeOp.getBoundaryCheck(); + } + + // Generate new `ptr`, `mask` and `other` + auto newPtr = info.generatePtr(builder, op->getLoc()); + auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); + Value newOther; + if (auto loadOp = dyn_cast(op)) + newOther = info.generateOther(builder, op->getLoc(), loadOp.getPadding()); + + // Create a new operation + if (auto loadOp = dyn_cast(op)) { + auto newResult = builder.create( + loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), + loadOp.getEvict(), loadOp.getIsVolatile()); + op->getResult(0).replaceAllUsesWith(newResult); + } else if (auto storeOp = dyn_cast(op)) { + builder.create(storeOp.getLoc(), newPtr, + storeOp.getValue(), newMask, + storeOp.getCache(), storeOp.getEvict()); + } + + // Erase the original operation + eraser.push(op); + return nullptr; + } + + Operation *rewriteIfOp(OpBuilder &builder, scf::IfOp op, + std::stack &eraser) { + auto thenYieldOp = op.thenYield(); + assert(op.getNumResults() == thenYieldOp.getNumOperands()); + SmallVector results = thenYieldOp.getOperands(); + + // get new result types + SmallVector newRetTypes; + bool needRewrite = false; + for (unsigned i = 0; i < results.size(); ++i) { + if (!triton::isTensorPointerType(results[i].getType())) { + newRetTypes.push_back(results[i].getType()); + continue; + } + needRewrite = true; + auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + newRetTypes.push_back(builder.getI64Type()); + } + } + if (!needRewrite) + return op; + // create and clone new IfOp + bool hasElse = !op.getElseRegion().empty(); + scf::IfOp newOp = builder.create(op.getLoc(), newRetTypes, + op.getCondition(), hasElse); + IRMapping mapping; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + mapping.map(op->getOperand(i), newOp->getOperand(i)); + } + auto rematerialize = [&](Block *block) { + for (Operation &opInIf : block->getOperations()) { + auto newOp = builder.clone(opInIf, mapping); + } + }; + builder.setInsertionPointToStart(newOp.thenBlock()); + rematerialize(op.thenBlock()); + if (hasElse) { + builder.setInsertionPointToStart(newOp.elseBlock()); + rematerialize(op.elseBlock()); + } + + // update rewritedInfo + unsigned oldResIdx = 0, newResIdx = 0; + while (oldResIdx < results.size()) { + if (!triton::isTensorPointerType(results[oldResIdx].getType())) { + oldResIdx++; + newResIdx++; + } else { + auto makeTensorPtrOp = getMakeTensorPtrOp(results[oldResIdx]); + assert(rewritedInfo.count(makeTensorPtrOp.getResult())); + auto info = rewritedInfo[makeTensorPtrOp.getResult()]; + for (unsigned j = 0; j < info.length(); ++j) { + info.setOffset(j, newOp->getResult(newResIdx++)); + } + rewritedInfo[op.getResult(oldResIdx)] = info; + oldResIdx++; + } + } + + eraser.push(op); + return newOp; + } + + Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, + std::stack &eraser) { + // Generate new iteration operands and set rewrited information + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; + ++i, ++oldI) { + if (!triton::isTensorPointerType(newIterOperands[i].getType())) + continue; + + // Expand the tensor pointer into offsets + assert(rewritedInfo.count(newIterOperands[i])); + auto info = rewritedInfo[newIterOperands[i]]; + newIterOperands = + generateNewOperands(newIterOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + + // Rebuild the loop type + auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), + op.getUpperBound(), op.getStep(), + newIterOperands); + + // Create value mapping. Note that for tensor pointers, we use identity + // mapping. It may refer to a value in the old loop, but we will rewrite it + // later + IRMapping mapping; + for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; + ++i, ++oldI) { + auto oldRegionIterArg = op.getRegionIterArg(oldI); + if (triton::isTensorPointerType(oldRegionIterArg.getType())) { + // Pass rewrited info inside + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + mapping.map(oldRegionIterArg, oldRegionIterArg); + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getRegionIterArg(i + j)); + rewritedInfo[oldRegionIterArg] = info; + i += info.length() - 1; + } else { + mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i)); + } + } + mapping.map(op.getInductionVar(), newForOp.getInductionVar()); + + // Clone body + builder.setInsertionPointToStart(newForOp.getBody()); + for (auto &opInFor : *op.getBody()) { + auto *newOp = builder.clone(opInFor, mapping); + for (unsigned i = 0; i < opInFor.getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + + // Replace later usages + assert(op.getNumResults() == op.getInitArgs().size()); + for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { + auto oldResult = op.getResult(oldI); + if (triton::isTensorPointerType(oldResult.getType())) { + // Pack new offsets into rewrited info + assert(rewritedInfo.count(oldIterOperands[oldI])); + auto info = rewritedInfo[oldIterOperands[oldI]]; + for (unsigned j = 0; j < info.length(); ++j) + info.setOffset(j, newForOp.getResult(i + j)); + i += info.length() - 1; + rewritedInfo[oldResult] = info; + } else { + oldResult.replaceAllUsesWith(newForOp.getResult(i)); + } + } + + // Erase later + eraser.push(op); + return newForOp; + } + + Operation *rewriteYieldOp(OpBuilder &builder, scf::YieldOp op, + std::stack &eraser) { + // Replace tensor pointers with offsets + SmallVector newOperands = op->getOperands(); + for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { + if (!triton::isTensorPointerType(newOperands[i].getType())) + continue; + + assert(rewritedInfo.count(newOperands[i])); + auto info = rewritedInfo[newOperands[i]]; + newOperands = generateNewOperands(newOperands, i, info.getOffsets()); + i += info.length() - 1; + size += info.length() - 1; + } + op->setOperands(newOperands); + + // No need to erase + return nullptr; + } + + Operation *rewriteOp(Operation *op, std::stack &eraser) { + OpBuilder builder(op); + + // Rewrite `make_tensor_ptr` and `advance` and make a tensor of pointers + // Rewriting functions return the next operation to visit, if there is no + // next one, simply return `nullptr` + std::pair rewrited; + if (auto makeTensorPtrOp = dyn_cast(op)) { + return rewriteMakeTensorPtrOp(builder, makeTensorPtrOp, eraser); + } else if (auto advanceOp = dyn_cast(op)) { + return rewriteAdvanceOp(builder, advanceOp, eraser); + } else if (isa(op) || isa(op)) { + return rewriteLoadStoreOp(builder, op, eraser); + } else if (op->getDialect()->getNamespace() == "scf" || + op->getDialect()->getNamespace() == "cf") { + if (auto ifOp = dyn_cast(op)) { + return rewriteIfOp(builder, ifOp, eraser); + } + if (!needRewrite(op)) + return op; + + if (auto forOp = dyn_cast(op)) { + return rewriteForOp(builder, forOp, eraser); + } else if (auto yieldOp = dyn_cast(op)) { + return rewriteYieldOp(builder, yieldOp, eraser); + } else { + llvm_unreachable("Currently we only support tensor pointer usages " + "inside a `scf::ForOp` or `scf::IfOp`, others such as " + "`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` " + "are not supported yet"); + } + } + + // Otherwise return the original one + return op; + } + + void visitOperation(Operation *op, std::stack &eraser) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + // We need an extra copy because erasing operations may break the + // iterator behavior + SmallVector blockCopy; + for (auto &nestedOp : block) + blockCopy.push_back(&nestedOp); + + // Rewrite and recursively visit + for (auto &nestedOp : blockCopy) { + if (auto newOp = rewriteOp(nestedOp, eraser)) + visitOperation(newOp, eraser); + } + } + } + } + + void runOnOperation() override { + // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because + // MLIR does not support one-multiple value mapping. For example, if we use + // `ConversionPatternRewriter`, we can not make a type converter, which + // converts `ptr` into multiple types `ptr<>, int64, int64, ...` + // (containing the base/offsets/strides...). What we can do is to convert + // `ptr` into a single type `Tuple, int64, int64, ...>`. But + // in this way, we also have to define `PackTuple` and `UnpackTuple` + // operations and make a canonicalization pass to optimize, which is much + // So here we recursively build the IR, to be specific, we have to rewrite + // `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, `tt.store`, + // `scf.for` (tensor pointer usages may be in a loop fashion) + std::stack eraser; + visitOperation(getOperation(), eraser); + + // The operation could not be erased during visit, because they may have + // later usages, so we erase after visit + rewritedInfo.clear(); + while (!eraser.empty()) { + auto op = eraser.top(); + eraser.pop(); + op->erase(); + } + } +}; + +std::unique_ptr triton::createRewriteTensorPointerPass() { + return std::make_unique(); +} diff --git a/third_party/xpu/lib/Dialect/TritonGPU/CMakeLists.txt b/third_party/xpu/lib/Dialect/TritonGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/xpu/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/xpu/lib/Dialect/TritonGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..baeb92b48 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonGPUIR + Dialect.cpp + LinearLayoutConversions.cpp + Types.cpp + + DEPENDS + TritonGPUTableGen + TritonGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRGPUDialect + TritonIR + TritonXPUIR + TritonTools +) diff --git a/third_party/xpu/lib/Dialect/TritonGPU/IR/Dialect.cpp b/third_party/xpu/lib/Dialect/TritonGPU/IR/Dialect.cpp new file mode 100644 index 000000000..ac1732b19 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -0,0 +1,2999 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/StrUtil.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/TypeSwitch.h" + +// Include TableGen'erated code +#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +// Utility +namespace mlir { +namespace triton { + +static Type getI1SameShapeFromTensorOrTensorPtr(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto tensorType = dyn_cast(type)) { + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); + } else if (auto ptrType = dyn_cast(type)) { + Type pointeeType = ptrType.getPointeeType(); + if (auto tensorType = dyn_cast(pointeeType)) { + return RankedTensorType::get(tensorType.getShape(), i1Type, + tensorType.getEncoding()); + } + } + return Type(); +} + +namespace gpu { + +// TODO: Inheritance of layout attributes +// so that all distributed layouts implement +// these utilities + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, + Type eltTy) { + if (auto tritonXPUAttr = + mlir::dyn_cast(layout)) { + return tritonXPUAttr.getTotalElemsPerThread(shape, eltTy); + } else if (auto tritonGPUAttr = mlir::dyn_cast(layout)) { + return tritonGPUAttr.getTotalElemsPerThread(shape, eltTy); + } else { + llvm::report_fatal_error("getTotalElemsPerThread not implemented"); + return 0; + } +} + +SmallVector getElemsPerThread(Attribute layout, + ArrayRef shape, Type eltTy) { + if (auto tritonXPUAttr = + mlir::dyn_cast(layout)) { + return tritonXPUAttr.getElemsPerThread(shape, eltTy); + } else if (auto tritonGPUAttr = mlir::dyn_cast(layout)) { + return tritonGPUAttr.getElemsPerThread(shape, eltTy); + } else { + llvm::report_fatal_error("getElemsPerThread not implemented"); + return SmallVector(); + } +} + +SmallVector getElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return SmallVector(1, 1); + auto tensorType = cast(type); + return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape(), + tensorType.getElementType()); +} + +unsigned getTotalElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return 1; + auto tensorType = cast(type); + return getTotalElemsPerThread(tensorType.getEncoding(), tensorType.getShape(), + tensorType.getElementType()); +} + +SmallVector getThreadsPerWarp(Attribute layout) { + if (auto distributedLayout = dyn_cast(layout)) { + return distributedLayout.getThreadsPerWarp(); + } else { + llvm::report_fatal_error("getThreadsPerWarp not implemented"); + return SmallVector(); + } +} + +unsigned getWarpSize(Attribute layout) { + unsigned size = 1; + auto threadsPerWarp = getThreadsPerWarp(layout); + for (auto e : threadsPerWarp) { + size *= e; + } + return size; +} + +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape) { + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentThreadsPerWarp = + getThreadsPerWarpWithUniqueData(parentLayout, parentShape); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + return threadsPerWarp; + } + auto threadsPerWarp = getThreadsPerWarp(layout); + assert(threadsPerWarp.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < threadsPerWarp.size(); i++) { + threadsPerWarp[i] = std::min(threadsPerWarp[i], tensorShape[i]); + } + + return threadsPerWarp; +} + +SmallVector getWarpsPerCTA(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getWarpsPerCTA(); + } + + llvm::report_fatal_error("getWarpsPerCTA not implemented"); + return SmallVector(); +} + +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape) { + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentWarpsPerCTA = + getWarpsPerCTAWithUniqueData(parentLayout, parentShape); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); + return warpsPerCTA; + } + auto warpsPerCTA = getWarpsPerCTA(layout); + assert(warpsPerCTA.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < warpsPerCTA.size(); i++) { + auto sizePerWarp = + getSizePerThread(layout)[i] * getThreadsPerWarp(layout)[i]; + auto maxWarpsPerDim = ceil(tensorShape[i], sizePerWarp); + warpsPerCTA[i] = std::min(warpsPerCTA[i], maxWarpsPerDim); + } + + return warpsPerCTA; +} + +SmallVector getSizePerThread(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getSizePerThread(); + } else { + llvm::report_fatal_error("getSizePerThread not implemented"); + return {}; + } +} + +SmallVector getContigPerThread(Attribute layout) { + if (auto distributedLayout = dyn_cast(layout)) { + return distributedLayout.getContigPerThread(); + } else { + llvm::report_fatal_error("getContigPerThread not implemented"); + return {}; + } +} + +SmallVector getUniqueContigPerThread(Attribute layout, + ArrayRef shape) { + // If slice layout, call recursively on parent layout, and drop + // sliced dim + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(shape); + auto parentUniqueContigPerThread = + getUniqueContigPerThread(parentLayout, parentShape); + parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() + + sliceLayout.getDim()); + return parentUniqueContigPerThread; + } + // Base case + auto rank = shape.size(); + SmallVector ret(rank); + auto contigPerThread = getContigPerThread(layout); + assert(contigPerThread.size() == rank && "Unexpected contigPerThread size"); + for (int d = 0; d < rank; ++d) { + ret[d] = std::min(shape[d], contigPerThread[d]); + } + return ret; +} + +SmallVector getShapePerCTATile(Attribute layout, + ArrayRef tensorShape) { + if (auto clusterLayout = + mlir::dyn_cast(layout)) { + SmallVector shape; + for (unsigned d = 0, n = clusterLayout.getOrder().size(); d < n; ++d) + shape.push_back(clusterLayout.getSizePerCore()[d] * + clusterLayout.getCoresPerGroup()[d] * + clusterLayout.getGroupsPerCluster()[d]); + return shape; + } else if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getShapePerCTATile(tensorShape); + } else { + llvm::report_fatal_error("getShapePerCTATile not implemented"); + return SmallVector(); + } +} + +bool isExpensiveView(Type srcType, Type dstType) { + return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); +} + +/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. + * Erase dim and decrease all values larger than dim by 1. + * Example: order = [0, 2, 4, 3, 1], dim = 2 + * resOrder = [0, 3, 2, 1] + */ +static SmallVector eraseOrder(ArrayRef order, + unsigned dim) { + unsigned rank = order.size(); + assert(dim < rank && "Invalid dim to erase"); + SmallVector resOrder; + for (unsigned i : order) + if (i < dim) + resOrder.push_back(i); + else if (i > dim) + resOrder.push_back(i - 1); + return resOrder; +} + +SmallVector getWarpOrder(Attribute layout) { + auto order = getOrder(layout); + if (auto mmaLayout = dyn_cast(layout)) { + if (mmaLayout.isHopper()) { + // Hopper MMA instructions force a warp order of [0, 1]. See docs: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 + auto it = std::find(order.begin(), order.end(), 0); + order.erase(it); + order.insert(order.begin(), 0); + } + } + return order; +} + +SmallVector getOrder(Attribute layout) { + if (auto clusterLayout = dyn_cast(layout)) { + return SmallVector(clusterLayout.getOrder().begin(), + clusterLayout.getOrder().end()); + } else if (auto blockedLayout = dyn_cast(layout)) { + return SmallVector(blockedLayout.getOrder().begin(), + blockedLayout.getOrder().end()); + } else if (auto mmaLayout = dyn_cast(layout)) { + auto distributedLayout = cast(layout); + auto rank = distributedLayout.getWarpsPerCTA().size(); + SmallVector order(rank); + for (auto i = 0; i < rank; ++i) + order[i] = rank - 1 - i; + if (auto mfmaLayout = dyn_cast(layout)) { + if (mfmaLayout.getIsTransposed()) { + std::swap(order[rank - 2], order[rank - 1]); + } + } + return order; + } else if (auto dotLayout = dyn_cast(layout)) { + auto rank = getWarpsPerCTA(dotLayout.getParent()).size(); + SmallVector order(rank); + for (auto i = 0; i < rank; ++i) + order[i] = rank - 1 - i; + return order; + } else if (auto sliceLayout = dyn_cast(layout)) { + SmallVector parentOrder = getOrder(sliceLayout.getParent()); + unsigned dim = sliceLayout.getDim(); + SmallVector order; + for (unsigned d : parentOrder) { + if (d == dim) + continue; + else if (d > dim) + order.push_back(d - 1); + else + order.push_back(d); + } + return order; + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + return SmallVector(sharedLayout.getOrder().begin(), + sharedLayout.getOrder().end()); + } else { + llvm::report_fatal_error("Unimplemented usage of getOrder"); + } + return {}; +}; + +CTALayoutAttr getCTALayout(Attribute layout) { + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return CTALayoutAttr::get( + layout.getContext(), getCTAsPerCGA(distributedLayout), + getCTASplitNum(distributedLayout), getCTAOrder(distributedLayout)); + } else if (auto sharedLayout = mlir::dyn_cast(layout)) + return sharedLayout.getCTALayout(); + else + llvm::report_fatal_error("Unimplemented usage of getCTALayout"); + return {}; +} + +SmallVector getCTAsPerCGA(Attribute layout) { + ArrayRef ref; + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getCTAsPerCGA(); + else if (mlir::isa(layout)) + return {1, 1}; + else if (auto sharedLayout = mlir::dyn_cast(layout)) + ref = sharedLayout.getCTALayout().getCTAsPerCGA(); + else + llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); + return SmallVector(ref.begin(), ref.end()); +} + +SmallVector getCTASplitNum(Attribute layout) { + SmallVector res; + if (auto clusterLayout = + mlir::dyn_cast(layout)) { + return SmallVector(clusterLayout.getSizePerCore().size(), 1); + } else if (auto distributedLayout = + mlir::dyn_cast(layout)) { + return distributedLayout.getCTASplitNum(); + } else if (mlir::isa(layout)) { + res.resize(2); + res[0] = res[1] = 1; + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + res.assign(sharedLayout.getCTALayout().getCTASplitNum().begin(), + sharedLayout.getCTALayout().getCTASplitNum().end()); + } else { + assert(false && "Unimplemented usage of getCTASplitNum"); + } + return res; +} + +SmallVector getCTAOrder(Attribute layout) { + SmallVector res; + if (auto distributedLayout = + mlir::dyn_cast(layout)) { + res = distributedLayout.getCTAOrder(); + } else if (mlir::isa(layout)) { + return {0, 1}; + } else if (auto sharedLayout = mlir::dyn_cast(layout)) { + res = SmallVector(sharedLayout.getCTALayout().getCTAOrder()); + } else { + llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); + } + return res; +} + +SmallVector getShapePerCTA(ArrayRef CTASplitNum, + ArrayRef shape) { + unsigned rank = shape.size(); + SmallVector shapePerCTA(rank); + for (unsigned i = 0; i < rank; ++i) { + // This wrapping rule must be consistent with emitCTAOffsetForLayout + unsigned splitNum = std::min(shape[i], CTASplitNum[i]); + shapePerCTA[i] = shape[i] / splitNum; + } + return shapePerCTA; +} + +SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { + if (auto sharedLayout = mlir::dyn_cast(layout)) { + // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. + // The first dim of shape is numStages. This is a work around, otherwise too + // many places would have to be modified in pipeline pass. Maybe we need to + // refactor this logic in the future. + auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); + if (shape.size() == CTASplitNum.size() + 1) { + auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); + res.insert(res.begin(), shape.front()); + return res; + } + } + return getShapePerCTA(getCTASplitNum(layout), shape); +} + +SmallVector getShapePerCTA(Type type) { + auto tensorType = cast(type); + return getShapePerCTA(tensorType.getEncoding(), tensorType.getShape()); +} + +unsigned getNumWarpsPerCTA(Attribute layout) { + SmallVector warpsPerCTA; + if (auto blockedLayout = dyn_cast(layout)) + warpsPerCTA = blockedLayout.getWarpsPerCTA(); + else if (auto sliceLayout = dyn_cast(layout)) + return getNumWarpsPerCTA(sliceLayout.getParent()); + else if (auto mmaLayout = dyn_cast(layout)) { + // Use the distributed layout interface to get the number of warps per CTA. + auto distributedLayout = cast(layout); + warpsPerCTA = distributedLayout.getWarpsPerCTA(); + } else if (auto mfmaLayout = dyn_cast(layout)) + warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + else if (auto wmmaLayout = dyn_cast(layout)) + warpsPerCTA = wmmaLayout.getWarpsPerCTA(); + else if (auto dotLayout = dyn_cast(layout)) + return getNumWarpsPerCTA(dotLayout.getParent()); + else if (auto sharedLayout = dyn_cast(layout)) + llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); + else + llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA"); + return product(warpsPerCTA); +} + +unsigned getNumCTAs(Attribute layout) { + return product(getCTAsPerCGA(layout)); +} + +bool isaDistributedLayout(Attribute layout) { + return isa(layout); +} + +template bool hasEncoding(Value value) { + auto type = value.getType(); + if (auto tensorType = dyn_cast(type)) { + auto encoding = tensorType.getEncoding(); + return encoding && isa(encoding); + } + return false; +} + +bool hasDotOperandEncoding(Value value) { + return hasEncoding(value); +} + +bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { + // If the new elements per thread is less than the old one, we will need to do + // convert encoding that goes through shared memory anyway. So we consider it + // as expensive. + RankedTensorType tensorTy = cat.getType(); + auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); + auto shape = tensorTy.getShape(); + auto elemTy = tensorTy.getElementType(); + auto newTotalElemsPerThread = + gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy); + return newTotalElemsPerThread < totalElemsPerThread; +} + +LogicalResult CTALayoutAttr::verify( + function_ref emitError, ArrayRef CTAsPerCGA, + ArrayRef CTASplitNum, ArrayRef CTAOrder) { + if (CTAsPerCGA.size() != CTASplitNum.size() || + CTASplitNum.size() != CTAOrder.size()) { + return emitError() << "CTAsPerCGA, CTASplitNum, and CTAOrder must all have " + "the same rank."; + } + + if (!isPermutationOfIota(CTAOrder)) { + return emitError() + << "CTAOrder must be a permutation of 0..(rank-1), but was [" + << CTAOrder << "]"; + } + + return success(); +} + +LogicalResult +BlockedEncodingAttr::verify(function_ref emitError, + ArrayRef sizePerThread, + ArrayRef threadsPerWarp, + ArrayRef warpsPerCTA, + ArrayRef order, CTALayoutAttr CTALayout) { + if (sizePerThread.size() != threadsPerWarp.size() || + threadsPerWarp.size() != warpsPerCTA.size() || + warpsPerCTA.size() != order.size()) { + return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and " + "order must all have the same rank."; + } + + // Empty CTALayout is allowed, but if it's present its rank must match the + // BlockedEncodingAttr's rank. + if (CTALayout.getCTASplitNum().size() != 0 && + sizePerThread.size() != CTALayout.getCTASplitNum().size()) { + return emitError() << "BlockedEncodingAttr and CTALayout's fields must " + "have the same rank."; + } + if (!isPermutationOfIota(order)) { + return emitError() + << "order must be a permutation of 0..(rank-1), but was [" << order + << "]"; + } + return success(); +} + +// 1 element per thread +// order = reverse(arange(rank)) +triton::gpu::BlockedEncodingAttr +getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, + int numWarps, int threadsPerWarp, int numCTAs) { + int rank = shape.size(); + llvm::SmallVector order(rank); + std::iota(order.begin(), order.end(), 0); + std::reverse(order.begin(), order.end()); + llvm::SmallVector sizePerThread(rank, 1); + triton::gpu::BlockedEncodingAttr encoding = + triton::gpu::BlockedEncodingAttr::get(context, shape, sizePerThread, + order, numWarps, threadsPerWarp, + numCTAs); + return encoding; +} + +} // namespace gpu +} // namespace triton +} // namespace mlir + +static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, + unsigned &value, StringRef desc) { + auto intAttr = mlir::dyn_cast(attr); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer type in ") + << desc; + return failure(); + } + if (intAttr.getType().isSignedInteger()) { + int64_t attrVal = intAttr.getSInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else if (intAttr.getType().isSignlessInteger()) { + int64_t attrVal = intAttr.getInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else { + value = intAttr.getUInt(); + } + return success(); +} + +static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, + bool &value, StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + +// parse an array of integers +static LogicalResult parseIntArrayAttr(AsmParser &parser, + const NamedAttribute &attr, + SmallVector &res, + StringRef desc) { + auto arrayAttr = mlir::dyn_cast(attr.getValue()); + if (!arrayAttr) { + parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; + return failure(); + } + for (Attribute i : arrayAttr) { + unsigned value; + if (parseIntAttrValue(parser, i, value, desc).failed()) + return failure(); + res.push_back(value); + } + return success(); +}; + +static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, + unsigned &value, StringRef desc) { + return parseIntAttrValue(parser, attr.getValue(), value, desc); +}; + +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr.getValue(), value, desc); +}; + +// Print the CTALayout if it's not equal to the default. +static void maybePrintCTALayout(mlir::MLIRContext *context, + mlir::AsmPrinter &printer, CTALayoutAttr layout, + unsigned rank) { + if (layout != CTALayoutAttr::getDefault(context, rank)) { + printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]" + << ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]" + << ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]"; + } +} + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" + +SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) { + return SliceEncodingAttr::get(getContext(), axis, *this); +} +SmallVector +BlockedEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + auto sizePerThread = getSizePerThread(); + auto warpsPerCTA = getWarpsPerCTA(); + auto threadsPerWarp = getThreadsPerWarp(); + auto shapePerCTA = getShapePerCTA(*this, shape); + assert(rank == sizePerThread.size() && + "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); + SmallVector elemsPerThread(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]; + elemsPerThread[i] = ceil(shapePerCTA[i], t) * sizePerThread[i]; + } + return elemsPerThread; +} +unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. +// But we need to have a consistent interface with e.g. SliceEncodingAttr, which +// computes some of these fields. +SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector BlockedEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector BlockedEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector BlockedEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector BlockedEncodingAttr::getWarpOrder() const { + return SmallVector(getOrder()); +} +SmallVector BlockedEncodingAttr::getThreadsPerWarp() const { + return SmallVector(getThreadsPerWarp__()); +} +SmallVector BlockedEncodingAttr::getThreadOrder() const { + return SmallVector(getOrder()); +} +SmallVector BlockedEncodingAttr::getSizePerThread() const { + return SmallVector(getSizePerThread__()); +} +SmallVector +BlockedEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + SmallVector shape; + for (unsigned d = 0, n = getOrder().size(); d < n; ++d) + shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] * + getWarpsPerCTA()[d]); + return shape; +} + +template +SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { + size_t rank = shape.size(); + unsigned dim = getDim(); + SmallVector retShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) + retShape[d] = shape[d]; + else if (d == dim) + retShape[d] = 1; + else + retShape[d] = shape[d - 1]; + } + return retShape; +} +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; + +SmallVector +SliceEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + auto parent = getParent(); + auto parentElemsPerThread = + ::getElemsPerThread(parent, paddedShape(shape), eltTy); + parentElemsPerThread.erase(parentElemsPerThread.begin() + getDim()); + return parentElemsPerThread; +} +unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} +SmallVector SliceEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + res.erase(res.begin() + getDim()); + return res; +} +SmallVector SliceEncodingAttr::getCTAOrder() const { + auto parentCTAOrder = ::getCTAOrder(getParent()); + return eraseOrder(parentCTAOrder, getDim()); +} +SmallVector SliceEncodingAttr::getCTAsPerCGA() const { + auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent()); + if (parentCTAsPerCGA[getDim()] == 1) { + parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim()); + return parentCTAsPerCGA; + } + /* For getCTAsPerCGA of a slice layout, we have two choices: + * (1) Return CTAsPerCGA of its parent. This is not a perfect solution + * because the rank of the returned CTAsPerCGA does not match the rank of + * tensorShape. + * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a + * perfect solution because the product of the returned CTAsPerCGA might not + * match numCTAs. + * To avoid introducing inconsistencies to the shape and + * layout system, the usage of directly getting CTAsPerCGA of a slice layout + * in which the sliced dim is not 1 is banned. You should always consider + * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) + * in the branch where layout is an instance of SliceEncodingAttr. This is + * inconvenient but safe. + */ + llvm::report_fatal_error( + "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); +} +SmallVector SliceEncodingAttr::getWarpsPerCTA() const { + auto parent = getParent(); + auto parentWarpsPerCTA = ::getWarpsPerCTA(parent); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + getDim()); + int32_t nextDim = getDim() < warpsPerCTA.size() ? getDim() : getDim() - 1; + warpsPerCTA[nextDim] *= parentWarpsPerCTA[getDim()]; + return warpsPerCTA; +} +SmallVector SliceEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector SliceEncodingAttr::getThreadsPerWarp() const { + auto parent = getParent(); + auto parentThreadsPerWarp = ::getThreadsPerWarp(parent); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + getDim()); + int32_t nextDim = getDim() < threadsPerWarp.size() ? getDim() : getDim() - 1; + threadsPerWarp[nextDim] *= parentThreadsPerWarp[getDim()]; + return threadsPerWarp; +} +SmallVector SliceEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector SliceEncodingAttr::getSizePerThread() const { + auto sizePerThread = ::getSizePerThread(getParent()); + sizePerThread.erase(sizePerThread.begin() + getDim()); + return sizePerThread; +} +SmallVector +SliceEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + SmallVector shape = ::getShapePerCTATile(getParent(), tensorShape); + shape.erase(shape.begin() + getDim()); + return shape; +} + +// + +SmallVector +AMDMfmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + assert((rank == 2 || rank == 3) && "Unexpected rank of mfma layout"); + + SmallVector elemsPerThread(rank); + auto nonKDim = getMDim(); + auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16); + if (rank == 3) + elemsPerThread[0] = ceil(shape[0], getWarpsPerCTA()[0]); + if (getIsTransposed()) { + unsigned elemsCol = + ceil(shape[rank - 1], nonKDim * getWarpsPerCTA()[rank - 1]) * + elemsPerThreadPerTile; + unsigned elemsRow = + ceil(shape[rank - 2], nonKDim * getWarpsPerCTA()[rank - 2]); + elemsPerThread[rank - 2] = elemsRow; + elemsPerThread[rank - 1] = elemsCol; + } else { + unsigned elemsCol = + ceil(shape[rank - 1], nonKDim * getWarpsPerCTA()[rank - 1]); + unsigned elemsRow = + ceil(shape[rank - 2], nonKDim * getWarpsPerCTA()[rank - 2]) * + elemsPerThreadPerTile; + elemsPerThread[rank - 2] = elemsRow; + elemsPerThread[rank - 1] = elemsCol; + } + return elemsPerThread; +} + +unsigned AMDMfmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// + +SmallVector +AMDWmmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + assert((rank == 2 || rank == 3) && "Unexpected rank of wmma layout"); + + SmallVector elemsPerThread(rank); + auto mnkDim = getMNKDimPerWMMAInstr(); + auto elemsPerThreadPerTile = getSizePerThread(); + auto warpsPerCTA = getWarpsPerCTA(); + + if (rank == 3) + elemsPerThread[0] = ceil(shape[0], getWarpsPerCTA()[0]); + elemsPerThread[rank - 2] = + ceil(shape[rank - 2], mnkDim[0] * warpsPerCTA[rank - 2]) * + elemsPerThreadPerTile[rank - 2]; + elemsPerThread[rank - 1] = + ceil(shape[rank - 1], mnkDim[1] * warpsPerCTA[rank - 1]) * + elemsPerThreadPerTile[rank - 1]; + return elemsPerThread; +} + +unsigned AMDWmmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// + +SmallVector +NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + assert(rank == 2 || + (rank == 3 && isAmpere()) && "Unexpected rank of mma layout"); + assert((isVolta() || isAmpere() || isHopper()) && + "For NvidiaMmaEncodingAttr only version 1~3 is supported"); + + auto shapePerCTA = getShapePerCTA(getCTALayout().getCTASplitNum(), shape); + + SmallVector elemsPerThread(rank); + if (isVolta()) { + auto [isARow, isBRow, isAVec4, isBVec4, id] = decodeVoltaLayoutStates(); + static constexpr std::array fpw{{2, 2}}; + unsigned packSize0 = (isARow || isAVec4) ? 1 : 2; + unsigned packSize1 = (isBRow && !isBVec4) ? 2 : 1; + unsigned repM = 2 * packSize0; + unsigned repN = 2 * packSize1; + unsigned spwM = fpw[0] * 4 * repM; + unsigned spwN = fpw[1] * 4 * repN; + unsigned wptM = getWarpsPerCTA()[0]; + unsigned wptN = getWarpsPerCTA()[1]; + unsigned resM = repM * std::max(1, shapePerCTA[0] / (spwM * wptM)); + unsigned resN = 2 * repN * std::max(1, shapePerCTA[1] / (spwN * wptN)); + elemsPerThread[0] = resM; + elemsPerThread[1] = resN; + } else if (isAmpere()) { + unsigned elemsRow = + ceil(shapePerCTA[rank - 2], 16 * getWarpsPerCTA()[rank - 2]) * + 2; + unsigned elemsCol = + ceil(shapePerCTA[rank - 1], 8 * getWarpsPerCTA()[rank - 1]) * + 2; + if (rank == 3) + elemsPerThread[0] = ceil(shapePerCTA[0], getWarpsPerCTA()[0]); + elemsPerThread[rank - 2] = elemsRow; + elemsPerThread[rank - 1] = elemsCol; + } else if (isHopper()) { + auto wpt = getWarpsPerCTA(); + auto instrMNK = getInstrShape(); + int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); + elemsPerThread[0] = 2 * repM; + elemsPerThread[1] = (instrMNK[1] / 4) * repN; + } else { + llvm_unreachable("Unexpected mma version"); + } + + return elemsPerThread; +} + +unsigned NvidiaMmaEncodingAttr::getElemsPerThreadOfOperand( + int opIdx, ArrayRef shape) const { + size_t rank = shape.size(); + assert(rank == 2 && "Unexpected rank of mma layout"); + auto shapePerCTA = getShapePerCTA(*this, shape); + int res = 0; + if (isVolta()) { + llvm_unreachable( + "getElemsPerThreadOfOperand() not supported for version 1"); + } else if (isAmpere()) { + llvm_unreachable( + "getElemsPerThreadOfOperand() not supported for version 2"); + } else if (isHopper()) { + auto wpt = getWarpsPerCTA(); + auto instrMNK = getInstrShape(); + if (opIdx == 0) { + int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repK = ceil(shapePerCTA[1], instrMNK[2]); + return 8 * repM * repK; + + } else if (opIdx == 1) { + int repK = ceil(shapePerCTA[0], instrMNK[2]); + int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); + // benzh@ here need more check + return 4 * std::max(instrMNK[1] / 32, 1) * repK * repN; + } + } + return res; +} + +unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +// + +SmallVector +SharedEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return SmallVector(); +} +unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for shared layout"); + return 0; +} + +SmallVector +DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + llvm_unreachable("getElemsPerThread is not supported for dot operand"); + return SmallVector(); +} + +unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + if (auto mmaParent = mlir::dyn_cast(getParent())) { + return mmaParent.getTotalElemsPerThreadForOperands(shape, eltTy, + getKWidth(), getOpIdx()); + } + if (auto blockedLayout = mlir::dyn_cast(getParent())) { + auto shapePerCTA = getShapePerCTA(*this, shape); + auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); + auto order = blockedLayout.getOrder(); + auto sizePerThread = ::getSizePerThread(blockedLayout); + + int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; + int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; + + bool isM = getOpIdx() == 0; + + int mSizePerThread = + order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int nSizePerThread = + order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + int sizePerThreadMN = isM ? mSizePerThread : nSizePerThread; + + int mShapePerCTATile = + order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int nShapePerCTATile = + order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; + int shapePerCTAMNTile = isM ? mShapePerCTATile : nShapePerCTATile; + + return K * std::max(otherDim / shapePerCTAMNTile, 1) * sizePerThreadMN; + } + llvm_unreachable("unknown dot operand parent layout"); + return 0; +} +SmallVector DotOperandEncodingAttr::getCTAsPerCGA() const { + return ::getCTAsPerCGA(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTAOrder() const { + return ::getCTAOrder(getParent()); +} +SmallVector DotOperandEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + auto rank = res.size(); + assert(rank == 2 || rank == 3 && "Invalid dotLayout"); + + // Do not split CTA in K dimension + getOpIdx() == 0 ? res[rank - 1] = 1 : res[rank - 2] = 1; + return res; +} +SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto distributedLayout = + mlir::dyn_cast(parentLayout)) { + return distributedLayout.getWarpsPerCTA(); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-DistributedEncodingAttr parent not " + "supported yet"); + } +} +SmallVector DotOperandEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector DotOperandEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector DotOperandEncodingAttr::getShapePerCTATile( + ArrayRef tensorShape) const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { + return parentMmaLayout.getShapePerCTATileForDotOperands(tensorShape, + getOpIdx()); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " + "supported yet"); + } +} + +LogicalResult DotOperandEncodingAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + unsigned opIdx, Attribute parent, unsigned kWidth) { + if (opIdx != 0 && opIdx != 1) { + return emitError() + << "triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: " + << opIdx; + } + if (!parent) { + return emitError() << "triton_gpu.dot_op parent paramenter cannot be null"; + } + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0 && !parentAttr.isAmpere()) + return emitError() << "triton_gpu.dot_op kWidth parameter can only be " + "non-zero for Ampere MMA parent"; + if (kWidth == 0 && parentAttr.isAmpere()) + return emitError() + << "triton_gpu.dot_op kWidth parameter is mandatory for " + "Ampere MMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + // TODO: remove this condition if new values are supported + if (kWidth != 16) + return emitError() << "triton_gpu.dot_op kWidth parameter supports " + "only 16 for WMMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth == 0) + return emitError() + << "triton_gpu.dot_op kWidth parameter is mandatory for " + "MFMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() + << "triton_gpu.dot_op kWidth parameter is not supported " + "when the parent is a blocked layout"; + return success(); + } + + return emitError() << "triton_gpu.dot_op unexpected parent layout: " + << parent; +} + +//===----------------------------------------------------------------------===// +// Blocked Encoding +//===----------------------------------------------------------------------===// + +static std::optional getCTALayoutOrError( + AsmParser &parser, std::optional> CTAsPerCGA, + std::optional> CTASplitNum, + std::optional> CTAOrder, unsigned rank) { + if (CTAsPerCGA && CTASplitNum && CTAOrder) { + return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum, + *CTAOrder); + } + if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) { + return CTALayoutAttr::getDefault(parser.getContext(), rank); + } + parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder " + "must all be present or all be absent"); + return std::nullopt; +} + +Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector sizePerThread; + SmallVector threadsPerWarp; + SmallVector warpsPerCTA; + SmallVector order; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "sizePerThread") { + if (parseIntArrayAttr(parser, attr, sizePerThread, + "number of elements per thread") + .failed()) + return {}; + } else if (attr.getName() == "threadsPerWarp") { + if (parseIntArrayAttr(parser, attr, threadsPerWarp, + "number of threads per warp") + .failed()) + return {}; + } else if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, + "number of warps per CTA") + .failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/sizePerThread.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), + sizePerThread, threadsPerWarp, + warpsPerCTA, order, *CTALayout); +} + +void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "sizePerThread = [" << ArrayRef(getSizePerThread()) << "]" + << ", threadsPerWarp = [" << ArrayRef(getThreadsPerWarp()) << "]" + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" + << ", order = [" << getOrder() << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getSizePerThread().size()); + + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// MMA encoding +//===----------------------------------------------------------------------===// + +Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + SmallVector instrShape; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { + return {}; + } + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CTALayout, + instrShape); +} + +void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + + printer << ", instrShape = [" << getInstrShape() << "]}>"; +} + +//===----------------------------------------------------------------------===// +// MFMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned versionMajor = 0; + unsigned versionMinor = 0; + SmallVector warpsPerCTA; + SmallVector instrShape; + bool isTransposed; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "versionMajor") { + if (parseUInt(parser, attr, versionMajor, "versionMajor").failed()) + return {}; + } + if (attr.getName() == "versionMinor") { + if (parseUInt(parser, attr, versionMinor, "versionMinor").failed()) + return {}; + } + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "instrShape") { + if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) + return {}; + } + if (attr.getName() == "isTransposed") { + if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked( + parser.getContext(), versionMajor, versionMinor, warpsPerCTA, + instrShape[0], instrShape[1], isTransposed, *CTALayout); +} + +void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "versionMajor = " << getVersionMajor() // + << ", versionMinor = " << getVersionMinor() // + << ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" // + << ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" // + << ", isTransposed = " << getIsTransposed(); + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + printer << "}>"; +} + +LogicalResult +AMDMfmaEncodingAttr::verify(function_ref emitError, + unsigned versionMajor, unsigned versionMinor, + llvm::ArrayRef warpsPerCTA, + unsigned mDim, unsigned nDim, bool isTransposed, + mlir::triton::gpu::CTALayoutAttr) { + if (!(versionMajor >= 0 && versionMajor <= 3)) { + return emitError() << "major version must be in the [0, 3] range"; + } + if (versionMinor != 0) { + return emitError() << "minor version must be 0"; + } + if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) { + return emitError() + << "(M, N) cases other than (32, 32) or (16, 16) unimplemented"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// WMMA encoding +//===----------------------------------------------------------------------===// + +Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector warpsPerCTA; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "warpsPerCTA") { + if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) + return {}; + } + if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } + if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } + if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), + warpsPerCTA, *CTALayout); +} + +void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getWarpsPerCTA().size()); + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// Sliced Encoding +//===----------------------------------------------------------------------===// + +Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + unsigned dim = mlir::cast(attrs.get("dim")).getInt(); + Attribute parent = attrs.get("parent"); + return parser.getChecked(parser.getContext(), dim, parent); +} + +void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "dim = " << getDim() << ", " + << "parent = " << getParent() << "}>"; +} + +//===----------------------------------------------------------------------===// +// Shared encoding +//===----------------------------------------------------------------------===// + +Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned vec = 0; + unsigned perPhase = 0; + unsigned maxPhase = 0; + SmallVector order; + std::optional> CTAsPerCGA; + std::optional> CTASplitNum; + std::optional> CTAOrder; + bool hasLeadingOffset = false; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "vec") { + if (parseUInt(parser, attr, vec, "vec").failed()) + return {}; + } else if (attr.getName() == "perPhase") { + if (parseUInt(parser, attr, perPhase, "perPhase").failed()) + return {}; + } else if (attr.getName() == "maxPhase") { + if (parseUInt(parser, attr, maxPhase, "maxPhase").failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "CTAsPerCGA") { + if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") + .failed()) + return {}; + } else if (attr.getName() == "CTASplitNum") { + if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") + .failed()) + return {}; + } else if (attr.getName() == "CTAOrder") { + if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") + .failed()) + return {}; + } else if (attr.getName() == "hasLeadingOffset") { + if (parseBool(parser, attr, hasLeadingOffset, "hasLeadingOffset") + .failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + std::optional CTALayout = getCTALayoutOrError( + parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/order.size()); + if (!CTALayout.has_value()) + return {}; + + return parser.getChecked(parser.getContext(), vec, + perPhase, maxPhase, order, + *CTALayout, hasLeadingOffset); +} + +void SharedEncodingAttr::print(AsmPrinter &printer) const { + printer << "<{" + << "vec = " << getVec() // + << ", perPhase = " << getPerPhase() + << ", maxPhase = " << getMaxPhase() // + << ", order = [" << getOrder() << "]"; + maybePrintCTALayout(getContext(), printer, getCTALayout(), + /*rank=*/getOrder().size()); + printer << ", hasLeadingOffset = " << getHasLeadingOffset() << "}>"; +} + +//===----------------------------------------------------------------------===// +// Mfma encoding +//===----------------------------------------------------------------------===// +// TODO: there is a lot of common code with MmaEncoding here + +SmallVector +AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + shapePerCTATile[rank - 1] *= getMDim(); + shapePerCTATile[rank - 2] *= getNDim(); + return shapePerCTATile; +} + +SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDMfmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDMfmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector AMDMfmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector AMDMfmaEncodingAttr::getThreadsPerWarp() const { + unsigned rows, cols; + auto rank = ::getOrder(*this).size(); + SmallVector res(rank, 1); + if (getMDim() == 32) { + cols = 2; + rows = 32; + } else { + assert(getMDim() == 16); + cols = 4; + rows = 16; + } + if (getIsTransposed()) { + res[rank - 1] = cols; + res[rank - 2] = rows; + } else { + res[rank - 1] = rows; + res[rank - 2] = cols; + } + return res; +} + +SmallVector AMDMfmaEncodingAttr::getSizePerThread() const { + unsigned rows, cols; + auto rank = ::getOrder(*this).size(); + SmallVector res(rank, 1); + if (getMDim() == 32) { + rows = 16; + cols = 1; + } else if (getMDim() == 16) { + rows = 4; + cols = 1; + } else + llvm_unreachable("Unexpected mfma non-k dim"); + + if (getIsTransposed()) { + res[rank - 1] = rows; + res[rank - 2] = cols; + } else { + res[rank - 1] = cols; + res[rank - 2] = rows; + } + return res; +} + +SmallVector +AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { + unsigned mDim = getMDim(); + unsigned nDim = getNDim(); + assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || + (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + constexpr int waveSize = 64; // MFMA is used on wave64 architectures only + int kGroups = -1; + if (mDim == nDim) + kGroups = waveSize / mDim; + if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) + kGroups = 1; + int64_t kDim = kWidth * kGroups; + if (opIdx == 0) + return {mDim, kDim}; + else + assert(opIdx == 1); + return {kDim, nDim}; +} + +SmallVector +AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, + int kWidth, int opIdx) const { + auto operandTileShape = getMFMAInstrShapeForOperands(kWidth, opIdx); + auto rank = operandShape.size(); + auto warpsPerCTA = getWarpsPerCTA(); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto rep = getMFMARepForOperands(shape, kWidth, opIdx); + return product(rep) * kWidth; +} + +SmallVector +AMDMfmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + if (opIdx == 0) { + return {4, 1}; + } else if (opIdx == 1) { + return {1, 4}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; + } +} + +SmallVector +AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, + int opIdx) const { + assert(getMDim() == 32 || getMDim() == 16); + auto parentShapePerCTATile = getShapePerCTATile(shape); + auto rank = parentShapePerCTATile.size(); + if (opIdx == 0) { + if (rank == 2) + return {parentShapePerCTATile[rank - 2], 32}; + else + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32}; + } else if (opIdx == 1) { + if (rank == 2) + return {32, parentShapePerCTATile[rank - 1]}; + else + return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } + llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1"); +} + +SmallVector +AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + + auto mnkDim = getMNKDimPerWMMAInstr(); + shapePerCTATile[rank - 2] *= mnkDim[0]; + shapePerCTATile[rank - 1] *= mnkDim[1]; + return shapePerCTATile; +} +SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector AMDWmmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector AMDWmmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector AMDWmmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector AMDWmmaEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector AMDWmmaEncodingAttr::getThreadsPerWarp() const { + auto rank = getWarpsPerCTA().size(); + SmallVector threads(rank, 1); + auto mnkInstr = getMNKDimPerWMMAInstr(); + threads[rank - 2] = mnkInstr[0] / getSizePerThread()[rank - 2]; + threads[rank - 1] = mnkInstr[1] / getSizePerThread()[rank - 1]; + return threads; +} + +SmallVector AMDWmmaEncodingAttr::getSizePerThread() const { + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); + sizePerThread[rank - 2] = 8; + sizePerThread[rank - 1] = 1; + return sizePerThread; +} +SmallVector +AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + auto rank = getWarpsPerCTA().size(); + SmallVector sizePerThread(rank, 1); + if (opIdx == 0) { + sizePerThread[rank - 2] = 1; + sizePerThread[rank - 1] = 16; + } else if (opIdx == 1) { + sizePerThread[rank - 2] = 16; + sizePerThread[rank - 1] = 1; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } + return sizePerThread; +} + +SmallVector +AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, + int opIdx) const { + auto parentShapePerCTA = getShapePerCTATile(shape); + auto rank = shape.size(); + assert(rank = 2); + if (opIdx == 0) { + return {parentShapePerCTA[0], static_cast(shape[1])}; + } else if (opIdx == 1) { + return {static_cast(shape[0]), parentShapePerCTA[1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } +} + +unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto rep = getWMMARepForOperands(shape, eltTy, kWidth, opIdx); + return product(rep) * kWidth; +} + +SmallVector +AMDWmmaEncodingAttr::getWMMAElemsPerInstrForOperands() const { + return {16, 16}; +} + +SmallVector +AMDWmmaEncodingAttr::getWMMARepForOperands(ArrayRef operandShape, + Type elemType, int kWidth, + int opIdx) const { + auto operandTileShape = getWMMAElemsPerInstrForOperands(); + assert(operandTileShape.size() == 2); + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = operandShape.size(); + assert(rank == 2 || rank == 3); + int numRepBatch = + rank == 3 ? std::max(1, operandShape[0] / warpsPerCTA[0]) : 1; + if (opIdx == 0) + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / + (operandTileShape[0] * warpsPerCTA[rank - 2])), + std::max(1, operandShape[rank - 1] / operandTileShape[1])}; + else { + assert(opIdx == 1); + return { + numRepBatch, + std::max(1, operandShape[rank - 2] / operandTileShape[0]), + std::max(1, operandShape[rank - 1] / (operandTileShape[1] * + warpsPerCTA[rank - 1]))}; + } +} + +SmallVector AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr() { + // TODO: move magic numbers out of the code + return {16, 16, 16}; +} + +//===----------------------------------------------------------------------===// +// Mma encoding +//===----------------------------------------------------------------------===// + +bool NvidiaMmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; } + +bool NvidiaMmaEncodingAttr::isTuring() const { + return getVersionMajor() == 2 && getVersionMinor() == 1; +} + +bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } + +bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } + +SmallVector NvidiaMmaEncodingAttr::getCTAsPerCGA() const { + return SmallVector(getCTALayout().getCTAsPerCGA()); +} +SmallVector NvidiaMmaEncodingAttr::getCTAOrder() const { + return SmallVector(getCTALayout().getCTAOrder()); +} +SmallVector NvidiaMmaEncodingAttr::getCTASplitNum() const { + return SmallVector(getCTALayout().getCTASplitNum()); +} +SmallVector NvidiaMmaEncodingAttr::getWarpsPerCTA() const { + return SmallVector(getWarpsPerCTA__()); +} +SmallVector NvidiaMmaEncodingAttr::getWarpOrder() const { + return ::getWarpOrder(*this); +} +SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { + auto rank = getWarpsPerCTA().size(); + SmallVector res(rank, 1); + if (isVolta()) { + res[rank - 2] = 4; + res[rank - 1] = 8; + return res; + } + if (isAmpere()) { + res[rank - 2] = 8; + res[rank - 1] = 4; + return res; + } + if (isHopper()) { + res[rank - 2] = 8; + res[rank - 1] = 4; + return res; + } + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for unknown Mma version "); +} +SmallVector NvidiaMmaEncodingAttr::getThreadOrder() const { + return ::getOrder(*this); +} +SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { + auto rank = ::getOrder(*this).size(); + SmallVector res(rank, 1); + if (isAmpere()) { + res[rank - 2] = 2; + res[rank - 1] = 2; + return res; + } + if (isVolta()) { + res[rank - 2] = 1; + res[rank - 1] = 2; + return res; + } + if (isHopper()) { + auto instrShape = getInstrShape(); + // WGMMA instructions have an order of [0, 1] with 4 warps, each with 8 + // unique thread ids (32 in a warp group) per column. It is 1 warp wide with + // 4 unique thread ids in the row. So the size per thread is the instruction + // size divided by the number of unique thread ids. + return SmallVector{instrShape[0] * 4 / 32, instrShape[1] / 4}; + } + llvm_unreachable("Unexpected mma version"); +} + +SmallVector +NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + if (isAmpere()) { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), + warpsPerCTA.end()); + shapePerCTATile[rank - 1] *= 8; + shapePerCTATile[rank - 2] *= 16; + return shapePerCTATile; + } + if (isVolta()) { + assert(!tensorShape.empty() && "Volta needs the tensorShape"); + if (tensorShape.size() == 1) // must be SliceEncoding + return {static_cast(tensorShape[0]), + static_cast(tensorShape[0])}; + return {static_cast(tensorShape[0]), + static_cast(tensorShape[1])}; + } + if (isHopper()) { + auto instrShape = getInstrShape(); + return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]}; + } + llvm::report_fatal_error("Unexpected MMA layout version found"); +} + +// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor +std::tuple +NvidiaMmaEncodingAttr::decodeVoltaLayoutStates() const { + unsigned versionMinor = getVersionMinor(); + bool isARow = versionMinor & (1 << 0); + bool isBRow = versionMinor & (1 << 1); + bool isAVec4 = versionMinor & (1 << 2); + bool isBVec4 = versionMinor & (1 << 3); + + int id = 0; + for (int i = numBitsToHoldMmaV1ID - 1; i >= 0; --i) + id = (id << 1) + static_cast(versionMinor & (1 << (4 + i))); + + return std::make_tuple(isARow, isBRow, isAVec4, isBVec4, id); +} + +bool NvidiaMmaEncodingAttr::getMMAv1IsRow(int opIdx) const { + auto [isARow, isBRow, _0, _1, _2] = decodeVoltaLayoutStates(); + return opIdx == 0 ? isARow : isBRow; +} +bool NvidiaMmaEncodingAttr::getMMAv1IsVec4(int opIdx) const { + auto [_0, _1, isAVec4, isBVec4, _2] = decodeVoltaLayoutStates(); + return opIdx == 0 ? isAVec4 : isBVec4; +} +int NvidiaMmaEncodingAttr::getMMAv1NumOuter(ArrayRef shape, + int opIdx) const { + auto spw = getMMAv1ShapePerWarp(opIdx); + auto rep = getMMAv1Rep(opIdx); + auto warpsPerCTA = getWarpsPerCTA(); + if (opIdx == 0) { + return rep[0] * shape[0] / (spw[0] * warpsPerCTA[0]); + } else { + return rep[1] * shape[1] / (spw[1] * warpsPerCTA[1]); + } +} +SmallVector NvidiaMmaEncodingAttr::getMMAv1Rep(int opIdx) const { + auto [isARow, isBRow, isAVec4, isBVec4, _] = decodeVoltaLayoutStates(); + // A + if (opIdx == 0) { + int packSize = (isARow || isAVec4) ? 1 : 2; + return {2 * packSize, 0, 1}; + } + // B + else { + int packSize = (isBRow && !isBVec4) ? 2 : 1; + return {0, 2 * packSize, 1}; + } +} +SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { + auto rep = getMMAv1Rep(opIdx); + if (opIdx == 0) { + return {8 * rep[0], 0, 1}; + } else { + return {0, 8 * rep[1], 1}; + } +} +int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { + return 2 * getMMAv1Rep(opIdx)[opIdx]; +} +SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, + int bitwidth, + int opIdx) const { + auto rank = shape.size(); + auto warpsPerCTA = getWarpsPerCTA(); + SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; + int numRepBatch = + rank == 3 + ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) + : 1; + assert(isAmpere()); + + if (opIdx == 0) + return {numRepBatch, + std::max(1, shape[rank - 2] / + (shapePerWarp[1] * warpsPerCTA[rank - 2])), + std::max(1, shape[rank - 1] / shapePerWarp[3])}; + else { + assert(opIdx == 1); + return {numRepBatch, + std::max(1, shape[rank - 2] / shapePerWarp[3]), + std::max(1, shape[rank - 1] / (shapePerWarp[2] * + warpsPerCTA[rank - 1]))}; + } +} +unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( + ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { + auto shapePerCTA = getShapePerCTA(*this, shape); + int warpsPerCTAM = getWarpsPerCTA()[0]; + int warpsPerCTAN = getWarpsPerCTA()[1]; + // H100 + if (isHopper()) { + return getTotalElemsPerThread(shape, eltTy); + } + // A100 + if (isAmpere()) { + auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); + if (opIdx == 0) + return 4 * rep[0] * rep[1] * rep[2]; + if (opIdx == 1) + return 4 * rep[0] * rep[1] * std::max(rep[2] / 2, 1); + } + // V100 + if (isVolta()) { + bool isRow = getMMAv1IsRow(opIdx); + bool isVec4 = getMMAv1IsVec4(opIdx); + if (opIdx == 0) { + int packSizeM = (isRow || isVec4) ? 1 : 2; + int repM = 2 * packSizeM; + int spwM = 2 * 4 * repM; + int numM = getMMAv1NumOuter(shape, opIdx); + int NK = shape[1]; + int vec = 2 * repM; + // Here we mimic the logic in loadA, the result cannot be calculated + // directly. + llvm::DenseSet> visited; + auto ld = [&](int m, int k) { + visited.insert({m, k}); + if (vec > 4) { + if (isRow) + visited.insert({m, k + 4}); + else + visited.insert({m + 1, k}); + } + }; + for (unsigned k = 0; k < NK; k += 4) + for (unsigned m = 0; m < numM / 2; ++m) + if (!visited.count({m, k})) + ld(m, k); + return visited.size() * 2; + } + if (opIdx == 1) { + int packSizeN = (isRow && !isVec4) ? 2 : 1; + int repN = 2 * packSizeN; + int spwN = 2 * 4 * repN; + int numN = getMMAv1NumOuter(shape, opIdx); + int vec = 2 * repN; + + int NK = shape[0]; + // Here we mimic the logic in loadA, the result cannot be calculated + // directly. + llvm::DenseSet> visited; + int elemsPerLd = vec > 4 ? 4 : 2; + auto ld = [&](int n, int k) { + visited.insert({n, k}); + if (vec > 4) { + if (isRow) + visited.insert({n + 1, k}); + else + visited.insert({n, k + 4}); + } + }; + + for (unsigned k = 0; k < NK; k += 4) + for (unsigned n = 0; n < numN / 2; ++n) { + if (!visited.count({n, k})) + ld(n, k); + } + + return visited.size() * 2; + } + } + llvm_unreachable("unknown mma layout"); +} +SmallVector +NvidiaMmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, + int opIdx) const { + assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); + auto parentShapePerCTATile = getShapePerCTATile(shape); + auto rank = parentShapePerCTATile.size(); + if (opIdx == 0) { + if (rank == 2) + return {parentShapePerCTATile[rank - 2], 16}; + else + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 16}; + } else if (opIdx == 1) { + if (rank == 2) + return {16, parentShapePerCTATile[rank - 1]}; + else + return {parentShapePerCTATile[0], 16, parentShapePerCTATile[rank - 1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } +} +SmallVector +NvidiaMmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); + if (opIdx == 0) { + return {2, 4}; + } else if (opIdx == 1) { + return {4, 1}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; + } +} + +//===----------------------------------------------------------------------===// +// DotOperand Encoding +//===----------------------------------------------------------------------===// +SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for DotOperandEncodingAttr"); +} +SmallVector DotOperandEncodingAttr::getSizePerThread() const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { + return parentMmaLayout.getSizePerThreadForOperands(getOpIdx()); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " + "supported yet"); + return {}; + } +} + +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// + +class TritonGPUOpAsmInterface : public OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (auto mmaAttr = mlir::dyn_cast(attr)) { + os << "mma"; + return AliasResult::FinalAlias; + } else if (auto sharedAttr = mlir::dyn_cast(attr)) { + os << "shared"; + return AliasResult::FinalAlias; + } else if (auto blockedAttr = mlir::dyn_cast(attr)) { + os << "blocked"; + return AliasResult::FinalAlias; + } /* else if (auto sliceAttr = dyn_cast(attr)) { + os << "slice"; + return AliasResult::FinalAlias; + } */ + return OpAsmDialectInterface::getAlias(attr, os); + } +}; + +struct TritonGPUInferLayoutInterface + : public triton::DialectInferLayoutInterface { + using DialectInferLayoutInterface::DialectInferLayoutInterface; + + LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const override { + resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis, + operandEncoding); + return success(); + } + + // Infer the encoding of a tt.trans(x) given the encoding of x. + // + // Our goal is to choose an encoding so that the trans is a "nop". For + // example, in a blocked encoding, the same GPU threads hold the same + // elements, they're just "renamed" -- what was element [i,j] of the tensor is + // now element [j,i], but that element is held by the same GPU thread. + // + // For most properties of the encoding, we let + // outputEnc.prop = inputEnc.prop * trans.order, + // where `x * y` means we apply permutation y to x. + // + // This works because prop[i] tells you something about the i'th dimension of + // the tensor. (For example, sizePerThread[2] == 4 means that one GPU thread + // contains 4 elements along dim 2 of the tensor.) The transpose reorders the + // dimensions according to the perm trans.order, so we achieve our goal of + // having a "nop" transpose by reordering the values in the prop the same way. + // + // The big exception to this is the encoding's `order`. + // + // An encoding's order is a list of dimensions, from fastest moving (most + // minor) to slowest moving. Thus enc.order[i] does not tell you something + // about the i'th dimension of the tensor, and it would be disasterously + // incorrect to do enc.order * trans.order. + // + // But! If we invert enc.order, it *does* meet this criterion. For example, + // if enc.order = [2,0,1], inverse(enc.order) = [1,2,0]. If you stare at it, + // you'll see that inverse(enc.order)[i] == j means that dimension i is the + // j'th most minor. Therefore we can safely permute *this* by trans.order. + // + // Thus we have + // + // outputEnc.order = inverse(inverse(inputEnc.order) * trans.order) + // = inverse(trans.order) * inputEnc.order. + // + LogicalResult inferTransOpEncoding(Attribute operandEncoding, + ArrayRef order, // trans order + Attribute &resultEncoding) const override { + // Note: inferFooOpEncoding should not crash if given invalid inputs, which + // happens when someone creates invalid IR. If we return failure() on + // error, then MLIR will generate a helpful error message. + + auto invOrder = inversePermutation(order); + SmallVector invOrderUnsigned(invOrder.begin(), invOrder.end()); + + auto permuteCTALayout = + [&](const CTALayoutAttr &layout) -> FailureOr { + auto n = order.size(); + if (layout.getCTAsPerCGA().size() != n || + layout.getCTASplitNum().size() != n || + layout.getCTAOrder().size() != n) { + return failure(); + } + + return CTALayoutAttr::get( + getDialect()->getContext(), + applyPermutation(layout.getCTAsPerCGA(), order), + applyPermutation(layout.getCTASplitNum(), order), + applyPermutation(invOrderUnsigned, layout.getCTAOrder())); + }; + + if (auto enc = mlir::dyn_cast(operandEncoding)) { + if (enc.getOrder().size() != order.size()) { + return failure(); + } + FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + if (failed(ctaLayout)) { + return failure(); + } + resultEncoding = SharedEncodingAttr::get( + getDialect()->getContext(), enc.getVec(), enc.getPerPhase(), + enc.getMaxPhase(), applyPermutation(invOrderUnsigned, enc.getOrder()), + *ctaLayout, enc.getHasLeadingOffset()); + return success(); + } + + if (auto enc = mlir::dyn_cast(operandEncoding)) { + auto n = order.size(); + if (enc.getSizePerThread().size() != n || + enc.getThreadsPerWarp().size() != n || + enc.getWarpsPerCTA().size() != n || enc.getOrder().size() != n) { + return failure(); + } + FailureOr ctaLayout = permuteCTALayout(enc.getCTALayout()); + if (failed(ctaLayout)) { + return failure(); + } + resultEncoding = BlockedEncodingAttr::get( + getDialect()->getContext(), + applyPermutation(enc.getSizePerThread(), order), + applyPermutation(enc.getThreadsPerWarp(), order), + applyPermutation(enc.getWarpsPerCTA(), order), + applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout); + return success(); + } + + return failure(); // unhandled encoding + } + + LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const override { + auto sliceEncoding = mlir::dyn_cast(operandEncoding); + if (!sliceEncoding) + return emitOptionalError( + location, "ExpandDimsOp operand encoding must be SliceEncodingAttr"); + if (sliceEncoding.getDim() != axis) + return emitOptionalError( + location, "Incompatible slice dimension for ExpandDimsOp operand"); + resultEncoding = sliceEncoding.getParent(); + return success(); + } + + LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const override { + auto mmaRetEncoding = mlir::dyn_cast(retEncoding); + if (mmaRetEncoding && mmaRetEncoding.isHopper()) { + auto dotOpEnc = mlir::dyn_cast(operandEncoding); + if (!mlir::isa(operandEncoding) && + !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && + mlir::isa(dotOpEnc.getParent()))) { + return emitOptionalError( + location, "unexpected operand layout for NvidiaMmaEncodingAttr v3"); + } + } else if (auto dotOpEnc = + mlir::dyn_cast(operandEncoding)) { + if (opIdx != dotOpEnc.getOpIdx()) + return emitOptionalError(location, "Wrong opIdx"); + if (retEncoding != dotOpEnc.getParent()) + return emitOptionalError(location, "Incompatible parent encoding"); + } else + return emitOptionalError( + location, "Dot's a/b's encoding should be of DotOperandEncodingAttr"); + return success(); + } + + LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const override { + auto aEncoding = + mlir::dyn_cast(operandEncodingA); + auto bEncoding = + mlir::dyn_cast(operandEncodingB); + if (!aEncoding && !bEncoding) + return mlir::success(); + auto mmaAEncoding = + mlir::dyn_cast_or_null(aEncoding.getParent()); + if (mmaAEncoding && mmaAEncoding.isHopper()) + return success(); + // Verify that the encodings are valid. + if (!aEncoding || !bEncoding) + return op->emitError("mismatching encoding between A and B operands"); + if (aEncoding.getKWidth() != bEncoding.getKWidth()) + return op->emitError("mismatching kWidth between A and B operands"); + return success(); + } + + // Given a src shape + encoding and a dst shape, our goal is to compute a dst + // encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z] + // contains elements [a,b,c,d] before the reshape, it contains those same + // elements after the reshape, they're just "renamed". + // + // A dst encoding that satisfies this property does not exist for all inputs. + // Here are some positive and negative examples. + // + // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so + // dim 1 is the fastest-changing in the dst, but the src has the opposite + // order. + // - OK: 2x2x32 order=[1,0,2] -> 4x32. We choose dst order [0,1]. + // What's important is that the 2x2 dimensions appear in major-to-minor + // order. + // - NOT OK: 32x32 sizePerThread=[2,2] -> 1024. Thread 0 in the src + // contains elements [(0,0), (0,1), (1,0), and (1,1)]. We cannot express + // this with an encoding based on the dst shape. + // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will + // contain the same elements as before. + // + // Users of this function require that it is symmetrical: if + // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => + // srcEnc. + LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + auto src = mlir::dyn_cast(srcEnc); + if (!src) { + return emitOptionalError( + loc, "Non-reordering reshape only supports BlockedEncoding"); + } + + // Nop reshape; we can always infer an encoding. + if (srcShape == dstShape) { + dstEnc = srcEnc; + return success(); + } + + // default -> default encoding is always a nop. + auto context = srcEnc.getContext(); + int32_t numWarps = product(src.getWarpsPerCTA()); + int32_t threadsPerWarp = product(src.getThreadsPerWarp()); + int32_t numCTAs = product(src.getCTALayout().getCTAsPerCGA()); + if (srcEnc == getDefaultBlockedEncoding(context, srcShape, numWarps, + threadsPerWarp, numCTAs)) { + dstEnc = getDefaultBlockedEncoding(context, dstShape, numWarps, + threadsPerWarp, numCTAs); + return success(); + } + + // Feature flag to disable this routine while it's relatively new. + // TODO(jlebar): Remove this once we're confident in the code. + if (triton::tools::getBoolEnv( + "TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE")) { + return failure(); + } + + // Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA + // should be like the other fields in blocked encoding, but I'm not sure how + // to handle CTASplitNum. + if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || + !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { + return emitOptionalError( + loc, "Non-reordering reshape does not currently support multi-CTA " + "layouts other than the default layout."); + } + + // Cowardly refuse to handle encodings where shape[dim] is not divisible by + // sizePerThread[dim], threadsPerWarp[dim], and warpsPerCTA[dim]. (We make + // an exception if the block is larger than the shape.) + auto checkDivisibility = [&](StringRef name, ArrayRef subblock) { + for (int dim = 0; dim < srcShape.size(); dim++) { + if (srcShape[dim] >= subblock[dim] && + srcShape[dim] % subblock[dim] != 0) { + return emitOptionalError(loc, + "Can't do a non-reordering reshape because " + "the size of dimension ", + dim, " (", srcShape[dim], ")", + " is not divisible by ", name, "[", dim, "]", + " = ", subblock[dim]); + } + } + return success(); + }; + if (!succeeded( + checkDivisibility("sizePerThread", src.getSizePerThread())) || + !succeeded( + checkDivisibility("threadsPerWarp", src.getThreadsPerWarp())) || + !succeeded(checkDivisibility("warpsPerCTA", src.getWarpsPerCTA()))) { + return failure(); + } + + SmallVector, SmallVector>> decomp = + getReshapeDecomposition(srcShape, dstShape); + + // enc.order[i] == j means that dimension j is the enc.order[i]'th most + // minor. But what we usually want is the inverse: inverse(enc.order)[i] = j + // means that dimension i is the j'th most minor (larger means more major). + auto srcInvOrder = inversePermutation(src.getOrder()); + + // If src dims [a,b,c] are to be merged, then they must be consecutive in + // physical order, with `a` being the most major. + for (const auto &[srcDims, dstDims] : decomp) { + if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { + return emitOptionalError(loc, + "Cannot do a non-reordering reshape given " + "this src encoding order. Dimensions [", + join(srcDims), + "] must be physically consecutive."); + } + } + + // If src dims [a,b,c] are to be merged, then `c` must fill up sizePerThread + // / threadsPerWarp / blocksPerCTA before `b` can have any non-1 values. + // Examples: + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,2,2]. + // The total sizePerThread for dim 2 is 2, which is less than dim 2's + // size of 4. Therefore dim 1 cannot have non-1 sizePerThread. + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4]. + // Dim 2's sizePerThread covers its whole size, so dim 1 is allowed to + // have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[2,1,4]. + // Dim 1's sizePerThread does not cover its whole size, so dim 0 is not + // allowed to have non-1 sizePerThread. + // + // - NOT OK: shape=[4,4,4], sizePerThread=[1,1,2], + // threadsPerWarp=[1,2,1]. + // Dim 2 has 2 elems per thread and 1 thread per warp. 2*1 is less than + // dim 2's size. Therefore dim 1 must have threadsPerWarp=1. + // + // In addition, the encoding's block can be larger than the shape, but only + // in the most-major dimension of each decomposed chunk, and only after + // we've "used up" the more minor dims. Examples: + // + // - OK: shape=[4,4,4], sizePerThread=[1,2,4], threadsPerWarp=[16,2,1], + // warpsPerCTA=[4,1,1]. + // The whole size of dims 0 and 1 are covered by sizePerThread * + // threadsPerWarp. Therefore dim 2 is allowed to have threadsPerWarp and + // warpsPerCTA larger than its size. + for (const auto &[srcDims, dstDims] : decomp) { + auto shapeRemaining = gather(srcShape, srcDims); + auto checkSubblock = [&, srcDims = srcDims](ArrayRef subblock) { + // Iterate minor-to-major (i==0 is most major). + for (int i = srcDims.size() - 1; i >= 0; i--) { + int dim = srcDims[i]; + if (subblock[dim] == 1) { + continue; + } + + // Check that more-minor dims all have 1 in shapeRemaining. + for (int j = i + 1; j < srcDims.size(); j++) { + if (shapeRemaining[j] != 1) { + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Must use " + "up sizePerThread / threadsPerWarp / warpsPerCTA for " + "more-minor dimensions before more major-dims can use them."); + } + } + + if (shapeRemaining[i] >= subblock[dim]) { + assert(shapeRemaining[i] % subblock[dim] == 0); // checked earlier + shapeRemaining[i] /= subblock[dim]; + } else { + shapeRemaining[i] = 0; + } + + // Is the block larger than the shape in this dimension? This is OK + // only if we're the most-major dimension of the chunk and in all + // future chunks, only this most-major dim has a non-1 size. + if (shapeRemaining[i] == 0 && i != 0) { + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Block " + "size in dimension ", + dim, + " is larger than the shape that dimension, but this is only " + "allowed for the most-major dimension of a reshape chunk"); + } + } + return success(); + }; + if (!succeeded(checkSubblock(src.getSizePerThread())) || + !succeeded(checkSubblock(src.getThreadsPerWarp())) || + !succeeded(checkSubblock(src.getWarpsPerCTA()))) { + return failure(); + } + } + + // Given e.g. src.getSizePerThread(), computeSubblockSize computes e.g. + // dst.getSizePerThread(). This should be called for each of sizePerThread, + // threadsPerWarp, and warpsPerCTA, in that order. + SmallVector dstShapeRemaining(dstShape); + auto computeSubblockSize = [&](ArrayRef srcSubblock, + SmallVector &dstSubblock, + StringRef fieldName) -> LogicalResult { + // The dst subblock is "filled up" greedily starting with the most minor + // dim. When we're done, we are left with a smaller shape, of size + // dstShape / dstSubblock, which we store in dstShapeRemaining and use for + // the next call to computeSubblockSize. + dstSubblock.resize(dstShape.size()); + for (const auto &[srcDims, dstDims] : decomp) { + int64_t subblockRemaining = product(gather(srcSubblock, srcDims)); + for (int i = dstDims.size() - 1; i >= 0; i--) { + auto &val = dstSubblock[dstDims[i]]; + auto &shapeRemaining = dstShapeRemaining[dstDims[i]]; + val = std::min(subblockRemaining, shapeRemaining); + + assert(shapeRemaining % val == 0); // Checked earlier. + subblockRemaining /= val; + shapeRemaining /= val; + } + + // If there are any elems remaining in the subblock, it must be because + // the block is larger than the shape. This excess goes into the + // most-major dim of the subblock. + dstSubblock[dstDims[0]] *= subblockRemaining; + } + return success(); + }; + + SmallVector dstSizePerThread; + SmallVector dstThreadsPerWarp; + SmallVector dstWarpsPerCTA; + if (!succeeded(computeSubblockSize(src.getSizePerThread(), dstSizePerThread, + "sizePerThread")) || + !succeeded(computeSubblockSize(src.getThreadsPerWarp(), + dstThreadsPerWarp, "threadsPerWarp")) || + !succeeded(computeSubblockSize(src.getWarpsPerCTA(), dstWarpsPerCTA, + "warpsPerCTA"))) { + return failure(); + } + + // Since we know that each set of srcDims is consecutive, we can + // meaningfully sort decomp by the physical order of the src dimensions, + // major-to-minor. This will also be the order of the dst dimensions. + llvm::sort(decomp, [&](const auto &a, const auto &b) { + const auto &[srcDimsA, dstDimsA] = a; + const auto &[srcDimsB, dstDimsB] = b; + return srcInvOrder[srcDimsA.front()] < srcInvOrder[srcDimsB.front()]; + }); + + // Compute the dst order. Make the dimensions appear in the same order as + // their corresponding src dimensions. + SmallVector dstInvOrder(dstShape.size()); + int i = 0; + for (const auto &[srcDims, dstDims] : decomp) { + for (auto dim : reverse(dstDims)) { + dstInvOrder[dim] = i++; + } + } + auto dstOrder = inversePermutation(dstInvOrder); + + // CTALayout can be all 1's because we bailed on multi-CTA layouts above. + auto CTALayout = CTALayoutAttr::get( + src.getContext(), + /*CTAsPerCGA=*/SmallVector(dstShape.size(), 1), + /*CTASplitNum=*/SmallVector(dstShape.size(), 1), + /*CTAOrder=*/llvm::to_vector(llvm::seq(dstShape.size()))); + + dstEnc = BlockedEncodingAttr::get(src.getContext(), dstSizePerThread, + dstThreadsPerWarp, dstWarpsPerCTA, + dstOrder, CTALayout); + + return success(); + } + + LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const override { + auto enc = mlir::dyn_cast(srcEnc); + if (!enc) { + return emitOptionalError(loc, + "JoinOp can only operate on BlockedEncoding"); + } + + // JoinOp takes two tensors of shape AxBxC and generates a tensor of shape + // AxBxCx2. The encoding is the same as the input, but with 2 elems per + // thread in the new dimension. The new dimension is most-minor. + auto append = [](ArrayRef vals, int val) { + SmallVector ret(vals); + ret.push_back(val); + return ret; + }; + auto appendMinorDim = [](ArrayRef order) { + SmallVector ret(order); + ret.insert(ret.begin(), ret.size()); + return ret; + }; + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + append(enc.getSizePerThread(), 2), // + append(enc.getThreadsPerWarp(), 1), // + append(enc.getWarpsPerCTA(), 1), // + appendMinorDim(enc.getOrder()), // + CTALayoutAttr::get(enc.getContext(), // + append(enc.getCTAsPerCGA(), 1), + append(enc.getCTASplitNum(), 1), + appendMinorDim(enc.getCTAOrder()))); + return success(); + } + + LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const override { + auto enc = mlir::dyn_cast(srcEnc); + if (!enc) { + return emitOptionalError(loc, + "SplitOp can only operate on BlockedEncoding"); + } + + // SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of + // shape AxBxC. The input must have 2 elements per thread in the last + // dimension, which must be most-minor. The result encoding is the same as + // the input, but with the last dimension removed. + if (enc.getSizePerThread().back() != 2) { + return emitOptionalError(loc, + "SplitOp requires 2 elements per thread in the " + "last dimension of the input"); + } + if (enc.getThreadsPerWarp().back() != 1 || + enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, "SplitOp requires threadsPerWarp, warpsPerCTA, " + "and CTAsPerCGA = 1 for the last dimension of the input"); + } + if (enc.getOrder().front() != enc.getOrder().size() - 1) { + return emitOptionalError( + loc, "SplitOp requires the last dimension to be most-minor in order"); + } + if (enc.getCTALayout().getCTAsPerCGA().back() != 1) { + return emitOptionalError( + loc, + "SplitOp requires the last dimension to be most-minor in CTAOrder"); + } + + dstEnc = BlockedEncodingAttr::get( + enc.getContext(), // + ArrayRef(enc.getSizePerThread()).drop_back(1), + ArrayRef(enc.getThreadsPerWarp()).drop_back(1), + ArrayRef(enc.getWarpsPerCTA()).drop_back(1), + ArrayRef(enc.getOrder()).drop_front(1), + CTALayoutAttr::get(enc.getContext(), // + ArrayRef(enc.getCTAsPerCGA()).drop_back(1), + ArrayRef(enc.getCTASplitNum()).drop_back(1), + ArrayRef(enc.getCTAOrder()).drop_front(1))); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Canonicalizer +//===----------------------------------------------------------------------===// + +// reshape(cvt) -> reshape +struct CanonicalizeConvertFromReshape + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + if (isExpensiveView(convert.getSrc().getType(), op.getType())) + return failure(); + if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), convert.getSrc(), op.getAllowReorder()); + return mlir::success(); + } +}; + +// histogram(cvt) -> histogram +struct CanonicalizeConvertFromHistogram + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::HistogramOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getSrc()); + return mlir::success(); + } +}; + +// alloc(cvt) -> alloc +struct CanonicalizeConvertFromAlloc + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalAllocOp op, + PatternRewriter &rewriter) const override { + if (!op.getSrc()) + return failure(); + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getSrc()); + return mlir::success(); + } +}; + +// local_store(cvt) -> local_store +struct CanonicalizeConvertFromLocalStore + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, + PatternRewriter &rewriter) const override { + auto convert = op.getSrc().getDefiningOp(); + if (!convert) + return failure(); + rewriter.replaceOpWithNewOp(op, convert.getSrc(), + op.getDst()); + return mlir::success(); + } +}; + +struct CanonicalizeConvertFromConvert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ConvertLayoutOp op, + PatternRewriter &rewriter) const override { + // Convert to the same layout is redundant. + if (op->getResultTypes() == op->getOperandTypes()) { + rewriter.replaceOp(op, op->getOperands()); + return success(); + } + + // We don't handle conversions to DotOperandEncodingAttr. This is a + // heuristic to accommodate fused attention. + auto srcType = op.getSrc().getType(); + auto dstType = op.getType(); + if (mlir::isa(dstType.getEncoding()) && + mlir::isa(srcType.getEncoding())) + return failure(); + + // for hopper MMAv3 + if (mlir::isa(dstType.getEncoding()) && + mlir::isa(srcType.getEncoding()) && + llvm::any_of(op.getResult().getUsers(), + [](Operation *dot) { return isa(dot); })) { + return failure(); + } + + Operation *arg = op.getSrc().getDefiningOp(); + if (!arg) + return failure(); + + // cvt(reshape) -> reshape + if (auto reshape = dyn_cast(arg)) { + if (!reshape.getAllowReorder() || + reshape.getEfficientLayout().has_value() || + isExpensiveView(reshape.getSrc().getType(), op.getType())) + return failure(); + + // In TritonGPUToLLVM phase, ViewOp is converted to unpacking and packing + // operations, which requires the element type to match between unpacking + // and packing. However, part of values with dot operand encoding will be + // packed/unpacked as i32 elements instead of the underlying element type. + // To avoid errors, skip this folding when either the operand or result + // of view has a dot operand encoding. + if (hasDotOperandEncoding(op->getOperand(0)) || + hasDotOperandEncoding(op->getResult(0))) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + reshape.getResult(), + reshape.getAllowReorder()); + return success(); + } + + // cvt(histogram) -> histogram + if (auto histogram = dyn_cast(arg)) { + // For histogram ops the input and output layouts are independent, so we + // can always fold convert into the histogram op. + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + histogram.getSrc()); + return success(); + } + + // cvt(local_load) -> local_load. + if (auto sharedLoad = dyn_cast(arg)) { + // Shared_load can load to any layout so we can always fold convert into + // it. + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + sharedLoad.getSrc()); + return success(); + } + + // cvt(cat) -> cat + if (auto cat = dyn_cast(arg)) { + if (isExpensiveCat(cat, op.getType().getEncoding())) + return failure(); + + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + cat.getOperands()); + return success(); + } + + // cvt(cvt(x, type1), type2) -> cvt(x, type2) + if (auto cvt = dyn_cast(arg)) { + auto srcType = op.getSrc().getType(); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), cvt.getSrc()); + return success(); + } + + // cvt(type1, splat(type2, x)) -> splat(type1, x) + if (auto splat = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + splat.getSrc()); + return success(); + } + + // cvt(type1, make_range(type2, x)) -> make_range(type1, x) + if (auto range = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), range.getStart(), range.getEnd()); + return success(); + } + + // cvt(type, constant) -> constant + if (auto cst = llvm::dyn_cast(arg)) + if (auto ret = dyn_cast(cst.getValue())) { + auto ty = cast(op->getResultTypes().front()); + auto newRet = + SplatElementsAttr::get(ty, ret.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newRet); + return success(); + } + return failure(); + } +}; + +void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); +} + +// LocalAllocOp +void LocalAllocOp::getEffects( + SmallVectorImpl> + &effects) { + Operation *op = getOperation(); + // If allocation is immutable, mark it as no side effect allow things like + // CSE, DCE to work in early compiler passes. + // After the memory offset is computed, we attach the true side effect to the + // op. + if (!getType().getMutableMemory() && !op->hasAttr("allocation.offset")) + return; + effects.emplace_back(MemoryEffects::Allocate::get(), + mlir::triton::gpu::SharedMemory::get()); + if (getSrc()) + effects.emplace_back(MemoryEffects::Write::get(), getResult(), + mlir::triton::gpu::SharedMemory::get()); +} + +// LocalLoadOp +void LocalLoadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + mlir::triton::gpu::SharedMemory::get()); +} + +// LocalStoreOp +void LocalStoreOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), getDst(), + mlir::triton::gpu::SharedMemory::get()); +} + +// AsyncCopyGlobalToLocalOp +void AsyncCopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), getResult(), + mlir::triton::gpu::SharedMemory::get()); +} + +LogicalResult MemDescSubviewOp::verify() { + auto srcTy = getSrc().getType(); + auto dstTy = getType(); + + if (srcTy.getElementType() != dstTy.getElementType()) { + return emitError("result element type must match desc element type"); + } + if (getOffsets().size() != srcTy.getRank()) { + return emitError("offsets must have the same rank as input"); + } + if (srcTy.getRank() < dstTy.getRank()) { + return emitError("result rank must be less than or equal to input rank"); + } + auto rankDiff = srcTy.getRank() - dstTy.getRank(); + for (int i = 0; i < dstTy.getRank(); i++) { + if (dstTy.getDimSize(i) > srcTy.getDimSize(i + rankDiff)) { + return emitError( + "result shape cannot be larger than input shape at dimension ") + << i; + } + } + + auto srcEnc = srcTy.getEncoding(); + auto dstEnc = dstTy.getEncoding(); + if (!!srcEnc != !!dstEnc) { + return emitError("src and result must both have or not have an encoding"); + } + + if (!isa(srcEnc)) { + return emitError("src encoding must be SharedEncodingAttr"); + } + if (!isa(dstEnc)) { + return emitError("result encoding must be SharedEncodingAttr"); + } + + // TODO(jlebar): Currently we generate illegal encodings, so we can't add a + // verifier for them. In particular, we use the same encoding for the src and + // dst of a subview op, when the subview removes a dimension. That generates + // an illegal shared encoding (because the size of `order` doesn't match the + // rank of the tensor), but it's not checked anywhere, and we believe the + // resulting code ultimately works. + + return success(); +} + +void TritonGPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc" + >(); + addInterfaces(); + addInterfaces(); +} + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" + +// verify TritonGPU ops +LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/third_party/xpu/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/third_party/xpu/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp new file mode 100644 index 000000000..ae34598ae --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -0,0 +1,489 @@ +#include + +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton::gpu { +namespace { + +// We use the following nomenclature in this file. +// +// - ctaLayout: A layout for one block, i.e. input dims (register, lane, warp). +// - cgaLayout: Arrangement of multiple blocks, i.e. input dims (block). +// +// Note that this is inconsistent with the type name CTALayoutAttr. That type +// is equivalent to our cgaLayout. +// +// IMO the type name is wrong. If we tried to be consistent anyway, then we'd +// have to rename ctaLayout to "warpLayout". I think that's more confusing than +// being inconsistent about "cgaLayout", especially when we have to consider the +// size of the warpLayout (surely that's not the "warpSize"). + +#define S(v) StringAttr::get(ctx, (v)) + +// Returns ["out0", "out1", ..., "out"]. +SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { + SmallVector ret; + for (int i = 0; i < rank; i++) { + ret.push_back(S("dim" + llvm::Twine(i))); + } + return ret; +} + +// Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of +// size product(shape) and then reshaping to permute(shape, order). +LinearLayout identityND(StringAttr inDimName, ArrayRef shape, + ArrayRef order, + ArrayRef outDimNames) { + assert(shape.size() == order.size()); + + MLIRContext *ctx = inDimName.getContext(); + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < shape.size(); i++) { + // Start with the most-minor dimension, which is order[0]. + int dim = order[i]; + ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]); + } + return ret; +} + +// Make a LinearLayout that maps a block-id to an N-dimensional index. +// +// The tensor is split up into CTAsPerCGA pieces, which are distributed among +// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups). +// +// See the nomenclature note at the top of the file for an explanation of why +// this is called makeCgaLayout when it accepts a CTALayoutAttr. +LinearLayout makeCgaLayout(CTALayoutAttr layout) { + MLIRContext *ctx = layout.getContext(); + StringAttr kBlock = S("block"); + + int rank = layout.getCTAOrder().size(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + LinearLayout ret = LinearLayout::empty(); + for (int i = 0; i < rank; i++) { + // Start with the most minor dimension, which is order[0]. + int dim = layout.getCTAOrder()[i]; + int split = layout.getCTASplitNum()[dim]; + int ctas = layout.getCTAsPerCGA()[dim]; + assert(ctas % split == 0); + ret *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * + LinearLayout::zeros1D(ctas / split, kBlock, outDimNames[dim]); + } + + // Transpose to standard order (dim0, dim1, ...). + return ret.transposeOuts(outDimNames); +} + +// Shrinks the output set of a layout function while leaving the input set +// unchanged, by making high-order inputs in inDimName map to the same output. +// Attempts to shrink down to desiredSize, but this is not always possible just +// by modifying one the specified input dimension. +// +// We do this by making the most-major inputs to the layout map to 0. This +// effectively duplicates data along that input dimension. For example, this +// layout has out-dim size 32: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 16. +// +// If we shrink it to size 16 along the `lane` dimension, we set L(lane=2) to 0: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0. +// +// This means that lane=2 has the same data as lane=0. +// +// If we shrink to size 8 along the lane dimension, we set L(lane=1) = 0 as +// well. But when we do this, we have to remove bit 1 (the value of L(lane=1)) +// from all other bases: +// +// L(register=1) = 4 +// L(register=2) = 2 +// L(register=1) = 1 +// L(lane=1) = 0 +// L(lane=2) = 0. +// +// Note this only works because the bases are powers of two. I don't quite know +// what to do when they're not. +LinearLayout shrinkCodomain(const LinearLayout &layout, StringAttr inDimName, + StringAttr outDimName, int desiredSize) { + assert(llvm::isPowerOf2_32(desiredSize)); + int outDimIdx = layout.getOutDimIndex(outDimName); + int desiredZeros = + llvm::Log2_32(layout.getOutDimSize(outDimName) / desiredSize); + if (desiredZeros == 0) { + return layout; + } + + // Find the desiredZeros most-major basis vectors that are not already zero. + // These are the ones we will set to zero. + SmallVector basesToZero; + for (int i = layout.getInDimSizeLog2(inDimName) - 1; + i >= 0 && basesToZero.size() < desiredZeros; i--) { + int basis = layout.getBasis(inDimName, i, outDimName); + if (basis != 0) { + basesToZero.push_back(basis); + } + } + + // Bail if all the bases are already zero; nothing more we can do. + if (basesToZero.empty()) { + return layout; + } + + // The algorithm below only works because the bases are powers of two. I'm + // not sure what to do otherwise. + assert(llvm::all_of(basesToZero, + [&](int basis) { return llvm::isPowerOf2_32(basis); })); + + // We want to zero out the bases in `basesToZero`, and also "shift out" the + // corresponding bits from all other bases. For example if we remove the + // basis with value 8 = 0b100, then if another basis has value 26 = 0b11010, + // the 1 in its 3rd position gets removed and it becomes 10 = 0b1010. + // + // We could manually alter the bases in `layout` to achieve this, but it's + // perhaps simpler to use the linearity of LLs to our advantage. + // + // Consider the function O which is the identity map from out-dims to + // out-dims. We can easily calculate what happens when we remove the relevant + // bases from O. Call this new function O'. + // + // Because of linearity, removing the bases from L is equivalent to composing + // L with O'. So that's what we do below. + + // Construct the out-dims -> out-dims identity layout O. + LinearLayout outputIdentity = LinearLayout::empty(); + for (StringAttr dim : layout.getOutDimNames()) { + outputIdentity *= + LinearLayout::identity1D(layout.getOutDimSize(dim), dim, dim); + } + + // Modify O to remove the relevant bases. + // + // TODO(jlebar): I don't like manually modifying bases here. Perhaps this + // should be a function on LinearLayout. + LinearLayout::BasesT newBases = outputIdentity.getBases(); + llvm::sort(basesToZero); + for (int basis : basesToZero) { + int idx = llvm::Log2_32(basis); + for (int i = newBases[outDimName].size() - 1; i > idx; i--) { + newBases[outDimName][i][outDimIdx] = + newBases[outDimName][i - 1][outDimIdx]; + } + newBases[outDimName][idx][outDimIdx] = 0; + } + + // Construct O'. + LinearLayout transform(std::move(newBases), layout.getOutDimNames()); + + // Compose O' with L. + return layout.compose(transform); +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// larger than shape[d]. Do this without changing the size of the layout's +// inputs (i.e. leave its domain unchanged). +// +// This function is invariant to the order of the layout's input and output +// dimensions. +LinearLayout ensureLayoutNotLargerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + MLIRContext *ctx = shape.begin()->first.getContext(); + + // For the purposes of this function, "block" is the "most-minor" dimension. + // This is just a consequence of how legacy layouts work: We only put the same + // tensor element into two different blocks as a last resort, only after all + // the registers in all the lanes in all the warps in a block already have the + // same tensor element. + SmallVector inDimNames = { + S("block"), + S("register"), + S("lane"), + S("warp"), + }; + + LinearLayout ret = layout; + for (auto outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + if (actualSize <= desiredSize) { + continue; + } + assert(actualSize % desiredSize == 0); + // TODO: We claim this is invariant to the order of dims, so can we get rid + // of llvm::reverse? + for (StringAttr inDimName : llvm::reverse(inDimNames)) { + if (ret.hasInDim(inDimName)) { + ret = shrinkCodomain(ret, inDimName, outDimName, desiredSize); + } + } + assert(ret.getOutDimSize(outDimName) == desiredSize); + } + return ret; +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along the "register" dimension. +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + + MLIRContext *ctx = shape.begin()->first.getContext(); + StringAttr kRegister = S("register"); + + LinearLayout ret = layout; + for (StringAttr outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + assert(actualSize > desiredSize || desiredSize % actualSize == 0); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kRegister, + outDimName); + assert(ret.getOutDimSize(outDimName) >= desiredSize); + } + return ret; +} + +// Combines the layout of a CTA (input dims [register, lane, warp]) with the +// layout of a CGA (i.e. a block), and ensures that the resulting layout has the +// given shape. +// +// See the nomenclature note at the top of the file for why the variable with +// type CTALayoutAttr is called cgaLayoutAttr. +LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, + CTALayoutAttr cgaLayoutAttr, + ArrayRef shape) { + int rank = shape.size(); + assert(ctaLayout.getNumOutDims() == rank); + assert(cgaLayoutAttr.getCTAOrder().size() == rank); + MLIRContext *ctx = cgaLayoutAttr.getContext(); + + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + llvm::SmallDenseMap labeledShape; + for (auto [dim, size] : llvm::zip(outDimNames, shape)) { + labeledShape[dim] = size; + } + + LinearLayout cgaLayout = + ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape) + .transposeOuts(ctaLayout.getOutDimNames()); + + // Calculate the shape of the ctaLayout, which is `shape` divided by the + // cgaLayout's size. + llvm::SmallDenseMap ctaShape; + assert(ctaLayout.getOutDimNames() == cgaLayout.getOutDimNames()); + for (auto dim : ctaLayout.getOutDimNames()) { + ctaShape[dim] = + std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim)); + } + + ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape); + ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape); + + LinearLayout ret = (ctaLayout * cgaLayout).transposeOuts(outDimNames); + for (auto dim : ret.getOutDimNames()) { + assert(ret.getOutDimSize(dim) == labeledShape[dim]); + } + return ret; +} + +LinearLayout blockedToLinearLayout(ArrayRef shape, + BlockedEncodingAttr blocked) { + assert(shape.size() == blocked.getOrder().size()); + + int rank = shape.size(); + MLIRContext *ctx = blocked.getContext(); + SmallVector outDimNames = standardOutDimNames(ctx, rank); + + const auto &order = blocked.getOrder(); + LinearLayout ctaLayout = + identityND(S("register"), blocked.getSizePerThread(), order, + outDimNames) * + identityND(S("lane"), blocked.getThreadsPerWarp(), order, outDimNames) * + identityND(S("warp"), blocked.getWarpsPerCTA(), order, outDimNames); + + return combineCtaCgaWithShape(ctaLayout, blocked.getCTALayout(), shape); +} + +LinearLayout ampereMmaToLinearLayout(ArrayRef shape, + NvidiaMmaEncodingAttr mma) { + int rank = shape.size(); + + assert(mma.isAmpere()); + assert(rank == 2 || rank == 3); + assert(mma.getInstrShape().size() == rank); + assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || + (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); + + MLIRContext *ctx = mma.getContext(); + SmallVector dimNames = standardOutDimNames(ctx, rank); + + LinearLayout ctaLayout( + {{S("register"), {{1, 0}, {0, 8}}}, + {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, + llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); + + ctaLayout *= identityND( + S("warp"), mma.getWarpsPerCTA(), + llvm::to_vector(llvm::reverse(llvm::seq(rank))), dimNames); + + return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); +} + +LinearLayout hopperMmaToLinearLayout(ArrayRef shape, + NvidiaMmaEncodingAttr mma) { + int rank = shape.size(); + assert(mma.isHopper()); + assert(rank == 2); + + // wgmma operates on groups of 4 warps. + assert(product(mma.getWarpsPerCTA()) % 4 == 0); + + // Check that it's a known MMA layout. + assert(mma.getInstrShape().size() == 3); + int m = mma.getInstrShape()[0]; + int n = mma.getInstrShape()[1]; + int k = mma.getInstrShape()[2]; + assert(m == 16); + assert(n == 16 || n == 32 || n == 64 || n == 128 || n == 256); + assert(k == 8 || k == 16 || k == 32); + + MLIRContext *ctx = mma.getContext(); + LinearLayout ctaLayout( + {{S("register"), {{1, 0}, {0, 8}}}, + {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, + {S("dim1"), S("dim0")}); + + // Expand the `register` dimension so the size of dim1 matches `n`. + ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")), + S("register"), S("dim1")); + + // Expand the `warp` dimension according to warpsPerCTA. + // + // It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but + // this really does seem to be correct. + ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}, + {S("dim0"), S("dim1")}) + .transposeOuts(ctaLayout.getOutDimNames()); + + return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); +} + +std::optional toLinearLayout(ArrayRef shape, + SliceEncodingAttr slice) { + MLIRContext *ctx = slice.getContext(); + + // First compute the linear layout for this layout's parent. + SmallVector parentShape(shape); + parentShape.insert(parentShape.begin() + slice.getDim(), 1); + std::optional parentLL = + triton::gpu::toLinearLayout(parentShape, slice.getParent()); + if (!parentLL) { + return std::nullopt; + } + + // Remove dimension slice.getDim() from the parent layout. + // + // 1. Construct a layout `transform` from parent-out-dims to slice-out-dims + // that removes the relevant out-dim. + // 2. Compute linearSlice = parent.compose(transform). Now linearSlice maps + // from parent in-dims to slice out-dims. + // 3. Fix up duplicate registers introduced by slicing. + auto outDimNames = standardOutDimNames(ctx, shape.size() + 1); + LinearLayout transform = LinearLayout::empty(); + for (auto [idx, outDim] : llvm::enumerate(parentLL->getOutDimNames())) { + if (idx == slice.getDim()) { + // Because we're multiplying by all zeros, we could replace outDimNames[0] + // with any other valid out-dim; the layout will be the same. + transform *= LinearLayout::zeros1D(parentLL->getOutDimSize(outDim), + outDim, outDimNames[0]); + } else { + transform *= LinearLayout::identity1D( + parentLL->getOutDimSize(outDim), outDim, + outDimNames[idx - (idx < slice.getDim() ? 0 : 1)]); + } + } + LinearLayout sliceLL = parentLL->compose(transform); + + // Step 3: Along the "register" dim, remove any all-zero bases. + auto bases = sliceLL.getBases(); + std::vector> newRegBases; + for (const auto &basis : bases[S("register")]) { + if (llvm::any_of(basis, [](int b) { return b != 0; })) { + newRegBases.push_back(basis); + } + } + bases[S("register")] = newRegBases; + + LinearLayout ret = LinearLayout(std::move(bases), sliceLL.getOutDimNames()); + + // Match a hack in the legacy code that ensures that the number of registers + // matches getTotalElemsPerThread. Yup: We just removed all the zeros, now + // we're (maybe) adding some back. :) + // + // TODO(jlebar): Once getTotalElemsPerThread uses LLs instead of the existing + // legacy code, I think we can remove this. + int expectedNumRegisters = getTotalElemsPerThread(RankedTensorType::get( + shape, IntegerType::get(ctx, 32) /*dummy type*/, slice)); + if (ret.getInDimSize(S("register")) != expectedNumRegisters) { + int extraZeros = expectedNumRegisters / ret.getInDimSize(S("register")); + // Our use of "dim0" here is arbitrary; because we're adding zeros, any + // output dimension would work. + ret *= LinearLayout::zeros1D(extraZeros, S("register"), S("dim0")); + } + return ret; +} + +} // anonymous namespace + +std::optional toLinearLayout(ArrayRef shape, + Attribute layout) { + if (auto blocked = dyn_cast(layout)) { + return blockedToLinearLayout(shape, blocked); + } + if (auto mma = dyn_cast(layout)) { + if (mma.isAmpere()) { + return ampereMmaToLinearLayout(shape, mma); + } + if (mma.isHopper()) { + return hopperMmaToLinearLayout(shape, mma); + } + } + if (auto slice = dyn_cast(layout)) { + return toLinearLayout(shape, slice); + } + + // TODO(jlebar): Other layouts + return std::nullopt; +} + +} // namespace mlir::triton::gpu diff --git a/third_party/xpu/lib/Dialect/TritonGPU/IR/Types.cpp b/third_party/xpu/lib/Dialect/TritonGPU/IR/Types.cpp new file mode 100644 index 000000000..77f673cc2 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/IR/Types.cpp @@ -0,0 +1,38 @@ +#include "triton/Dialect/TritonGPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::gpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + +Type TokenType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + int type = 1; + if (parser.parseInteger(type)) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return TokenType::get(parser.getContext(), type); +} + +void TokenType::print(AsmPrinter &printer) const { + printer << "<" << getType() << ">"; +} + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::gpu::TritonGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonGPU/IR/Types.cpp.inc" + >(); +} diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp new file mode 100644 index 000000000..df84c4e62 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -0,0 +1,405 @@ +#include + +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +// Get the highest version supported for the hardware and the dot. +static int getMMAVersionSafe(int computeCapability, DotOp op) { + int baseVersion = 0; + if (computeCapability < 75) { + baseVersion = 1; + } else if (computeCapability < 90) { + baseVersion = 2; + } else if (computeCapability < 100) { + baseVersion = 3; + } else { + assert(false && "computeCapability not supported"); + } + + for (; baseVersion >= 1; baseVersion--) { + if (supportMMA(op, baseVersion)) { + return baseVersion; + } + } + + return 0; +} + +SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, + int numWarps) { + auto rank = shape.size(); + // Early exit for batched matmul + if (rank == 3) + return {(unsigned)numWarps, 1, 1}; + + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion() && + !isa(op); + }; + auto slices = multiRootGetSlice(dotOp, {filter}, {filter}); + bool hasChainedDot = false; + for (Operation *op : slices) { + if (isa(op) && (op != dotOp)) { + auto chainedDot = cast(op); + auto resTy = chainedDot.getResult().getType(); + if (resTy.getRank() != rank) { + continue; + } + if (auto mmaEncoding = + dyn_cast(resTy.getEncoding())) { + return getWarpsPerCTA(mmaEncoding); + } + hasChainedDot = true; + } + } + if (hasChainedDot) { + if (shape[0] >= shape[1]) { + return {(unsigned)numWarps, 1}; + } else { + return {1, (unsigned)numWarps}; + } + } + + SmallVector ret(rank, 1); + SmallVector shapePerWarp(rank, 1); + shapePerWarp[rank - 1] = 8; + shapePerWarp[rank - 2] = 16; + // TODO (@daadaada): double-check. + // original logic in + // https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252 + // seems buggy for shape = [32, 16] ? + do { + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] / shapePerWarp[0] / ret[0] >= + shape[1] / (shapePerWarp[1] * 2) / ret[1]) { + if (ret[0] < shape[0] / shapePerWarp[0]) { + ret[0] *= 2; + } else + ret[1] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; +} + +SmallVector +warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, + const SmallVector &instrShape) { + SetVector slices; + mlir::getForwardSlice(dotOp.getResult(), &slices); + if (llvm::find_if(slices, [](Operation *op) { return isa(op); }) != + slices.end()) + return {(unsigned)numWarps, 1}; + + // For MMAv3, the smallest indivisible unit of warp shape is (4, 1). + SmallVector ret = {4, 1}; + SmallVector shapePerWarp = {16, instrShape[1]}; + do { + if (ret[0] * ret[1] >= numWarps) + break; + if (shape[0] > shapePerWarp[0] * ret[0]) { + ret[0] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + return ret; +} + +class BlockedToMMA : public mlir::RewritePattern { + int computeCapability; + mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding + mutable llvm::DenseMap dotOpInstNs; + + static bool bwdFilter(Operation *op) { + return op->getNumOperands() == 1 && + (isa(op) || + isPureUnaryInlineAsm(op) || + op->getDialect()->getTypeID() == + mlir::TypeID::get()); + } + + // Finds the first different bitwidth in the chain of shape-preserving + // unary ops that x depends on. + // There are two primary scenarios: + // (1) Upcasting: A sequence such as loading an fp16, followed by arithmetic + // operations, then bitcasting to fp32, and finally computing in fp32. + // (2) Downcasting: This might involve loading an fp32, performing arithmetic + // operations, bitcasting to fp16, and finally computing in fp16. + // In the upcasting scenario, element reordering converts the original + // elements distribution to the order of higher precision primitives. As a + // result, kwidth can be the bitwidth of the lower precision primitive. + // Conversely, in the downcasting scenario, no reordering is performed, + // making it directory use the lower precision primitive. + static int computeOrigBitWidth(Value x) { + int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); + int origBitWidth = finalBitWidth; + SetVector slice; + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = bwdFilter; + getBackwardSlice(x, &slice, opt); + for (auto op : slice) { + if (Value arg = op->getOperand(0)) + if (auto argTy = dyn_cast(arg.getType())) { + auto argBitWidth = argTy.getElementType().getIntOrFloatBitWidth(); + if (argBitWidth != origBitWidth) { + origBitWidth = std::min(origBitWidth, argBitWidth); + break; + } + } + } + return origBitWidth; + } + +public: + BlockedToMMA(mlir::MLIRContext *context, int computeCapability) + : mlir::RewritePattern(DotOp::getOperationName(), 2, context), + computeCapability(computeCapability) {} + + static SmallVector + getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, + int numWarps, const SmallVector &instrShape) { + switch (version) { + case 2: + return warpsPerTileV2(dotOp, shape, numWarps); + case 3: + return warpsPerTileV3(dotOp, shape, numWarps, instrShape); + default: + assert(false && "not supported version"); + return {0, 0}; + } + } + + static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter, + int opIdx) { + OpBuilder::InsertionGuard g(rewriter); + Value arg = v; + if (auto cvtOp = v.getDefiningOp()) + arg = cvtOp.getSrc(); + auto argType = cast(arg.getType()); + auto eltType = argType.getElementType(); + assert(argType.getEncoding() && "unexpected tensor type"); + auto newOrder = getOrder(argType.getEncoding()); + + // MMAv3 with transpose only supports f16 and bf16 data type + // fallback to MMAv3 without transpose for other data types + if (!eltType.isF16() && !eltType.isBF16()) { + if (opIdx == 1) { + newOrder = {0, 1}; + } else { + newOrder = {1, 0}; + } + } + + auto CTALayout = getCTALayout(argType.getEncoding()); + auto newLayout = + SharedEncodingAttr::get(argType.getContext(), argType.getShape(), + newOrder, CTALayout, argType.getElementType()); + auto newType = MemDescType::get(argType.getShape(), + argType.getElementType(), newLayout); + rewriter.setInsertionPointAfterValue(arg); + return rewriter.create(arg.getLoc(), newType, arg); + } + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + if (computeCapability < 70) + return failure(); + auto dotOp = cast(op); + auto ctx = op->getContext(); + // TODO: Check data-types and SM compatibility + RankedTensorType oldRetType = dotOp.getType(); + if (!oldRetType.getEncoding() || + mlir::isa(oldRetType.getEncoding())) + return failure(); + + // get MMA encoding for the given number of warps + auto retShapePerCTA = getShapePerCTA(oldRetType); + auto mod = op->getParentOfType(); + int numWarps = TritonGPUDialect::getNumWarps(mod); + auto CTALayout = getCTALayout(oldRetType.getEncoding()); + + int versionMajor = getMMAVersionSafe(computeCapability, dotOp); + if (!versionMajor) + return failure(); + + auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, + dotOp.getA().getType(), numWarps); + // operands + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = dotOp.getA().getType(); + auto oldBType = dotOp.getB().getType(); + + NvidiaMmaEncodingAttr mmaEnc; + if (versionMajor == 1) { + SetVector aBwdSlices, bBwdSlices; + auto isCvt = [](Operation *op) { return isa(op); }; + mlir::BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = isCvt; + getBackwardSlice(a, &aBwdSlices, opt); + getBackwardSlice(b, &bBwdSlices, opt); + // get the source of the first conversion found in slices + auto getCvtArgOrder = [](Operation *op) { + return mlir::cast( + cast(op).getSrc().getType().getEncoding()) + .getOrder(); + }; + bool isARow = true; + bool isBRow = true; + Operation *aOp = a.getDefiningOp(); + Operation *bOp = b.getDefiningOp(); + if (!aBwdSlices.empty()) + aOp = aBwdSlices[0]; + if (!bBwdSlices.empty()) + bOp = bBwdSlices[0]; + if (aOp) + isARow = getCvtArgOrder(aOp)[0] == 1; + if (bOp) + isBRow = getCvtArgOrder(bOp)[0] == 1; + + mmaEnc = NvidiaMmaEncodingAttr::get( + oldRetType.getContext(), versionMajor, numWarps, CTALayout, + instrShape, oldAType.getShape(), oldBType.getShape(), retShapePerCTA, + isARow, isBRow, mmaV1Counter++); + } else if (versionMajor == 2 || versionMajor == 3) { + int versionMinor = computeCapability == 75 ? 1 : 0; + auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); + mmaEnc = NvidiaMmaEncodingAttr::get(oldRetType.getContext(), versionMajor, + versionMinor, warpsPerTile, CTALayout, + instrShape); + } + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); + // convert accumulator + auto oldAcc = dotOp.getOperand(2); + auto newAcc = + rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); + + if (versionMajor == 3) { + a = getMMAv3Operand(a, rewriter, 0); + b = getMMAv3Operand(b, rewriter, 1); + } else { + + // convert operands + int minBitwidth = + std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + Type minType = IntegerType::get(ctx, minBitwidth); + // convert A operand + auto newAEncoding = DotOperandEncodingAttr::get( + oldAType.getContext(), 0, newRetType.getEncoding(), + minBitwidth > 0 ? minType : oldAType.getElementType()); + auto newAType = RankedTensorType::get( + oldAType.getShape(), oldAType.getElementType(), newAEncoding); + a = rewriter.create(a.getLoc(), newAType, a); + // convert B operand + auto newBEncoding = DotOperandEncodingAttr::get( + oldBType.getContext(), 1, newRetType.getEncoding(), + minBitwidth > 0 ? minType : oldBType.getElementType()); + auto newBType = RankedTensorType::get( + oldBType.getShape(), oldBType.getElementType(), newBEncoding); + b = rewriter.create(b.getLoc(), newBType, b); + } + // convert dot instruction + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); + + rewriter.replaceOpWithNewOp(op, oldRetType, + newDot.getResult()); + return success(); + } +}; +} // namespace + +static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, + Type promotedType) { + Type tensorPromotedType = cast(operand.getType()) + .cloneWith(std::nullopt, promotedType); + return builder.create(loc, tensorPromotedType, operand); +} + +// promote operands of dot op if the existing combination is not natively +// supported. +static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) { + mod.walk([=](DotOp dotOp) -> void { + auto D = dotOp.getD(); + OpBuilder builder(dotOp); + Type AElType = dotOp.getA().getType().getElementType(); + Type promoteType; + NvidiaMmaEncodingAttr mmaLayout = + dyn_cast(D.getType().getEncoding()); + if (mmaLayout) { + bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); + // promote operands for sm < 89 since fp8 mma is not natively supported + // promote operands for sm >= 90 when mma is not v3 + if (!isNativeFP8 || + (isNativeFP8 && (computeCapability == 89 || mmaLayout.isHopper()))) + return; + promoteType = builder.getF16Type(); + } else { + // FMA case. + Type AElType = dotOp.getA().getType().getElementType(); + Type DElType = D.getType().getElementType(); + if (AElType == DElType) + return; + promoteType = DElType; + } + Location loc = dotOp.getLoc(); + Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType); + Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType); + dotOp.setOperand(0, promotedA); + dotOp.setOperand(1, promotedB); + }); +} + +#define GEN_PASS_DEF_TRITONGPUACCELERATEMATMUL +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUAccelerateMatmulPass + : public impl::TritonGPUAccelerateMatmulBase< + TritonGPUAccelerateMatmulPass> { +public: + using impl::TritonGPUAccelerateMatmulBase< + TritonGPUAccelerateMatmulPass>::TritonGPUAccelerateMatmulBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + auto computeCapability = getNVIDIAComputeCapability(m); + + mlir::RewritePatternSet patterns(context); + patterns.add(context, computeCapability); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + // Now that we have picked the mma type, decompose dot that are not natively + // supported. + decomposeMixedModeDotOp(m, computeCapability); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..7b2ab63e8 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -0,0 +1,31 @@ +add_triton_library(TritonGPUTransforms + AccelerateMatmul.cpp + Coalesce.cpp + F32DotTC.cpp + CombineTensorSelectAndIf.cpp + ReduceDataDuplication.cpp + OptimizeDotOperands.cpp + OptimizeThreadLocality.cpp + Pipeliner/MatmulLoopPipeline.cpp + Pipeliner/OuterLoopPipeline.cpp + Pipeliner/PipelineExpander.cpp + Pipeliner/SoftwarePipeliner.cpp + Pipeliner/TMAStoresPipeline.cpp + Pipeliner/PipeliningUtility.cpp + Prefetch.cpp + RemoveLayoutConversions.cpp + ReorderInstructions.cpp + Utility.cpp + + DEPENDS + TritonGPUTransformsIncGen + + LINK_LIBS PUBLIC + MLIRTransforms + MLIRTransformUtils + TritonAnalysis + TritonIR + TritonGPUIR + TritonNvidiaGPUIR + MLIRTransformUtils +) diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp new file mode 100644 index 000000000..06a7d963d --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -0,0 +1,198 @@ +#include +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritongpu-coalesce" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOALESCE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +struct CoalescePass : public impl::TritonGPUCoalesceBase { + void + setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, + int numWarps, int threadsPerWarp, + llvm::MapVector &layoutMap) { + Value ptr = getMemAccessPtr(op); + auto refTensorType = cast(ptr.getType()); + + LDBG("Considering op: " << *op); + LLVM_DEBUG({ + DBGS() << "axis info of pointer: "; + axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity(); + SmallVector order = argSort(contiguity); + LDBG("order=[" << triton::join(order, ", ") << "]"); + + auto matchesShape = [&refTensorType](const Value &val) { + auto rttType = dyn_cast(val.getType()); + return rttType && rttType.getShape() == refTensorType.getShape(); + }; + + // The desired divisibility is the maximum divisibility among all dependent + // pointers which have the same shape and order as `ptr`. + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); + if (ptr.getDefiningOp()) { + for (Operation *use : mlir::multiRootGetSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use)) + continue; + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use); + memAccessesSameOrder.insert(use); + } + } + } + + auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType); + LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]"); + + int numElems = product(shapePerCTA); + int numThreads = numWarps * threadsPerWarp; + + unsigned perThread = getNumElementsPerThread(op, order, axisInfoAnalysis); + LDBG("perThread for op: " << perThread); + + for (Operation *opSameOrder : memAccessesSameOrder) { + if (opSameOrder == op) + continue; + unsigned currPerThread = + getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis); + LDBG("perThread for opSameOrder: " << currPerThread); + perThread = std::max(perThread, currPerThread); + } + + perThread = std::min(perThread, std::max(numElems / numThreads, 1)); + LDBG("perThread: " << perThread); + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + unsigned elemNumBits = getElementBitWidth(refTensorType); + perThread = std::min( + perThread, getNumElementsPerThread(op, order, axisInfoAnalysis)); + } + SmallVector sizePerThread(refTensorType.getRank(), 1); + sizePerThread[order[0]] = perThread; + + auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); + layoutMap[op] = triton::gpu::BlockedEncodingAttr::get( + &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, + threadsPerWarp, CTALayout); + } + + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + } + + void coalesceOp(Attribute encoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + } + + void runOnOperation() override { + // Run axis info analysis + ModuleOp moduleOp = getOperation(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // For each i/o operation, we determine what layout + // the pointers should have for best memory coalescing + llvm::MapVector layoutMap; + moduleOp.walk([&](Operation *curr) { + Value ptr = getMemAccessPtr(curr); + if (!ptr) + return; + // We only convert `tensor>` load/store + bool isPtrTensor = false; + if (auto tensorType = dyn_cast(ptr.getType())) + isPtrTensor = isa(tensorType.getElementType()); + if (!isPtrTensor) + return; + auto mod = curr->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, + layoutMap); + }); + + // For each memory op that has a layout L1: + // 1. Create a coalesced memory layout L2 of the pointer operands + // 2. Convert all operands from layout L1 to layout L2 + // 3. Create a new memory op that consumes these operands and + // produces a tensor with layout L2 + // 4. Convert the output of this new memory op back to L1 + // 5. Replace all the uses of the original memory op by the new one + for (auto &kv : layoutMap) { + coalesceOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp new file mode 100644 index 000000000..16183b1af --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp @@ -0,0 +1,124 @@ +#include "mlir/IR/Dominance.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +#include + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUCOMBINETENSORSELECTANDIF +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Return true if the select could be merged into the If without breaking SSA +// rules. +static bool canMergeIntoIf(arith::SelectOp selectOp, scf::IfOp ifOp, + DominanceInfo &dom) { + // If needs to be dominated by the select. + if (!dom.dominates(selectOp.getOperation(), ifOp.getOperation())) { + return false; + } + // If needs to dominate all the select's users. + for (auto user : selectOp.getResult().getUsers()) { + if (!dom.dominates(ifOp, user)) { + return false; + } + } + return true; +} + +class CombineTensorSelectAndIfPass + : public impl::TritonGPUCombineTensorSelectAndIfBase< + CombineTensorSelectAndIfPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + DominanceInfo dom(m); + + // Go over the arith.select ops, look if there is an if + // with the same condition. + llvm::MapVector> selectToIf; + m.walk([&](arith::SelectOp selectOp) { + // Look if there is an if in the same block, with the same condition. + auto *parentBlock = selectOp->getBlock(); + Value condition = selectOp.getOperand(0); + SetVector conditionUsers(condition.getUsers().begin(), + condition.getUsers().end()); + // sort the users in topological order. + conditionUsers = multiRootTopologicalSort(conditionUsers); + // Get condition's users + for (Operation *user : conditionUsers) { + auto ifOp = dyn_cast(user); + if (!ifOp || ifOp->getBlock() != parentBlock) + continue; + if (canMergeIntoIf(selectOp, ifOp, dom)) { + selectToIf[ifOp].push_back(selectOp); + break; + } + } + }); + + for (auto [ifOp, selectOps] : selectToIf) { + // Add new return value to the if (and create else block if necessary), + // then yield the select value in the then block and the else block. + OpBuilder builder(ifOp); + auto loc = ifOp.getLoc(); + // Create an scf::IfOp with extra return value. + SmallVector newResultTypes = {ifOp.getResultTypes().begin(), + ifOp.getResultTypes().end()}; + for (arith::SelectOp selectOp : selectOps) { + newResultTypes.push_back(selectOp.getResult().getType()); + } + auto newIfOp = builder.create( + loc, newResultTypes, ifOp.getCondition(), /*hasElse*/ true); + // Move the existing blocks to the new if. + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + + if (ifOp.elseBlock()) { + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + } else { + // Create an empty yield + auto yieldOp = newIfOp.getElseBodyBuilder().create(loc); + } + + SmallVector ifYieldOperands = newIfOp.thenYield().getOperands(); + SmallVector elseYieldOperands = newIfOp.elseYield().getOperands(); + for (arith::SelectOp selectOp : selectOps) { + Value thenValue = selectOp.getTrueValue(); + Value elseValue = selectOp.getFalseValue(); + ifYieldOperands.push_back(thenValue); + elseYieldOperands.push_back(elseValue); + } + // Update yields + auto updateYield = [&](scf::YieldOp yield, SmallVector &operands) { + builder.setInsertionPoint(yield); + builder.create(loc, operands); + yield.erase(); + }; + updateYield(newIfOp.thenYield(), ifYieldOperands); + updateYield(newIfOp.elseYield(), elseYieldOperands); + + int resultIdx = 0; + // Replace old if with the new one. + for (auto result : ifOp.getResults()) { + result.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + } + // Replace the select with the new return value. + for (arith::SelectOp selectOp : selectOps) { + selectOp.replaceAllUsesWith(newIfOp->getResult(resultIdx++)); + selectOp.erase(); + } + + ifOp.erase(); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp new file mode 100644 index 000000000..f701634d4 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp @@ -0,0 +1,90 @@ +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUF32DOTTC +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers +// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385 +// For a, b f32 +// dot(a, b, inputPrecision="tf32x3") -> +// let aBig = f32ToTF32(a), aSmall = a - aBig; +// let bBig = f32ToTF32(b), bSmall = b - bBig; +// dot(aSmall, bBig, inputPrecision="tf32") + +// dot(aBig, bSmall, inputPrecision="tf32") + +// dot(aBig, bBig, inputPrecision="tf32") +class TF32x3 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + + auto isF32 = [](Value operand) { + return cast(operand.getType()).getElementType().isF32(); + }; + + if (!(dotOp.getInputPrecision() == InputPrecision::TF32x3 && + isF32(dotOp.getA()) && isF32(dotOp.getB()))) { + return failure(); + } + + // Aux functions + auto f32ToTF32 = [&](Value value) -> Value { + return rewriter + .create(dotOp.getLoc(), value.getType(), + "cvt.rna.tf32.f32 $0, $1;", "=r,r", + /*isPure=*/true, /*pack=*/1, + ArrayRef{value}) + .getResult()[0]; + }; + auto sub = [&](Value a, Value b) -> Value { + return rewriter.create(dotOp.getLoc(), a, b); + }; + auto dot = [&](Value a, Value b, Value c) -> Value { + return rewriter.create(dotOp->getLoc(), c.getType(), a, b, c, + InputPrecision::TF32, + dotOp.getMaxNumImpreciseAcc()); + }; + + auto aBig = f32ToTF32(dotOp.getA()); + auto aSmall = sub(dotOp.getA(), aBig); + + auto bBig = f32ToTF32(dotOp.getB()); + auto bSmall = sub(dotOp.getB(), bBig); + + auto dot1 = dot(aSmall, bBig, dotOp.getC()); + auto dot2 = dot(aBig, bSmall, dot1); + auto dot3 = dot(aBig, bBig, dot2); + + rewriter.replaceOp(dotOp, dot3); + return success(); + } +}; + +} // anonymous namespace + +struct F32DotTCPass : public impl::TritonGPUF32DotTCBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); + if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) + .failed()) { + signalPassFailure(); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp new file mode 100644 index 000000000..4a30bf9f3 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -0,0 +1,340 @@ +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +namespace mlir { +namespace triton { +namespace gpu { + +namespace { + +// Given +// convert(trans(src)) #dot_operand -> +// convert(local_load(trans(alloc(src)))) +// change the encoding of the inner convert to a special, swizzled shared +// encoding. +class SwizzleShmemConvert : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp, + PatternRewriter &rewriter) const override { + // Match outerCvt(trans(innerCvt(x))). + auto trans = cvtOp.getSrc().getDefiningOp(); + if (!trans || trans.getOrder() != ArrayRef{1, 0}) + return failure(); + + auto srcTy = dyn_cast(trans.getSrc().getType()); + + if (auto srcCvt = trans.getSrc().getDefiningOp()) { + srcTy = srcCvt.getSrc().getType(); + } + auto sharedLoadTy = cast(cvtOp.getType()); + auto cvtEncoding = + dyn_cast(sharedLoadTy.getEncoding()); + if (!cvtEncoding) + return failure(); + + // TODO(Qingyi): need to check whether the CTALayout of innerCvtEnc should + // be used here. For tests where numCTAs = 1, this is not a problem since + // all CTALayouts are the same. + // + // Set needTrans to true here. newInnerCvtEnc is computed based on + // argEncoding which is before the transpose. Without needTrans we will + // compute vec and maxPhase based on incorrect m, n and k size of mma. The + // type inference of TransOp simply swap the order but doesn't fix the vec + // and maxPhase for the YType, hence it would causing incorrect swizzling + // code. + auto newInnerCvtEnc = + SharedEncodingAttr::get(getContext(), cvtEncoding, srcTy.getShape(), + /*order=*/getOrder(srcTy.getEncoding()), + triton::gpu::getCTALayout(srcTy.getEncoding()), + srcTy.getElementType(), /*needTrans=*/true); + if (newInnerCvtEnc == cvtEncoding) + return failure(); + + rewriter.setInsertionPoint(trans); + auto alloc = rewriter.create( + trans.getLoc(), + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), + newInnerCvtEnc), + trans.getSrc()); + auto newTrans = rewriter.create(trans.getLoc(), alloc, + ArrayRef({1, 0})); + rewriter.replaceOpWithNewOp(trans, sharedLoadTy, newTrans); + return success(); + } +}; + +// Move convert-to-dot-operand "up" past elementwise ops: +// +// convert(elementwise(x)) #dot_operand -> +// elementwise(convert(x, #dot_operand)). +// +// The goal is to put the convert right next to the originating load. If we can +// accomplish this, then we can save a shmem round-trip: +// +// Before: +// +// - Load from global into shmem using an async copy. +// - Load from shmem into a #blocked layout. +// - Do elementwise ops over #blocked layout. +// - Convert to #dot_operand (round-trip through shmem). +// - Do dot. +// +// After: +// +// - Load from global into shmem using an async copy (same as before). +// - Load from shmem into a #dot_operand layout. +// - Do elementwise ops over #dot_operand layout. +// - Do dot. +// +// Eliminating the shmem round-trip is such a big win, we're willing to do it +// even if this duplicates work because some of the elementwise ops have uses +// that don't flow into the dot. On the other hand, we only want to do this if +// we can in fact reduce shmem round-trips: For example, simply moving a convert +// up above e.g. an `add` now means we have *two* converts. That's worse, +// unless we can continue moving the converts upwards and eventually merge them. +// So we try to check that this will be beneficial before making any changes. +class HoistLayoutConversion : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConvertLayoutOp cvt, + PatternRewriter &rewriter) const override { + // Only consider conversions to dot operand. + auto cvtTy = cast(cvt.getType()); + if (!isa(cvtTy.getEncoding())) + return failure(); + + auto src = cvt.getSrc().getDefiningOp(); + if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1) + return failure(); + + auto srcTy = dyn_cast(src->getResult(0).getType()); + if (!srcTy) + return failure(); + + if (!all_of(src->getOperandTypes(), + [](Type ty) { return isa(ty); })) + return failure(); + + // Only consider custom conversions or arith ops. + // TODO(jlebar): Is this too restrictive? + if (!isa(src) && !isPureUnaryInlineAsm(src) && + src->getDialect()->getTypeID() != TypeID::get()) + return failure(); + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(src)) + return failure(); + + // Check that the conversion is transitively dependent on a load, and all + // operations between the load and the conversion are layout preserving. + // + // TODO(jlebar): This is accidentally quadratic; we iterate over the whole + // slice but then at the end we only modify one op! + SetVector slice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + // TODO(jlebar): Is this filter redundant with omitBlockArguments == true? + // That is, is it possible to get into a different region without going + // through a block argument? + opt.filter = [&](Operation *op) { + return op->getParentRegion() == cvt->getParentRegion(); + }; + getBackwardSlice(cvt.getOperation(), &slice, opt); + + // TODO(jlebar): This is too conservative when there are multiple loads in + // the chain (e.g. cvt(load(x) + load(y))). The intent is to check that all + // of the ops between the loads and the convert are elementwise. But + // actually we set foundLoad = true once we see the first load, and so we + // will reject the chain if the *second* load we encounter uses a + // non-elementwise op to calculate its pointers. + bool foundLoad = false; + for (Operation *currOp : slice) { + if (isa(currOp)) { + foundLoad = true; + } else if (foundLoad) { + // Bail out if there exists an op after Load that is not FpToFp, + // Bitcast, or Arith. + if (!isa(currOp) && + !isPureUnaryInlineAsm(currOp) && + currOp->getDialect()->getTypeID() != + TypeID::get()) + return failure(); + } + } + if (!foundLoad) + return failure(); + + SmallVector newOperands; + for (auto operand : src->getOperands()) { + // We checked earlier that all operands are ranked tensors. + auto operandTy = cast(operand.getType()); + Type newCvtTy = RankedTensorType::get( + srcTy.getShape(), operandTy.getElementType(), cvtTy.getEncoding()); + newOperands.push_back( + rewriter.create(cvt.getLoc(), newCvtTy, operand)); + } + auto newRet = rewriter.clone(*src); + for (int i = 0; i < newOperands.size(); i++) + newRet->setOperand(i, newOperands[i]); + newRet->getResult(0).setType(RankedTensorType::get( + srcTy.getShape(), srcTy.getElementType(), cvtTy.getEncoding())); + + rewriter.replaceOp(cvt, newRet->getResults()); + return success(); + } +}; + +// Rewrite +// +// dot(alloc(trans() #shared1) -> +// dot(trans(alloc() #shared2)) +// +// if dot is an MMAv3 (because MMAv3 allows us to fold transposes). +class FuseTransHopper : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LocalAllocOp allocOp, + PatternRewriter &rewriter) const override { + if (!allocOp->hasOneUse() || + !isa(*allocOp->getUsers().begin())) + return failure(); + + auto dot = *allocOp->getUsers().begin(); + auto dotEnc = dyn_cast( + cast(dot->getResult(0).getType()).getEncoding()); + if (!dotEnc || dotEnc.getVersionMajor() != 3) + return failure(); + + if (!allocOp.getSrc()) + return failure(); + + // Match outerCvt(trans(innerCvt(x))). + auto trans = allocOp.getSrc().getDefiningOp(); + if (!trans || trans.getOrder() != ArrayRef({1, 0})) + return failure(); + + MemDescType allocType = allocOp.getType(); + auto allocEncoding = cast(allocType.getEncoding()); + TensorOrMemDesc srcTy = trans.getSrc().getType(); + + // MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3 + // without transpose for other data types.) + auto newInnerCvtOrder = getOrder(srcTy.getEncoding()); + if (auto cvt = trans.getSrc().getDefiningOp()) { + newInnerCvtOrder = getOrder(cvt.getSrc().getType().getEncoding()); + } + auto srcElemTy = allocType.getElementType(); + if (!srcElemTy.isF16() && !srcElemTy.isBF16()) { + if (allocOp.getResult() == dot->getOperand(0)) { + newInnerCvtOrder = {0, 1}; + } else if (allocOp.getResult() == dot->getOperand(1)) { + newInnerCvtOrder = {1, 0}; + } + } + + // TODO(Qingyi): need to check whether the CTALayout of innerCvtEnc should + // be used here. For tests where numCTAs = 1, this is not a problem since + // all CTALayouts are the same. + auto newInnerEnc = SharedEncodingAttr::get( + getContext(), srcTy.getShape(), newInnerCvtOrder, + allocEncoding.getCTALayout(), srcTy.getElementType()); + + MemDescType innerTy = + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc); + auto newAlloc = rewriter.create(allocOp.getLoc(), innerTy, + trans.getSrc()); + rewriter.replaceOpWithNewOp(allocOp, newAlloc, + ArrayRef({1, 0})); + return success(); + } +}; + +// Rewrite +// dot(convert(lhs #mma) #shared, rhs) #mma -> +// dot(convert(lhs #mma) #dot_operand, rhs) #mma, +// for fp16 or bf16 MMAv3 dots. +struct MMAV3UseRegOperand : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + auto alloc = dotOp.getOperand(0).getDefiningOp(); + if (!alloc || !alloc.getSrc()) + return failure(); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + if (!isa(getEncoding(dotOp.getOperand(0)))) + return failure(); + auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); + auto dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!srcEnc || srcEnc.getVersionMajor() != 3 || !dstEnc || + dstEnc.getVersionMajor() != 3) + return failure(); + auto srcTy = cast(alloc.getSrc().getType()); + auto dotOperandEnc = DotOperandEncodingAttr::get( + dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); + auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), + dotOperandEnc); + if (!isMmaToDotShortcut(srcTy, newTy)) + return failure(); + + Value newOperand = + rewriter.create(dotOp.getLoc(), newTy, alloc.getSrc()); + rewriter.modifyOpInPlace(dotOp, [&]() { dotOp.setOperand(0, newOperand); }); + return success(); + } +}; + +} // namespace + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUOptimizeDotOperandsPass + : public impl::TritonGPUOptimizeDotOperandsBase< + TritonGPUOptimizeDotOperandsPass> { +public: + using impl::TritonGPUOptimizeDotOperandsBase< + TritonGPUOptimizeDotOperandsPass>::TritonGPUOptimizeDotOperandsBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::PassManager pm(m.getContext()); + pm.addPass(mlir::createCanonicalizerPass()); + auto ret = pm.run(m); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + if (this->hoistLayoutConversion.getValue()) + patterns.add(context); + patterns.add(context); + patterns.add(context); + ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp new file mode 100644 index 000000000..30211da08 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -0,0 +1,436 @@ +#include +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZETHREADLOCALITY +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +// Change the destination layout of reshape ops allowing reorder when used by a +// reduction in order to minimize the amount of cross thread communication for +// the reduction. +struct OptimizeReshapeLayoutPattern + : public mlir::OpRewritePattern { + OptimizeReshapeLayoutPattern(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::ReshapeOp viewOp, + mlir::PatternRewriter &rewriter) const override { + if (!viewOp.getAllowReorder()) + return failure(); + std::optional reductionAxis; + for (Operation *user : viewOp.getResult().getUsers()) { + if (auto reduceOp = dyn_cast(user)) { + if (reductionAxis) { + if (reductionAxis != reduceOp.getAxis()) + return failure(); + } else { + reductionAxis = reduceOp.getAxis(); + } + } + } + if (!reductionAxis) + return failure(); + RankedTensorType tensorType = viewOp.getType(); + if (auto blocked = mlir::dyn_cast( + tensorType.getEncoding())) { + // If the layout already has all the elements along the reduction + // dimension in the same thread we can skip. + if (blocked.getThreadsPerWarp()[*reductionAxis] == 1 && + blocked.getWarpsPerCTA()[*reductionAxis] == 1 && + blocked.getCTAsPerCGA()[*reductionAxis] == 1) + return failure(); + } + ArrayRef shape = tensorType.getShape(); + llvm::SmallVector order; + for (int i : triton::gpu::getOrder(tensorType.getEncoding())) { + if (i != *reductionAxis) + order.push_back(i); + } + // Make the reduction axis last so that elements won't be distributed + // amongst threads along this dimension. + order.push_back(*reductionAxis); + llvm::SmallVector sizePerThread(shape.size(), 1); + auto mod = viewOp->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); + triton::gpu::BlockedEncodingAttr encoding = + triton::gpu::BlockedEncodingAttr::get(viewOp.getContext(), shape, + sizePerThread, order, numWarps, + threadsPerWarp, numCTAs); + if (encoding == tensorType.getEncoding()) + return failure(); + RankedTensorType newType = + RankedTensorType::get(shape, tensorType.getElementType(), encoding); + if (triton::gpu::isExpensiveView(viewOp.getSrc().getType(), newType)) + return failure(); + rewriter.setInsertionPointAfter(viewOp); + rewriter.modifyOpInPlace(viewOp, [&]() { + viewOp.getResult().setType(newType); + viewOp.setEfficientLayout(true); + }); + auto cvt = rewriter.create( + viewOp.getLoc(), tensorType, viewOp.getResult()); + rewriter.replaceAllUsesExcept(viewOp.getResult(), cvt.getResult(), cvt); + return mlir::success(); + } +}; + +} // namespace + +class TritonGPUOptimizeThreadLocalityPass + : public impl::TritonGPUOptimizeThreadLocalityBase< + TritonGPUOptimizeThreadLocalityPass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // First try to optimize the layout of existing views. + mlir::RewritePatternSet viewLayoutPatterns(&getContext()); + viewLayoutPatterns.add(&getContext()); + if (mlir::applyPatternsAndFoldGreedily(mod, std::move(viewLayoutPatterns)) + .failed()) { + signalPassFailure(); + } + + DenseSet reduceOps; + mod.walk([&](triton::ReduceOp reduce) -> void { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto reductionOp = getReductionOp(reduce); + if (!reductionOp || + !isa( + reductionOp.value())) + return; + // TODO: relax this restriction + if (!(isa(srcEncoding) && rank > 1)) + return; + for (auto operand : reduce->getOperands()) { + auto def = operand.getDefiningOp(); + if (!isa(def)) + return; + } + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + // Not worth applying this optimization if there is only one element per + // thread on the reduction axis + if (elemsPerThread == 1) + return; + if (!reduce->hasOneUse()) + return; + Operation *user = *(reduce->getUsers().begin()); + if (!user->hasOneUse()) + return; + OpOperand &yieldOpOperand = *(user->getUses().begin()); + auto yieldOp = dyn_cast(yieldOpOperand.getOwner()); + if (!yieldOp) + return; + auto operandNumber = yieldOpOperand.getOperandNumber(); + Block *block = reduce->getBlock(); + Operation *parentOp = block->getParentOp(); + auto forOp = dyn_cast(parentOp); + if (!forOp) + return; + auto argNum = yieldOpOperand.getOperandNumber(); + auto oldAccum = forOp.getInitArgs()[argNum]; + auto cstOp = dyn_cast(oldAccum.getDefiningOp()); + if (!cstOp) + return; + reduceOps.insert(reduce); + }); + + IRRewriter builder(&getContext()); + for (auto reduce : reduceOps) { + builder.setInsertionPoint(reduce); + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + assert(isa(srcEncoding) && + "Thread locality optimization only supports blocked encoding"); + auto blocked = dyn_cast(srcEncoding); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto rank = srcShape.size(); + // create new layouts + auto blocked3d = getThreadLocalityOptimizedEncoding(reduce); + auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce); + auto viewOpTensorType = RankedTensorType::get( + viewOpTensorShape, srcType.getElementType(), blocked3d); + auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank, + blocked3d); + // Get forOp + assert(reduce->hasOneUse()); + OpOperand &use = *(reduce->getUses().begin()); + auto operandNumber = use.getOperandNumber(); + auto oldUpdate = use.getOwner(); + assert(oldUpdate->getNumOperands() == 2); + auto accumOperandNumber = (operandNumber == 0) ? 1 : 0; + auto accumOperand = oldUpdate->getOperand(accumOperandNumber); + assert(isa(accumOperand)); + auto blockArg = dyn_cast(accumOperand); + auto blockArgNum = blockArg.getArgNumber(); + auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); + // get oldAccum + auto oldAccum = + forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()]; + // get old loop user + Value loopResult = + forOp.getResult(blockArgNum - forOp.getNumInductionVars()); + assert(loopResult.hasOneUse()); + OpOperand &loopUse = *(loopResult.getUses().begin()); + Operation *loopUser = loopUse.getOwner(); + // get old loop yield + auto oldYield = cast(forOp.getBody()->getTerminator()); + // create newAccum initialization + auto newAccum = + createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d); + // create new loop by copying the old for op signature and appending + // newAccum to the block arguments + auto newLoop = replaceForOpWithNewSignature( + builder, forOp, ValueRange{newAccum->getResult(0)}); + // create thread local reduction (also adds viewOps) + auto newReduce = createReduce(builder, reduce, viewOpTensorType); + + // create new accum update + auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate); + // create new yield + auto newYield = createYield(builder, newLoop, oldYield, + newUpdate->getResult(0), blockArgNum); + // create post loop reduction on the original reduce axis + auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce); + // add convert_layout to get back to original layout, the result layout + // should now match the layout of the old accumulator (%cst) + Type destType = loopResult.getType(); + auto cvtLayout = createConvertLayout(builder, destType, newReduce2); + // incorporate the original accumulator value into the final result + auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate, + cvtLayout, oldAccum); + // Replace the old loop user with the final result + loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0)); + + // cleanup + oldYield.erase(); + forOp.erase(); + } + }; + +private: + std::optional getReductionOp(triton::ReduceOp reduce) const { + auto numRegions = reduce->getNumRegions(); + if (numRegions != 1) + return std::nullopt; + Region ®ion = reduce->getRegion(0); + auto numBlocks = region.getBlocks().size(); + if (numBlocks != 1) + return std::nullopt; + Block &block = region.front(); + auto blockWithoutTerminator = block.without_terminator(); + auto blockSizeWithoutTerminator = std::distance( + blockWithoutTerminator.begin(), blockWithoutTerminator.end()); + if (blockSizeWithoutTerminator != 1) + return std::nullopt; + Operation *op = &block.front(); + return std::optional(op); + } + Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder, + Operation *oldUpdate, + Operation *cvtLayout, + Value oldAccum) const { + builder.setInsertionPointAfter(cvtLayout); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), oldAccum); + mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0)); + auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping); + return finalOp; + } + Operation *createConvertLayout(OpBuilder &builder, Type destType, + Operation *newReduce) const { + builder.setInsertionPointAfter(newReduce); + auto newCvt = builder.create( + newReduce->getLoc(), destType, newReduce->getResult(0)); + return newCvt; + } + + Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop, + triton::ReduceOp &reduce) const { + auto resultIndex = + loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars(); + auto newLoopResult = loop.getResult(resultIndex); + builder.setInsertionPointAfter(loop); + IRMapping mapping; + mapping.map(*(reduce.getOperands().begin()), newLoopResult); + auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping); + return newReduce2; + } + + Operation *createYield(OpBuilder &builder, scf::ForOp &loop, + scf::YieldOp &oldYield, Value newUpdate, + int oldAccumBlockArgNum) const { + builder.setInsertionPoint(oldYield); + SmallVector yieldValues = llvm::to_vector(oldYield.getOperands()); + yieldValues[oldAccumBlockArgNum - 1] = + loop.getBody()->getArgument(oldAccumBlockArgNum); + yieldValues.push_back(newUpdate); + auto newYield = + builder.create(oldYield.getLoc(), yieldValues); + return newYield; + } + + Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop, + Operation *newReduce, Operation *oldUpdate) const { + auto blockArgNum = loop.getBody()->getNumArguments() - 1; + auto newArg = loop.getBody()->getArgument(blockArgNum); + builder.setInsertionPointAfter(newReduce); + IRMapping mapping; + mapping.map(oldUpdate->getOperand(0), newArg); + mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0)); + auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping); + return newUpdate; + } + + Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce, + Type viewOpTensorType) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + builder.setInsertionPointAfter(reduce); + IRMapping mapping; + for (auto operand : reduce.getOperands()) { + auto viewOp = builder.create( + reduce.getLoc(), viewOpTensorType, operand, /*allowReorder=*/true); + viewOp.setEfficientLayout(true); + mapping.map(operand, viewOp); + } + + auto newReduce = cloneWithInferType(builder, &(*reduce), mapping); + newReduce->setAttr("axis", builder.getI32IntegerAttr(rank)); + auto typeInfer = dyn_cast(newReduce); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newReduce->getContext(), newReduce->getLoc(), + newReduce->getOperands(), newReduce->getAttrDictionary(), + newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newReduce->getResult(i).setType(newTypes[i]); + } + } + return newReduce; + } + + // Work around the lack of support for MaxNumFOp and MinNumFOp in + // arith::getNeutralElement. + std::optional getNeutralElement(Operation *op) const { + if (isa(op)) { + OpBuilder builder(op->getContext()); + + Type resultType = op->getResult(0).getType(); + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/true)); + } + if (isa(op)) { + return builder.getFloatAttr( + resultType, APFloat::getInf(semantic, /*Negative=*/false)); + } + } else { + return mlir::arith::getNeutralElement(op); + } + llvm_unreachable("Unhandled reduction op"); + return std::nullopt; + } + + Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce, + Value &oldAccum, SmallVector &shape, + Attribute &slice2d) const { + // Drop the last dimension (thread locality dimension) + SmallVector accumShape(shape.begin(), shape.end() - 1); + auto elemType = cast(oldAccum.getType()).getElementType(); + // Create tensor type for the new accumulator + auto accumType = RankedTensorType::get(accumShape, elemType, slice2d); + // Create new accumulator + builder.setInsertionPointAfter(oldAccum.getDefiningOp()); + auto reductionOp = getReductionOp(reduce); + assert(reductionOp && "Processing a reduce that is not supported!"); + auto neutralVal = getNeutralElement(reductionOp.value()); + assert(neutralVal && "Could not find neutral value for reduction op!"); + auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value()); + auto newAccum = builder.create(oldAccum.getLoc(), + accumType, denseAttr); + return newAccum; + } + + SmallVector + getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto srcShape = srcType.getShape(); + auto rank = srcShape.size(); + auto elemsPerThread = + triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; + auto viewOpTensorShape = insertValue(srcShape, rank, 1); + viewOpTensorShape[reduce.getAxis()] /= elemsPerThread; + viewOpTensorShape[rank] = elemsPerThread; + return viewOpTensorShape; + } + + Attribute getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const { + auto srcType = cast(reduce.getOperands()[0].getType()); + auto rank = srcType.getShape().size(); + auto srcEncoding = srcType.getEncoding(); + auto blocked = dyn_cast(srcEncoding); + auto sizePerThread3d = + insertValue(blocked.getSizePerThread(), rank, + blocked.getSizePerThread()[reduce.getAxis()]); + sizePerThread3d[reduce.getAxis()] = 1; + auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1); + auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1); + auto order3d = insertValue(blocked.getOrder(), 0, rank); + auto ctasPerCGA3d = + insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1); + auto ctasSplitNum3d = + insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1); + auto ctaOrder3d = + insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank); + auto ctaLayout3d = triton::gpu::CTALayoutAttr::get( + reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d); + auto blocked3d = triton::gpu::BlockedEncodingAttr::get( + reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d, + order3d, ctaLayout3d); + return blocked3d; + } + + template + SmallVector insertValue(ArrayRef vec, unsigned index, int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } + template + SmallVector insertValue(const SmallVector &vec, unsigned index, + int value) const { + SmallVector res(vec.begin(), vec.end()); + res.insert(res.begin() + index, static_cast(value)); + return res; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp new file mode 100644 index 000000000..457d42f4e --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -0,0 +1,1762 @@ +#include "PipelineExpander.h" +#include "PipeliningUtility.h" +#include "Schedule.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#include + +#define DEBUG_TYPE "triton-matmul-loop-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define int_attr(num) builder.getI64IntegerAttr(num) + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +// TODO: We can extra some helpers into common utilities once we add more +// schedules. + +namespace { + +struct LoadInfo { + // Layout of the data in the shared memory. + ttg::SharedEncodingAttr sharedEncoding = nullptr; + // Blocked encoding is used for loads not used by the dot. + ttg::BlockedEncodingAttr blockedEncoding = nullptr; + bool loadIsMMAV3 = false; + int distToUse = 0; + bool usedByDot = false; +}; + +} // namespace + +class CoarseSchedule { +public: + class ClusterList { + std::list orderClusters; + + public: + using iterator = decltype(orderClusters)::iterator; + ClusterList() = default; + iterator begin() { return orderClusters.begin(); } + iterator end() { return orderClusters.end(); } + size_t size() { return orderClusters.size(); } + iterator newAtBack() { + orderClusters.push_back(orderClusters.size()); + return std::prev(orderClusters.end()); + } + iterator newAtFront() { + orderClusters.push_front(-1); + for (auto &clusterId : orderClusters) { + clusterId++; + } + return orderClusters.begin(); + } + iterator newBefore(iterator cluster) { + auto ret = orderClusters.insert(cluster, *cluster); + for (auto &clusterId : llvm::make_range(cluster, orderClusters.end())) { + clusterId++; + } + return ret; + } + }; + + CoarseSchedule(int numStages) : numStages(numStages) {} + int numStages; + ClusterList clusters; + using Cluster = decltype(clusters)::iterator; + + DenseMap> opToStageAndCluster; + + void insert(Operation *op, int stage, Cluster cluster) { + opToStageAndCluster[op] = {stage, cluster}; + } + + bool insertIfAbsent(Operation *op, int stage, Cluster cluster) { + if (opToStageAndCluster.count(op)) + return false; + insert(op, stage, cluster); + return true; + } + + void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, + bool includeArg) { + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + if (insertIfAbsent(defOp, stage, cluster)) { + insertDepsOfOp(defOp, stage, cluster, includeArg); + } + } + } + } + + void erase(Operation *op) { opToStageAndCluster.erase(op); } + + int count(Operation *op) { return opToStageAndCluster.count(op); } + + std::pair operator[](Operation *op) { + return opToStageAndCluster[op]; + } + + SmallVector> + getOpsInOrder(scf::ForOp forOp) { + SmallVector>, 8> + orderClusters(clusters.size()); + for (auto &op : forOp.getBody()->without_terminator()) { + if (opToStageAndCluster.count(&op) == 0) { + continue; + } + assert(opToStageAndCluster[&op].first < numStages && + "Op with invalid stage!"); + int clusterId = *opToStageAndCluster[&op].second; + assert(clusterId == std::distance(clusters.begin(), + opToStageAndCluster[&op].second) && + "Cluster ID mismatch!"); + orderClusters[clusterId].push_back( + make_tuple(&op, opToStageAndCluster[&op].first, + opToStageAndCluster[&op].second)); + } + SmallVector> opsInOrder; + for (int i = 0; i < orderClusters.size(); i++) { + for (auto [op, stage, cluster] : orderClusters[i]) { + opsInOrder.push_back({op, stage, cluster}); + } + } + + return opsInOrder; + } + + std::vector> + createFinalSchedule(scf::ForOp forOp) { + SmallVector> opsInOrder = + getOpsInOrder(forOp); + std::vector> schedule; + for (auto [op, stage, cluster] : opsInOrder) { + LDBG("Adding op to schedule at stage " << stage << " cluster " << *cluster + << ":" << *op); + schedule.push_back({op, stage}); + } + return schedule; + } + + void dump() { + for (int i = 0; i < numStages; i++) { + LDBG("- Ops in stage " << i); + for (auto &[op, stageAndCluster] : opToStageAndCluster) { + if (i == stageAndCluster.first) { + llvm::outs() << " cluster: " << *stageAndCluster.second << " "; + op->dump(); + } + } + } + } +}; + +static bool isMMAv3Dot(Operation *op) { + auto dot = dyn_cast(op); + if (!dot) + return false; + auto enc = + mlir::dyn_cast(dot.getType().getEncoding()); + return enc && enc.isHopper(); +} + +// Replace the ForOp's yield with a new one with the given operands appended. +static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { + // Fix up the yield op. + Operation *yieldOp = forOp.getBody()->getTerminator(); + SmallVector operands(yieldOp->getOperands()); + operands.append(newOperands.begin(), newOperands.end()); + + OpBuilder builder(yieldOp); + builder.create(yieldOp->getLoc(), operands); + yieldOp->erase(); +} + +static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, + CoarseSchedule &schedule, + CoarseSchedule::Cluster prefetchCluster, + llvm::MapVector &loadToInfo, + int numStages) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + if (!isExpensiveLoadOrStore(loadOp) && loadToInfo[loadOp].blockedEncoding) { + // For inexpensive loads that do not directly feed into dot ops + // we want to use optimal layout for the data. + ttg::BlockedEncodingAttr encoding = loadToInfo[loadOp].blockedEncoding; + auto convertBlockLayout = [&](Value src) { + auto ty = cast(src.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), encoding); + auto cvt = + builder.create(loadOp->getLoc(), newTy, src); + return cvt.getResult(); + }; + src = convertBlockLayout(src); + if (mask) + mask = convertBlockLayout(mask); + if (other) + other = convertBlockLayout(other); + } + + tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + tt::MemDescType subviewTy = tt::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), /*mutableMemory=*/true); + auto view = + builder.create(loc, subviewTy, alloc, copyOffsets); + Operation *copy = builder.create( + loc, src, view, mask, other, loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile()); + Operation *commmit = + builder.create(loc, copy->getResult(0)); + Operation *wait = + builder.create(loc, commmit->getResult(0), 0); + + bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + schedule.insert(commmit, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + if (isMMV3Load) { + auto alloc = cast((*loadOp->getUsers().begin())); + alloc.replaceAllUsesWith(viewLoad.getResult()); + alloc.erase(); + } else { + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + alloc.replaceAllUsesWith(viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) { + alloc.erase(); + } + + auto sharedLoad = builder.create( + loc, loadOp.getType(), viewLoad, wait->getResult(0)); + auto result = sharedLoad->getResults(); + + // Create a select for non-zero other values as they are not handled by + // AsyncCopyGlobalToLocalOp for now. + Value other = loadOp.getOther(); + if (other && !isZeroConst(other)) { + auto select = builder.create( + loc, loadOp.getType(), mask, sharedLoad.getResult(), other); + result = select->getResults(); + } + + loadOp->replaceAllUsesWith(result); + + // Prefetch load if is not MMAV3 and is used by the dot. + if (loadToInfo[loadOp].usedByDot) { + schedule.insert(wait, numStages - 2, prefetchCluster); + schedule.insert(viewLoad, numStages - 2, prefetchCluster); + } + } + loadOp.erase(); +} + +static void createTMAAsyncCopy( + scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, + Value phase, CoarseSchedule &schedule, + llvm::MapVector &loadToInfo, int numStages) { + assert(phase && "Phase value is required for TMA async copy."); + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + tt::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + copyOffsets[0] = insertIdx; + tt::MemDescType subviewTy = tt::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), /*mutableMemory=*/true); + auto view = + builder.create(loc, subviewTy, alloc, copyOffsets); + + Value pred = builder.create(loc, 1, 1); + Operation *copy = builder.create( + loc, loadOp.getDescPtr(), loadOp.getIndices(), barrier, view, pred); + + bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + builder.setInsertionPointAfter(waitOp); + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + if (isMMV3Load) { + auto alloc = cast((*loadOp->getUsers().begin())); + alloc.replaceAllUsesWith(viewLoad.getResult()); + alloc.erase(); + } else { + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + alloc.replaceAllUsesWith(viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) { + alloc.erase(); + } + + auto sharedLoad = builder.create( + loc, loadOp.getType(), viewLoad /*,wait->getResult(0)*/); + auto result = sharedLoad->getResults(); + loadOp->replaceAllUsesWith(result); + } + loadOp.erase(); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreDotEnc(Value val) { + ttg::SharedEncodingAttr attr; + for (Operation *user : val.getUsers()) { + ttg::SharedEncodingAttr tempAttr; + if (user->getNumResults() != 1) + return std::nullopt; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()).getEncoding()); + if (!dotOpEnc) + return std::nullopt; + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), dotOpEnc, srcTy.getShape(), + ttg::getOrder(srcTy.getEncoding()), + ttg::getCTALayout(srcTy.getEncoding()), + srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); + } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; + } + return attr; +} + +static ttg::BlockedEncodingAttr +getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { + Value src = loadOp.getPtr(); + auto ty = cast(src.getType()); + auto mod = loadOp->getParentOfType(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + tt::AxisInfo::DimVectorT contiguity = + axisInfo.getAxisInfo(src)->getContiguity(); + SmallVector order = argSort(contiguity); + unsigned currPerThread = getNumElementsPerThread(loadOp, order, axisInfo); + SmallVector sizePerThread(order.size(), 1); + sizePerThread[order[0]] = currPerThread; + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(ty.getEncoding()); + return ttg::BlockedEncodingAttr::get(loadOp->getContext(), ty.getShape(), + sizePerThread, order, numWarps, + threadsPerWarp, ctaLayout); +} + +static std::optional +getSharedEncoding(Operation *loadOp, bool isMMAV3) { + auto ty = cast(loadOp->getResultTypes()[0]); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + auto blockedOrder = ttg::getOrder(ty.getEncoding()); + SmallVector order; + if (blockedOrder.size() == 3) { + for (unsigned i = 0; i < blockedOrder.size(); ++i) { + if (blockedOrder[i] == 0) + continue; + order.push_back(blockedOrder[i]); + } + order.push_back(0); + } else { + order = blockedOrder; + } + if (isMMAV3) { + return ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), order, + ctaLayout, ty.getElementType()); + } + + // If the load is used by a LocalAllocOp, use the same encoding as the allocs. + // If the allocs don't all have the same encoding, bail. + if (llvm::any_of(loadOp->getUsers(), [&](Operation *user) { + return isa(user); + })) { + ttg::SharedEncodingAttr localAllocEnc; + for (auto user : loadOp->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + auto enc = mlir::cast( + localAlloc.getType().getEncoding()); + if (!localAllocEnc) { + localAllocEnc = enc; + } + if (enc != localAllocEnc) + return std::nullopt; + } + return localAllocEnc; + } + + // Use non-swizzled layout for loads that do not feed into dot ops. + // TODO: This won't be optimal for 2D tensors. + return ttg::SharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, + ctaLayout); +} + +// Create a map from load ops to their indirection level and the +// final use of the load op (another load op, or a dot op). +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +static llvm::SmallVector> +loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + llvm::SmallVector> + loadOpToIndLevelAndUse; + DenseSet seen; + + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + if (!seen.insert(op).second) + return; + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.push_back(std::make_tuple(op, distance, use)); + use = op; + distance++; + } + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } + + return loadOpToIndLevelAndUse; +} + +static bool loadIsMMAv3(Operation *loadOp) { + if (!loadOp->hasOneUse()) + return false; + auto alloc = dyn_cast(*loadOp->getUsers().begin()); + if (!alloc) + return false; + auto sharedEnc = cast(alloc.getType().getEncoding()); + if (!sharedEnc.getHasLeadingOffset()) + return false; + + // MMA V3 case. + auto newOrder = sharedEnc.getOrder(); + auto ty = cast(loadOp->getResultTypes()[0]); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + + // The operand of MMAv3 is in SharedEncoding and its order should not + // be changed after FuseTranspositions Pass. So we only pipeline the + // load if the order of the loaded BlockedEncoding is the same as the + // order of the SharedEncoding it is converted to. + return oldOrder == newOrder; +} + +static llvm::MapVector +assignMemoryLayouts(llvm::SmallVector> + &loadOpToIndLevelAndUse, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + llvm::MapVector loadToInfo; + + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(op)) + // TODO pawel: err, we'd need to verify that the distance is the same + continue; + LoadInfo loadInfo; + + if (auto loadOp = dyn_cast(op)) { + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) + continue; + auto ty = + cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * ty.getIntOrFloatBitWidth(); + + // We do not pipeline all loads for the following reasons: + // 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16. + // 2. It's likely that pipling small loads won't offer much performance + // improvement and may even hurt performance by increasing register + // pressure. + LDBG("Load " << *loadOp << " has width " << width); + if (width < 32) + continue; + } + + if (auto dot = dyn_cast(use)) { + loadInfo.usedByDot = true; + if (loadIsMMAv3(op)) { + loadInfo.loadIsMMAV3 = true; + loadInfo.sharedEncoding = + getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); + } else if (isa(op)) { + loadInfo.sharedEncoding = + getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); + } else { + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + + // HACK: Triton LLVM codegen has a bug where local_loads from #shared to + // #mma layout can lead to invalid code if the loaded shape is smaller + // than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with + // tile {16,8} is bad because 1 < 8). To work around this, don't + // pipeline such loads. + // + // The codegen bug is caught by an assertion, so if you think you've + // fixed it, feel free to delete this code and see if the assert still + // fails. :) + if (!loadInfo.sharedEncoding) { + if (auto dotEnc = dyn_cast( + dot.getResult().getType().getEncoding())) { + auto loadTy = cast(op->getResultTypes()[0]); + auto mmaInstrShape = dotEnc.getInstrShape(); + if (loadTy.getRank() < mmaInstrShape.size()) + continue; + bool ok = true; + for (int i = 0; i < mmaInstrShape.size(); i++) { + if (loadTy.getShape()[loadTy.getRank() - mmaInstrShape.size() + + i] < mmaInstrShape[i]) { + ok = false; + break; + } + } + // If this load might trigger the bug, don't do the fallback logic + // below, which might allow the load to be pipelined. + if (!ok) + continue; + } + } + } + } else if (auto loadOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadsToPipeline already, it means that the use is not valid for + // pipelining for some reason. We should skip this loadOp, too. Note that + // we have an assumption that distAndUse.second (i.e. the use of this + // loadOp) has already be processed in a previous loop iteration. This + // assumption is held by how loadOpsToIndirectionLevelAndUse recursively + // collects loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(loadOp) == 0) { + continue; + } + } + + // If we still don't have a shared encoding, try a "generic" shared + // encoding. + if (!loadInfo.sharedEncoding && !isMMAv3Dot(use)) { + loadInfo.sharedEncoding = + getSharedEncoding(op, /*isMMAV3=*/loadInfo.loadIsMMAV3) + .value_or(nullptr); + if (auto loadOp = dyn_cast(op)) { + loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis); + } + } + + // If that still didn't work, bail on pipelining this load. + if (!loadInfo.sharedEncoding) { + continue; + } + loadToInfo[op] = loadInfo; + } + + return loadToInfo; +} + +static llvm::MapVector +scheduleLoads(scf::ForOp forOp, CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return {}; + + // Check which loads are good for pipelining, and assign them + // memory layouts. + llvm::MapVector loadToInfo = + assignMemoryLayouts(loadOpToIndLevelAndUse, axisInfoAnalysis); + + if (loadToInfo.empty()) + return {}; + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + } + unsigned stagesBetweenLoads = + ceil(numStages - 2, maxIndirectionLevel + 1); + + CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + // Put the root uses of the loads in the last stage. + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + // Non-LoadOp(s) are the root uses of all LoadOp(s) and should be + // always present in the opInfo + if (!isa(use)) { + schedule.insert(use, numStages - 1, rootUsersCluster); + rootUsers.insert(use); + } + } + + SmallVector loadsClusters; + for (int i = 0; i < maxIndirectionLevel + 1; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); + } + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadsClusters[indLevel]); + } + + // Distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + return loadToInfo; +} + +// Schedule the prologue and epilogue `if` ops in the loop, pushing them as +// close to the loop boundaries as possible. Return the cluster after the +// prologue (or the beginning of the loop if there is no prologue). +static CoarseSchedule::Cluster +schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice((Operation *)op, &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } + } + } + } + CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insert(ifOp, stage, prologueCluster); + } + + // Look for the IfOp that is in the forward slice of the root users and put it + // at the end of the loop. + CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto rootUser : rootUsers) { + SetVector forwardSlice; + getForwardSlice(rootUser, &forwardSlice); + + int stage = schedule[rootUser].first; + for (auto op : forwardSlice) { + scf::IfOp ifOp = dyn_cast(op); + if (ifOp == nullptr) { + // check if the op is in the body of an if op that's part of the loop + auto parentOp = op->getParentOp(); + if (parentOp != nullptr && + parentOp->getParentOp() == forOp.getOperation()) { + ifOp = dyn_cast(parentOp); + } + } + if (ifOp) { + schedule.insertIfAbsent(ifOp, stage, + epilogueCluster); // after prefetch extracts + } + } + } + return afterPrologue; +} + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +static void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule, + int numStages) { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + schedule.insertDepsOfOp(op, stage, cluster, false); + } + } +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +static void scheduleDistanceOneDependencies(scf::ForOp forOp, + CoarseSchedule &schedule, + int numStages) { + auto getNestedOperands = [](Operation *op) -> SmallVector { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + if (auto arg = dyn_cast(operand)) { + if (arg.getArgNumber() > 0 && arg.getOwner() == op.getBlock()) { + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (defOp && schedule.count(defOp) == 0) { + if (isa(defOp)) { + // Exception: Schedule loads with a distance of 1 together + // with the current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], + true); + } + } + } + } + } + } +} + +static void scheduleRemainingToLastStage(scf::ForOp forOp, + CoarseSchedule &schedule, + CoarseSchedule::Cluster afterPrologue, + int numStages) { + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) { + opToCluster[&op] = afterPrologue; + } + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == numStages - 1) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + CoarseSchedule::Cluster userCluster = opToCluster[user]; + CoarseSchedule::Cluster opCluster = schedule[op].second; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, numStages - 1, cluster); + } +} + +// Create an allocation that can hold distance number of loadOp shapes. +static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, + ttg::SharedEncodingAttr sharedEnc, unsigned distance) { + OpBuilder builder(forOp); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), distance); + Type memdescType = mlir::triton::MemDescType::get( + bufferShape, ty.getElementType(), sharedEnc, /*mutableMemory*/ true); + Value alloc = builder.create( + loadOp->getLoc(), memdescType, Value()); + return alloc; +} + +// Create an allocation to hold the mbarriers. +static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { + OpBuilder builder(forOp); + Location loc = forOp.getLoc(); + auto context = forOp.getContext(); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = + tt::MemDescType::get({distance}, builder.getI64Type(), barrierEncoding, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = tt::MemDescType::get( + {1}, builder.getI64Type(), barrierEncoding, /*mutableMemory=*/true); + Value barrierAlloc = builder.create( + loc, barrierMemDescType, Value()); + for (unsigned i = 0; i < distance; i++) { + Value idx = builder.create(loc, i, 32); + Value barrierView = builder.create( + loc, singleBarrierMemDescType, barrierAlloc, idx); + builder.create(forOp->getLoc(), barrierView, 1); + } + return barrierAlloc; +} + +struct AsyncLoad { + AsyncLoad(Operation *loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {} + Operation *loadOp; + Value alloc; + Value barrier; + Operation *waitOp = nullptr; + bool isTMALoad = false; +}; + +// Create barriers and wait ops for the async loads. Barriers may be shared by +// multiple loads is the schedule allows it. +static void createTMABarrierAndWait( + scf::ForOp &forOp, SmallVector &asyncLoads, Value insertIdx, + Value extractIdx, Value phase, int numBuffers, CoarseSchedule &schedule, + SmallVector &barriers, + const llvm::MapVector &loadToInfo) { + llvm::SmallDenseMap loadToAsyncLoad; + for (AsyncLoad &asyncLoad : asyncLoads) { + loadToAsyncLoad[asyncLoad.loadOp] = &asyncLoad; + } + SmallVector> loadGroups; + llvm::SmallDenseSet visited; + // Find groups of loads that can share the same barrier. We look consecutive + // loads and check that there are uses in between. + for (AsyncLoad &asyncLoad : asyncLoads) { + if (!asyncLoad.isTMALoad || visited.count(asyncLoad.loadOp)) + continue; + llvm::SmallDenseSet users; + SmallVector group; + Block *loadBlock = asyncLoad.loadOp->getBlock(); + auto addToGroup = [&](AsyncLoad *loadInfo) { + group.push_back(loadInfo); + visited.insert(loadInfo->loadOp); + for (Operation *user : loadInfo->loadOp->getUsers()) { + auto it = loadToInfo.find(loadInfo->loadOp); + if (it != loadToInfo.end()) { + // Special case for MMAv3 loads, we can ignore the alloc and only + // consider uses of the alloc op since it will be removed. + if (it->second.loadIsMMAV3) { + auto alloc = cast( + (*loadInfo->loadOp->getUsers().begin())); + if (alloc->getBlock() == loadBlock) { + users.insert(alloc->getUsers().begin(), alloc->getUsers().end()); + continue; + } + } + } + Operation *userInBlock = loadBlock->findAncestorOpInBlock(*user); + if (userInBlock) + users.insert(userInBlock); + } + }; + addToGroup(&asyncLoad); + Operation *nextOp = asyncLoad.loadOp->getNextNode(); + while (nextOp) { + if (users.count(nextOp) || visited.count(nextOp)) + break; + if (isa(nextOp)) { + auto it = loadToAsyncLoad.find(nextOp); + if (it != loadToAsyncLoad.end() && it->second->isTMALoad) { + addToGroup(it->second); + } + } + nextOp = nextOp->getNextNode(); + } + loadGroups.push_back(group); + } + + // For each group calculate the size and insert the barrier after the last + // load. + for (SmallVector &group : loadGroups) { + int sizeInBytes = 0; + for (AsyncLoad *asyncLoad : group) { + auto tensorTy = + cast(asyncLoad->loadOp->getResult(0).getType()); + int loadSize = product(tensorTy.getShape()); + sizeInBytes += + loadSize * tensorTy.getElementType().getIntOrFloatBitWidth() / 8; + } + + Value barrierAlloc = createBarrierAlloc(forOp, numBuffers); + barriers.push_back(barrierAlloc); + Location loc = forOp.getLoc(); + OpBuilder builder(forOp); + tt::MemDescType barrierTy = tt::MemDescType::get( + {1}, builder.getI64Type(), + cast(barrierAlloc.getType()).getEncoding(), + /*mutableMemory=*/true); + builder.setInsertionPoint(group[0]->loadOp); + Value barrier = builder.create( + loc, barrierTy, barrierAlloc, ArrayRef({insertIdx})); + Value pred = builder.create(loc, 1, 1); + Operation *expect = builder.create( + forOp.getLoc(), barrier, sizeInBytes, pred); + auto [stage, cluster] = schedule[asyncLoads[0].loadOp]; + schedule.insert(expect, stage, cluster); + + builder.setInsertionPointAfter(group.back()->loadOp); + Value barrierViewWait = builder.create( + loc, barrierTy, barrierAlloc, ArrayRef({extractIdx})); + Operation *wait = + builder.create(loc, barrierViewWait, phase); + // Update the async loads info. + for (AsyncLoad *asyncLoad : group) { + asyncLoad->barrier = barrier; + asyncLoad->waitOp = wait; + } + } +} + +// Convert load ops into their asyn version and apply multi-buffering based on +// the required number of buffers. +static SmallVector +createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, + llvm::MapVector &loadToInfo, + SmallVector &barriers, int numStages) { + // Calculate the number of buffers needed for each load. + // TODO pawel: we could do more fine-grained allocation here and + // allocate only the number of buffers that specific loads need. + // Instead, we allocate the maximum number of buffers needed by any load. + int numBuffers = + llvm::max_element(llvm::make_second_range(loadToInfo), [](auto &lhs, + auto &rhs) { + return lhs.distToUse < rhs.distToUse; + })->distToUse; + bool hasMMAV3 = + llvm::any_of(loadToInfo, [](auto &kv) { return kv.second.loadIsMMAV3; }); + if (hasMMAV3) { + // For MMAv3, we need an extra buffer as this is assumed in the wgmma + // pipelining post-processing. + numBuffers++; + }; + + SmallVector asyncLoads; + SmallVector allocs; + bool hasTMALoad = false; + for (auto &[loadOp, info] : loadToInfo) { + assert(info.sharedEncoding && "LoadOp shared encoding not defined."); + Value alloc = createAlloc(forOp, loadOp, info.sharedEncoding, numBuffers); + assert(alloc && "Failed to create alloc for the async load."); + allocs.push_back(alloc); + asyncLoads.emplace_back(loadOp, alloc); + if (isa(loadOp)) { + hasTMALoad = true; + asyncLoads.back().isTMALoad = true; + } + } + + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); + + Location loc = forOp.getLoc(); + // Create two new counters to index into the allocs. + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value insertIdx = minusOne; + Value extractIdx = minusOne; + Value phase = Value(); + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + SmallVector newOperands; + newOperands.push_back(insertIdx); + newOperands.push_back(extractIdx); + if (hasTMALoad) { + phase = builder.create(loc, 0, 32); + newOperands.push_back(phase); + } + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + scf::ForOp newForOp = + replaceForOpWithNewSignature(builder, forOp, newOperands); + forOp.erase(); + forOp = newForOp; + insertIdx = newForOp.getBody()->getArgument(newOperandIndex); + extractIdx = newForOp.getBody()->getArgument(newOperandIndex + 1); + if (phase) { + phase = newForOp.getBody()->getArgument(newOperandIndex + 2); + } + + // Create two counters for the insert and extract indices to avoid creating + // long liverange. + builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + insertIdx = builder.create(loc, insertIdx, one); + Value cndIns = builder.create(loc, arith::CmpIPredicate::slt, + insertIdx, numBuffersVal); + insertIdx = builder.create(loc, cndIns, insertIdx, zero); + + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + if (phase) { + Value nextPhase = builder.create(loc, phase, one); + phase = builder.create(loc, cndExt, phase, nextPhase); + } + createTMABarrierAndWait(forOp, asyncLoads, insertIdx, extractIdx, phase, + numBuffers, schedule, barriers, loadToInfo); + + // Create a cluster for the prefetches. It may end up being empty, but this + // is OK. + CoarseSchedule::Cluster prefetchCluster = schedule.clusters.newAtBack(); + + for (AsyncLoad &asyncLoad : asyncLoads) { + if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { + createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, + schedule, prefetchCluster, loadToInfo, numStages); + } else { + auto descLoad = cast(asyncLoad.loadOp); + createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx, + extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase, + schedule, loadToInfo, numStages); + } + } + SmallVector newYieldOperands = {insertIdx, extractIdx}; + if (phase) + newYieldOperands.push_back(phase); + // Patch the yield with the updated counters. + appendToYield(forOp, newYieldOperands); + + return allocs; +} + +static void invalidateBarriers(OpBuilder &builder, + SmallVector &barriers) { + for (Value barrier : barriers) { + int numBarriers = cast(barrier.getType()).getShape()[0]; + for (int i = 0; i < numBarriers; i++) { + Value idx = builder.create(barrier.getLoc(), i, 32); + tt::MemDescType barrierTy = tt::MemDescType::get( + {1}, builder.getI64Type(), + cast(barrier.getType()).getEncoding(), + /*mutableMemory=*/true); + Value barrierView = builder.create( + barrier.getLoc(), barrierTy, barrier, idx); + builder.create(barrier.getLoc(), barrierView); + } + } +} + +bool mlir::triton::preProcessLoopAndGetSchedule( + scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + CoarseSchedule coarseSchedule(numStages); + llvm::MapVector loadToInfo = + scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + if (loadToInfo.empty()) + return false; + + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + coarseSchedule.dump(); + }); + + SmallVector barriers; + // Convert the loads into async loads and create the allocs. + SmallVector allocs = + createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages); + + LLVM_DEBUG({ + LDBG("Coarse schedule with async loads:"); + coarseSchedule.dump(); + }); + + CoarseSchedule::Cluster afterPrologue = + schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with prologue and epilogue:"); + coarseSchedule.dump(); + }); + + scheduleDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + coarseSchedule.dump(); + }); + + scheduleDistanceOneDependencies(forOp, coarseSchedule, numStages); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + coarseSchedule.dump(); + }); + + scheduleRemainingToLastStage(forOp, coarseSchedule, afterPrologue, numStages); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + coarseSchedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + coarseSchedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = tt::predicateOp; + options.supportDynamicLoops = true; + options.annotateFn = [](Operation *op, + mlir::triton::PipeliningOption::PipelinerPart part, + unsigned iteration) {}; + // Insert a wait 0 after the loop + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + builder.create(forOp.getLoc(), ValueRange({}), 0); + // Invalidate any mbarrier create + invalidateBarriers(builder, barriers); + // Explicitly deallocate allocated tensors after the wait op + for (auto alloc : allocs) + builder.create(forOp.getLoc(), alloc); + return true; +} + +/// Find the minimum number of async_commit_group ops between the wait +/// and the associated async_commit_group. This can be safely used as the wait +/// number. +static int minNumInterleavedCommitOps(Operation *waitOp) { + auto countCommitsBetween = [](Operation *op1, Operation *op2) { + int count = 0; + for (auto op = op1; op != op2; op = op->getNextNode()) { + if (isa(op)) + count++; + // Intentionally skip block ops' children. This will give us + // convervatively low number of insert ops. + } + return count; + }; + + int minCommitNumber = INT_MAX; + + // DFS the def chain of the extract op to find the insert op. On each path + // we calculate the number of async_commit. Then we select the minimum number + // of async_commit ops among all the paths. + std::function minOverHistories = + [&](Value val, Operation *sinkOp, int thisHistorySum) -> int { + if (Operation *defOp = val.getDefiningOp()) { + thisHistorySum += countCommitsBetween(defOp->getNextNode(), sinkOp); + minCommitNumber = std::min(minCommitNumber, thisHistorySum); + return minCommitNumber; + } + if (auto arg = mlir::dyn_cast(val)) { + Block *block = arg.getOwner(); + auto forOp = dyn_cast(block->getParentOp()); + + // Failed to track, return 0 conservatively. + if (!forOp) + return 0; + + Operation *firstForInst = &*forOp.getBody()->begin(); + int insertsBetween = countCommitsBetween(firstForInst, sinkOp); + thisHistorySum += insertsBetween; + if (thisHistorySum >= minCommitNumber) + return minCommitNumber; + + // get the value value assigned to the argument coming from outside the + // loop + Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1]; + int min1 = minOverHistories(incomingVal, forOp, thisHistorySum); + + // get the value value assigned to the argument coming from the previous + // iteration + Operation *yieldOp = block->getTerminator(); + Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1); + int min2 = minOverHistories(prevVal, yieldOp, thisHistorySum); + return std::min(std::min(min1, min2), minCommitNumber); + } + // Failed to track, return 0 conservatively. + return 0; + }; + + if (waitOp->getNumOperands() != 1) + return 0; + int minCommits = minOverHistories(waitOp->getOperand(0), waitOp, 0); + return minCommits; +} + +// Look for consecutive wait ops and combine them into a single wait op. +static void +combineRedundantWaitOps(llvm::SmallSetVector &waitOps) { + llvm::MapVector toDelete; + for (auto waitOp : waitOps) { + if (toDelete.count(waitOp)) + continue; + SmallVector waitGroup = {waitOp}; + SmallVector depTokens; + unsigned minWaitNumber = waitOp.getNum(); + Operation *next = waitOp->getNextNode(); + while (next && isa(next)) { + if (auto nextWait = dyn_cast(next)) { + waitGroup.push_back(nextWait); + minWaitNumber = std::min(minWaitNumber, nextWait.getNum()); + depTokens.append(nextWait.getOperands().begin(), + nextWait.getOperands().end()); + } + next = next->getNextNode(); + } + if (waitGroup.size() == 1) + continue; + OpBuilder builder(waitGroup.back()); + auto newWaitOp = builder.create(waitOp.getLoc(), + depTokens, minWaitNumber); + for (auto waitOp : waitGroup) { + toDelete[waitOp] = newWaitOp; + } + } + for (auto waitOp : toDelete) { + waitOp.first->replaceAllUsesWith(waitOp.second); + waitOp.first->erase(); + } +} + +/// Update wait op number by analyzing the number of async_commit_group ops +/// along all paths. +void mlir::triton::updateWaits(ModuleOp module) { + llvm::SmallSetVector waitOps; + module.walk([&](ttg::AsyncWaitOp waitOp) { + int minNumCommits = minNumInterleavedCommitOps(waitOp); + waitOp.setNum(minNumCommits); + waitOps.insert(waitOp); + }); + combineRedundantWaitOps(waitOps); +} + +// Add the given values as operands of the given wait, and replace all uses of +// the values with the wait. Also adds related MemDesc's to the wait. +// +// Threading %a through the wait transforms +// +// %a = <...> +// (%x', %y') = ttng.async_wait %x, %y +// %b = fn(%a) +// +// into +// +// %a = <...> +// (%x', %y', %a') = ttng.async_wait %x, %y, %a +// %b = fn(%a') +// +// The wait must dominate all uses of the elements of `values`. +// +// In addition to adding each value from `values` to the wait, this function +// also adds some MemDesc's to the wait. The idea is that if you have +// +// %alloc = ttg.local_alloc ... +// %a = ttng.dot_async %alloc +// %a1 = ttng.dot_wait %a +// +// then we want the wait to depend on %alloc as well as %a. This extends the +// live range of %alloc, so that it won't be destroyed until after the dot is +// waited on. +// +// Specifically, this function finds all dot_async ops that elements of `values` +// depend on. Then it adds the MemDesc operands of those dots to the wait. +static void threadValuesThroughWait(ttng::DotWaitOp wait, + MutableArrayRef values) { + IRRewriter builder(wait.getContext()); + builder.setInsertionPoint(wait); + + // Operands are only added to the wait through this function, so we can have + // the invariant that the wait has no duplicates. This makes things a bit + // easier below. + size_t origNumOperands = wait.getNumOperands(); + SetVector newOperands(wait.getOperands().begin(), + wait.getOperands().end()); + assert(newOperands.size() == origNumOperands && + "Wait op has duplicate operands."); + + newOperands.insert(values.begin(), values.end()); + + // Find memdefs depended on by `values` through async dot ops. + SmallVector asyncDots; + for (Value v : values) { + BackwardSliceOptions options; + options.omitBlockArguments = true; + options.filter = [&](Operation *op) { + if (auto dot = dyn_cast(op)) { + asyncDots.push_back(dot); + return false; + } + return op->getBlock() == wait->getBlock(); + }; + SetVector slice; + getBackwardSlice(v, &slice, options); + } + + for (ttng::DotAsyncOp dot : asyncDots) { + for (Value operand : dot.getOperands()) { + if (isa(operand.getType())) { + newOperands.insert(operand); + } + } + } + + // We can't use replaceWithNewOp because we're changing the number of return + // values in the operation. + auto newWait = builder.create( + wait.getLoc(), llvm::to_vector(newOperands), wait.getPendings()); + + auto dominatedByNewWait = [&](OpOperand &operand) { + auto opInThisBlock = + newWait->getBlock()->findAncestorOpInBlock(*operand.getOwner()); + return opInThisBlock && newWait->isBeforeInBlock(opInThisBlock); + }; + for (int i = 0; i < origNumOperands; i++) { + Value operand = wait.getResult(i); + if (!isa(operand.getType())) + operand.replaceAllUsesWith(newWait.getResult(i)); + } + for (int i = origNumOperands; i < newOperands.size(); i++) { + Value operand = newWait.getOperand(i); + if (!isa(operand.getType())) + operand.replaceUsesWithIf(newWait.getResult(i), dominatedByNewWait); + } + wait->erase(); +} + +// Determines whether a given MMAv3 dot op, represented as ttng.dot_async, needs +// a wait immediately after it. +// +// In PTX, MMAv3 exists only as an asynchronous op. In Triton, we can represent +// MMAv3 ops as either tt.dot (synchronous) or ttng.dot_async. But even if we +// use ttng.dot_async, the conservative thing is to make a dot "effectively +// synchronous" by inserting a `ttng.dot_wait {pendings=0}` right after it. +// +// We can omit the wait and create a "properly async" dot if all of the +// following are true. +// +// 1. All operands that touch shared memory are multi-buffered, i.e. can't read +// an incomplete value while it's being written asynchronously by a load. +// +// 2. If the dot is used by any op in the loop, it must be used under an `if`, +// and will be synced with a `wait 0` at the beginning of the `if` block. +// +// 3. During iteration i, between the start of the loop up until the first +// `ttng.dot_wait {pendings=0}` op, the result of the dot from iteration i-1 +// is consumed only by other MMAv3 dots as the `c` operand. +// +// This is safe because the following pseudo-PTX is valid: +// +// %accum = dot_async %a1, %b1, %c1 +// %accum = dot_async %a2, %b2, %accum +// +// That is, the second async dot can use the result of the first one without +// an intervening wait. However, the only operation that can legally read +// %accum before the wait is another dot_async, and this only works for the +// `c` operand, not `a` or `b`. See +// https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence +// (ttng::DotAsyncOp corresponds to wgmma.fence followed by one or more +// wgmma.async ops, so our understanding is that the two ttng::DotAsyncOps +// don't have to correspond to wgmma.async ops with the same shapes as +// specified in the docs, because there's an intervening fence.) +// +// If the op can be properly async, this function returns the index of the dot +// in the loop's iter_args. (Rule (2) above ensures this is well-defined.) +// +static std::optional dotCanBeProperlyAsync(ttng::DotAsyncOp dotOp, + scf::ForOp forOp) { + LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp); + + // Rule 1: All shmem operands are multi-buffered. + auto checkOperand = [&](Value operand) { + if (!isa( + cast(operand.getType()).getEncoding())) { + return true; + } + + // If it's a shmem operand, it must either be defined outside the loop, or + // come from an MemDescSubview op. Only ConvertLayout and Trans ops are + // allowed in between. + Value transitiveOperand = operand; + while (isa_and_nonnull( + transitiveOperand.getDefiningOp())) { + transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0); + } + return forOp.isDefinedOutsideOfLoop(transitiveOperand) || + isa(transitiveOperand.getDefiningOp()); + }; + + // We don't have to call checkOperand on getC() because it's always in + // registers, never in shmem. + assert(isa(dotOp.getC().getType().getEncoding())); + if (!checkOperand(dotOp.getA()) || !checkOperand(dotOp.getB())) { + LDBG("Can't make dot async because shmem operands aren't multi-buffered"); + return std::nullopt; + } + + // Rule 2: The dot cannot be unconditionally used by any op in the loop. + // Uses under `if` are allowed, as can be explicitly synced with a `wait 0`. + int iterArgIdx = -1; + Value iterArg = nullptr; + SmallVector> queue; + for (auto &use : dotOp->getUses()) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + while (!queue.empty()) { + auto [user, argIdx] = queue.pop_back_val(); + if (user->getParentOp() == forOp) { + if (isa(user)) { + if (iterArg) { + // The dot is used by the loop's yield, but we can't have any other + // uses. + LDBG("Can't make dot async because dot is used by multiple ops in " + "the loop."); + return std::nullopt; + } + iterArgIdx = argIdx; + iterArg = forOp.getRegionIterArg(argIdx); + continue; + } + LDBG("Can't make dot async because dot is unconditionally used in the " + "loop."); + return std::nullopt; + } + if (auto ifOp = dyn_cast(user->getParentOp())) { + if (isa(user)) { + // The result is returned by the if, follow it further. + auto uses = ifOp.getResult(argIdx).getUses(); + for (auto &use : uses) { + queue.push_back({use.getOwner(), use.getOperandNumber()}); + } + } + } else { + return std::nullopt; + } + } + + // Rule 3a: Are the only users of the dot's result from iteration i-1 other + // MMAv3 dots? If so, we're done, this dot can be properly async. + if (llvm::all_of(iterArg.getUses(), [&](OpOperand &use) { + return isa(use.getOwner()) && + use.getOperandNumber() == 2; + })) { + return iterArgIdx; + } + + // Rule 3b: Are all users of the dot's result from iteration i-1 after the + // first `dot_wait {pendings=0}` op? If so, the dot can be properly async, + // but we have to thread its result from iteration i-1 through the wait. + auto waitOps = forOp.getBody()->getOps(); + auto firstWaitOpIter = llvm::find_if( + waitOps, [&](auto waitOp) { return waitOp.getPendings() == 0; }); + if (firstWaitOpIter != waitOps.end() && + llvm::all_of(iterArg.getUsers(), [&](Operation *user) { + assert(forOp->isAncestor(user)); + while (user->getParentOp() != forOp) { + user = user->getParentOp(); + } + return (*firstWaitOpIter)->isBeforeInBlock(user); + })) { + LDBG("MMAv3 dot can be properly async because it follows a dot_wait " + "{pendings=0}.\n" + << " wait: " << *firstWaitOpIter << "\n" + << " dot: " << dotOp); + threadValuesThroughWait(*firstWaitOpIter, {iterArg}); + return iterArgIdx; + } + + LDBG("Can't make dot async because its result from i-1 is used by " + "something other than another MMAv3 dot as the `c` operand."); + return std::nullopt; +} + +// If necessary, insert a dot-wait inside the loop, waiting for the results of +// the properly-async dots from iteration i-1 to complete. (We pipeline to +// depth 2, so there are at most 2 copies of each dot_async in flight at a +// time.) +// +// We can skip inserting the wait if we have a `dot_wait {pendings=0}` somewhere +// in the loop. To see why, consider: +// +// dot_async +// dot_async; wait 0 // synchronous dot +// dot_async +// dot_async +// +// In this example, there are three properly-async dots, so we'd normally put +// `wait 3` at the end of the loop, meaning "wait until there are 3 or fewer +// pending async dots". But note that when this iteration of the loop +// completes, there are only *two* pending async dots from this iteration, so +// this wait would do nothing. This is true in general, no matter where the +// `wait 0` appears. +static void insertAsyncDotWaitInLoop( + scf::ForOp forOp, + const llvm::MapVector &properlyAsyncDots) { + if (properlyAsyncDots.empty()) + return; + + if (llvm::any_of(forOp.getBody()->getOps(), + [](auto wait) { return wait.getPendings() == 0; })) { + return; + } + + // Insert waits before the users of the properly async dots other than loop + // yield. + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + SmallVector uses; + for (auto &use : asyncDot->getUses()) { + if (auto yieldOp = dyn_cast(use.getOwner())) { + continue; + } + uses.push_back(&use); + } + + DenseMap> blockToUsers; + for (auto use : uses) { + auto block = use->getOwner()->getBlock(); + blockToUsers[block].push_back(use->get()); + } + + for (auto [block, users] : blockToUsers) { + OpBuilder builder(block, block->begin()); + auto newWait = builder.create(asyncDot->getLoc(), + ArrayRef{}, 0); + + threadValuesThroughWait(newWait, users); + } + } + + // Add the wait right after the last properly-async dot. This only needs to + // wait for all properly-async dots from the i-1'th iteration to complete, IOW + // we wait until there are most `asyncDots.size()` dots in flight. + // + // (You might want to put the wait at the end of the loop instead of right + // after the last dot, but there could be a load into shmem between the last + // async dot and the end of the loop, and that could clobber memory being used + // by a dot.) + IRRewriter builder(forOp.getContext()); + auto lastAsyncDot = properlyAsyncDots.back().first; + builder.setInsertionPointAfter(lastAsyncDot); + auto wait = builder.create(lastAsyncDot->getLoc(), + /*inputs=*/ArrayRef{}, + properlyAsyncDots.size()); + + // Thread the results of the async dots through the wait. + SmallVector addlWaitOperands; + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + addlWaitOperands.push_back(asyncDot->getResult(0)); + } + threadValuesThroughWait(wait, addlWaitOperands); +} + +// Convert MMAv3 tt::DotOps (i.e. Hopper wgmma) into ttng::DotAsyncOps and +// insert ttng::DotWaitOps as necessary. +// +// We assume we have space for each dot to be pipelined to depth 2, i.e. each +// dot op in the loop can have at most 2 dot_async ops in flight at once. (Each +// dot_async op usually corresponds to a series of wgmma.async ops.) +void triton::asyncLaunchDots(scf::ForOp forOp) { + LDBG("Original loop:\n" << *forOp); + + // First, change every MMAv3 tt.dot into ttng.dot_async. The rest of this + // function is concerned with inserting ttng.dot_wait ops in the appropriate + // places. + // + // It's not strictly necessary to convert every dot into dot_async: + // Synchronous MMAv3 dots can be represented equally well as `tt.dot` or + // `ttng.dot_async; wait 0`. But this makes things easier elsewhere. + // + // We call those dots that don't need to be followed immediately by a `wait 0` + // "properly async", or sometimes just "async". + IRRewriter builder(forOp.getContext()); + for (auto dotOp : llvm::to_vector(forOp.getBody()->getOps())) { + if (isMMAv3Dot(dotOp)) { + builder.setInsertionPoint(dotOp); + builder.replaceOpWithNewOp( + dotOp, dotOp.getA(), dotOp.getB(), dotOp.getC(), + dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); + } + } + + // For each dot, determine whether it can be properly async, or if it needs a + // sync immediately after. If it can be properly async, we know its only use + // is in the loop's `yield` statement; asyncDots maps the op to its index in + // the yield op. + llvm::MapVector properlyAsyncDots; + for (auto dotOp : forOp.getBody()->getOps()) { + if (auto iterArgIdx = dotCanBeProperlyAsync(dotOp, forOp)) { + properlyAsyncDots[dotOp] = *iterArgIdx; + } else { + builder.setInsertionPointAfter(dotOp); + auto wait = + builder.create(dotOp.getLoc(), ArrayRef{}, + /*pendings=*/0); + SmallVector waitOperands = {dotOp.getResult()}; + threadValuesThroughWait(wait, waitOperands); + } + } + + if (properlyAsyncDots.empty()) { + LDBG("No properly async dots."); + return; + } + + // Next, insert a wait inside the loop. We pipeline to depth 2, so the third + // iteration's set of asynchronous dots (and their corresponding async copies + // from global to shmem) can't start until the first iteration's set has + // completed. + insertAsyncDotWaitInLoop(forOp, properlyAsyncDots); + + // Finally, insert a wait after the loop, waiting for dots from the final + // iteration of the loop. + SmallVector waitOperands; + for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) { + waitOperands.push_back(forOp.getResult(iterArgIdx)); + } + // Wait until there are 0 outstanding async dot ops. + builder.setInsertionPointAfter(forOp); + auto dotWaitAfterLoop = + builder.create(forOp.getLoc(), ArrayRef{}, 0); + threadValuesThroughWait(dotWaitAfterLoop, waitOperands); +} diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp new file mode 100644 index 000000000..8b3f55bb8 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp @@ -0,0 +1,131 @@ +#include "PipelineExpander.h" +#include "PipeliningUtility.h" +#include "Schedule.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +// create the schedule for a matmul loop. This is ad hoc based on how we know +// matmul loops should be pipelined and is not a generic scheduler. +static std::vector> +createSchedule(scf::ForOp forOp, int numStages) { + SmallVector insertOps; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + insertOps.emplace_back(&op); + } + DenseSet insertAndDeps; + for (Operation *op : insertOps) { + tt::addDep(op, insertAndDeps, true); + } + + DenseSet epilogue; + bool foundLoop = false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (insertAndDeps.count(&op)) + continue; + if (isa(op)) + foundLoop = true; + if (isa(op)) + continue; + if (foundLoop) + epilogue.insert(&op); + } + + std::vector> schedule; + // Schedule stage 1 first. + tt::addOps(forOp, 1, schedule, [&](Operation *op) { + return insertAndDeps.count(op) == 0 && epilogue.count(op) == 0; + }); + + // Then Schedule stage 0. + tt::addOps(forOp, 0, schedule, + [&](Operation *op) { return insertAndDeps.count(op); }); + + // Then schedule the epilogue in stage 1 + tt::addOps(forOp, 1, schedule, + [&](Operation *op) { return epilogue.count(op); }); + return schedule; +} + +// pre-process the loop by hosting allocations/deallocation out of the +// loop. +static void hoistAllocAndConst(scf::ForOp forOp) { + SmallVector toHoist; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (auto allocOp = dyn_cast(op)) { + // We hoist the allocOp only if it is created by the inner loop + // pipelining. + if (!allocOp.getSrc()) + toHoist.push_back(&op); + } else if (isa(op)) { + toHoist.push_back(&op); + } + } + for (Operation *op : toHoist) { + op->moveBefore(forOp); + auto allocOp = dyn_cast(op); + if (!allocOp) + continue; + for (Operation *user : allocOp->getUsers()) { + if (auto dealloc = dyn_cast(user)) { + dealloc->moveAfter(forOp); + } + } + } +} + +static bool preCondition(scf::ForOp forOp) { + // Check if there is a dependency from the loop to the async copy op. In this + // case we cannot pipeline the async copy. + SmallVector insertOps; + int numForOps = 0; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) + insertOps.emplace_back(&op); + if (isa(op)) + numForOps++; + } + if (insertOps.empty() || numForOps != 1) + return false; + DenseSet insertAndDeps; + for (Operation *op : insertOps) { + tt::addDep(op, insertAndDeps, true); + } + // If there is a recurrence containing both the async and the for op we cannot + // pipeline. + for (Operation *op : insertAndDeps) { + if (isa(op)) + return false; + } + return true; +} + +bool mlir::triton::getOuterLoopSchedule( + scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { + assert(numStages == 2 && "only support 2 stage pipelining for now"); + // 1. Check precondition, we cannot have a recurrence involving async cp ops + if (!preCondition(forOp)) + return false; + + // 2. pre-process the loop by hosting allocations. + hoistAllocAndConst(forOp); + + // 3. Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> schedule = + createSchedule(forOp, numStages); + + // 4. Fill out the pipeline options. + options.getScheduleFn = + [schedule](scf::ForOp forOp, + std::vector> &s) { + s = std::move(schedule); + }; + options.peelEpilogue = false; + options.predicateFn = mlir::triton::predicateOp; + options.supportDynamicLoops = true; + return true; +} diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp new file mode 100644 index 000000000..6dfd0e344 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -0,0 +1,776 @@ +//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop software pipelining +// +//===----------------------------------------------------------------------===// + +// Fork of upstream pipeliner. This will be merged upstream once things are +// stable. Modifications so far are: +// -Bug fix for def with a distance of 1 scheduled in stage 0. +// -Support dynamic loops and predicate operations in the prologue. +// -Support for non-index type for induction variable. +// -Support source with distance of 1 used multiple stages later. +// -Fix bug when a value yield is used outside the loop and the value def is not +// in the last stage. If we are not peeling the epilgue we need to remap the +// output correctly. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" + +#include "PipelineExpander.h" + +#define DEBUG_TYPE "triton-loop-pipelining" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::scf; +using namespace mlir::triton; + +namespace { + +/// Helper to keep internal information during pipelining transformation. +struct LoopPipelinerInternal { + /// Coarse liverange information for ops used across stages. + struct LiverangeInfo { + unsigned lastUseStage = 0; + unsigned defStage = 0; + }; + +protected: + ForOp forOp; + unsigned maxStage = 0; + DenseMap stages; + std::vector opOrder; + Value ub; + Value lb; + Value step; + bool dynamicLoop; + triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; + + // When peeling the kernel we generate several version of each value for + // different stage of the prologue. This map tracks the mapping between + // original Values in the loop and the different versions + // peeled from the loop. + DenseMap> valueMapping; + + /// Assign a value to `valueMapping`, this means `val` represents the version + /// `idx` of `key` in the epilogue. + void setValueMapping(Value key, Value el, int64_t idx); + + /// Return the defining op of the given value, if the Value is an argument of + /// the loop return the associated defining op in the loop and its distance to + /// the Value. + std::pair getDefiningOpAndDistance(Value value); + + /// Return true if the schedule is possible and return false otherwise. A + /// schedule is correct if all definitions are scheduled before uses. + bool verifySchedule(); + +public: + /// Initialize the information for the given `op`, return true if it + /// satisfies the pre-condition to apply pipelining. + bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options); + /// Emits the prologue, this creates `maxStage - 1` part which will contain + /// operations from stages [0; i], where i is the part index. + void emitPrologue(RewriterBase &rewriter); + /// Gather liverange information for Values that are used in a different stage + /// than its definition. + llvm::MapVector analyzeCrossStageValues(); + scf::ForOp createKernelLoop( + const llvm::MapVector &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap); + /// Emits the pipelined kernel. This clones loop operations following user + /// order and remaps operands defined in a different stage as their use. + LogicalResult createKernel( + scf::ForOp newForOp, + const llvm::MapVector &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter); + /// Emits the epilogue, this creates `maxStage - 1` part which will contain + /// operations from stages [i; maxStage], where i is the part index. + void emitEpilogue(RewriterBase &rewriter, + llvm::SmallVector &returnValues); +}; + +bool LoopPipelinerInternal::initializeLoopInfo( + ForOp op, const triton::PipeliningOption &options) { + LDBG("Start initializeLoopInfo"); + forOp = op; + ub = forOp.getUpperBound(); + lb = forOp.getLowerBound(); + step = forOp.getStep(); + + dynamicLoop = true; + auto upperBoundCst = ub.getDefiningOp(); + auto lowerBoundCst = lb.getDefiningOp(); + auto stepCst = step.getDefiningOp(); + if (!upperBoundCst || !lowerBoundCst || !stepCst) { + if (!options.supportDynamicLoops) { + LDBG("--dynamic loop not supported -> BAIL"); + return false; + } + } else { + int64_t ubImm = upperBoundCst.value(); + int64_t lbImm = lowerBoundCst.value(); + int64_t stepImm = stepCst.value(); + int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm); + if (numIteration > maxStage) { + dynamicLoop = false; + } else if (!options.supportDynamicLoops) { + LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + return false; + } + } + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { + LDBG("--no epilogue or predicate set -> BAIL"); + return false; + } + if (dynamicLoop && peelEpilogue) { + LDBG("--dynamic loop doesn't support epilogue yet -> BAIL"); + return false; + } + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) { + LDBG("--empty schedule -> BAIL"); + return false; + } + + opOrder.reserve(schedule.size()); + for (auto &opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + + // All operations need to have a stage. + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!stages.contains(&op)) { + op.emitOpError("not assigned a pipeline stage"); + LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + return false; + } + } + + if (!verifySchedule()) { + LDBG("--invalid schedule: " << op << " -> BAIL"); + return false; + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto &[op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError("the owning Block of all operations assigned a stage " + "should be the loop body block"); + LDBG("--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"); + return false; + } + } + + // Support only loop-carried dependencies with a distance of one iteration or + // those defined outside of the loop. This means that any dependency within a + // loop should either be on the immediately preceding iteration, the current + // iteration, or on variables whose values are set before entering the loop. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [this](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def || + (!stages.contains(def) && forOp->isAncestor(def)); + })) { + LDBG("--only support loop carried dependency with a distance of 1 or " + "defined outside of the loop -> BAIL"); + return false; + } + annotateFn = options.annotateFn; + return true; +} + +/// Find operands of all the nested operations within `op`. +static SetVector getNestedOperands(Operation *op) { + SetVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + operands.insert(operand); + } + }); + return operands; +} + +/// Compute unrolled cycles of each op (consumer) and verify that each op is +/// scheduled after its operands (producers) while adjusting for the distance +/// between producer and consumer. +bool LoopPipelinerInternal::verifySchedule() { + int64_t numCylesPerIter = opOrder.size(); + // Pre-compute the unrolled cycle of each op. + DenseMap unrolledCyles; + for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) { + Operation *def = opOrder[cycle]; + auto it = stages.find(def); + assert(it != stages.end()); + int64_t stage = it->second; + unrolledCyles[def] = cycle + stage * numCylesPerIter; + } + for (Operation *consumer : opOrder) { + int64_t consumerCycle = unrolledCyles[consumer]; + for (Value operand : getNestedOperands(consumer)) { + auto [producer, distance] = getDefiningOpAndDistance(operand); + if (!producer) + continue; + auto it = unrolledCyles.find(producer); + // Skip producer coming from outside the loop. + if (it == unrolledCyles.end()) + continue; + int64_t producerCycle = it->second; + if (consumerCycle < producerCycle - numCylesPerIter * distance) { + consumer->emitError("operation scheduled before its operands"); + return false; + } + } + } + return true; +} + +/// Clone `op` and call `callback` on the cloned op's operands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation * +cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, + function_ref callback) { + Operation *clone = rewriter.clone(*op); + clone->walk([&](Operation *nested) { + // 'clone' itself will be visited first. + for (OpOperand &operand : nested->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || isa(operand.get())) + callback(&operand); + } + }); + return clone; +} + +void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (auto [arg, operand] : + llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { + setValueMapping(arg, operand.get(), 0); + } + auto yield = cast(forOp.getBody()->getTerminator()); + Location loc = forOp.getLoc(); + SmallVector predicates(maxStage); + for (int64_t i = 0; i < maxStage; i++) { + if (dynamicLoop) { + Type t = ub.getType(); + // pred = ub > lb + (i * step) + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, i)))); + predicates[i] = rewriter.create( + loc, arith::CmpIPredicate::slt, iv, ub); + } + + // special handling for induction variable as the increment is implicit. + // iv = lb + i * step + Type t = lb.getType(); + Value iv = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, + rewriter.getIntegerAttr(t, i)))); + setValueMapping(forOp.getInductionVar(), iv, i); + for (Operation *op : opOrder) { + if (stages[op] > i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); + int predicateIdx = i - stages[op]; + if (predicates[predicateIdx]) { + newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); + assert(newOp && "failed to predicate op."); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); + // If the value is a loop carried dependency update the loop argument + // mapping. + for (OpOperand &operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), i - stages[op] + 1); + } + } + } + } +} + +llvm::MapVector +LoopPipelinerInternal::analyzeCrossStageValues() { + llvm::MapVector crossStageValues; + for (Operation *op : opOrder) { + unsigned stage = stages[op]; + + auto analyzeOperand = [&](OpOperand &operand) { + auto [def, distance] = getDefiningOpAndDistance(operand.get()); + if (!def) + return; + auto defStage = stages.find(def); + if (defStage == stages.end() || defStage->second == stage || + defStage->second == stage + distance) + return; + assert(stage > defStage->second); + LiverangeInfo &info = crossStageValues[operand.get()]; + info.defStage = defStage->second; + info.lastUseStage = std::max(info.lastUseStage, stage); + }; + + for (OpOperand &operand : op->getOpOperands()) + analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) { + analyzeOperand(*operand); + }); + } + return crossStageValues; +} + +std::pair +LoopPipelinerInternal::getDefiningOpAndDistance(Value value) { + int64_t distance = 0; + if (auto arg = dyn_cast(value)) { + if (arg.getOwner() != forOp.getBody()) + return {nullptr, 0}; + // Ignore induction variable. + if (arg.getArgNumber() == 0) + return {nullptr, 0}; + distance++; + value = + forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + Operation *def = value.getDefiningOp(); + if (!def) + return {nullptr, 0}; + return {def, distance}; +} + +scf::ForOp LoopPipelinerInternal::createKernelLoop( + const llvm::MapVector + &crossStageValues, + RewriterBase &rewriter, + llvm::DenseMap, unsigned> &loopArgMap) { + // Creates the list of initial values associated to values used across + // stages. The initial values come from the prologue created above. + // Keep track of the kernel argument associated to each version of the + // values passed to the kernel. + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance of 1 or " + "outside the loop"); + auto defStage = stages.find(def); + if (defStage != stages.end()) { + Value valueVersion = + valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage->second]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } else + newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]); + } + for (auto escape : crossStageValues) { + LiverangeInfo &info = escape.second; + Value value = escape.first; + for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; + stageIdx++) { + Value valueVersion = + valueMapping[value][maxStage - info.lastUseStage + stageIdx]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - + stageIdx)] = newLoopArg.size() - 1; + } + } + + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) { + Type t = ub.getType(); + Location loc = forOp.getLoc(); + // newUb = ub - maxStage * step + Value maxStageValue = rewriter.create( + loc, rewriter.getIntegerAttr(t, maxStage)); + Value maxStageByStep = + rewriter.create(loc, step, maxStageValue); + newUb = rewriter.create(loc, ub, maxStageByStep); + } + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); + return newForOp; +} + +LogicalResult LoopPipelinerInternal::createKernel( + scf::ForOp newForOp, + const llvm::MapVector + &crossStageValues, + const llvm::DenseMap, unsigned> &loopArgMap, + RewriterBase &rewriter) { + valueMapping.clear(); + + // Create the kernel, we clone instruction based on the order given by + // user and remap operands coming from a previous stages. + rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + Location loc = newForOp.getLoc(); + Type t = ub.getType(); + for (unsigned i = 0; i < maxStage; i++) { + // c = ub - (maxStage - i) * step + Value c = rewriter.create( + loc, ub, + rewriter.create( + loc, step, + rewriter.create( + loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); + + Value pred = rewriter.create( + newForOp.getLoc(), arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); + predicates[i] = pred; + } + } + for (Operation *op : opOrder) { + int64_t useStage = stages[op]; + auto *newOp = rewriter.clone(*op, mapping); + SmallVector operands; + // Collect all the operands for the cloned op and its nested ops. + op->walk([&operands](Operation *nestedOp) { + for (OpOperand &operand : nestedOp->getOpOperands()) { + operands.push_back(&operand); + } + }); + for (OpOperand *operand : operands) { + Operation *nestedNewOp = mapping.lookup(operand->getOwner()); + // Special case for the induction variable uses. We replace it with a + // version incremented based on the stage where it is used. + if (operand->get() == forOp.getInductionVar()) { + rewriter.setInsertionPoint(newOp); + + // offset = (maxStage - stages[op]) * step + Type t = step.getType(); + Value offset = rewriter.create( + forOp.getLoc(), step, + rewriter.create( + forOp.getLoc(), + rewriter.getIntegerAttr(t, maxStage - stages[op]))); + Value iv = rewriter.create( + forOp.getLoc(), newForOp.getInductionVar(), offset); + nestedNewOp->setOperand(operand->getOperandNumber(), iv); + rewriter.setInsertionPointAfter(newOp); + continue; + } + Value source = operand->get(); + auto arg = dyn_cast(source); + if (arg && arg.getOwner() == forOp.getBody()) { + Value ret = forOp.getBody()->getTerminator()->getOperand( + arg.getArgNumber() - 1); + Operation *dep = ret.getDefiningOp(); + if (!dep) + continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) + continue; + // If the value is a loop carried value coming from stage N + 1 remap, + // it will become a direct use. + if (stageDep->second == useStage + 1) { + nestedNewOp->setOperand(operand->getOperandNumber(), + mapping.lookupOrDefault(ret)); + continue; + } + source = ret; + } + // For operands defined in a previous stage we need to remap it to use + // the correct region argument. We look for the right version of the + // Value based on the stage where it is used. + Operation *def = source.getDefiningOp(); + if (!def) + continue; + auto stageDef = stages.find(def); + if (stageDef == stages.end() || stageDef->second == useStage) + continue; + auto remap = loopArgMap.find( + std::make_pair(operand->get(), useStage - stageDef->second)); + assert(remap != loopArgMap.end()); + nestedNewOp->setOperand(operand->getOperandNumber(), + newForOp.getRegionIterArgs()[remap->second]); + } + + if (predicates[useStage]) { + newOp = predicateFn(rewriter, newOp, predicates[useStage]); + if (!newOp) + return failure(); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0); + } + + // Collect the Values that need to be returned by the forOp. For each + // value we need to have `LastUseStage - DefStage` number of versions + // returned. + // We create a mapping between original values and the associated loop + // returned values that will be needed by the epilogue. + llvm::SmallVector yieldOperands; + for (OpOperand &yieldOperand : + forOp.getBody()->getTerminator()->getOpOperands()) { + Value source = mapping.lookupOrDefault(yieldOperand.get()); + // When we don't peel the epilogue and the yield value is used outside the + // loop we need to make sure we return the version from numStages - + // defStage. + if (!peelEpilogue && + !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) { + Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first; + if (def) { + auto defStage = stages.find(def); + if (defStage != stages.end() && defStage->second < maxStage) { + Value pred = predicates[defStage->second]; + source = rewriter.create( + pred.getLoc(), pred, source, + newForOp.getBody() + ->getArguments()[yieldOperand.getOperandNumber() + 1]); + } + } + } + yieldOperands.push_back(source); + } + + for (auto &it : crossStageValues) { + int64_t version = maxStage - it.second.lastUseStage + 1; + unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; + // add the original version to yield ops. + // If there is a live range spanning across more than 2 stages we need to + // add extra arg. + for (unsigned i = 1; i < numVersionReturned; i++) { + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back( + newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]); + } + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + } + // Map the yield operand to the forOp returned value. + for (const auto &retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation *def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance of 1 or " + "defined outside the loop"); + auto defStage = stages.find(def); + if (defStage == stages.end()) { + for (unsigned int stage = 1; stage <= maxStage; stage++) + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + retVal.value(), stage); + } else if (defStage->second > 0) { + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage->second + 1); + } + } + rewriter.create(forOp.getLoc(), yieldOperands); + return success(); +} + +void LoopPipelinerInternal::emitEpilogue( + RewriterBase &rewriter, llvm::SmallVector &returnValues) { + // Emit different versions of the induction variable. They will be + // removed by dead code if not used. + for (int64_t i = 0; i < maxStage; i++) { + Location loc = forOp.getLoc(); + Type t = lb.getType(); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(t, -1)); + // number of iterations = ((ub - 1) - lb) / step + Value totalNumIteration = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, ub, minusOne), lb), + step); + // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) + Value minusI = + rewriter.create(loc, rewriter.getIntegerAttr(t, -i)); + Value newlastIter = rewriter.create( + loc, lb, + rewriter.create( + loc, step, + rewriter.create(loc, totalNumIteration, minusI))); + setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); + } + // Emit `maxStage - 1` epilogue part that includes operations from stages + // [i; maxStage]. + for (int64_t i = 1; i <= maxStage; i++) { + for (Operation *op : opOrder) { + if (stages[op] < i) + continue; + Operation *newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[maxStage - stages[op] + i]; + newOperand->set(replacement); + } + }); + if (annotateFn) + annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue, + i - 1); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + maxStage - stages[op] + i); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand &operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != op->getResult(destId)) + continue; + unsigned version = maxStage - stages[op] + i + 1; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + if (version > maxStage) { + returnValues[operand.getOperandNumber()] = newOp->getResult(destId); + continue; + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), version); + } + } + } + } +} + +void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { + auto it = valueMapping.find(key); + // If the value is not in the map yet add a vector big enough to store all + // versions. + if (it == valueMapping.end()) + it = + valueMapping + .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) + .first; + it->second[idx] = el; +} + +} // namespace + +FailureOr +mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, + const triton::PipeliningOption &options, + bool *modifiedIR) { + if (modifiedIR) + *modifiedIR = false; + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) + return failure(); + + if (modifiedIR) + *modifiedIR = true; + + // 1. Emit prologue. + pipeliner.emitPrologue(rewriter); + + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); + + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, + rewriter))) + return failure(); + + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + pipeliner.emitEpilogue(rewriter, returnValues); + } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); + + return newForOp; +} diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h new file mode 100644 index 000000000..0a3d736c6 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h @@ -0,0 +1,101 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ +#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ + +// This is a fork of upstream pipeline transformation. This will be merged back +// upstream once we have a stable solution. + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +class RewriterBase; +class Operation; +class Value; + +namespace scf { +class ForOp; +} + +namespace triton { + +/// Options to dictate how loops should be pipelined. +struct PipeliningOption { + /// Lambda returning all the operation in the forOp, with their stage, in the + /// order picked for the pipelined loop. + using GetScheduleFnType = std::function> &)>; + GetScheduleFnType getScheduleFn = nullptr; + enum class PipelinerPart { + Prologue, + Kernel, + Epilogue, + }; + /// Lambda called by the pipeliner to allow the user to annotate the IR while + /// it is generated. + /// The callback passes the operation created along with the part of the + /// pipeline and the iteration index. The iteration index is always 0 for the + /// kernel. For the prologue and epilogue, it corresponds to the iteration + /// peeled out of the loop in the range [0, maxStage[. + using AnnotationlFnType = + std::function; + AnnotationlFnType annotateFn = nullptr; + + /// Control whether the epilogue should be peeled out of the loop or + /// operations should be predicated to skip the early stages in the last loop + /// iterations. If the epilogue is predicated; the user needs to provide a + /// lambda to generate the predicated version of operations. + bool peelEpilogue = true; + + /// Control whether the transformation checks that the number of iterations is + /// greater or equal to the number of stages and skip the transformation if + /// this is not the case. If the loop is dynamic and this is set to true the + /// pipeliner will have to predicate operations in the the prologue/epilogue. + bool supportDynamicLoops = false; + + // Callback to predicate operations when the prologue or epilogue are not + // peeled. This takes the original operation, an i1 predicate value and the + // pattern rewriter. It is expected to replace the given operation with + // the predicated equivalent and return it, or return nullptr if the + // predication is impossible. In the latter case, pipelining will fail and + // may leave IR in a partially transformed state. + using PredicateOpFnType = + std::function; + PredicateOpFnType predicateFn = nullptr; + + // TODO: add option to decide if the prologue should be peeled. +}; + +/// Generate a pipelined version of the scf.for loop based on the schedule given +/// as option. This applies the mechanical transformation of changing the loop +/// and generating the prologue/epilogue for the pipelining and doesn't make any +/// decision regarding the schedule. +/// Based on the options the loop is split into several stages. +/// The transformation assumes that the scheduling given by user is valid. +/// For example if we break a loop into 3 stages named S0, S1, S2 we would +/// generate the following code with the number in parenthesis as the iteration +/// index: +/// +/// S0(0) // Prologue +/// S0(1) S1(0) // Prologue +/// scf.for %I = %C0 to %N - 2 { +/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel +/// } +/// S1(N) S2(N-1) // Epilogue +/// S2(N) // Epilogue +/// +/// If `modifiedIR` is provided, it will be set to a value that indicates +/// whether pipelining modified the IR before failing, signaling to the caller +/// whether they can proceed with different transformations. +FailureOr pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp, + const PipeliningOption &options, + bool *modifiedIR = nullptr); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_ diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp new file mode 100644 index 000000000..c773d808c --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -0,0 +1,123 @@ +#include "PipeliningUtility.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +// Combine the current mask with the given predicate. +static Value getPredMask(RewriterBase &rewriter, Type typeLike, + Value currentMask, Value pred) { + Type maskType = tt::getI1SameShape(typeLike); + Location loc = pred.getLoc(); + Value mask = pred; + if (isa(maskType)) { + mask = rewriter.create(loc, maskType, pred); + } + if (currentMask) { + mask = rewriter.create(loc, mask, currentMask); + } + return mask; +} + +// Function to mask operations during scheduling. +Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, + Value pred) { + OpBuilder::InsertionGuard guard(rewriter); + if (mlir::isMemoryEffectFree(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (isa(op)) + return op; + if (auto ifOp = dyn_cast(op)) { + rewriter.setInsertionPoint(op); + Value cnd = getPredMask(rewriter, ifOp.getCondition().getType(), + ifOp.getCondition(), pred); + ifOp.getConditionMutable().assign(cnd); + return op; + } + if (auto asyncCopyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(asyncCopyOp); + Value mask = getPredMask(rewriter, asyncCopyOp.getSrc().getType(), + asyncCopyOp.getMask(), pred); + asyncCopyOp.getMaskMutable().assign(mask); + return op; + } + if (auto loadOp = dyn_cast(op)) { + rewriter.setInsertionPoint(loadOp); + Value mask = getPredMask(rewriter, loadOp.getPtr().getType(), + loadOp.getMask(), pred); + loadOp.getMaskMutable().assign(mask); + return op; + } + if (auto copyOp = dyn_cast(op)) { + rewriter.setInsertionPoint(copyOp); + Value mask = getPredMask(rewriter, copyOp.getPred().getType(), + copyOp.getPred(), pred); + copyOp.getPredMutable().assign(mask); + return op; + } + if (auto expectOp = dyn_cast(op)) { + rewriter.setInsertionPoint(expectOp); + Value mask = getPredMask(rewriter, expectOp.getPred().getType(), + expectOp.getPred(), pred); + expectOp.getPredMutable().assign(mask); + return op; + } + + assert("don't know how to predicate this op" && false); + return op; +} + +/// Helper to recursively add dependencies to the same stage. +void mlir::triton::addDep(Operation *op, DenseSet &deps, + bool includeArg, DenseSet *filter) { + if (filter && filter->count(op)) + return; + if (!deps.insert(op).second) + return; + for (Value operand : op->getOperands()) { + Value v = operand; + llvm::SmallDenseSet seen; + while (auto arg = mlir::dyn_cast(v)) { + if (!includeArg) + break; + if (!seen.insert(v).second) + break; + if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) { + auto yieldOp = op->getBlock()->getTerminator(); + v = yieldOp->getOperand(arg.getArgNumber() - 1); + continue; + } + break; + } + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + addDep(defOp, deps, includeArg, filter); + } + } +} + +// Add operations to the schedule with the given stage based on the filter +// function. +void mlir::triton::addOps( + scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!filter(&op)) + continue; + schedule.emplace_back(&op, stage); + } +} diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.h b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.h new file mode 100644 index 000000000..25f0806db --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.h @@ -0,0 +1,29 @@ +#ifndef TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ +#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include + +namespace mlir { +namespace triton { + +static const char *kNumStagesAttrName = "tt.num_stages"; + +/// Function to mask operations during scheduling. +Operation *predicateOp(RewriterBase &rewriter, Operation *op, Value pred); + +/// Collect ssa dependencies of `op` in `deps`. if `includeArg` is true, +/// continue looking through loop block arguments. +void addDep(Operation *op, DenseSet &deps, bool includeArg = true, + DenseSet *filter = nullptr); + +/// Add operations from `forOp` into a pipeline schedule with the the given +/// `stage` when filter is true. This will add operation in the original loop +/// order. +void addOps(scf::ForOp forOp, int stage, + std::vector> &schedule, + std::function filter); +} // namespace triton +} // namespace mlir + +#endif // TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h similarity index 100% rename from lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h rename to third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp new file mode 100644 index 000000000..e5ed6ed37 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -0,0 +1,164 @@ +#include "PipelineExpander.h" +#include "PipeliningUtility.h" +#include "Schedule.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create async operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop. +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUPIPELINE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Return true if the preconditions for pipelining the loop are met. +static bool preCondition(scf::ForOp forOp) { + // Skip loop with distance > 1 for now. + // TODO: relax the constraint in the expander. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [](Value operand) { + Operation *def = operand.getDefiningOp(); + return !def; + })) + return false; + // Don't pipeline outer loops. + if (forOp + ->walk([&](Operation *op) { + if (forOp.getOperation() == op) + return WalkResult::advance(); + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + return true; +} + +static void tryAndPipelineOuterLoop(scf::ForOp forOp) { + mlir::triton::PipeliningOption options; + bool foundSchedule = false; + // Limit 2 stages to not require extra shared memory. + foundSchedule = getOuterLoopSchedule(forOp, /*numStage=*/2, options); + if (!foundSchedule) + return; + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); +} + +static bool pipelineLoop(scf::ForOp forOp, int numStages) { + mlir::triton::PipeliningOption options; + if (!preCondition(forOp)) + return false; + + bool foundSchedule = false; + foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options); + + // TODO: add more pipelines strategy. + if (!foundSchedule) + return false; + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + FailureOr newForOp = + mlir::triton::pipelineForLoop(rewriter, forOp, options); + + if (failed(newForOp)) + return false; + mlir::triton::asyncLaunchDots(newForOp.value()); + return true; +} + +struct PipelinePass : public impl::TritonGPUPipelineBase { + + using impl::TritonGPUPipelineBase::TritonGPUPipelineBase; + + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return numStages; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + } + + void runOnOperation() override { + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + if (loops.empty()) + return; + + llvm::SmallSetVector outerLoops; + for (scf::ForOp forOp : loops) { + auto outerLoop = dyn_cast(forOp->getParentOp()); + int loopNumStages = getNumStagesOrDefault(forOp); + bool pipelined = pipelineLoop(forOp, loopNumStages); + if (pipelined && outerLoop && getNumStagesOrDefault(outerLoop) > 1) + outerLoops.insert(outerLoop); + } + + // schedule the waits + mlir::triton::updateWaits(getOperation()); + + // Clean up arithmetic before applying the next level of pipelining to + // simplify the IR. + auto arithDialect = + getOperation().getContext()->getLoadedDialect(); + RewritePatternSet patterns(getOperation().getContext()); + arithDialect->getCanonicalizationPatterns(patterns); + if (applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)) + .failed()) + return signalPassFailure(); + + // Try to pipeline the outer loop to overlap the prologue and epilogue of + // the inner loop. + for (scf::ForOp outerLoop : outerLoops) + tryAndPipelineOuterLoop(outerLoop); + + // Re-collect loop ops + loops.clear(); + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1. + if (getNumStagesOrDefault(forOp) > 1) + loops.push_back(forOp); + }); + + for (scf::ForOp forOp : loops) { + mlir::triton::pipelineTMAStores(forOp); + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp new file mode 100644 index 000000000..6318b178d --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -0,0 +1,93 @@ +#include "Schedule.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = mlir::triton::nvidia_gpu; + +static SmallVector +getTMAStores(scf::ForOp forOp) { + SmallVector tmaStores; + + // Do not use walk, as we don't want to walk into nested loops. + std::function collectTMAStores = [&](Operation *op) { + if (auto storeOp = dyn_cast(op)) { + tmaStores.push_back(storeOp); + } + for (Region ®ion : op->getRegions()) { + for (Operation &op : region.getOps()) { + if (!isa(op)) + collectTMAStores(&op); + } + } + }; + collectTMAStores(forOp); + return tmaStores; +} + +static Value createAlloc(scf::ForOp &forOp, + tt::ExperimentalDescriptorStoreOp storeOp) { + OpBuilder builder(forOp); + auto ty = cast(storeOp.getSrc().getType()); + auto order = ttg::getOrder(ty.getEncoding()); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + Attribute encoding = + ttg::SharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order, ctaLayout); + if (ty.getRank() > 1) { + encoding = ttg::SharedEncodingAttr::get( + ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType()); + } + + Type memdescType = tt::MemDescType::get(ty.getShape(), ty.getElementType(), + encoding, /*mutableMemory*/ true); + Value alloc = builder.create(storeOp->getLoc(), + memdescType, Value()); + return alloc; +} + +static void createTMAAsyncCopy(scf::ForOp &forOp, + tt::ExperimentalDescriptorStoreOp storeOp, + Value alloc) { + OpBuilder builder(storeOp); + auto loc = storeOp.getLoc(); + auto ty = cast(storeOp.getSrc().getType()); + auto order = ttg::getOrder(ty.getEncoding()); + auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); + + // Put wait before the local_store make the store truly async. We know + // that we are the only user of the CopyLocalToGlobal. + builder.create(loc, 0); + builder.create(loc, storeOp.getSrc(), alloc); + builder.create(loc, false); + builder.create( + loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc); + + storeOp->erase(); +} + +bool mlir::triton::pipelineTMAStores(scf::ForOp forOp) { + SmallVector tmaStores = + getTMAStores(forOp); + if (tmaStores.empty()) + return false; + + DenseMap storeToAlloc; + for (tt::ExperimentalDescriptorStoreOp op : tmaStores) { + storeToAlloc[op] = createAlloc(forOp, op); + } + + for (tt::ExperimentalDescriptorStoreOp op : tmaStores) { + createTMAAsyncCopy(forOp, op, storeToAlloc[op]); + } + + // Deallocate shared memory buffers. + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + builder.create(forOp->getLoc(), 0); + for (auto it : storeToAlloc) { + builder.create(forOp->getLoc(), it.second); + } + return true; +} diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp new file mode 100644 index 000000000..85a95aaa7 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -0,0 +1,407 @@ +//===----------------------------------------------------------------------===// +// +// This pass tries to prefetch operands (a and b) of tt.dot. +// Those ConvertLayoutOps will be lowered to shared memory loads. +// +// For example: +// %a: tensor<128x32xf16, #enc> +// scf.for %iv = ... iter_args(%a_arg = %a, ...) { +// %d = tt.dot %a_arg, %b, %c +// ... +// scf.yield %a_next, ... +// } +// +// will be translated to +// +// %a: tensor<128x32xf16, #enc> +// %a_tmp = tensor.subview %a[0, 0] [128, 16] +// %a_prefetch = triton_gpu.local_load %a_tmp +// scf.for %iv = ... iter_args(%a_buf = %a, ..., %a_prefetch_arg = %a_prefetch) +// { +// %x = tt.dot %a_prefetch_arg, %b, %c +// %a_tmp_rem = tensor.subview %a_buf[0, 16] [128, 16] +// %a_prefetch_next = triton_gpu.local_load %a_tmp_rem +// ... +// scf.yield %next_a, ..., %a_prefetch_next +// } +//===----------------------------------------------------------------------===// + +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUPREFETCH +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +class Prefetcher { + /// cache the ForOp we are working on + scf::ForOp forOp; + /// cache the YieldOp of this ForOp + scf::YieldOp yieldOp; + /// + // TODO: add a hook to infer prefetchWidth + unsigned prefetchWidth = 32; + + /// dots to be prefetched + SetVector dots; + /// dot => dot operand + DenseMap dot2aLoopArg; + DenseMap dot2aHeaderDef; + DenseMap dot2bLoopArg; + DenseMap dot2bHeaderDef; + DenseMap dot2aYield; + DenseMap dot2bYield; + DenseMap> dot2aVals; + DenseMap> dot2bVals; + /// operand => defining + DenseMap operand2headPrefetch; + + LogicalResult isForOpOperand(Value v); + + Value generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + std::optional offsetK = std::nullopt, + std::optional shapeK = std::nullopt); + + void cloneElementwiseOps(Value &bRem, const SmallVector &vals, + OpBuilder &builder); + +public: + Prefetcher() = delete; + + Prefetcher(scf::ForOp forOp) : forOp(forOp) { + yieldOp = cast(forOp.getBody()->getTerminator()); + } + + LogicalResult initialize(); + + void emitPrologue(); + + scf::ForOp createNewForOp(); +}; + +void Prefetcher::cloneElementwiseOps(Value &ret, const SmallVector &vals, + OpBuilder &builder) { + IRMapping mapping; + mapping.map(vals[1], ret); + for (int i = 2; i < vals.size(); i++) { + Value v = vals[i]; + Value curr = builder.clone(*v.getDefiningOp(), mapping)->getResult(0); + if (isa(curr.getType())) { + auto retType = RankedTensorType::get( + cast(ret.getType()).getShape(), + cast(curr.getType()).getElementType(), + cast(curr.getDefiningOp()->getOperand(0).getType()) + .getEncoding()); + curr.setType(retType); + } + mapping.map(v, curr); + } + if (vals.size() > 1) + ret = mapping.lookup(vals.back()); +} + +Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, + Attribute dotEncoding, OpBuilder &builder, + std::optional offsetK, + std::optional shapeK) { + // opIdx: 0 => a, 1 => b + auto type = cast(v.getType()); + SmallVector shape{type.getShape().begin(), type.getShape().end()}; + SmallVector offset{0, 0}; + Type elementType = type.getElementType(); + + // k => (prefetchWidth, k - prefetchWidth) + int64_t kIdx = opIdx == 0 ? 1 : 0; + + offset[kIdx] = isPrologue ? 0 : prefetchWidth; + shape[kIdx] = isPrologue ? prefetchWidth : (shape[kIdx] - prefetchWidth); + + if (shapeK) + shape[kIdx] = *shapeK; + if (offsetK) + offset[kIdx] = *offsetK; + + SmallVector offsetsVal; + for (int64_t off : offset) + offsetsVal.push_back( + builder.create(v.getLoc(), off, 32)); + Value newSmem = builder.create( + v.getLoc(), + triton::MemDescType::get(shape, elementType, type.getEncoding()), v, + offsetsVal); + + auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( + builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + Value prefetchSlice = builder.create( + v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), + newSmem); + + return prefetchSlice; +} + +LogicalResult Prefetcher::initialize() { + Block *loop = forOp.getBody(); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + SmallVector dotsInFor; + for (Operation &op : *loop) + if (auto dotOp = dyn_cast(op)) { + // bail out if there exist non v2 dots. + auto dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); + if (!dstEnc || dstEnc.getVersionMajor() != 2) + return failure(); + dotsInFor.push_back(dotOp); + } + + if (dotsInFor.empty()) + return failure(); + + // TODO: segfault (original for still has uses) + // when used in flash attention that has 2 dots in the loop + if (dotsInFor.size() > 1) + return failure(); + + // returns source of cvt + + // returns source of cvt + auto getPrefetchSrc = [](Value v) -> SmallVector { + // walk back to conversion + Operation *op = v.getDefiningOp(); + bool foundConvertFromShared = false; + SmallVector rets; + rets.push_back(op->getResult(0)); + while (op) { + if (op->getNumOperands() != 1) + break; + if (!op->getResult(0).hasOneUse()) + break; + rets.push_back(op->getOperand(0)); + if (auto cvt = dyn_cast(op)) { + foundConvertFromShared = true; + break; + } + op = op->getOperand(0).getDefiningOp(); + } + std::reverse(rets.begin(), rets.end()); + + if (foundConvertFromShared) + return rets; + return {}; + }; + + auto getIncomingOp = [this](Value v) -> Value { + if (auto arg = mlir::dyn_cast(v)) + if (arg.getOwner()->getParentOp() == forOp.getOperation()) + return forOp.getTiedLoopInit(arg)->get(); + return Value(); + }; + + auto getYieldOp = [this](Value v) -> Value { + auto arg = mlir::cast(v); + unsigned yieldIdx = arg.getArgNumber() - forOp.getNumInductionVars(); + return yieldOp.getOperand(yieldIdx); + }; + + for (triton::DotOp dot : dotsInFor) { + auto aType = dot.getA().getType(); + auto bType = dot.getB().getType(); + auto aEnc = + mlir::cast(aType.getEncoding()); + auto bEnc = + mlir::cast(bType.getEncoding()); + int aKWidth = aEnc.getKWidth(); + int bKWidth = bEnc.getKWidth(); + assert(aKWidth == bKWidth); + + auto kSize = aType.getShape()[1]; + + // works better with nvidia tensor cores + unsigned elementWidth = aType.getElementTypeBitWidth(); + if (aKWidth == 0) + prefetchWidth = 256 / elementWidth; + else + prefetchWidth = 8 * aKWidth; + + // Skip prefetching if kSize is less than prefetchWidth + if (kSize < prefetchWidth) + continue; + auto aVals = getPrefetchSrc(dot.getA()); + auto bVals = getPrefetchSrc(dot.getB()); + + if (aVals.size() && bVals.size()) { + Value aSmem = aVals.front(); + Value bSmem = bVals.front(); + Value aHeaderDef = getIncomingOp(aSmem); + Value bHeaderDef = getIncomingOp(bSmem); + // Only prefetch loop arg + if (aHeaderDef && bHeaderDef) { + dots.insert(dot); + dot2aVals[dot] = aVals; + dot2bVals[dot] = bVals; + dot2aHeaderDef[dot] = aHeaderDef; + dot2bHeaderDef[dot] = bHeaderDef; + dot2aLoopArg[dot] = aSmem; + dot2bLoopArg[dot] = bSmem; + dot2aYield[dot] = getYieldOp(aSmem); + dot2bYield[dot] = getYieldOp(bSmem); + } + } + } + + return success(); +} + +void Prefetcher::emitPrologue() { + OpBuilder builder(forOp); + + for (triton::DotOp dot : dots) { + Attribute dotEncoding = dot.getType().getEncoding(); + Value aPrefetched = + generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder); + cloneElementwiseOps(aPrefetched, dot2aVals[dot], builder); + Value bPrefetched = + generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder); + cloneElementwiseOps(bPrefetched, dot2bVals[dot], builder); + + operand2headPrefetch[dot.getA()] = aPrefetched; + operand2headPrefetch[dot.getB()] = bPrefetched; + } +} + +scf::ForOp Prefetcher::createNewForOp() { + OpBuilder builder(forOp); + + SmallVector loopArgs; + for (auto v : forOp.getInitArgs()) + loopArgs.push_back(v); + for (triton::DotOp dot : dots) { + loopArgs.push_back(operand2headPrefetch[dot.getA()]); + loopArgs.push_back(operand2headPrefetch[dot.getB()]); + } + + auto newForOp = builder.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), loopArgs); + + builder.setInsertionPointToStart(newForOp.getBody()); + IRMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + + for (Operation &op : forOp.getBody()->without_terminator()) { + Operation *newOp = builder.clone(op, mapping); + auto dot = dyn_cast(&op); + if (dot && dots.contains(dot)) { + Attribute dotEncoding = dot.getType().getEncoding(); + // prefetched dot + Operation *firstDot = builder.clone(*dot, mapping); + if (Value a = operand2headPrefetch.lookup(dot.getA())) + firstDot->setOperand( + 0, newForOp.getTiedLoopRegionIterArg(&*a.use_begin())); + if (Value b = operand2headPrefetch.lookup(dot.getB())) + firstDot->setOperand( + 1, newForOp.getTiedLoopRegionIterArg(&*b.use_begin())); + + // remaining part + int64_t kOff = prefetchWidth; + int64_t kRem = dot.getA().getType().getShape()[1] - prefetchWidth; + Operation *prevDot = firstDot; + while (kRem != 0) { + // int64_t kShape = largestPow2(kRem); + int64_t kShape = prefetchWidth; + auto insertionPoint = builder.saveInsertionPoint(); + builder.setInsertionPoint(prevDot); + Value aRem = + generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false, + dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(aRem, dot2aVals[dot], builder); + Value bRem = + generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false, + dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(bRem, dot2bVals[dot], builder); + builder.restoreInsertionPoint(insertionPoint); + newOp = builder.clone(*dot, mapping); + newOp->setOperand(0, aRem); + newOp->setOperand(1, bRem); + newOp->setOperand(2, prevDot->getResult(0)); + prevDot = newOp; + kOff += kShape; + kRem -= kShape; + } + } + // update mapping of results + for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) + mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); + } + + // prefetch next iteration + SmallVector yieldValues; + for (Value v : forOp.getBody()->getTerminator()->getOperands()) + yieldValues.push_back(mapping.lookupOrDefault(v)); + for (triton::DotOp dot : dots) { + Attribute dotEncoding = dot.getType().getEncoding(); + Value aToYield = generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, true, + dotEncoding, builder); + cloneElementwiseOps(aToYield, dot2aVals[dot], builder); + yieldValues.push_back(aToYield); + // bToYield + Value bToYield = generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, true, + dotEncoding, builder); + cloneElementwiseOps(bToYield, dot2bVals[dot], builder); + yieldValues.push_back(bToYield); + } + // Update ops of yield + if (!yieldValues.empty()) + builder.create(yieldOp.getLoc(), yieldValues); + return newForOp; +} + +} // anonymous namespace + +struct PrefetchPass : public impl::TritonGPUPrefetchBase { + void runOnOperation() override { + + // Canonicalize convert ops to make the pattern matching easier. + RewritePatternSet cleanUpPatterns(&getContext()); + triton::gpu::ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, + &getContext()); + if (mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(cleanUpPatterns)) + .failed()) { + signalPassFailure(); + } + getOperation()->walk([&](scf::ForOp forOp) { + Prefetcher prefetcher(forOp); + + if (prefetcher.initialize().failed()) + return; + + prefetcher.emitPrologue(); + + scf::ForOp newForOp = prefetcher.createNewForOp(); + + // replace the original loop + for (unsigned i = 0; i < forOp->getNumResults(); ++i) + forOp->getResult(i).replaceAllUsesWith(newForOp->getResult(i)); + forOp->erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp new file mode 100644 index 000000000..c0b586d60 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -0,0 +1,91 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREDUCEDATADUPLICATION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUReduceDataDuplicationPass + : public impl::TritonGPUReduceDataDuplicationBase< + TritonGPUReduceDataDuplicationPass> { +public: + void runOnOperation() override { + ModuleOp mod = getOperation(); + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cast(cvtOp.getSrc().getType()); + auto dstType = cast(cvtOp.getType()); + auto srcEncoding = srcType.getEncoding(); + if (isa(srcEncoding)) + return; + auto dstDotOp = + dyn_cast(dstType.getEncoding()); + if (!dstDotOp) + return; + if (auto srcMmaEncoding = + dyn_cast(srcEncoding)) { + + if (srcMmaEncoding.getVersionMajor() != 2 || + (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && + dstDotOp.getParent() == srcMmaEncoding)) + return; + } + if (auto srcMfmaEncoding = + dyn_cast(srcEncoding)) { + + if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 && + srcMfmaEncoding.getIsTransposed() && + dstDotOp.getParent() == srcMfmaEncoding) + return; + } + auto srcOrder = triton::gpu::getOrder(srcEncoding); + auto rank = srcOrder.size(); + SmallVector sharedOrder; + if (rank == 3) { + // add all elements except the element that is zero + for (unsigned i = 0; i < rank; ++i) + if (srcOrder[i] != 0) + sharedOrder.emplace_back(srcOrder[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = srcOrder; + } + auto tmpType = triton::MemDescType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::SharedEncodingAttr::get( + mod.getContext(), dstDotOp, srcType.getShape(), sharedOrder, + triton::gpu::getCTALayout(srcEncoding), + srcType.getElementType())); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getSrc()); + auto newConvert = builder.create(cvtOp.getLoc(), + dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp new file mode 100644 index 000000000..967d34c8f --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -0,0 +1,1321 @@ +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREMOVELAYOUTCONVERSIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-remove-layout-conversions" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0)) +class ConvertDotConvert : public RewritePattern { +public: + ConvertDotConvert(MLIRContext *context) + : RewritePattern(ConvertLayoutOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto dstOp = cast(op); + auto dotOp = dstOp.getSrc().getDefiningOp(); + if (!dotOp) + return failure(); + if (std::distance(dstOp->user_begin(), dstOp->user_end()) != 1 || + std::distance(dotOp->user_begin(), dotOp->user_end()) != 1) + return failure(); + auto cvtOp = dotOp.getOperand(2).getDefiningOp(); + if (!cvtOp) + return failure(); + if (!cvtOp.getSrc().getDefiningOp()) + return failure(); + RankedTensorType dstTy = dstOp.getType(); + RankedTensorType srcTy = cvtOp.getSrc().getType(); + if (dstTy != srcTy) + return failure(); + + auto _0f = rewriter.create( + op->getLoc(), dstTy.getElementType(), + rewriter.getZeroAttr(dstTy.getElementType())); + auto _0 = rewriter.create(op->getLoc(), dotOp.getType(), _0f); + auto newDot = rewriter.create( + op->getLoc(), dotOp.getType(), dotOp.getOperand(0), dotOp.getOperand(1), + _0, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); + auto newCvt = rewriter.create(op->getLoc(), dstTy, + newDot.getResult()); + rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getSrc()); + return success(); + } +}; + +// The current algorithm works by analyzing the IR and doing a one-shot rewrite +// based on the analysis. The algorithm is as follows. +// +// 1. Find all the anchor ops. These are ops that have a layout we want to +// preserve. +// +// 2. For each anchor, propagate its layout to all its descendants. +// An op can have multiple ancestors that are anchors, so at this stage an op +// may have multiple layouts associated with it. +// +// 3. Resolve conflicts by deciding which of the multiple layouts the op should +// keep, inserting convert-layout ops to resolve conflicts. After this +// stage, each value has only one layout associated with it. +// +// 4. Rewrite the IR by walking the function in dominance order. Since we +// assume the IR is structured we just need to process the regions in the +// correct order. For each op, rewrite it using the layout decided by the +// analysis phase. +class LayoutPropagation { +public: + // Structure to keep track of the layout associated to a value. + struct LayoutInfo { + LayoutInfo(Attribute encoding) { encodings.insert(encoding); } + LayoutInfo() {} + llvm::SmallSetVector encodings; + }; + LayoutPropagation(FuncOp F) : funcOp(F) {} + // Find the anchor ops and set their layout in the data structure. + void initAnchorLayout(); + // Recursively Propagate the layout to all the users of the anchor ops until + // we reach a fix point. + void propagateLayout(); + // Add layouts given in `Info` to the uses of `value`. + SmallVector propagateToUsers(Value value, LayoutInfo &info); + // Set the encoding to all the values and fill out the values with new layout + // in `changed`. + void setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, Operation *op); + // Resolve cases where a value has multiple layouts associated to it. + void resolveConflicts(); + // Rewrite the IR for the full module. + void rewrite(); + // Rewrite the IR for a region. + void rewriteRegion(Region &R); + // Rewrite an op based on the layout picked by the analysis. + Operation *rewriteOp(Operation *op); + // Rewrite a for op based on the layout picked by the analysis. + Operation *rewriteForOp(scf::ForOp forOp); + Operation *rewriteWhileOp(scf::WhileOp whileOp); + Operation *rewriteIfOp(scf::IfOp ifOp); + void rewriteYieldOp(scf::YieldOp yieldOp); + void rewriteConditionOp(scf::ConditionOp conditionOp); + void rewriteReduceToScalar(Operation *reduceOp); + void rewriteAssertOp(AssertOp assertOp); + Operation *cloneElementwise(OpBuilder &rewriter, Operation *op, + Attribute encoding); + // Map the original value to the rewritten one. + void map(Value old, Value newV); + // Return the mapped value in the given encoding. This will insert a convert + // if the encoding is different than the encoding decided at resolve time. + Value getValueAs(Value value, Attribute encoding); + // Dump the current stage of layout information. + void dump(); + +private: + // map from value to layout information. + llvm::MapVector layouts; + // map of the values rewrite based on their encoding. + DenseMap, Value> rewriteMapping; + SetVector opToDelete; + FuncOp funcOp; +}; + +class LayoutRematerialization { +public: + LayoutRematerialization(FuncOp F) : funcOp(F) {} + // Map the original value to the remat'ed one. + void addRematValue(Value old, Attribute encoding, Value newV); + bool hasRematValue(Value value, Attribute encoding) { + return rematMapping.contains({value, encoding}); + } + // Return the remat'ed value in the given encoding. + Value getRematValue(Value value, Attribute encoding) { + auto it = rematMapping.find({value, encoding}); + assert(it != rematMapping.end()); + return it->second; + } + void cleanup(); + void backwardRematerialization(); + void backwardRematerialization(ConvertLayoutOp convertOp); + void hoistConvertOnTopOfExtOrBroadcast(); + void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp, IRMapping &mapping); + void rewriteSlice(SetVector &slice, DenseMap &layout, + ConvertLayoutOp convertOp); + +private: + void updateRematMapping(SmallVector> &values); + // Existing tuples of (value, layout) that needs to be updated when recreating + // scf ops. This prevents keeping track of Values that have been delete when + // rewriting slices. + DenseMap mappedValues; + // map of the values remat based on encoding. + DenseMap, Value> rematMapping; + // DenseMap, Operation*> + SetVector opToDelete; + FuncOp funcOp; +}; + +void LayoutRematerialization::addRematValue(Value old, Attribute encoding, + Value newV) { + LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); + rematMapping[{old, encoding}] = newV; + mappedValues[old] = encoding; +} + +// Remove unneeded values now that we are done with the rematMapping. +void LayoutRematerialization::cleanup() { + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +// Look ahead to at the transitive uses and see if there is a convert to mma +// operations. +bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { + SmallVector queue = {op->getResult(0)}; + SetVector forwardSlice; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + Value currentValue = queue.back(); + queue.pop_back(); + getForwardSlice(currentValue, &forwardSlice); + for (Operation *op : forwardSlice) { + // HACK: Stop propagation if the ReduceOp is using mma layout but is + // producing tensor smaller than the layout we would like to propagate. + // This is to avoid stepping into the known bug. + if (isa(op)) { + auto tensorType = + dyn_cast(op->getOperand(0).getType()); + if (tensorType && + isa(tensorType.getEncoding())) { + auto mmaInstrShape = + cast(encoding).getInstrShape(); + if (tensorType.getShape()[tensorType.getRank() - 2] < + mmaInstrShape[0] || + tensorType.getShape()[tensorType.getRank() - 1] < + mmaInstrShape[1]) { + return false; + } + } + } + + if (auto convertOp = dyn_cast(op)) { + Attribute dstEncoding = convertOp.getType().getEncoding(); + if (auto mmaLayout = dyn_cast(dstEncoding)) + return (mmaLayout.getVersionMajor() > 1) ? true + : mmaLayout == encoding; + if (isa(dstEncoding)) + return true; + if (isa(dstEncoding)) { + if (auto mmaLayout = dyn_cast(encoding)) { + return mmaLayout.getVersionMajor() > 1; + } else { + assert((mlir::isa(encoding))); + return true; + } + } + } + bool isMMAV3 = + isa(encoding) && + cast(encoding).getVersionMajor() == 3; + if (isMMAV3 && (isa(op) || isa(op))) + return true; + auto yield = dyn_cast(op); + if (!yield) + continue; + if (auto ifOp = dyn_cast(yield->getParentOp())) { + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && + (forwardSlice.count(def) || operand.get() == currentValue) && + (seen.insert(operand.get()).second == true)) + queue.push_back(ifOp.getResult(operand.getOperandNumber())); + } + } + auto forOp = dyn_cast(yield.getOperation()->getParentOp()); + if (!forOp) + continue; + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && (forwardSlice.count(def) || operand.get() == currentValue) && + (seen.insert(operand.get()).second == true)) + queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); + } + } + } + return false; +} + +// Return true if the op is an op with a layout we don't want to change. We will +// propagate the layout starting from anchor ops. +bool isLayoutAnchor(Operation *op) { + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return true; + + // Heuristic: Mark permuting reshape as a layout anchor. Its dst can be + // anything, so it stops forward-propagation of layouts. We rely on the + // backwards pass to fix it up if necessary. (If we didn't do this, then + // anything following the reshape won't be covered by the forward pass at + // all.) + if (auto reshape = dyn_cast(op)) + return reshape.getAllowReorder(); + + return false; +} + +void LayoutPropagation::initAnchorLayout() { + auto maybeAddAnchor = [&](Value v) { + if (auto tensorType = dyn_cast(v.getType())) { + // Workaround, don't popagate MMA layout unless there is a convert + // back to mma further down to avoid generating reduction with MMA + // layout that may have lower performance. + // This can be improved with more aggressive backward propagation. + if (isa(tensorType.getEncoding()) && + v.getDefiningOp() && + !hasConvertToMMATransisitiveUse(v.getDefiningOp(), + tensorType.getEncoding())) { + return; + } + layouts.insert({v, LayoutInfo(tensorType.getEncoding())}); + } + }; + + // Consider function args as anchors. This makes it easier to write tests -- + // you can pass a tensor with an encoding as an arg, instead of explicitly + // calling tt.load. + for (auto arg : funcOp.getArguments()) { + maybeAddAnchor(arg); + } + + funcOp.walk([&](Operation *op) { + if (isLayoutAnchor(op)) { + for (auto result : op->getResults()) { + maybeAddAnchor(result); + } + } + }); +} + +void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, + SmallVector &changed, + Operation *op) { + for (Value value : values) { + if (!isa(value.getType())) + continue; + bool hasChanged = false; + for (auto encoding : info.encodings) { + std::optional dstEncoding; + if (isa(op)) { + // Try to remove the convert by making the dst encoding match the source + // encoding. + dstEncoding = encoding; + } else { + dstEncoding = inferDstEncoding(op, encoding); + } + if (dstEncoding) + hasChanged |= layouts[value].encodings.insert(*dstEncoding); + } + if (hasChanged) + changed.push_back(value); + } +} + +SmallVector LayoutPropagation::propagateToUsers(Value value, + LayoutInfo &info) { + SmallVector changed; + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (auto forOp = dyn_cast(user)) { + Value arg = forOp.getTiedLoopRegionIterArg(&use); + Value result = forOp.getTiedLoopResult(&use); + setEncoding({arg, result}, info, changed, user); + continue; + } + if (auto whileOp = dyn_cast(user)) { + Value arg = whileOp.getBeforeArguments()[use.getOperandNumber()]; + setEncoding({arg}, info, changed, user); + continue; + } + if (auto yieldOp = dyn_cast(user)) { + auto parent = yieldOp->getParentOp(); + SmallVector valuesToPropagate; + if (isa(parent)) + valuesToPropagate.push_back(parent->getResult(use.getOperandNumber())); + if (auto forOp = dyn_cast(parent)) + valuesToPropagate.push_back( + forOp.getRegionIterArg(use.getOperandNumber())); + if (auto whileOp = dyn_cast(parent)) { + valuesToPropagate.push_back( + whileOp.getBeforeArguments()[use.getOperandNumber()]); + valuesToPropagate.push_back( + whileOp->getOperand(use.getOperandNumber())); + } + if (isa(parent)) + setEncoding(valuesToPropagate, info, changed, user); + continue; + } + if (auto conditionOp = dyn_cast(user)) { + auto whileOp = cast(conditionOp->getParentOp()); + // Skip arg 0 as it is the condition. + unsigned argIndex = use.getOperandNumber() - 1; + Value afterArg = whileOp.getAfterArguments()[argIndex]; + Value result = whileOp->getResult(argIndex); + setEncoding({afterArg, result}, info, changed, user); + continue; + } + if (user->hasTrait() || + user->hasTrait() || + isa(user)) { + setEncoding(user->getResults(), info, changed, user); + continue; + } + } + return changed; +} + +void LayoutPropagation::propagateLayout() { + SmallVector queue; + for (auto it : layouts) { + queue.push_back(it.first); + } + while (!queue.empty()) { + Value currentValue = queue.back(); + LayoutInfo info = layouts[currentValue]; + queue.pop_back(); + SmallVector changed = propagateToUsers(currentValue, info); + + LLVM_DEBUG({ + DBGS() << "propagateLayout considering " << currentValue << ", which has " + << info.encodings.size() << " candidate encoding(s):\n"; + for (Attribute encoding : info.encodings) + DBGS() << " " << encoding << "\n"; + }); + + queue.insert(queue.end(), changed.begin(), changed.end()); + } +} + +void LayoutPropagation::resolveConflicts() { + for (auto &it : layouts) { + Operation *op = it.first.getDefiningOp(); + LayoutInfo &info = it.second; + if (info.encodings.size() <= 1) + continue; + // Hacky resolve, prefer block encoding. + // TODO: add a proper heuristic. + Attribute encoding = *info.encodings.begin(); + bool isLoadOrStore = + op && isa(op); + for (Attribute e : info.encodings) { + if ((isLoadOrStore && isa(e)) || + (!isLoadOrStore && isa(e))) { + encoding = e; + break; + } + } + info.encodings.clear(); + info.encodings.insert(encoding); + } +} + +void LayoutPropagation::dump() { + for (auto it : layouts) { + llvm::errs() << "Value: "; + OpPrintingFlags flags; + flags.skipRegions(); + it.first.print(llvm::errs(), flags); + llvm::errs() << " \n encoding:\n"; + for (auto encoding : it.second.encodings) { + encoding.print(llvm::errs()); + llvm::errs() << "\n"; + } + llvm::errs() << "--\n"; + } +} + +void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } + +bool reduceToScalar(Operation *op) { + // For reductions returning a scalar we can change the src encoding without + // affecting the output. + return isa(op) && !isa(op->getResultTypes()[0]); +} + +void LayoutPropagation::rewriteRegion(Region ®ion) { + SmallVector queue = {®ion}; + while (!queue.empty()) { + Region *currentRegion = queue.back(); + queue.pop_back(); + for (Operation &op : currentRegion->getOps()) { + bool needRewrite = false; + SmallVector results = op.getResults(); + for (Value result : results) { + auto it = layouts.find(result); + // If we haven't mapped this value skip. + if (it == layouts.end()) + continue; + LayoutInfo &info = it->second; + assert(info.encodings.size() == 1 && + "we should have resolved to a single encoding"); + auto encoding = cast(result.getType()).getEncoding(); + // If the encoding is already what we want skip. + if (encoding == *info.encodings.begin()) + continue; + needRewrite = true; + } + if (needRewrite) { + Operation *newOp = rewriteOp(&op); + for (Region &R : newOp->getRegions()) + queue.push_back(&R); + } else if (auto yieldOp = dyn_cast(&op)) { + rewriteYieldOp(yieldOp); + } else if (auto conditionOp = dyn_cast(&op)) { + rewriteConditionOp(conditionOp); + } else if (reduceToScalar(&op)) { + rewriteReduceToScalar(&op); + } else if (auto assertOp = dyn_cast(&op)) { + rewriteAssertOp(assertOp); + } else { + // If we don't need to rewrite the op we still need to remap the + // operands. + for (OpOperand &operand : op.getOpOperands()) { + auto it = layouts.find(operand.get()); + if (it == layouts.end()) + continue; + Attribute encoding = + cast(operand.get().getType()).getEncoding(); + Value newOperand = getValueAs(operand.get(), encoding); + op.setOperand(operand.getOperandNumber(), newOperand); + } + for (Region &R : op.getRegions()) + queue.push_back(&R); + } + } + } + for (Operation *op : llvm::reverse(opToDelete)) + op->erase(); +} + +void LayoutPropagation::map(Value old, Value newV) { + rewriteMapping[{old, cast(newV.getType()).getEncoding()}] = + newV; +} + +Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { + if (auto tensorType = dyn_cast(value.getType())) { + Value rewrittenValue; + auto layoutIt = layouts.find(value); + if (layoutIt == layouts.end()) { + rewrittenValue = value; + } else { + assert(layoutIt->second.encodings.size() == 1 && + "we should have resolved to a single encoding"); + Attribute encodingPicked = *(layoutIt->second.encodings.begin()); + if (encodingPicked == tensorType.getEncoding()) + rewrittenValue = value; + else + rewrittenValue = rewriteMapping[{value, encodingPicked}]; + } + assert(rewrittenValue); + if (cast(rewrittenValue.getType()).getEncoding() == + encoding) + return rewrittenValue; + OpBuilder rewriter(value.getContext()); + rewriter.setInsertionPointAfterValue(rewrittenValue); + auto tmpType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + Value converted = rewriter.create(value.getLoc(), tmpType, + rewrittenValue); + // TODO: we could cache the conversion. + return converted; + } + return value; +} + +Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, + Operation *op, + Attribute encoding) { + Operation *newOp = rewriter.clone(*op); + + std::optional operandEnc; + if (op->getNumOperands() > 0) { + operandEnc = inferSrcEncoding(op, encoding); + assert(operandEnc.has_value()); + } + + for (OpOperand &operand : op->getOpOperands()) { + newOp->setOperand(operand.getOperandNumber(), + getValueAs(operand.get(), *operandEnc)); + } + + for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { + auto origType = dyn_cast(op->getResult(i).getType()); + if (!origType) + continue; + auto newType = RankedTensorType::get(origType.getShape(), + origType.getElementType(), encoding); + newOp->getResult(i).setType(newType); + } + return newOp; +} + +Operation *LayoutPropagation::rewriteForOp(scf::ForOp forOp) { + SmallVector operands; + OpBuilder rewriter(forOp); + for (auto [operand, result] : + llvm::zip(forOp.getInitArgs(), forOp.getResults())) { + Value convertedOperand = operand; + if (layouts.count(result)) + convertedOperand = + getValueAs(operand, *layouts[result].encodings.begin()); + operands.push_back(convertedOperand); + } + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), operands); + newForOp->setAttrs(forOp->getAttrs()); + newForOp.getBody()->getOperations().splice( + newForOp.getBody()->getOperations().begin(), + forOp.getBody()->getOperations()); + + for (auto [oldResult, newResult] : + llvm::zip(forOp.getResults(), newForOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + + for (auto [oldArg, newArg] : llvm::zip(forOp.getBody()->getArguments(), + newForOp.getBody()->getArguments())) { + if (oldArg.getType() == newArg.getType()) { + oldArg.replaceAllUsesWith(newArg); + continue; + } + map(oldArg, newArg); + } + return newForOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) { + SmallVector operands; + SmallVector returnTypes; + OpBuilder rewriter(whileOp); + for (auto [operand, arg] : + llvm::zip(whileOp->getOperands(), whileOp.getBeforeArguments())) { + Value convertedOperand = operand; + if (layouts.count(arg)) + convertedOperand = getValueAs(operand, *layouts[arg].encodings.begin()); + operands.push_back(convertedOperand); + } + for (Value ret : whileOp.getResults()) { + auto it = layouts.find(ret); + if (it == layouts.end()) { + returnTypes.push_back(ret.getType()); + continue; + } + auto origType = dyn_cast(ret.getType()); + auto newType = + RankedTensorType::get(origType.getShape(), origType.getElementType(), + it->second.encodings[0]); + returnTypes.push_back(newType); + } + + auto newWhileOp = + rewriter.create(whileOp.getLoc(), returnTypes, operands); + SmallVector argsTypesBefore; + for (Value operand : operands) + argsTypesBefore.push_back(operand.getType()); + SmallVector bbArgLocsBefore(argsTypesBefore.size(), + whileOp.getLoc()); + SmallVector bbArgLocsAfter(returnTypes.size(), whileOp.getLoc()); + rewriter.createBlock(&newWhileOp.getBefore(), {}, argsTypesBefore, + bbArgLocsBefore); + rewriter.createBlock(&newWhileOp.getAfter(), {}, returnTypes, bbArgLocsAfter); + + for (int i = 0; i < whileOp.getNumRegions(); ++i) { + newWhileOp->getRegion(i).front().getOperations().splice( + newWhileOp->getRegion(i).front().getOperations().begin(), + whileOp->getRegion(i).front().getOperations()); + } + + auto remapArg = [&](Value oldVal, Value newVal) { + if (oldVal.getType() == newVal.getType()) + oldVal.replaceAllUsesWith(newVal); + else + map(oldVal, newVal); + }; + for (auto [oldResult, newResult] : + llvm::zip(whileOp.getResults(), newWhileOp.getResults())) + remapArg(oldResult, newResult); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getBeforeArguments(), newWhileOp.getBeforeArguments())) + remapArg(oldArg, newArg); + for (auto [oldArg, newArg] : + llvm::zip(whileOp.getAfterArguments(), newWhileOp.getAfterArguments())) + remapArg(oldArg, newArg); + return newWhileOp.getOperation(); +} + +Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { + SmallVector operands; + OpBuilder rewriter(ifOp); + SmallVector newResultTypes(ifOp->getResultTypes()); + for (unsigned i = 0, e = ifOp->getNumResults(); i < e; ++i) { + auto it = layouts.find(ifOp->getResult(i)); + if (it == layouts.end()) + continue; + auto origType = cast(ifOp->getResult(i).getType()); + Attribute encoding = *(it->second.encodings.begin()); + newResultTypes[i] = RankedTensorType::get( + origType.getShape(), origType.getElementType(), encoding); + } + auto newIfOp = rewriter.create(ifOp.getLoc(), newResultTypes, + ifOp.getCondition(), true, true); + newIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); + newIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); + for (auto [oldResult, newResult] : + llvm::zip(ifOp.getResults(), newIfOp.getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newIfOp.getOperation(); +} + +void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { + Operation *parentOp = yieldOp->getParentOp(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + Type yieldType = operand.get().getType(); + if (isa(parentOp)) + yieldType = parentOp->getResult(operand.getOperandNumber()).getType(); + if (auto whileOp = dyn_cast(parentOp)) + yieldType = + whileOp.getBeforeArguments()[operand.getOperandNumber()].getType(); + auto tensorType = dyn_cast(yieldType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + yieldOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) { + scf::WhileOp whileOp = cast(conditionOp->getParentOp()); + for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) { + OpOperand &operand = conditionOp->getOpOperand(i); + Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType(); + auto tensorType = dyn_cast(argType); + if (!tensorType) + continue; + Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); + conditionOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteReduceToScalar(Operation *reduceOp) { + OpBuilder rewriter(reduceOp); + Attribute srcEncoding; + // Since all the operands need to have the same encoding pick the first one + // and use it for all the operands. + for (Value operand : reduceOp->getOperands()) { + auto it = layouts.find(operand); + if (it != layouts.end()) { + srcEncoding = it->second.encodings[0]; + break; + } + } + if (!srcEncoding) + return; + for (OpOperand &operand : reduceOp->getOpOperands()) { + Value newOperand = getValueAs(operand.get(), srcEncoding); + reduceOp->setOperand(operand.getOperandNumber(), newOperand); + } +} + +void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) { + Attribute srcEncoding; + // Only need to deal with the first operand which is the condition tensor. + Value operand = assertOp->getOperand(0); + auto it = layouts.find(operand); + if (it == layouts.end()) + return; + srcEncoding = it->second.encodings[0]; + Value newOperand = getValueAs(operand, srcEncoding); + assertOp->setOperand(0, newOperand); +} + +Operation *LayoutPropagation::rewriteOp(Operation *op) { + opToDelete.insert(op); + if (auto forOp = dyn_cast(op)) + return rewriteForOp(forOp); + if (auto whileOp = dyn_cast(op)) + return rewriteWhileOp(whileOp); + if (auto ifOp = dyn_cast(op)) + return rewriteIfOp(ifOp); + OpBuilder rewriter(op); + Attribute encoding = *layouts[op->getResult(0)].encodings.begin(); + if (auto convertOp = dyn_cast(op)) { + Attribute srcEncoding = convertOp.getSrc().getType().getEncoding(); + auto it = layouts.find(convertOp.getSrc()); + if (it != layouts.end()) + srcEncoding = *(it->second.encodings.begin()); + Value src = getValueAs(convertOp.getSrc(), srcEncoding); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, src); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (canFoldIntoConversion(op, encoding)) { + Operation *newOp = rewriter.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); + auto cvt = rewriter.create(op->getLoc(), newType, + newOp->getResult(0)); + map(op->getResult(0), cvt.getResult()); + return cvt.getOperation(); + } + if (op->hasTrait() || + op->hasTrait() || + isa(op)) { + Operation *newOp = cloneElementwise(rewriter, op, encoding); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + if (oldResult.getType() == newResult.getType()) { + oldResult.replaceAllUsesWith(newResult); + continue; + } + map(oldResult, newResult); + } + return newOp; + } + llvm::report_fatal_error("unexpected op in rewrite"); + return nullptr; +} + +bool canBeRemat(Operation *op) { + if (isa(op)) + return !isExpensiveLoadOrStore(op); + if (isa(op)) + return false; + if (isa(op)) + return false; + + return true; +} + +void LayoutRematerialization::updateRematMapping( + SmallVector> &values) { + for (auto [old, newV] : values) { + auto it = mappedValues.find(old); + if (it != mappedValues.end()) { + Attribute encoding = it->second; + auto rematIt = rematMapping.find({old, it->second}); + assert(rematIt != rematMapping.end()); + Value replacedValue = rematIt->second; + rematMapping.erase(rematIt); + mappedValues.erase(it); + // Loop through the replacement value to find the new version of remat + // value. This should be okay as the number of values should be small. + for (auto [before, after] : values) { + if (before == replacedValue) { + replacedValue = after; + break; + } + } + rematMapping[{newV, encoding}] = replacedValue; + mappedValues[newV] = encoding; + } + } +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp, + IRMapping &mapping) { + SetVector opsToRewrite; + // Keep track of yield operands that need to be duplicated. + DenseMap> yieldOperandsMap; + // Keep these around to remove them from the slice after our collection pass + // This ensures we don't duplicate them during an for rewrite or causing the + // for/yield to fall out of sync + SetVector valuesWithExistingRemat; + for (Value v : slice) { + auto layoutIt = layout.find(v); + assert(layoutIt != layout.end()); + // If we already have a remat value for this value, use it. + if (hasRematValue(v, layoutIt->second)) { + mapping.map(v, getRematValue(v, layoutIt->second)); + valuesWithExistingRemat.insert(v); + continue; + } + if (v.getDefiningOp()) { + opsToRewrite.insert(v.getDefiningOp()); + if (auto ifOp = v.getDefiningOp()) { + unsigned operandIdx = cast(v).getResultNumber(); + opsToRewrite.insert(ifOp.thenYield().getOperation()); + yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx); + opsToRewrite.insert(ifOp.elseYield().getOperation()); + yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx); + } + } else { + BlockArgument blockArg = cast(v); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (auto loopOp = cast(parentOp)) { + opsToRewrite.insert(loopOp.getOperation()); + OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg); + auto yieldOp = blockArg.getOwner()->getTerminator(); + yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber()); + opsToRewrite.insert(yieldOp); + } + } + } + slice.set_subtract(valuesWithExistingRemat); + opsToRewrite = multiRootTopologicalSort(opsToRewrite); + + // replaceAllUsesWith calls delayed until after initial rewrite. + // This is required for slice.count(value) to work mid rewrite. + SmallVector> replacements; + + SmallVector deadOps; + IRRewriter builder(slice.begin()->getContext()); + for (Operation *op : opsToRewrite) { + if (auto forOp = dyn_cast(op)) { + // Keep a mapping of the operands index to the new operands index. + SmallVector> argMapping; + SmallVector newOperands; + for (auto arg : forOp.getRegionIterArgs()) { + if (slice.count(arg)) { + OpOperand &initVal = *forOp.getTiedLoopInit(arg); + argMapping.push_back(std::make_pair( + forOp.getTiedLoopResult(&initVal).getResultNumber(), + forOp.getInitArgs().size() + newOperands.size())); + newOperands.push_back(mapping.lookup(initVal.get())); + } + } + // Create a new for loop with the new operands. + scf::ForOp newForOp = replaceForOpWithNewSignature( + builder, forOp, newOperands, replacements); + deadOps.push_back(forOp.getOperation()); + Block &loopBody = *newForOp.getBody(); + for (auto m : argMapping) { + mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second)); + int numIndVars = newForOp.getNumInductionVars(); + mapping.map(loopBody.getArgument(m.first + numIndVars), + loopBody.getArgument(m.second + numIndVars)); + LLVM_DEBUG({ + DBGS() << "mapping forOp " + << loopBody.getArgument(m.first + numIndVars) << " to " + << loopBody.getArgument(m.second + numIndVars) << '\n'; + }); + // The result is not in the layout/slice, the argument is. + Value oldArg = loopBody.getArgument(m.first + numIndVars); + addRematValue(newForOp.getResult(m.first), layout[oldArg], + newForOp.getResult(m.second)); + addRematValue(oldArg, layout[oldArg], + loopBody.getArgument(m.second + numIndVars)); + } + continue; + } + if (auto ifOp = dyn_cast(op)) { + SmallVector newTypes; + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + auto it = layout.find(res); + assert(it != layout.end()); + + auto oldType = cast(res.getType()); + auto newType = RankedTensorType::get( + oldType.getShape(), oldType.getElementType(), it->second); + newTypes.push_back(newType); + } + } + scf::IfOp newIfOp = + replaceIfOpWithNewSignature(builder, ifOp, newTypes, replacements); + unsigned oldIdx = 0; + unsigned newIdx = ifOp.getNumResults(); + for (auto res : ifOp.getResults()) { + if (slice.count(res)) { + // Why can't we use res instead of ifOp.getResult(oldIdx)? + mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx)); + addRematValue(ifOp.getResult(oldIdx), layout[res], + newIfOp.getResult(newIdx)); + ++newIdx; + } + ++oldIdx; + } + deadOps.push_back(ifOp.getOperation()); + continue; + } + builder.setInsertionPoint(op); + if (auto yieldOp = dyn_cast(op)) { + auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); + SmallVector operandsToRewrite = yieldOperandsMap[op]; + // Sort so that operands are added in the same order as the new scf + // results/arguments. + std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); + for (int operandIdx : operandsToRewrite) { + yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); + } + builder.create(op->getLoc(), yieldOperands); + op->erase(); + continue; + } + if (isa(op)) { + Operation *newOp = builder.clone(*op); + auto tensorType = cast(op->getResult(0).getType()); + auto newType = RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), + layout[op->getResult(0)]); + auto cvt = builder.create(op->getLoc(), newType, + newOp->getResult(0)); + mapping.map(op->getResult(0), cvt.getResult()); + addRematValue(op->getResult(0), layout[op->getResult(0)], + cvt.getResult()); + continue; + } + Operation *newOp = builder.clone(*op, mapping); + for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { + auto it = layout.find(old); + if (it == layout.end()) + continue; + auto newType = RankedTensorType::get( + cast(old.getType()).getShape(), + cast(old.getType()).getElementType(), it->second); + newV.setType(newType); + addRematValue(old, it->second, newV); + } + } + // Check mapping and see if there are existing convertOps on the old Argument + convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc())); + opToDelete.insert(convertOp); + + updateRematMapping(replacements); + for (auto &kv : replacements) { + builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + + for (Operation *op : deadOps) + opToDelete.insert(op); +} + +void LayoutRematerialization::rewriteSlice(SetVector &slice, + DenseMap &layout, + ConvertLayoutOp convertOp) { + IRMapping mapping; + rewriteSlice(slice, layout, convertOp, mapping); +} + +LogicalResult getRematerializableSlice( + Value root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation = nullptr) { + LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding, + layout, stopPropagation); + if (result.failed() || slice.empty()) + return failure(); + + // Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return failure(); + } + } + return success(); +} + +void LayoutRematerialization::backwardRematerialization() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + backwardRematerialization(convertOp); + } +} + +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { + // Go through each ConvertLayoutOp. + SmallVector convertOps; + funcOp.walk( + [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); + for (ConvertLayoutOp convertOp : convertOps) { + hoistConvertOnTopOfExtOrBroadcast(convertOp); + } +} + +void LayoutRematerialization::backwardRematerialization( + ConvertLayoutOp convertOp) { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristic to accommodate fused attention + RankedTensorType targetType = convertOp.getType(); + if (isa(targetType.getEncoding())) + return; + Value oldV = convertOp->getOperand(0); + LDBG("check backward remat with source " << oldV << " encoding " + << targetType.getEncoding()); + // Check to see if there are existing remat'ed values for the pair of oldValue + // and encoding. + if (hasRematValue(oldV, targetType.getEncoding())) { + // Replace it with the remat'ed value. + Value newV = getRematValue(oldV, targetType.getEncoding()); + convertOp.replaceAllUsesWith(newV); + opToDelete.insert(convertOp); + LDBG("found remat'ed value" << newV); + return; + } + + // 1. Take a backward slice of all the tensor dependencies that can be + // rematerialized. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getSrc(), targetType.getEncoding(), slice, layout); + if (result.failed()) { + LDBG(" getRematerializableSlice failed"); + return; + } + + LLVM_DEBUG({ + DBGS() << " remat convert op " << convertOp << '\n'; + for (Value v : slice) + DBGS() << " " << v << '\n'; + }); + // 2. Rewrite the slice. + rewriteSlice(slice, layout, convertOp); +} + +// For convert left we try to hoist them above type extension to reduce the cost +// of the convert. +void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( + ConvertLayoutOp convertOp) { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + RankedTensorType targetType = convertOp.getType(); + if (mlir::isa(targetType.getEncoding())) + return; + + auto isExtOrBroadcastOp = [](Operation *op) { + if (isa(op)) { + return true; + } + if (auto fpToFpOp = dyn_cast(op)) { + auto srcType = cast(fpToFpOp.getOperand().getType()); + return getElementBitWidth(srcType) < + getElementBitWidth(fpToFpOp.getType()); + } + return false; + }; + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + LogicalResult result = + getRematerializableSlice(convertOp.getSrc(), targetType.getEncoding(), + slice, layout, isExtOrBroadcastOp); + if (result.failed()) + return; + + Operation *extOrBroadcatOp = nullptr; + unsigned sliceSize = slice.size(); + for (unsigned i = 0; i < sliceSize; i++) { + Value v = slice[i]; + Operation *op = v.getDefiningOp(); + if (!op) + continue; + if (isExtOrBroadcastOp(op)) { + SetVector tempSlice; + DenseMap tempLayout; + std::optional srcEncoding = inferSrcEncoding(op, layout[v]); + if (!srcEncoding) + return; + LogicalResult result = getRematerializableSlice( + op->getOperand(0), *srcEncoding, tempSlice, tempLayout); + // If we can rematerialize the rest of the ext slice we can ignore this + // ext as it won't need a convert. + if (result.succeeded()) { + slice.insert(tempSlice.begin(), tempSlice.end()); + layout.insert(tempLayout.begin(), tempLayout.end()); + continue; + } + // Only apply it if there is a single ext op otherwise we would have to + // duplicate the convert. + if (extOrBroadcatOp != nullptr) + return; + extOrBroadcatOp = op; + } + } + + if (extOrBroadcatOp == nullptr) + return; + Attribute dstEncoding = layout[extOrBroadcatOp->getResult(0)]; + std::optional srcEncoding = + inferSrcEncoding(extOrBroadcatOp, dstEncoding); + if (!srcEncoding) + return; + // Move the convert before the ext op and rewrite the slice. + OpBuilder builder(extOrBroadcatOp); + auto tensorType = + cast(extOrBroadcatOp->getOperand(0).getType()); + auto newType = RankedTensorType::get( + tensorType.getShape(), tensorType.getElementType(), *srcEncoding); + auto newConvertOp = builder.create( + convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); + Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp); + newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); + auto oldExtOrBroadcastType = + cast(extOrBroadcatOp->getResult(0).getType()); + Type newExtOrBroadcasrType = RankedTensorType::get( + oldExtOrBroadcastType.getShape(), oldExtOrBroadcastType.getElementType(), + dstEncoding); + newExtOrBroadcast->getResult(0).setType(newExtOrBroadcasrType); + IRMapping mapping; + mapping.map(extOrBroadcatOp->getResult(0), newExtOrBroadcast->getResult(0)); + slice.remove(extOrBroadcatOp->getResult(0)); + // 3. Rewrite the slice. + rewriteSlice(slice, layout, convertOp, mapping); +} + +void backwardRematerialization(ModuleOp module) { + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.backwardRematerialization(); + layoutRemat.cleanup(); + }); +} + +void hoistConvert(ModuleOp module) { + SmallVector convertOps; + module.walk([](FuncOp funcOp) { + LayoutRematerialization layoutRemat(funcOp); + layoutRemat.hoistConvertOnTopOfExtOrBroadcast(); + layoutRemat.cleanup(); + }); +} +} // namespace + +class TritonGPURemoveLayoutConversionsPass + : public impl::TritonGPURemoveLayoutConversionsBase< + TritonGPURemoveLayoutConversionsPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + // 1. Propagate layout forward starting from "anchor" ops. + m.walk([](FuncOp funcOp) { + LayoutPropagation layoutPropagation(funcOp); + layoutPropagation.initAnchorLayout(); + layoutPropagation.propagateLayout(); + layoutPropagation.resolveConflicts(); + layoutPropagation.rewrite(); + }); + + LLVM_DEBUG({ + DBGS() << "Module after propagating layouts forward:\n"; + m.dump(); + }); + + RewritePatternSet cleanUpPatterns(context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns, context); + if (applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns)).failed()) { + signalPassFailure(); + } + + LLVM_DEBUG({ + DBGS() << "Module after canonicalizing:\n"; + m.dump(); + }); + + // 2. For remaining convert ops, try to rematerialize the slice of producer + // operation to avoid having to convert. + backwardRematerialization(m); + LLVM_DEBUG({ + DBGS() << "Module after backward remat:\n"; + m.dump(); + }); + + // 3. For remaining converts, try to hoist them above cast generating larger + // size types in order to reduce the cost of the convert op. + hoistConvert(m); + LLVM_DEBUG({ + DBGS() << "Module after hoisting converts:\n"; + m.dump(); + }); + + RewritePatternSet decomposePatterns(context); + decomposePatterns.add(context); + if (applyPatternsAndFoldGreedily(m, std::move(decomposePatterns)) + .failed()) { + signalPassFailure(); + } + LLVM_DEBUG({ + DBGS() << "Module after decomposing dot-converts:\n"; + m.dump(); + }); + + // 4. Apply clean up patterns to remove remove dead convert and dead code + // generated by the previous transformations. + RewritePatternSet cleanUpPatterns2(context); + populateForOpDeadArgumentElimination(cleanUpPatterns2); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + scf::IfOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + ConvertLayoutOp::getCanonicalizationPatterns(cleanUpPatterns2, context); + if (applyPatternsAndFoldGreedily(m, std::move(cleanUpPatterns2)).failed()) { + signalPassFailure(); + } + LLVM_DEBUG({ + DBGS() << "Module after final cleanups:\n"; + m.dump(); + }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp new file mode 100644 index 000000000..bff277c59 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp @@ -0,0 +1,140 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUREORDERINSTRUCTIONS +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +static bool willIncreaseRegisterPressure(Operation *op) { + if (isa(op)) + return true; + auto cvt = dyn_cast(op); + if (!cvt) + return false; + if (mlir::isa( + cvt.getType().getEncoding())) + return true; + return false; +} + +class TritonGPUReorderInstructionsPass + : public impl::TritonGPUReorderInstructionsBase< + TritonGPUReorderInstructionsPass> { +public: + TritonGPUReorderInstructionsPass() = default; + + Operation *getFirstUse(Operation *op) { + std::vector users; + for (auto user : op->getUsers()) { + if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user)) + users.push_back(ancestor); + } + auto minOpIt = std::min_element(users.begin(), users.end(), + [](mlir::Operation *a, mlir::Operation *b) { + return a->isBeforeInBlock(b); + }); + return minOpIt != users.end() ? *minOpIt : nullptr; + } + + void runOnOperation() override { + ModuleOp m = getOperation(); + mlir::DominanceInfo dom(m); + // sink conversion after the last dealloc + // before the first use ancestor in its block + m.walk([&](triton::gpu::ConvertLayoutOp op) { + auto curr = mlir::Block::iterator(op); + for (; &*curr != getFirstUse(op); curr++) + if (isa(&*curr)) + op->moveAfter(&*curr); + }); + // Sink conversions into loops when they will increase + // register pressure + DenseMap opToMove; + auto moveAfter = [](Operation *lhs, Operation *rhs) { + lhs->moveAfter(rhs); + }; + m.walk([&](Operation *op) { + if (!willIncreaseRegisterPressure(op)) + return; + auto user_begin = op->user_begin(); + auto user_end = op->user_end(); + if (std::distance(user_begin, user_end) != 1) + return; + if (user_begin->getParentOfType() == + op->getParentOfType()) + return; + opToMove.insert({op, *user_begin}); + }); + for (auto &kv : opToMove) + kv.first->moveBefore(kv.second); + // Move alloc(load) immediately after dependent load + m.walk([&](triton::gpu::LocalAllocOp op) { + if (!op.getSrc()) + return; + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + moveAfter(op, argOp); + }); + // Move transpositions just after their definition + opToMove.clear(); + m.walk([&](triton::TransOp op) { + Operation *argOp = op.getSrc().getDefiningOp(); + if (!argOp) + return; + moveAfter(op, argOp); + }); + // Move `dot` operand so that conversions to opIdx=1 happens after + // conversions to opIdx=0 + m.walk([&](triton::gpu::LocalLoadOp op) { + auto dstEncoding = mlir::dyn_cast( + op.getType().getEncoding()); + if (!dstEncoding) + return; + int opIdx = dstEncoding.getOpIdx(); + if (opIdx != 1) + return; + if (!op->hasOneUse()) + return; + auto dotUser = dyn_cast(*op->user_begin()); + if (!dotUser) + return; + auto AOp = + dotUser.getOperand(0).getDefiningOp(); + if (!AOp) + return; + // Check that the conversion to OpIdx=1 happens before and can be moved + // after the conversion to OpIdx=0. + if (!dom.dominates(op.getOperation(), AOp.getOperation())) + return; + moveAfter(op, AOp); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Utility.cpp new file mode 100644 index 000000000..1d6152417 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -0,0 +1,977 @@ +#include "triton/Analysis/Utility.h" + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "ttg-utility" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { + +using namespace triton; + +SmallVector mmaVersionToInstrShape(int version, + const ArrayRef &shape, + TensorOrMemDesc type, + int numWarps) { + if (version == 1) + return {16, 16}; + else if (version == 2) { + auto rank = shape.size(); + SmallVector ret(rank, 1); + ret[rank - 1] = 8; + ret[rank - 2] = 16; + return ret; + } else if (version == 3) { + unsigned k = 256 / type.getElementTypeBitWidth(); + if (shape[0] % 64 != 0 || shape[1] % 8 != 0) { + assert(false && "type not supported"); + return {0, 0, 0}; + } + auto eltType = type.getElementType(); + SmallVector validN; + + // MMAv3 with larger instruction shape is preferred. + if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FNUZ() || + eltType.isF16() || eltType.isBF16() || eltType.isF32()) { + validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, + 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, + 80, 72, 64, 56, 48, 40, 32, 24, 16, 8}); + } + + if (eltType.isInteger(8)) { + validN.assign({224, 208, 192, 176, 160, 144, 128, 112, 96, 80, 64, 48, 32, + 24, 16, 8}); + } + + unsigned m = 16; + unsigned mWarps = std::max(shape[0] / m, 1); + unsigned nWarps = std::max(numWarps / mWarps, 1); + unsigned maxN = std::max(shape[1] / nWarps, 8); + for (auto n : validN) { + if (shape[1] % n == 0 && n <= maxN) { + return {m, n, k}; + } + } + + assert(false && "type not supported"); + return {0, 0, 0}; + } else { + assert(false && "version not supported"); + return {0, 0}; + } +} + +bool isLoadFromTensorPtr(triton::LoadOp op) { + return mlir::triton::isTensorPointerType(op.getPtr().getType()); +} + +SmallVector argSort(const SmallVector &arr) { + SmallVector ret(arr.size()); + std::iota(ret.begin(), ret.end(), 0); + std::stable_sort(ret.begin(), ret.end(), + [&](unsigned x, unsigned y) { return arr[x] > arr[y]; }); + return ret; +} + +Value getMemAccessPtr(Operation *op) { + if (auto ld = dyn_cast(op)) + return ld.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto copy = dyn_cast(op)) + return copy.getSrc(); + if (auto store = dyn_cast(op)) + return store.getPtr(); + return nullptr; +} + +unsigned getElementBitWidth(RankedTensorType type) { + auto typeForMem = + isa(type.getElementType()) + ? cast(type.getElementType()).getPointeeType() + : type.getElementType(); + return typeForMem.getIntOrFloatBitWidth(); +} + +unsigned getNumElementsPerThread(Operation *op, SmallVector order, + ModuleAxisInfoAnalysis &axisInfoAnalysis) { + Value val = getMemAccessPtr(op); + auto ty = cast(val.getType()); + auto shapePerCTA = triton::gpu::getShapePerCTA(ty); + AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val); + unsigned elemNumBits = getElementBitWidth(ty); + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); + unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); + unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); + unsigned maxContig = + std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); + unsigned alignment = std::min(maxMultiple, maxContig); + unsigned currPerThread = std::min(alignment, 128 / elemNumBits); + LDBG("elemNumBytes: " << elemNumBytes + << ", divisibility: " << maxMultipleBytes + << ", contig: " << valInfo.getContiguity(order[0]) + << ", alignment: " << alignment); + return currPerThread; +} + +//===----------------------------------------------------------------------===// +// GraphDumper +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphDumper::onValue(Value value) const { + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +GraphDumper::NodeInfo GraphDumper::onOperation(Operation *op) const { + return {{"shape", "ellipse"}, {"style", "filled"}, {"fillcolor", "white"}}; +} + +std::string GraphDumper::dump(triton::FuncOp func) const { + llvm::SetVector values; + llvm::SetVector operations; + + func.walk([&](Operation *op) { + operations.insert(op); + for (Value operand : op->getOperands()) + values.insert(operand); + for (Value result : op->getResults()) + values.insert(result); + }); + + std::ostringstream oss; + oss << "// Generated by Triton GraphDumper\n" + << "\n" + << "digraph {\n"; + + oss << " // Value Nodes\n"; + for (Value value : values) + oss << " " << emitValueNode(value) << "\n"; + oss << "\n"; + + oss << " // Operation Nodes\n"; + for (Operation *op : operations) + oss << " " << emitOperationNode(op) << "\n"; + oss << "\n"; + + oss << " // Edges\n"; + for (Operation *op : operations) { + for (Value operand : op->getOperands()) + oss << " " << emitEdge(getUniqueId(operand), getUniqueId(op)) << "\n"; + for (Value result : op->getResults()) + oss << " " << emitEdge(getUniqueId(op), getUniqueId(result)) << "\n"; + } + + oss << "}\n"; + return oss.str(); +} + +void GraphDumper::dumpToFile(triton::FuncOp func, + const std::string &filename) const { + std::ofstream ofs(filename); + ofs << dump(func); +} + +std::string GraphDumper::getShapeStr(const Type &type) const { + std::ostringstream oss; + oss << "["; + if (auto tensorTy = dyn_cast(type)) { + auto shape = tensorTy.getShape(); + for (unsigned i = 0; i < shape.size(); ++i) { + if (i > 0) + oss << ", "; + oss << shape[i]; + } + } + oss << "]"; + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Value value) const { + std::ostringstream oss; + oss << value.getImpl(); + return oss.str(); +} + +std::string GraphDumper::getUniqueId(Operation *op) const { + std::ostringstream oss; + oss << op; + return oss.str(); +} + +std::string GraphDumper::emitNode(const std::string &id, + const GraphDumper::NodeInfo info) const { + std::ostringstream oss; + oss << "\"" << id << "\" ["; + for (auto it = info.begin(); it != info.end(); ++it) { + if (it != info.begin()) + oss << ", "; + oss << it->first << " = \"" << it->second << "\""; + } + oss << "];"; + return oss.str(); +} + +std::string GraphDumper::emitEdge(const std::string &srcId, + const std::string &destId) const { + std::ostringstream oss; + oss << "\"" << srcId << "\" -> \"" << destId << "\";"; + return oss.str(); +} + +std::string GraphDumper::emitValueNode(Value value) const { + NodeInfo info = onValue(value); + if (info.find("label") == info.end()) { + std::string shapeStr = getShapeStr(value.getType()); + if (auto arg = mlir::dyn_cast(value)) + info["label"] = + "BlockArg" + std::to_string(arg.getArgNumber()) + " " + shapeStr; + else + info["label"] = shapeStr; + } + return emitNode(getUniqueId(value), info); +} + +std::string GraphDumper::emitOperationNode(Operation *op) const { + NodeInfo info = onOperation(op); + if (info.find("label") == info.end()) + info["label"] = op->getName().getStringRef().str(); + return emitNode(getUniqueId(op), info); +} + +//===----------------------------------------------------------------------===// +// GraphLayoutMarker +//===----------------------------------------------------------------------===// + +GraphDumper::NodeInfo GraphLayoutMarker::onValue(Value value) const { + std::string color = getColor(value.getType()); + return {{"shape", "box"}, {"style", "filled"}, {"fillcolor", color}}; +} + +std::string GraphLayoutMarker::getColor(const Type &type) const { + if (auto tensorTy = dyn_cast(type)) { + auto layout = tensorTy.getEncoding(); + if (isa(layout)) + return "green"; + else if (isa(layout)) + return "yellow"; + else if (isa(layout)) + return "lightslateblue"; + else if (isa(layout)) + return "orange"; + else if (isa(layout)) + return "orangered"; + else { + llvm::report_fatal_error("Unrecognized layout"); + return "unknown"; + } + } else { + return "white"; + } +} +// -------------------------------------------------------------------------- // + +static std::optional inferDstEncoding(triton::ReduceOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding); +} + +static std::optional inferDstEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return std::nullopt; + if (op.getAxis() != sliceEncoding.getDim()) + return std::nullopt; + return sliceEncoding.getParent(); +} + +static std::optional inferDstEncoding(JoinOp op, Attribute srcEnc) { + Attribute dstEnc; + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferJoinOpEncoding(srcEnc, dstEnc, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return std::nullopt; +} + +static std::optional inferDstEncoding(SplitOp op, Attribute srcEnc) { + Attribute dstEnc; + if (srcEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(srcEnc, dstEnc, + /*loc=*/std::nullopt) + .succeeded()) { + return dstEnc; + } + return std::nullopt; +} + +static std::optional inferSrcEncoding(triton::ReduceOp op, + Attribute encoding) { + auto sliceEncoding = mlir::dyn_cast(encoding); + if (!sliceEncoding) + return std::nullopt; + if (op.getAxis() != sliceEncoding.getDim()) + return std::nullopt; + return sliceEncoding.getParent(); +} + +static std::optional inferSrcEncoding(triton::ExpandDimsOp op, + Attribute encoding) { + return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(), + encoding); +} + +static std::optional inferSrcEncoding(JoinOp op, Attribute dstEnc) { + // Split is the inverse of join. + Attribute srcEnc; + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferSplitOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return std::nullopt; +} + +static std::optional inferSrcEncoding(SplitOp op, Attribute dstEnc) { + // Join is the inverse of split. + Attribute srcEnc; + if (dstEnc.getDialect() + .getRegisteredInterface() + ->inferJoinOpEncoding(dstEnc, srcEnc, /*loc=*/std::nullopt) + .succeeded()) { + return srcEnc; + } + return std::nullopt; +} + +static std::optional +inferTransOpDstEncoding(Attribute srcEnc, ArrayRef order) { + // Simply forward to the existing inferTransOpEncoding function. + Attribute retEncoding; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferTransOpEncoding(srcEnc, order, retEncoding))) { + return retEncoding; + } + return std::nullopt; +} + +static std::optional inferDstEncoding(triton::TransOp op, + Attribute encoding) { + return inferTransOpDstEncoding(encoding, op.getOrder()); +} + +static std::optional inferSrcEncoding(triton::TransOp op, + Attribute encoding) { + // We want to solve for srcEnc in + // transpose(srcEnc, order) -> dstEnc. + // Given the identity + // transpose(transpose(x, order), inverse(order)) == x, + // we can see this is equivalent to + // transpose(dstEnc, inverse(order)) -> srcEnc. + return inferTransOpDstEncoding(encoding, + triton::inversePermutation(op.getOrder())); +} + +static std::optional +inferReshapeOpDstEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, bool allowReorder) { + // We don't do anything smart to allow-reorder reshapes here. They are + // handled in OptimizeThreadLocality. + if (allowReorder) + return std::nullopt; + + Attribute dstEnc; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpNoReorderEncoding( + srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) { + return dstEnc; + } + return std::nullopt; +} + +static std::optional inferDstEncoding(triton::ReshapeOp op, + Attribute encoding) { + return inferReshapeOpDstEncoding(op.getSrc().getType().getShape(), encoding, + op.getType().getShape(), + op.getAllowReorder()); +} + +static std::optional inferSrcEncoding(triton::ReshapeOp op, + Attribute encoding) { + // The encoding of x given the encoding of y in `reshape(x) -> y` is the same + // as the encoding of x given the encoding of y in `reshape(y) -> x`. It's an + // invariant of inferReshapeOpNoReorderEncoding that it's symmetric in this + // way. + return inferReshapeOpDstEncoding(op.getType().getShape(), encoding, + op.getSrc().getType().getShape(), + op.getAllowReorder()); +} + +std::optional inferSrcEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + // Scan only supports blocked encoding at the moment. + if (!isa(encoding)) + return std::nullopt; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa( + op)) { + return encoding; + } + + if (auto reduceOp = dyn_cast(op)) + return inferSrcEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferSrcEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferSrcEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferSrcEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferSrcEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferSrcEncoding(reshape, encoding); + + return std::nullopt; +} + +std::optional inferDstEncoding(Operation *op, Attribute encoding) { + if (isa(op)) { + if (!isa(encoding)) + return std::nullopt; + } + if (op->hasTrait() || + op->hasTrait() || + op->hasTrait() || + isa(op)) + return encoding; + if (auto reduceOp = dyn_cast(op)) + return inferDstEncoding(reduceOp, encoding); + if (auto expand = dyn_cast(op)) + return inferDstEncoding(expand, encoding); + if (auto join = dyn_cast(op)) + return inferDstEncoding(join, encoding); + if (auto split = dyn_cast(op)) + return inferDstEncoding(split, encoding); + if (auto trans = dyn_cast(op)) + return inferDstEncoding(trans, encoding); + if (auto reshape = dyn_cast(op)) + return inferDstEncoding(reshape, encoding); + + return std::nullopt; +} + +bool isSingleValue(Value value) { + // Don't consider load as expensive if it is loading a scalar. + if (auto tensorTy = dyn_cast(value.getType())) + return tensorTy.getNumElements() == 1; + // TODO: Handle other cases. + // For example, when ptr is a tensor of single value. + // It means that ptr is a resultant of broadcast or generated through + // a chain of broadcast and other operations. + // Rematerialize it without considering contiguous memory access pattern is + // fine. + return true; +} + +bool isExpensiveLoadOrStore(Operation *op) { + // Case 1: Pointer of tensor is always expensive + auto operandType = op->getOperand(0).getType(); + if (triton::isTensorPointerType(operandType)) + return true; + // Case 2a: A size 1 tensor is not expensive since all threads will load the + // same + if (isSingleValue(op->getOperand(0))) + return false; + // Case 2b: Tensor of pointers has more threads than elements + // we can presume a high hit-rate that makes it cheap to load + auto ptrType = cast(op->getOperand(0).getType()); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (ptrType.getNumElements() < numWarps * threadsPerWarp) + return false; + return true; +} + +bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) { + if (!op) + return true; + if (isa(op)) + return isExpensiveLoadOrStore(op); + if (isa(op)) + return triton::gpu::isExpensiveCat(cast(op), targetEncoding); + if (isa(op)) + return true; + if (isa( + op)) + return true; + return false; +} + +bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { + if (isa(op)) + return !triton::gpu::isExpensiveCat(cast(op), + targetEncoding); + if (auto convert = dyn_cast(op)) { + if (mlir::isa(targetEncoding)) { + auto srcEncoding = convert.getSrc().getType().getEncoding(); + if (targetEncoding != srcEncoding) + return false; + } + return true; + } + + if (auto reshape = dyn_cast(op)) { + auto reshapeDstType = reshape.getType(); + RankedTensorType newDstType = + RankedTensorType::get(reshapeDstType.getShape(), + reshapeDstType.getElementType(), targetEncoding); + return reshape.getAllowReorder() && + !reshape.getEfficientLayout().has_value() && + !triton::gpu::isExpensiveView(reshape.getSrc().getType(), + newDstType); + } + return isa(op); +} + +scf::ForOp replaceForOpWithNewSignature( + RewriterBase &rewriter, scf::ForOp loop, ValueRange newIterOperands, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + + // Create a new loop before the existing one, with the extra operands. + auto operands = llvm::to_vector<4>(loop.getInitArgs()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands); + newLoop->setAttrs(loop->getAttrs()); + newLoop.getBody()->erase(); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); + for (Value operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + replacements.push_back(it); + return newLoop; +} + +scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, + ValueRange newIterOperands) { + SmallVector> replacements; + auto newForOp = replaceForOpWithNewSignature(rewriter, loop, newIterOperands, + replacements); + for (auto &kv : replacements) { + rewriter.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv)); + } + return newForOp; +} + +scf::IfOp replaceIfOpWithNewSignature( + RewriterBase &rewriter, scf::IfOp ifOp, TypeRange newResultTypes, + SmallVectorImpl> &replacements) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(ifOp); + + // Create a new loop before the existing one, with the extra operands. + auto resultTypes = llvm::to_vector<4>(ifOp.getResults().getTypes()); + resultTypes.append(newResultTypes.begin(), newResultTypes.end()); + scf::IfOp newIf = rewriter.create( + ifOp.getLoc(), resultTypes, ifOp.getCondition(), /*withElse=*/true); + newIf->setAttrs(ifOp->getAttrs()); + + rewriter.inlineBlockBefore(ifOp.thenBlock(), newIf.thenBlock(), + newIf.thenBlock()->begin()); + rewriter.inlineBlockBefore(ifOp.elseBlock(), newIf.elseBlock(), + newIf.elseBlock()->begin()); + + for (auto it : llvm::zip(ifOp.getResults(), + newIf.getResults().take_front(ifOp.getNumResults()))) + replacements.push_back(it); + return newIf; +} + +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, + IRMapping &mapping) { + Operation *newOp = rewriter.clone(*op, mapping); + // if input types haven't changed, we're done + bool preserveTypes = + std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) { + return !mapping.contains(v) || + v.getType() == mapping.lookup(v).getType(); + }); + if (preserveTypes) + return newOp; + + if (newOp->getNumResults() == 0) + return newOp; + auto origType = dyn_cast(op->getResult(0).getType()); + auto argType = dyn_cast(newOp->getOperand(0).getType()); + if (!origType || !argType) + return newOp; + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), argType.getEncoding()); + newOp->getResult(0).setType(newType); + auto typeInfer = dyn_cast(newOp); + if (typeInfer) { + SmallVector newTypes; + auto success = typeInfer.inferReturnTypes( + newOp->getContext(), newOp->getLoc(), newOp->getOperands(), + newOp->getAttrDictionary(), newOp->getPropertiesStorage(), + newOp->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newOp->getResult(i).setType(newTypes[i]); + } + } + return newOp; +} + +// Check if the convert will be a no-op in codegen. +static bool isFreeConvert(Operation *op) { + auto convertOp = dyn_cast(op); + if (!convertOp) + return false; + return isMmaToMmaShortcut(convertOp.getSrc().getType(), convertOp.getType()); +} + +LogicalResult +getConvertBackwardSlice(Value root, SetVector &slice, + Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation) { + DenseSet visited; + SmallVector> queue = {{root, rootEncoding}}; + while (!queue.empty()) { + auto [currentValue, encoding] = queue.back(); + queue.pop_back(); + if (!visited.insert(currentValue).second) + continue; + if (!isa(currentValue.getType())) + continue; + // Skip propagating through for op results for now. + // TODO: enable this based on needs. + if (currentValue.getDefiningOp()) + return failure(); + slice.insert(currentValue); + if (layout.find(currentValue) != layout.end()) { + if (layout[currentValue] != encoding) + return failure(); + } + layout[currentValue] = encoding; + + if (auto ifOp = currentValue.getDefiningOp()) { + auto results = ifOp.getResults(); + unsigned argIdx = mlir::cast(currentValue).getResultNumber(); + + auto thenValue = ifOp.thenYield().getOperand(argIdx); + auto elseValue = ifOp.elseYield().getOperand(argIdx); + + queue.push_back({thenValue, encoding}); + queue.push_back({elseValue, encoding}); + + continue; + } + if (auto *definingOp = currentValue.getDefiningOp()) { + // If the op has multiple results we need to update all results layout. + for (Value result : definingOp->getResults()) { + if (result == currentValue || !isa(result.getType())) + continue; + if (layout.find(result) != layout.end()) { + if (layout[result] != encoding) + return failure(); + continue; + } + layout[result] = encoding; + } + if (!isFreeConvert(definingOp) && + canFoldIntoConversion(definingOp, encoding)) + continue; + if (stopPropagation && stopPropagation(definingOp)) + continue; + if (isa(definingOp)) + return failure(); + for (Value operand : definingOp->getOperands()) { + auto srcEncoding = inferSrcEncoding(definingOp, encoding); + if (!srcEncoding) + return failure(); + if (slice.count(operand) == 0) + queue.push_back({operand, *srcEncoding}); + } + continue; + } + auto blockArg = cast(currentValue); + Block *block = blockArg.getOwner(); + Operation *parentOp = block->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); + Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( + blockArg.getArgNumber() - forOp.getNumInductionVars()); + queue.push_back({initOperand->get(), encoding}); + queue.push_back({yieldOperand, encoding}); + continue; + } + // TODO: add support for WhileOp and other region types. + return failure(); + } + return success(); +} + +// TODO(thomas): this is duplicated with what is in GPUToLLVM +// Convert an \param index to a multi-dim coordinate given \param shape and +// \param order. +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape, + ArrayRef order) { + unsigned rank = shape.size(); + assert(rank == order.size()); + auto reordered = triton::applyPermutation(shape, order); + auto reorderedMultiDim = delinearize(b, loc, linear, reordered); + SmallVector multiDim(rank); + for (unsigned i = 0; i < rank; ++i) { + multiDim[order[i]] = reorderedMultiDim[i]; + } + return multiDim; +} + +SmallVector delinearize(OpBuilder &b, Location loc, Value linear, + ArrayRef shape) { + unsigned rank = shape.size(); + assert(rank > 0); + SmallVector multiDim(rank); + if (rank == 1) { + multiDim[0] = linear; + } else { + Value remained = linear; + for (auto &&en : llvm::enumerate(shape.drop_back())) { + auto dimSize = b.create(loc, en.value(), 32); + multiDim[en.index()] = b.create(loc, remained, dimSize); + remained = b.create(loc, remained, dimSize); + } + multiDim[rank - 1] = remained; + } + return multiDim; +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { + return linearize(b, loc, triton::applyPermutation(multiDim, order), + triton::applyPermutation(shape, order)); +} + +Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, + ArrayRef shape) { + auto rank = multiDim.size(); + Value linear = b.create(loc, 0, 32); + if (rank > 0) { + linear = multiDim.back(); + for (auto [dim, dimShape] : + llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) { + Value dimSize = b.create(loc, dimShape, 32); + linear = b.create( + loc, b.create(loc, linear, dimSize), dim); + } + } + return linear; +} + +bool isPureUnaryInlineAsm(Operation *op) { + auto inlineAsmOp = dyn_cast(op); + if (!inlineAsmOp) + return false; + return op->getNumOperands() == 1 && op->getNumResults() == 1 && + inlineAsmOp.getPure(); +} + +int getNVIDIAComputeCapability(Operation *module) { + assert(module->hasAttr(triton::AttrTargetName) && + "Expected a target attribute on the module operation"); + + StringAttr targetAttr = + cast(module->getAttr(triton::AttrTargetName)); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("cuda:") && + "expected target attribute to be prefixed with \"cuda:\""); + + StringRef capabilityStr = ref.drop_front(5); // drop the "cuda:" + int computeCapability; + bool parseError = capabilityStr.getAsInteger(10, computeCapability); + assert(!parseError && + "invalid compute capability string in target attribute"); + + return computeCapability; +} + +namespace { + +/// Detect dead arguments in scf.for op by assuming all the values are dead and +/// propagate liveness property. +struct ForOpDeadArgElimination : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const final { + Block &block = *forOp.getBody(); + auto yieldOp = cast(block.getTerminator()); + // Assume that nothing is live at the beginning and mark values as live + // based on uses. + DenseSet aliveValues; + SmallVector queue; + // Helper to mark values as live and add them to the queue of value to + // propagate if it is the first time we detect the value as live. + auto markLive = [&](Value val) { + if (!forOp->isAncestor(val.getParentRegion()->getParentOp())) + return; + if (aliveValues.insert(val).second) + queue.push_back(val); + }; + // Mark all yield operands as live if the associated forOp result has any + // use. + for (auto result : llvm::enumerate(forOp.getResults())) { + if (!result.value().use_empty()) + markLive(yieldOp.getOperand(result.index())); + } + if (aliveValues.size() == forOp.getNumResults()) + return failure(); + // Operations with side-effects are always live. Mark all theirs operands as + // live. + block.walk([&](Operation *op) { + if (!isa(op) && !wouldOpBeTriviallyDead(op)) { + for (Value operand : op->getOperands()) + markLive(operand); + } + }); + // Propagate live property until reaching a fixed point. + while (!queue.empty()) { + Value value = queue.pop_back_val(); + if (auto nestedFor = value.getDefiningOp()) { + auto result = mlir::cast(value); + OpOperand &forOperand = *nestedFor.getTiedLoopInit(result); + markLive(forOperand.get()); + auto nestedYieldOp = + cast(nestedFor.getBody()->getTerminator()); + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + continue; + } + if (auto nestedIf = value.getDefiningOp()) { + auto result = mlir::cast(value); + for (scf::YieldOp nestedYieldOp : + {nestedIf.thenYield(), nestedIf.elseYield()}) { + Value nestedYieldOperand = + nestedYieldOp.getOperand(result.getResultNumber()); + markLive(nestedYieldOperand); + } + continue; + } + if (Operation *def = value.getDefiningOp()) { + // TODO: support while ops. + if (isa(def)) + return failure(); + for (Value operand : def->getOperands()) + markLive(operand); + continue; + } + // If an argument block is live then the associated yield operand and + // forOp operand are live. + auto arg = mlir::cast(value); + if (auto forOwner = dyn_cast(arg.getOwner()->getParentOp())) { + if (arg.getArgNumber() < forOwner.getNumInductionVars()) + continue; + unsigned iterIdx = arg.getArgNumber() - forOwner.getNumInductionVars(); + Value yieldOperand = + forOwner.getBody()->getTerminator()->getOperand(iterIdx); + markLive(yieldOperand); + markLive(forOwner.getInitArgs()[iterIdx]); + } + } + SmallVector deadArg; + for (auto yieldOperand : llvm::enumerate(yieldOp->getOperands())) { + if (aliveValues.contains(yieldOperand.value())) + continue; + if (yieldOperand.value() == block.getArgument(yieldOperand.index() + 1)) + continue; + + // The yield operand might live outside the loop, e.g. + // %init = ... + // %x = ... + // %y = for iter_args(%unused = %init) { + // yield %x + // } + // + // In this case, the loop returns %x if it runs 1 or more times, and + // otherwise it returns %init. We cowardly refuse to remove this operand + // from the yield. (We could, but we'd need to prove that the loop runs 0 + // or >=1 times.) + // + // As a special case, if it doesn't matter whether the loop runs 0 or >=1 + // times (because the loop returns the same value in both cases) then we + // can still mark the operand as dead. This occurs in the above example + // when %init is the same as %x. + if (!forOp->isAncestor( + yieldOperand.value().getParentRegion()->getParentOp()) && + yieldOperand.value() != forOp.getInitArgs()[yieldOperand.index()]) + continue; + + deadArg.push_back(yieldOperand.index()); + } + if (deadArg.empty()) + return failure(); + rewriter.modifyOpInPlace(forOp, [&]() { + // For simplicity we just change the dead yield operand to use the + // associated argument and leave the operations and argument removal to + // dead code elimination. + for (unsigned deadArgIdx : deadArg) { + BlockArgument arg = block.getArgument(deadArgIdx + 1); + yieldOp.setOperand(deadArgIdx, arg); + } + }); + return success(); + } +}; + +} // namespace + +void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt new file mode 100644 index 000000000..b3def5dc8 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(TritonNvidiaGPUIR + Dialect.cpp + Ops.cpp + Types.cpp + + DEPENDS + TritonNvidiaGPUTableGen + TritonNvidiaGPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR +) diff --git a/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp new file mode 100644 index 000000000..0a982ce05 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc" + +using namespace mlir; +using namespace mlir::triton::nvidia_gpu; + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc" + +//===----------------------------------------------------------------------===// + +void TritonNvidiaGPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc" + >(); +} + +// verify TritonNvidiaGPU ops +LogicalResult +TritonNvidiaGPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp new file mode 100644 index 000000000..0b06ee643 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "mlir/IR/Builders.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc" + +namespace mlir { +namespace triton { +namespace nvidia_gpu { + +// -- DotAsyncOp -- +mlir::LogicalResult DotAsyncOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + + // verify encodings + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); + auto retEnc = accTy.getEncoding(); + if (aEnc) { + assert(bEnc); + Dialect &dialect = aEnc.getDialect(); + auto interface = dyn_cast(&dialect); + if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) + return mlir::failure(); + if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) + return mlir::failure(); + } + return mlir::success(); +} + +void DotAsyncOp::getEffects( + SmallVectorImpl> + &effects) { + auto a = getA(); + auto b = getB(); + if (isa(a.getType())) + effects.emplace_back(MemoryEffects::Read::get(), a, + mlir::triton::gpu::SharedMemory::get()); + if (isa(b.getType())) + effects.emplace_back(MemoryEffects::Read::get(), b, + mlir::triton::gpu::SharedMemory::get()); +} + +// -- DotWaitOp -- +LogicalResult DotWaitOp::inferReturnTypes( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + for (Value operand : operands) + inferredReturnTypes.push_back(operand.getType()); + return mlir::success(); +} + +static LogicalResult verifyBarrierType(Operation *op, MemDescType barrierType) { + if (!barrierType.getElementType().isInteger(64) || + barrierType.getShape() != ArrayRef({1})) + return op->emitOpError( + "barrier allocation must be a descriptor of 1xi64 type"); + return success(); +} + +// -- InitBarrierOp -- +LogicalResult InitBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +void InitBarrierOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- InvalBarrierOp -- +LogicalResult InvalBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +void InvalBarrierOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- BarrierExpectOp -- +LogicalResult BarrierExpectOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +void BarrierExpectOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), getAlloc(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- WaitBarrierOp -- +LogicalResult WaitBarrierOp::verify() { + if (failed(verifyBarrierType(*this, getAlloc().getType()))) + return failure(); + return success(); +} + +void WaitBarrierOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getAlloc(), + mlir::triton::gpu::SharedMemory::get()); + // Need a side effect to prevent compiler from reordering and removing + // the wait operation. + effects.emplace_back(MemoryEffects::Write::get(), + mlir::SideEffects::DefaultResource::get()); +} + +// -- AsyncTMACopyGlobalToLocalOp -- +LogicalResult AsyncTMACopyGlobalToLocalOp::verify() { + if (failed(verifyBarrierType(*this, getBarrier().getType()))) + return failure(); + if (getCoord().size() < 1 || getCoord().size() > 5) + return emitOpError("TMA copies must have between 1 and 5 coordinates"); + return success(); +} + +void AsyncTMACopyGlobalToLocalOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getDescPtr(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), getBarrier(), + mlir::triton::gpu::SharedMemory::get()); + effects.emplace_back(MemoryEffects::Write::get(), getResult(), + mlir::triton::gpu::SharedMemory::get()); +} + +// -- AsyncTMACopyLocalToGlobalOp -- +void AsyncTMACopyLocalToGlobalOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), getDescPtr(), + mlir::triton::GlobalMemory::get()); + effects.emplace_back(MemoryEffects::Read::get(), getSrc(), + mlir::triton::gpu::SharedMemory::get()); +} + +} // namespace nvidia_gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/Types.cpp b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/Types.cpp new file mode 100644 index 000000000..326f4948a --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/IR/Types.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::nvidia_gpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc" + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc" + >(); +} diff --git a/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..5adebc352 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonNvidiaGPUTransforms + FenceInsertion.cpp + PlanCTA.cpp + TMALowering.cpp + + DEPENDS + TritonNvidiaGPUTransformsIncGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR + TritonGPUTransforms + TritonNvidiaGPUIR + MLIRTransformUtils +) diff --git a/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp new file mode 100644 index 000000000..c7dd8d595 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -0,0 +1,138 @@ +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// +// This pass works after all other passes, inserting fences to ensure that +// memory operations are properly ordered across generic and async proxy. +// +//===----------------------------------------------------------------------===// + +using namespace mlir; +namespace tt = ::mlir::triton; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +using ::mlir::triton::gpu::SharedEncodingAttr; + +namespace { + +struct FenceInsertionPass + : public TritonGPUFenceInsertionBase { + +public: + FenceInsertionPass() = default; + FenceInsertionPass(int computeCapability) { + this->computeCapability = computeCapability; + } + // TODO: support more general patterns to insert fences. eg. any op(generic) + // to shared in use-def chain which refers by async proxy. We have generic( + // convertlayout with sts/stmatix) + fence + async(wgmma) up to now + void runOnOperation() override { + // Only insert fences for compute capability 9.0 + if (computeCapability < 90) + return; + if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) + return; + ModuleOp mod = getOperation(); + mod.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + OpBuilder builder(op); + auto a = op->getOperand(0); + auto b = op->getOperand(1); + auto mmaEncoding = dyn_cast( + cast(op->getResult(0).getType()).getEncoding()); + if (!mmaEncoding || !mmaEncoding.isHopper()) + return WalkResult::advance(); + bool aDependsOnShared = dependOnSharedEncOperand(a); + bool bDependsOnShared = dependOnSharedEncOperand(b); + if (!aDependsOnShared && !bDependsOnShared) + return WalkResult::advance(); + Operation *fence = builder.create( + op->getLoc(), /*bCluster=*/false); + // If there is all the dependencies are outside of the loop try to hoist + // the fence. + while (auto loopOp = fence->getParentOfType()) { + if (aDependsOnShared && + loopOp->isAncestor(a.getParentBlock()->getParentOp())) + break; + if (bDependsOnShared && + loopOp->isAncestor(b.getParentBlock()->getParentOp())) + break; + loopOp.moveOutOfLoop(fence); + } + return WalkResult::advance(); + }); + } + +private: + bool dependOnSharedEncOperand(Value operand) { + static DenseSet> trace; + auto op = operand.getDefiningOp(); + // avoid redundant insertion + if (op && isa(op)) + return false; + // reach convertlayout + if (op && isa(op) && + cast(op).getSrc()) + return true; + // root and not BlockArgument + if (!op && !isa(operand)) + return false; + // op and not BlockArgument + if (op && !isa(operand)) { + for (auto v : op->getOperands()) { + if (dependOnSharedEncOperand(v)) + return true; + } + } + // reach BlockArgument + // TODO: support other scf ops, IfOp, WhileOp, etc. + if (BlockArgument arg = dyn_cast(operand)) { + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + // support ForOp only + if (auto forOp = dyn_cast(argOwner)) { + // prologue + auto iterOperands = forOp.getInitArgs(); + if (argNum == 0) + return false; + if (dependOnSharedEncOperand(iterOperands[argNum - 1])) + return true; + // yield + auto yieldOp = forOp.getBody()->getTerminator(); + Value v = yieldOp->getOperand(argNum - 1); + auto entry = std::make_pair(std::move(yieldOp), + std::move(argNum)); + // avoid cyclic + if (trace.contains(entry)) + return false; + else + trace.insert(entry); + + if (dependOnSharedEncOperand(v)) + return true; + } else if (auto whileOp = dyn_cast(argOwner)) { + assert(false && "FenceInsertionPass does not supported WhileOp"); + } else if (auto ifOp = dyn_cast(argOwner)) { + assert(false && "FenceInsertionPass does not supported IfOp"); + } + } + return false; + } +}; +} // namespace + +std::unique_ptr +mlir::createTritonNvidiaGPUFenceInsertionPass(int computeCapability) { + return std::make_unique(computeCapability); +} diff --git a/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp new file mode 100644 index 000000000..e26af25a6 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -0,0 +1,1040 @@ +/* + * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include + +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +using namespace mlir; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +// TODO: use ConvertLayoutOp +using CastOp = ::mlir::UnrealizedConversionCastOp; + +unsigned getNumUsers(Value value) { + return std::distance(value.user_begin(), value.user_end()); +} + +Type replaceLayout(const Type &type, const Attribute &newLayout) { + Type curType = type; + auto ptrTy = dyn_cast(curType); + if (ptrTy) + curType = ptrTy.getPointeeType(); + if (auto tensorTy = dyn_cast(curType)) + curType = RankedTensorType::get(tensorTy.getShape(), + tensorTy.getElementType(), newLayout); + if (ptrTy) + curType = triton::PointerType::get(curType, ptrTy.getAddressSpace()); + return curType; +} + +Attribute replaceCTALayout(Attribute layout, llvm::ArrayRef shape, + const ttg::CTALayoutAttr &newCTALayout) { + if (auto blockedLayout = mlir::dyn_cast(layout)) { + return ttg::BlockedEncodingAttr::get( + layout.getContext(), shape, blockedLayout.getSizePerThread(), + blockedLayout.getOrder(), ttg::getNumWarpsPerCTA(layout), 32, + newCTALayout); + } else if (auto sliceLayout = + mlir::dyn_cast(layout)) { + return ttg::SliceEncodingAttr::get( + layout.getContext(), sliceLayout.getDim(), + replaceCTALayout(sliceLayout.getParent(), shape, newCTALayout)); + } else { + // Other layouts are generated by passes after PlanCTAPass + llvm::report_fatal_error("replaceCTALayout not implemented"); + return layout; + } +} + +class CTAPlanner { +public: + CTAPlanner(ttng::ClusterInfo *clusterInfo_); + ~CTAPlanner(); + + void run(triton::FuncOp &funcOp); + +private: + CastOp markBackward(CastOp cast) const; + CastOp markForward(CastOp cast) const; + bool isBackward(CastOp cast) const; + bool isForward(CastOp cast) const; + + void setTiling(llvm::ArrayRef CTAsPerCGA); + bool processDot(triton::FuncOp &funcOp); + bool processReduce(triton::FuncOp &funcOp); + void processStoreLikeOps(triton::FuncOp &funcOp); + + bool propagate(CastOp cast); + bool propagateBackward(CastOp cast); + bool propagateForward(CastOp cast); + + void eraseCastOp(CastOp cast); + void eraseCastOpFromQueue(CastOp cast); + void eraseCastOpsFromQueue(llvm::ArrayRef casts); + + void insertCasts(Operation *op, llvm::ArrayRef newOperandLayouts, + llvm::ArrayRef newResultLayouts); + void eliminateAdjacentCasts(CastOp cast0, CastOp cast1); + + bool isLoadStoreOp(Operation *op) const; + bool processLoadStore(Operation *op, Attribute layout); + + bool isElementwiseOp(Operation *op) const; + bool processElementwise(Operation *op, Attribute layout); + + bool processConstant(arith::ConstantOp constant, Attribute layout); + bool processSplat(triton::SplatOp splat, Attribute layout); + bool processMakeRange(triton::MakeRangeOp makeRange, Attribute layout); + bool processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr, + Attribute layout); + + bool processBroadcast(triton::BroadcastOp broadcast, Attribute layout); + bool processExpandDimsBackward(triton::ExpandDimsOp expandDims, + Attribute newResultLayout); + bool processExpandDimsForward(triton::ExpandDimsOp expandDims, + Attribute newSrcLayout); + + bool processConvertLayoutBackward(ttg::ConvertLayoutOp convertLayout, + CastOp cast); + bool processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout, + CastOp cast); + + bool processIfOp(scf::IfOp ifOp, int index, const Type &newType); + bool processForOp(scf::ForOp forOp, int index, const Type &newType); + + bool processIfOpBackward(scf::IfOp ifOp, CastOp cast); + bool processForOpBackward(scf::ForOp forOp, CastOp cast); + bool processBlockArgBackward(BlockArgument arg, CastOp cast); + bool processForOpForward(scf::ForOp forOp, CastOp cast); + bool processYieldOpForward(scf::YieldOp yieldOp, CastOp cast); + + bool processOpFallback(Operation *op); + + bool processMultiUsersBackward(Value input, CastOp cast); + bool processMultiUsersForward(Value output, CastOp cast); + + // This flag indicates whether clusterInfo needs to be deleted in the + // destructor of CTAPlanner. The flag `ownInfo` is set to false when a + // non-null pointer to clusterInfo is passed to the constructor of CTAPlanner. + // Otherwise, a self-managed ClusterInfo will be created and the ownInfo will + // be set to true. + bool ownInfo; + ttng::ClusterInfo *clusterInfo; + bool tiled; + unsigned step; + unsigned stepUnchanged; + std::queue queue; +}; + +CTAPlanner::CTAPlanner(ttng::ClusterInfo *clusterInfo_) + : ownInfo(false), clusterInfo(clusterInfo_), tiled(false), step(0), + stepUnchanged(0) { + if (clusterInfo == nullptr) { + clusterInfo = new ttng::ClusterInfo(); + ownInfo = true; + } +} + +CTAPlanner::~CTAPlanner() { + if (ownInfo) { + delete clusterInfo; + // Actually not necessary but safer + ownInfo = false; + clusterInfo = nullptr; + } +} + +void CTAPlanner::run(triton::FuncOp &funcOp) { + assert(!tiled && "Please create a new CTAPlanner"); + static const unsigned maxSteps = 10000; + + auto nextStep = [&]() { + ++step; + assert(step < maxSteps && "Maximum number of steps exceeded"); + }; + + processDot(funcOp); + nextStep(); + + processReduce(funcOp); + nextStep(); + + if (!tiled) { + processStoreLikeOps(funcOp); + nextStep(); + } + + while (!queue.empty()) { + CastOp cast = queue.front(); + queue.pop(); + bool changed = propagate(cast); + if (changed) { + stepUnchanged = 0; + } else { + queue.push(cast); + ++stepUnchanged; + } + nextStep(); + } +} + +CastOp CTAPlanner::markBackward(CastOp cast) const { + cast->setAttr("direction", StringAttr::get(cast.getContext(), "backward")); + return cast; +} + +CastOp CTAPlanner::markForward(CastOp cast) const { + cast->setAttr("direction", StringAttr::get(cast.getContext(), "forward")); + return cast; +} + +bool CTAPlanner::isBackward(CastOp cast) const { + return cast->getAttrOfType("direction") == "backward"; +} + +bool CTAPlanner::isForward(CastOp cast) const { + return cast->getAttrOfType("direction") == "forward"; +} + +void CTAPlanner::setTiling(llvm::ArrayRef CTAsPerCGA) { + assert(!tiled && "CTA tiling is already determinted"); + assert(clusterInfo && "ClusterInfo pointer is null"); + assert(CTAsPerCGA.size() <= 3 && "setTiling not implemented"); + if (CTAsPerCGA.size() > 0) + clusterInfo->clusterDimX = CTAsPerCGA[0]; + if (CTAsPerCGA.size() > 1) + clusterInfo->clusterDimY = CTAsPerCGA[1]; + if (CTAsPerCGA.size() > 2) + clusterInfo->clusterDimZ = CTAsPerCGA[2]; + tiled = true; +} + +bool CTAPlanner::processDot(triton::FuncOp &funcOp) { + // TODO: This is a naive implementation and should be refactored + auto getCTATiling = [](int64_t M, int64_t N, int64_t K, + unsigned numCTAs) -> std::pair { + // prefer a larger chunk size, at most 128; first assign splitM. + unsigned chunk_m = 128; + auto isLegal = [](unsigned chunk) { return chunk >= 64; }; + unsigned splitM, splitN; + for (; isLegal(chunk_m); chunk_m /= 2) { + splitM = std::clamp(M / chunk_m, 1, numCTAs); + splitN = numCTAs / splitM; + if (isLegal(N / splitN)) // chunk_n; + break; + } + return {splitM, splitN}; + }; + + funcOp.walk([&](triton::DotOp dot) { + MLIRContext *ctx = dot.getContext(); + + auto aTy = cast(dot.getA().getType()); + auto bTy = cast(dot.getB().getType()); + auto dTy = cast(dot.getD().getType()); + + assert(isa(aTy.getEncoding()) && + isa(bTy.getEncoding()) && + isa(dTy.getEncoding()) && + "PlanCTAPass should follow immediately after CoalescePass"); + + auto aLayout = cast(aTy.getEncoding()); + auto bLayout = cast(bTy.getEncoding()); + auto dLayout = cast(dTy.getEncoding()); + + unsigned M = dTy.getShape()[0]; + unsigned N = dTy.getShape()[1]; + unsigned K = aTy.getShape()[1]; + + unsigned splitM, splitN; + std::tie(splitM, splitN) = getCTATiling(M, N, K, ttg::getNumCTAs(dLayout)); + // FIXME: Should consider IR with more than one DotOps + setTiling({splitM, splitN, 1}); + + auto newCTALayout = ttg::CTALayoutAttr::get(ctx, {splitM, splitN}, + {splitM, splitN}, {1, 0}); + auto newDLayout = ttg::BlockedEncodingAttr::get( + ctx, dTy.getShape(), dLayout.getSizePerThread(), dLayout.getOrder(), + ttg::getNumWarpsPerCTA(dLayout), 32, newCTALayout); + auto newALayout = ttg::DotOperandEncodingAttr::get(ctx, aLayout.getOpIdx(), + newDLayout, 0); + auto newBLayout = ttg::DotOperandEncodingAttr::get(ctx, bLayout.getOpIdx(), + newDLayout, 0); + + insertCasts(dot.getOperation(), {newALayout, newBLayout, newDLayout}, + {newDLayout}); + }); + + return true; +} + +bool CTAPlanner::processReduce(triton::FuncOp &funcOp) { + ModuleOp mod = funcOp->getParentOfType(); + unsigned numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + + funcOp.walk([&](triton::ReduceOp reduce) { + MLIRContext *context = reduce.getContext(); + Value src = reduce.getOperands()[0]; + unsigned axis = reduce.getAxis(); + + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + auto srcLayout = srcTy.getEncoding(); + + auto rank = srcShape.size(); + auto order = ttg::getOrder(srcLayout); + auto sizePerThread = ttg::getSizePerThread(srcLayout); + auto CTAOrder = ttg::getCTAOrder(srcLayout); + + llvm::SmallVector CTAsPerCGA(rank, 0); + unsigned remainingCTAs = numCTAs; + for (int i = rank - 1; i >= 0; --i) { + unsigned dim = order[i]; + if (dim == axis) { + CTAsPerCGA[dim] = 1; + } else { + CTAsPerCGA[dim] = std::min(srcShape[dim] / sizePerThread[dim], + remainingCTAs); + remainingCTAs /= CTAsPerCGA[dim]; + } + } + + for (int i = rank - 1; i >= 0; --i) { + unsigned dim = order[i]; + if (dim != axis) { + CTAsPerCGA[dim] *= remainingCTAs; + break; + } + } + + llvm::SmallVector CTASplitNum = CTAsPerCGA; + + // If numCTAs > 1 and the only dimension is the reduced dimension, after the + // above two for-loops, CTAsPerCGA = [0] and remainingCTAs = numCTAs. We set + // CTAsPerCGA[0] = numCTAs and keep CTASplitNum[0] = 1 to ensure that no + // cross-CTA reduction is required, although this will introduce duplicated + // calculation + if (remainingCTAs > 0) + CTAsPerCGA[order[rank - 1]] *= remainingCTAs; + + auto CTALayout = + ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + if (!tiled) + setTiling(CTALayout.getCTAsPerCGA()); + auto newSrcLayout = replaceCTALayout(srcLayout, srcShape, CTALayout); + auto newResultLayout = + ttg::SliceEncodingAttr::get(context, axis, newSrcLayout); + unsigned numOperands = reduce.getNumOperands(); + SmallVector newSrcLayoutVec(numOperands, newSrcLayout); + SmallVector newResultLayoutVec(numOperands, newResultLayout); + + insertCasts(reduce.getOperation(), newSrcLayoutVec, newResultLayoutVec); + }); + return true; +} + +void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) { + assert(!tiled && "CTA tiling is already determinted"); + + llvm::SmallVector stores; + funcOp.walk([&](Operation *op) { + if (llvm::isa( + op)) + stores.push_back(op); + }); + assert(stores.size() > 0 && "Cannot find store-like ops"); + + ttg::CTALayoutAttr CTALayout; + for (Operation *store : stores) { + if (auto tensorTy = + dyn_cast(store->getOperand(0).getType())) { + if (!tiled) { + // Use CTA tiling of the first store-like op as global CTA tiling + CTALayout = ttg::getCTALayout(tensorTy.getEncoding()); + setTiling(CTALayout.getCTAsPerCGA()); + } + auto newLayout = replaceCTALayout(tensorTy.getEncoding(), + tensorTy.getShape(), CTALayout); + processElementwise(store, newLayout); + } + } + + // If all store-like ops are processing scalar values and no ReduceOp is + // found, we can conclude that this is an all-scalar computation, since + // ReduceOp is the only op that converts tensor values to scalar values. + if (!tiled) + setTiling({1, 1, 1}); +} + +bool CTAPlanner::propagate(CastOp cast) { + return isBackward(cast) ? propagateBackward(cast) : propagateForward(cast); +} + +bool CTAPlanner::propagateBackward(CastOp cast) { + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + unsigned numUsers = getNumUsers(input); + if (numUsers == 0) { + llvm::report_fatal_error("Unreachable branch"); + return false; + } else if (numUsers == 1) { + Type outTy = output.getType(); + if (auto ptrTy = dyn_cast(outTy)) + outTy = ptrTy.getPointeeType(); + Attribute layout = mlir::cast(outTy).getEncoding(); + Operation *op = input.getDefiningOp(); + if (op == nullptr) { + assert(isa(input) && + "Unexpected Value without defining op"); + processBlockArgBackward(llvm::cast(input), cast); + } else if (auto prevCast = llvm::dyn_cast(op)) { + eliminateAdjacentCasts(prevCast, cast); + } else if (isLoadStoreOp(op)) { + processLoadStore(op, layout); + } else if (isElementwiseOp(op)) { + processElementwise(op, layout); + } else if (auto constant = llvm::dyn_cast(op)) { + processConstant(constant, layout); + } else if (auto splat = llvm::dyn_cast(op)) { + processSplat(splat, layout); + } else if (auto makeRange = llvm::dyn_cast(op)) { + processMakeRange(makeRange, layout); + } else if (auto makeTensorPtr = + llvm::dyn_cast(op)) { + processMakeTensorPtr(makeTensorPtr, layout); + } else if (llvm::isa(op)) { + // ptr operand and result have the same layout, while other operands are + // scalar values + processElementwise(op, layout); + } else if (auto broadcast = llvm::dyn_cast(op)) { + processBroadcast(broadcast, layout); + } else if (auto expandDims = llvm::dyn_cast(op)) { + processExpandDimsBackward(expandDims, layout); + } else if (auto ifOp = llvm::dyn_cast(op)) { + processIfOpBackward(ifOp, cast); + } else if (auto forOp = llvm::dyn_cast(op)) { + processForOpBackward(forOp, cast); + } else if (auto convertLayout = llvm::dyn_cast(op)) { + return processConvertLayoutBackward(convertLayout, cast); + } else { + // Keep original layouts. This may result in a loss of performance. + return processOpFallback(op); + } + return true; + } else { + return processMultiUsersBackward(input, cast); + } +} + +bool CTAPlanner::propagateForward(CastOp cast) { + Value input = cast.getOperand(0); + Value output = cast.getResult(0); + unsigned numUsers = getNumUsers(output); + if (numUsers == 0) { + cast.erase(); + } else if (numUsers == 1) { + Type inTy = input.getType(); + if (auto ptrTy = dyn_cast(inTy)) + inTy = ptrTy.getPointeeType(); + Attribute layout = mlir::cast(inTy).getEncoding(); + Operation *op = *output.user_begin(); + if (auto nextCast = llvm::dyn_cast(op)) { + eliminateAdjacentCasts(cast, nextCast); + } else if (isLoadStoreOp(op)) { + processLoadStore(op, layout); + } else if (isElementwiseOp(op)) { + processElementwise(op, layout); + } else if (llvm::isa(op)) { + // ptr operand and result have the same layout, while other operands are + // scalar values + processElementwise(op, layout); + } else if (auto convertLayout = llvm::dyn_cast(op)) { + return processConvertLayoutForward(convertLayout, cast); + } else if (auto forOp = llvm::dyn_cast(op)) { + processForOpForward(forOp, cast); + } else if (auto yieldOp = llvm::dyn_cast(op)) { + processYieldOpForward(yieldOp, cast); + } else { + // Keep original layouts. This may result in a loss of performance. + return processOpFallback(op); + } + } else { + processMultiUsersForward(output, cast); + } + return true; +} + +void CTAPlanner::eraseCastOp(CastOp cast) { + Value output = cast.getResult(0); + assert(getNumUsers(output) == 0 && + "Cannot erase CastOp because it is still in use"); + cast.erase(); +} + +void CTAPlanner::eraseCastOpFromQueue(CastOp cast) { + eraseCastOpsFromQueue({cast}); +} + +void CTAPlanner::eraseCastOpsFromQueue(llvm::ArrayRef casts) { + llvm::DenseSet erased; + for (CastOp cast : casts) { + eraseCastOp(cast); + erased.insert(cast); + } + + decltype(queue) tempQueue; + std::swap(queue, tempQueue); + + // This is only a naive implementation. Should refactor with linked-list. + while (!tempQueue.empty()) { + auto cast = tempQueue.front(); + tempQueue.pop(); + if (!erased.contains(cast)) + queue.push(cast); + } +} + +void CTAPlanner::insertCasts(Operation *op, + llvm::ArrayRef newOperandLayouts, + llvm::ArrayRef newResultLayouts) { + assert(op->getNumOperands() == newOperandLayouts.size() && + "NumOperands mismatched"); + assert(op->getNumResults() == newResultLayouts.size() && + "NumResults mismatched"); + + Location loc = op->getLoc(); + OpBuilder builder(op->getContext()); + + builder.setInsertionPoint(op); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto operandTy = operand.getType(); + if (triton::isTensorOrTensorPointerType(operandTy)) { + operandTy = replaceLayout(operandTy, newOperandLayouts[i]); + auto cast = markBackward(builder.create(loc, operandTy, operand)); + op->setOperand(i, cast.getResult(0)); + queue.push(cast); + } + } + + builder.setInsertionPointAfter(op); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + auto resultTy = result.getType(); + if (triton::isTensorOrTensorPointerType(resultTy)) { + resultTy = replaceLayout(resultTy, newResultLayouts[i]); + auto cast = + markForward(builder.create(loc, result.getType(), result)); + result.setType(resultTy); + result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation()); + queue.push(cast); + } + } +} + +void CTAPlanner::eliminateAdjacentCasts(CastOp cast0, CastOp cast1) { + assert(cast0.getResult(0) == cast1.getOperand(0) && + "The two casts are not adjacent"); + assert(isForward(cast0) && isBackward(cast1) && + "Expected pattern of adjacent casts: forward + backward"); + + Value input = cast0.getOperand(0); + Value output = cast1.getResult(0); + + if (input.getType() == output.getType()) { + output.replaceAllUsesWith(input); + eraseCastOpsFromQueue({cast1, cast0}); + } else { + OpBuilder builder(cast1.getOperation()); + auto cvt = builder.create(cast1.getLoc(), + output.getType(), input); + output.replaceAllUsesWith(cvt.getResult()); + eraseCastOpsFromQueue({cast1, cast0}); + } +} + +bool CTAPlanner::isLoadStoreOp(Operation *op) const { + return llvm::isa(op); +} + +bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) { + // Special logic for: + // LoadOp -> SliceLayout + // Transform to: + // LoadOp -> originalLayout -> ConvertLayout(DSmem) -> SliceLayout + if (auto sliceLayout = mlir::dyn_cast(layout)) { + auto dim = sliceLayout.getDim(); + auto CTAsPerCGA = ttg::getCTAsPerCGA(sliceLayout.getParent()); + if (CTAsPerCGA[dim] > 1) { + // Find an input or output value of LoadOp or StoreOp to get its layout + Value val = + op->getNumResults() > 0 ? op->getResult(0) : op->getOperand(0); + Attribute originalLayout = + cast(val.getType()).getEncoding(); + // Insert casts using originalLayout. Adjacent casts will be eliminated + // and generate a ConvertLayoutOp with DSmem access + return processLoadStore(op, originalLayout); + } + } + + auto CTALayout = ttg::getCTALayout(layout); + + llvm::SmallVector newOperandLayouts; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + auto type = op->getOperand(i).getType(); + if (auto ptrTy = dyn_cast(type)) + type = ptrTy.getPointeeType(); + auto tensorTy = cast(type); + auto newLayout = replaceCTALayout(tensorTy.getEncoding(), + tensorTy.getShape(), CTALayout); + newOperandLayouts.push_back(newLayout); + } + + llvm::SmallVector newResultLayouts; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + auto type = op->getResult(i).getType(); + if (auto ptrTy = dyn_cast(type)) + type = ptrTy.getPointeeType(); + auto tensorTy = cast(type); + auto newLayout = replaceCTALayout(tensorTy.getEncoding(), + tensorTy.getShape(), CTALayout); + newResultLayouts.push_back(newLayout); + } + + insertCasts(op, newOperandLayouts, newResultLayouts); + return true; +} + +bool CTAPlanner::isElementwiseOp(Operation *op) const { + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (llvm::isa(op)) + return true; + if (auto externElementwiseOp = dyn_cast(op)) + return externElementwiseOp.getPure(); + if (llvm::isa(op)) + return true; + return false; +} + +bool CTAPlanner::processElementwise(Operation *op, Attribute layout) { + llvm::SmallVector newOperandLayouts(op->getNumOperands(), layout); + llvm::SmallVector newResultLayouts(op->getNumResults(), layout); + insertCasts(op, newOperandLayouts, newResultLayouts); + return true; +} + +bool CTAPlanner::processConstant(arith::ConstantOp constant, Attribute layout) { + if (auto tensorTy = dyn_cast(constant.getType())) { + if (auto attr = dyn_cast(constant.getValue())) { + + auto newTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), layout); + constant.setValueAttr( + SplatElementsAttr::get(newTensorTy, attr.getSplatValue())); + } + } + insertCasts(constant.getOperation(), {}, {layout}); + return true; +} + +bool CTAPlanner::processSplat(triton::SplatOp splat, Attribute layout) { + insertCasts(splat.getOperation(), {{}}, {layout}); + return true; +} + +bool CTAPlanner::processMakeRange(triton::MakeRangeOp makeRange, + Attribute layout) { + insertCasts(makeRange.getOperation(), {}, {layout}); + return true; +} + +bool CTAPlanner::processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr, + Attribute layout) { + // All inputs of `makeTensorPtr` are scalar types + llvm::SmallVector dummyInAttrs(makeTensorPtr.getNumOperands(), {}); + insertCasts(makeTensorPtr.getOperation(), dummyInAttrs, {layout}); + return true; +} + +bool CTAPlanner::processBroadcast(triton::BroadcastOp broadcast, + Attribute layout) { + insertCasts(broadcast.getOperation(), {layout}, {layout}); + return true; +} + +bool CTAPlanner::processExpandDimsBackward(triton::ExpandDimsOp expandDims, + Attribute newResultLayout) { + auto newSrcLayout = ttg::SliceEncodingAttr::get( + newResultLayout.getContext(), expandDims.getAxis(), newResultLayout); + insertCasts(expandDims.getOperation(), {newSrcLayout}, {newResultLayout}); + return true; +} + +bool CTAPlanner::processExpandDimsForward(triton::ExpandDimsOp expandDims, + Attribute newSrcLayout) { + llvm::report_fatal_error("processExpandDimsForward not implemented yet"); + return true; +} + +bool CTAPlanner::processConvertLayoutBackward( + ttg::ConvertLayoutOp convertLayout, CastOp cast) { + Value src = convertLayout.getSrc(); + Value result = convertLayout.getResult(); + assert(getNumUsers(result) == 1 && + "Expect to call processMultiUsersBackward first"); + result.replaceAllUsesWith(src); + convertLayout.erase(); + queue.push(cast); + return true; +} + +bool CTAPlanner::processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout, + CastOp cast) { + Value src = convertLayout.getSrc(); + Value result = convertLayout.getResult(); + assert(getNumUsers(src) == 1 && + "Expect to call processMultiUsersForward first"); + src.setType(result.getType()); + result.replaceAllUsesWith(src); + convertLayout.erase(); + queue.push(cast); + return true; +} + +bool CTAPlanner::processIfOp(scf::IfOp ifOp, int index, const Type &newType) { + // Check index + assert(index < ifOp.getNumResults() && "Invalid result index of IfOp"); + assert(index < ifOp.thenYield().getNumOperands() && + "Invalid operand index of YieldOp"); + assert(index < ifOp.elseYield().getNumOperands() && + "Invalid operand index of YieldOp"); + + Location loc = ifOp.getLoc(); + OpBuilder builder(ifOp.getContext()); + + // Insert forward cast after ifOp + Value result = ifOp.getResult(index); + builder.setInsertionPointAfter(ifOp.getOperation()); + auto newCast = + markForward(builder.create(loc, result.getType(), result)); + result.setType(newType); + result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + // Insert backward casts before yield + for (scf::YieldOp yield : {ifOp.thenYield(), ifOp.elseYield()}) { + Value yieldSrc = yield.getOperand(index); + builder.setInsertionPoint(yield.getOperation()); + newCast = markBackward(builder.create(loc, newType, yieldSrc)); + yield->setOperand(index, newCast.getResult(0)); + queue.push(newCast); + } + + return true; +} + +bool CTAPlanner::processForOp(scf::ForOp forOp, int index, + const Type &newType) { + Block *body = forOp.getBody(); + auto yield = llvm::cast(forOp.getBody()->getTerminator()); + + // Check index + assert(index + forOp.getNumControlOperands() < forOp.getNumOperands() && + "Invalid operand index of ForOp"); + assert(index + forOp.getNumInductionVars() < body->getNumArguments() && + "Invalid block arg index of ForOp"); + assert(index < yield.getNumOperands() && "Invalid operand index of YieldOp"); + assert(index < forOp.getNumResults() && "Invalid result index of IfOp"); + + Location loc = forOp.getLoc(); + OpBuilder builder(forOp.getContext()); + + // Insert backward cast before forOp + OpOperand &operand = + forOp->getOpOperand(index + forOp.getNumControlOperands()); + builder.setInsertionPoint(forOp.getOperation()); + auto newCast = + markBackward(builder.create(loc, newType, operand.get())); + operand.set(newCast.getResult(0)); + queue.push(newCast); + + // Insert forward cast after block arg + Value arg = body->getArgument(index + forOp.getNumInductionVars()); + builder.setInsertionPointToStart(body); + newCast = markForward(builder.create(loc, arg.getType(), arg)); + arg.setType(newType); + arg.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + // Insert backward cast before yield + Value yieldSrc = yield.getOperand(index); + builder.setInsertionPoint(yield.getOperation()); + newCast = markBackward(builder.create(loc, newType, yieldSrc)); + yield->setOperand(index, newCast.getResult(0)); + queue.push(newCast); + + // Insert forward cast after forOp + Value result = forOp.getResult(index); + builder.setInsertionPointAfter(forOp.getOperation()); + newCast = markForward(builder.create(loc, result.getType(), result)); + result.setType(newType); + result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation()); + queue.push(newCast); + + return true; +} + +int findResultIndex(Operation *op, Value result) { + for (int i = 0; i < op->getNumResults(); ++i) + if (op->getResult(i) == result) + return i; + llvm::report_fatal_error("Invalid index of op result"); + return -1; +} + +bool CTAPlanner::processIfOpBackward(scf::IfOp ifOp, CastOp cast) { + int index = findResultIndex(ifOp.getOperation(), cast.getOperand(0)); + auto newType = cast.getResult(0).getType(); + return processIfOp(ifOp, index, newType); +} + +bool CTAPlanner::processForOpBackward(scf::ForOp forOp, CastOp cast) { + int index = findResultIndex(forOp.getOperation(), cast.getOperand(0)); + auto newType = cast.getResult(0).getType(); + return processForOp(forOp, index, newType); +} + +bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) { + if (auto forOp = llvm::dyn_cast(arg.getOwner()->getParentOp())) { + int index = int(arg.getArgNumber()) - forOp.getNumInductionVars(); + auto newType = cast.getResult(0).getType(); + return processForOp(forOp, index, newType); + } else { + llvm::report_fatal_error("Unexpected parent op of block argument"); + return true; + } +} + +bool CTAPlanner::processForOpForward(scf::ForOp forOp, CastOp cast) { + int index = cast.getResult(0).use_begin()->getOperandNumber() - + forOp.getNumControlOperands(); + auto newType = cast.getOperand(0).getType(); + return processForOp(forOp, index, newType); +} + +bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) { + int index = cast.getResult(0).use_begin()->getOperandNumber(); + auto newType = cast.getOperand(0).getType(); + if (auto ifOp = llvm::dyn_cast(yieldOp->getParentOp())) + return processIfOp(ifOp, index, newType); + else if (auto forOp = llvm::dyn_cast(yieldOp->getParentOp())) + return processForOp(forOp, index, newType); + else + llvm::report_fatal_error("Unexpected parent op of YieldOp"); + return true; +} + +bool CTAPlanner::processOpFallback(Operation *op) { + Location loc = op->getLoc(); + OpBuilder builder(op->getContext()); + + builder.setInsertionPoint(op); + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + auto operandTy = operand.getType(); + if (triton::isTensorOrTensorPointerType(operandTy)) { + auto cast = markBackward(builder.create(loc, operandTy, operand)); + op->setOperand(i, cast.getResult(0)); + queue.push(cast); + } + } + + builder.setInsertionPointAfter(op); + for (unsigned i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + auto resultTy = result.getType(); + if (triton::isTensorOrTensorPointerType(resultTy)) { + auto cast = markForward(builder.create(loc, resultTy, result)); + result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation()); + queue.push(cast); + } + } + + return true; +} + +bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) { + Location loc = input.getLoc(); + OpBuilder builder(input.getContext()); + + llvm::DenseMap> typeToIndices; + for (OpOperand &operand : input.getUses()) { + auto brotherCast = llvm::dyn_cast(operand.getOwner()); + if (!brotherCast) { + if (stepUnchanged <= queue.size()) + return false; + builder.setInsertionPoint(operand.getOwner()); + brotherCast = markBackward( + builder.create(loc, cast.getResult(0).getType(), input)); + auto newCast = markForward(builder.create( + loc, input.getType(), brotherCast.getResult(0))); + operand.set(newCast.getResult(0)); + queue.push(brotherCast); + queue.push(newCast); + } + auto type = brotherCast.getResult(0).getType(); + typeToIndices[type].push_back(brotherCast); + } + + bool first = true; + for (auto it : typeToIndices) { + Type &type = it.first; + llvm::SmallVector &casts = it.second; + Value newInput = input; + if (!first) { + if (Operation *defOp = input.getDefiningOp()) { + builder.setInsertionPointAfter(defOp); + Operation *clonedOp = builder.clone(*defOp); + newInput = clonedOp->getResult(0); + } else { + llvm::report_fatal_error("Layout conflict for block arg"); // TODO + return false; + } + } + first = false; + if (Operation *defOp = newInput.getDefiningOp()) { + builder.setInsertionPointAfter(defOp); + } else { + assert(isa(newInput) && + "Unexpected Value without defining op"); + builder.setInsertionPointToStart( + llvm::cast(newInput).getOwner()); + } + auto newCast = markBackward(builder.create(loc, type, newInput)); + queue.push(newCast); + auto newResult = newCast.getResult(0); + for (CastOp &brotherCast : casts) { + brotherCast.getResult(0).replaceAllUsesWith(newResult); + eraseCastOpFromQueue(brotherCast); + } + } + return true; +} + +bool CTAPlanner::processMultiUsersForward(Value castResult, CastOp cast) { + Value castSrc = cast.getOperand(0); + + Location loc = cast.getLoc(); + OpBuilder builder(cast.getContext()); + builder.setInsertionPointAfter(cast.getOperation()); + + while (!castResult.use_empty()) { + auto newCast = + markForward(builder.create(loc, castResult.getType(), castSrc)); + castResult.use_begin()->set(newCast.getResult(0)); + queue.push(newCast); + } + + eraseCastOp(cast); + return true; +} + +struct PlanCTAPass : public TritonGPUPlanCTAPassBase { + PlanCTAPass(ttng::ClusterInfo *clusterInfo_ = nullptr) + : clusterInfo(clusterInfo_) {} + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // Skip PlanCTAPass when numCTAs == 1 + if (ttg::TritonGPUDialect::getNumCTAs(mod) == 1) + return; + + mod.walk([&](triton::FuncOp funcOp) { + CTAPlanner planner(clusterInfo); + planner.run(funcOp); + + // FIXME: Clone funcOp so that the IR change can be identified after + // PlanCTAPass. Without this, the change after PlanCTAPass will not be + // displayed when MLIR_ENABLE_DUMP=1. This is not reasonable and should + // be fixed later. + OpBuilder builder(funcOp); + builder.clone(*funcOp.getOperation()); + funcOp.erase(); + }); + } + + ttng::ClusterInfo *clusterInfo; +}; + +} // namespace + +std::unique_ptr +mlir::createTritonNvidiaGPUPlanCTAPass(ttng::ClusterInfo *clusterInfo) { + return std::make_unique(clusterInfo); +} + +/* TODO + * - Use ConvertLayoutOp instead of UnrealizedConversionCastOp. + * - Move PlanCTAPass to the front of CoalescePass. + * - Design better tiling strategy for DotOp and ReduceOp. + * - Consider cases where there are more than one DotOps. + * - Use better data structure for erasing CastOps from queue (linked list?). + * - Process eliminable CastOps in higher priority. + * - Fix the clone func bug in PlanCTAPass::runOnOperation. + * - Add some comments to introduce the overall idea of this pass. + * - Add some lit tests for this pass. + */ diff --git a/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp new file mode 100644 index 000000000..58e2888b7 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -0,0 +1,116 @@ +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +#include + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +namespace { + +using namespace mlir; +using namespace triton; +using namespace triton::gpu; +using namespace triton::nvidia_gpu; + +class TMALoadLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto tensorType = op.getResult().getType(); + auto order = getOrder(tensorType.getEncoding()); + auto ctaLayout = getCTALayout(tensorType.getEncoding()); + Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, + 1, order, ctaLayout); + if (tensorType.getRank() > 1) { + encoding = SharedEncodingAttr::get( + tensorType.getContext(), tensorType.getShape(), order, ctaLayout, + tensorType.getElementType()); + } + MemDescType memDescType = + MemDescType::get(tensorType.getShape(), tensorType.getElementType(), + encoding, /*mutableMemory=*/true); + Value alloc = rewriter.create(loc, memDescType, Value()); + auto barrierCTALayout = CTALayoutAttr::get( + /*context=*/tensorType.getContext(), /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = SharedEncodingAttr::get(tensorType.getContext(), 1, + 1, 1, {0}, barrierCTALayout); + MemDescType barrierMemDescType = MemDescType::get( + {1}, rewriter.getI64Type(), barrierEncoding, /*mutableMemory=*/true); + Value barrierAlloc = + rewriter.create(loc, barrierMemDescType, Value()); + rewriter.create(loc, barrierAlloc, 1); + int sizeInBytes = product(tensorType.getShape()) * + tensorType.getElementType().getIntOrFloatBitWidth() / 8; + Value pred = rewriter.create(loc, 1, 1); + rewriter.create(loc, barrierAlloc, + sizeInBytes, pred); + rewriter.create( + loc, op.getDescPtr(), op.getIndices(), barrierAlloc, alloc, pred); + Value phase = rewriter.create(loc, 0, 32); + rewriter.create(loc, barrierAlloc, phase); + rewriter.create(loc, barrierAlloc); + rewriter.replaceOpWithNewOp(op, op.getType(), alloc); + return success(); + } +}; + +class TMAStoreLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExperimentalDescriptorStoreOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto tensorType = op.getSrc().getType(); + auto order = getOrder(tensorType.getEncoding()); + auto ctaLayout = getCTALayout(tensorType.getEncoding()); + Attribute encoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, + 1, order, ctaLayout); + if (tensorType.getRank() > 1) { + encoding = SharedEncodingAttr::get( + tensorType.getContext(), tensorType.getShape(), order, ctaLayout, + tensorType.getElementType()); + } + MemDescType memDescType = + MemDescType::get(tensorType.getShape(), tensorType.getElementType(), + encoding, /*mutableMemory=*/true); + Value alloc = rewriter.create(loc, memDescType, op.getSrc()); + rewriter.create(loc, false); + rewriter.create( + loc, op.getDescPtr(), op.getIndices(), alloc); + rewriter.create(loc, 0); + rewriter.eraseOp(op); + return success(); + } +}; + +class TritonNvidiaGPUTMALoweringPass + : public TritonNvidiaGPUTMALoweringPassBase< + TritonNvidiaGPUTMALoweringPass> { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::createTritonNvidiaGPUTMALoweringPass() { + return std::make_unique(); +} diff --git a/third_party/xpu/lib/Dialect/TritonXPU/CMakeLists.txt b/third_party/xpu/lib/Dialect/TritonXPU/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/third_party/xpu/lib/Dialect/TritonXPU/IR/CMakeLists.txt b/third_party/xpu/lib/Dialect/TritonXPU/IR/CMakeLists.txt new file mode 100644 index 000000000..bf76aa3aa --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +add_triton_library(TritonXPUIR + Dialect.cpp + Ops.cpp + + DEPENDS + TritonXPUTableGen + TritonXPUAttrDefsIncGen +) diff --git a/third_party/xpu/lib/Dialect/TritonXPU/IR/Dialect.cpp b/third_party/xpu/lib/Dialect/TritonXPU/IR/Dialect.cpp new file mode 100644 index 000000000..1abbc293e --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/IR/Dialect.cpp @@ -0,0 +1,574 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Dialect/Triton/IR/Dialect.h" // triton::PointType [Ops.cpp.inc] +#include "mlir/IR/DialectImplementation.h" // DialectAsmParser + +// clang-format off +#include "triton/Dialect/TritonXPU/IR/Dialect.h" // before cpp.inc +#include "triton/Dialect/TritonXPU/IR/Dialect.cpp.inc" +// clang-format on + +#include "triton/Dialect/Triton/IR/Utility.h" // ceil +#include "llvm/ADT/TypeSwitch.h" // TypeSwitch + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace triton { +namespace xpu { + +SmallVector getSizePerCore(Attribute layout) { + if (auto tritonXPUAttr = mlir::dyn_cast(layout)) { + return tritonXPUAttr.getSizePerCoreInterface(); + } + + llvm::report_fatal_error("getSizePerCoreInterface not implemented"); + return SmallVector(); +} + +SmallVector getCoresPerGroup(Attribute layout) { + if (auto tritonXPUAttr = mlir::dyn_cast(layout)) { + return tritonXPUAttr.getCoresPerGroupInterface(); + } + + llvm::report_fatal_error("getCoresPerGroupInterface not implemented"); + return SmallVector(); +} + +SmallVector getGroupsPerCluster(Attribute layout) { + if (auto tritonXPUAttr = mlir::dyn_cast(layout)) { + return tritonXPUAttr.getGroupsPerClusterInterface(); + } + + llvm::report_fatal_error("getGroupsPerClusterInterface not implemented"); + return SmallVector(); +} + +SmallVector getCoresPerCluster(Attribute layout) { + if (auto tritonXPUAttr = mlir::dyn_cast(layout)) { + return tritonXPUAttr.getCoresPerClusterInterface(); + } + + llvm::report_fatal_error("getCoresPerClusterInterface not implemented"); + return SmallVector(); +} + +unsigned getTotalElemsPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || isa(type)) + return 1; + auto tensorType = cast(type); + return getTotalElemsPerThread(tensorType.getEncoding(), tensorType.getShape(), + tensorType.getElementType()); +} + +unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, + Type eltTy) { + if (auto tritonXPUAttr = mlir::dyn_cast(layout)) { + return tritonXPUAttr.getTotalElemsPerThread(shape, eltTy); + } else if (auto tritonGPUAttr = + mlir::dyn_cast(layout)) { + return tritonGPUAttr.getTotalElemsPerThread(shape, eltTy); + } else { + llvm::report_fatal_error("getTotalElemsPerThread not implemented"); + return 0; + } +} + +unsigned getGroupSize(Attribute layout) { + unsigned size = 1; + auto coresPerGroup = getCoresPerGroup(layout); + for (auto e : coresPerGroup) { + size *= e; + } + return size; +} + +// 1 element per thread +// order = reverse(arange(rank)) +triton::xpu::ClusterLayoutAttr +getDefaultClusterEncoding(MLIRContext *context, ArrayRef shape, + uint32_t buffer_size, uint32_t core_num) { + int rank = shape.size(); + llvm::SmallVector order(rank); + std::iota(order.begin(), order.end(), 0); + // TODO[dyq]: why blockEncoding reverse order in triton 3.0 + triton::xpu::ClusterLayoutAttr encoding = triton::xpu::ClusterLayoutAttr::get( + context, shape, order, buffer_size, core_num); + return encoding; +} + +SmallVector +getCoresPerClusterWithUniqueData(Attribute layout, + ArrayRef tensorShape) { + if (auto sliceLayout = + mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentCoresPerCluster = + getCoresPerClusterWithUniqueData(parentLayout, parentShape); + SmallVector coresPerCluster = parentCoresPerCluster; + coresPerCluster.erase(coresPerCluster.begin() + sliceLayout.getDim()); + return coresPerCluster; + } + auto coresPerCluster = getCoresPerCluster(layout); + assert(coresPerCluster.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < coresPerCluster.size(); i++) { + auto sizePerCore = getSizePerCore(layout)[i]; + auto maxCoresPerDim = ceil(tensorShape[i], sizePerCore); + coresPerCluster[i] = std::min(coresPerCluster[i], maxCoresPerDim); + } + + return coresPerCluster; +} + +SmallVector +getCoresPerGroupWithUniqueData(Attribute layout, + ArrayRef tensorShape) { + if (auto sliceLayout = + mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentCoresPerGroup = + getCoresPerGroupWithUniqueData(parentLayout, parentShape); + SmallVector coresPerGroup = parentCoresPerGroup; + coresPerGroup.erase(coresPerGroup.begin() + sliceLayout.getDim()); + return coresPerGroup; + } + auto coresPerGroup = getCoresPerGroup(layout); + assert(coresPerGroup.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < coresPerGroup.size(); i++) { + coresPerGroup[i] = std::min(coresPerGroup[i], tensorShape[i]); + } + + return coresPerGroup; +} + +SmallVector getUniqueContigPerCore(Attribute layout, + ArrayRef shape) { + // If slice layout, call recursively on parent layout, and drop + // sliced dim + if (auto sliceLayout = + mlir::dyn_cast(layout)) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(shape); + auto parentUniqueContigPerCore = + triton::xpu::getUniqueContigPerCore(parentLayout, parentShape); + parentUniqueContigPerCore.erase(parentUniqueContigPerCore.begin() + + sliceLayout.getDim()); + return parentUniqueContigPerCore; + } + // Base case + auto rank = shape.size(); + SmallVector ret(rank); + auto contigPerCore = getSizePerCore(layout); + assert(contigPerCore.size() == rank && "Unexpected contigPerCore size"); + for (int d = 0; d < rank; ++d) { + ret[d] = std::min(shape[d], contigPerCore[d]); + } + return ret; +} + +LogicalResult +ClusterLayoutAttr::verify(function_ref emitError, + ArrayRef sizePerCore, + ArrayRef coresPerGroup, + ArrayRef groupsPerCluster, + ArrayRef order, bool isReduceOpt) { + if (sizePerCore.size() != coresPerGroup.size() || + coresPerGroup.size() != groupsPerCluster.size() || + groupsPerCluster.size() != order.size()) { + return emitError() << "sizePerCore, coresPerGroup, groupsPerCluster, and " + "order must all have the same rank."; + } + + if (!isPermutationOfIota(order)) { + return emitError() + << "order must be a permutation of 0..(rank-1), but was [" << order + << "]"; + } + return success(); +} + +} // namespace xpu +} // namespace triton +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Parse Utility +//===----------------------------------------------------------------------===// + +static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, + unsigned &value, StringRef desc) { + auto intAttr = mlir::dyn_cast(attr); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer type in ") + << desc; + return failure(); + } + if (intAttr.getType().isSignedInteger()) { + int64_t attrVal = intAttr.getSInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else if (intAttr.getType().isSignlessInteger()) { + int64_t attrVal = intAttr.getInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else { + value = intAttr.getUInt(); + } + return success(); +} + +static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, + bool &value, StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + +// parse an array of integers +static LogicalResult parseIntArrayAttr(AsmParser &parser, + const NamedAttribute &attr, + SmallVector &res, + StringRef desc) { + auto arrayAttr = mlir::dyn_cast(attr.getValue()); + if (!arrayAttr) { + parser.emitError(parser.getNameLoc(), "expected an array for ") << desc; + return failure(); + } + for (Attribute i : arrayAttr) { + unsigned value; + if (parseIntAttrValue(parser, i, value, desc).failed()) + return failure(); + res.push_back(value); + } + return success(); +}; + +static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, + unsigned &value, StringRef desc) { + return parseIntAttrValue(parser, attr.getValue(), value, desc); +}; + +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr.getValue(), value, desc); +}; + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonXPU/IR/TritonXPUAttrInterfaces.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonXPU/IR/TritonXPUAttrDefs.cpp.inc" + +SmallVector +triton::xpu::ClusterLayoutAttr::getSizePerCoreInterface() const { + return SmallVector(getSizePerCore()); +} +SmallVector +triton::xpu::ClusterLayoutAttr::getCoresPerGroupInterface() const { + return SmallVector(getCoresPerGroup()); +} +SmallVector +triton::xpu::ClusterLayoutAttr::getGroupsPerClusterInterface() const { + return SmallVector(getGroupsPerCluster()); +} +SmallVector +triton::xpu::ClusterLayoutAttr::getCoresPerClusterInterface() const { + SmallVector coresPerCluster; + for (unsigned d = 0, n = getOrder().size(); d < n; ++d) + coresPerCluster.push_back(getCoresPerGroup()[d] * getGroupsPerCluster()[d]); + return coresPerCluster; +} + +SmallVector +triton::xpu::ClusterLayoutAttr::getElemsPerThread(ArrayRef shape, + Type eltTy) const { + size_t rank = shape.size(); + auto sizePerCore = getSizePerCore(); + auto coresPerGroup = getCoresPerGroup(); + auto groupsPerCluster = getGroupsPerCluster(); + assert(rank == sizePerCore.size() && + "unexpected rank in BlockedEncodingAttr::getElemsPerThread"); + SmallVector elemsPerThread(rank); + for (size_t i = 0; i < rank; ++i) { + unsigned t = groupsPerCluster[i] * coresPerGroup[i]; + elemsPerThread[i] = ceil(shape[i], t); + } + return elemsPerThread; +} + +unsigned +triton::xpu::ClusterLayoutAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + +Attribute triton::xpu::ClusterLayoutAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + SmallVector sizePerCore; + SmallVector coresPerGroup; + SmallVector groupsPerCluster; + SmallVector order; + bool isReduceOpt; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "sizePerCore") { + if (parseIntArrayAttr(parser, attr, sizePerCore, + "number of elements per core") + .failed()) + return {}; + } else if (attr.getName() == "coresPerGroup") { + if (parseIntArrayAttr(parser, attr, coresPerGroup, + "number of cores per group") + .failed()) + return {}; + } else if (attr.getName() == "groupsPerCluster") { + if (parseIntArrayAttr(parser, attr, groupsPerCluster, + "number of groups per cluster") + .failed()) + return {}; + } else if (attr.getName() == "order") { + if (parseIntArrayAttr(parser, attr, order, "order").failed()) + return {}; + } else if (attr.getName() == "isReduceOpt") { + if (parseBool(parser, attr, isReduceOpt, "isReduceOpt").failed()) + return {}; + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } + } + + return parser.getChecked(parser.getContext(), sizePerCore, + coresPerGroup, groupsPerCluster, + order, isReduceOpt); +} + +void triton::xpu::ClusterLayoutAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "sizePerCore = [" << ArrayRef(getSizePerCore()) << "]" + << ", coresPerGroup = [" << ArrayRef(getCoresPerGroup()) << "]" + << ", groupsPerCluster = [" << ArrayRef(getGroupsPerCluster()) << "]" + << ", order = [" << getOrder() << "]" + << ", isReduceOpt = " << getIsReduceOpt() << "}>"; + + printer << "}>"; +} + +//===----------------------------------------------------------------------===// +// ASM Interface (i.e.: alias) +//===----------------------------------------------------------------------===// + +class TritonXPUOpAsmInterface : public OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (auto clusterAttr = + mlir::dyn_cast(attr)) { + os << "cluster"; + return AliasResult::FinalAlias; + } + return OpAsmDialectInterface::getAlias(attr, os); + } +}; + +struct TritonXPUInferLayoutInterface + : public triton::DialectInferLayoutInterface { + using DialectInferLayoutInterface::DialectInferLayoutInterface; + + LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const override { + resultEncoding = triton::gpu::SliceEncodingAttr::get( + getDialect()->getContext(), axis, operandEncoding); + return success(); + } + + // Infer the encoding of a tt.trans(x) given the encoding of x. + // + // Our goal is to choose an encoding so that the trans is a "nop". For + // example, in a blocked encoding, the same GPU threads hold the same + // elements, they're just "renamed" -- what was element [i,j] of the tensor is + // now element [j,i], but that element is held by the same GPU thread. + // + // For most properties of the encoding, we let + // outputEnc.prop = inputEnc.prop * trans.order, + // where `x * y` means we apply permutation y to x. + // + // This works because prop[i] tells you something about the i'th dimension of + // the tensor. (For example, sizePerThread[2] == 4 means that one GPU thread + // contains 4 elements along dim 2 of the tensor.) The transpose reorders the + // dimensions according to the perm trans.order, so we achieve our goal of + // having a "nop" transpose by reordering the values in the prop the same way. + // + // The big exception to this is the encoding's `order`. + // + // An encoding's order is a list of dimensions, from fastest moving (most + // minor) to slowest moving. Thus enc.order[i] does not tell you something + // about the i'th dimension of the tensor, and it would be disasterously + // incorrect to do enc.order * trans.order. + // + // But! If we invert enc.order, it *does* meet this criterion. For example, + // if enc.order = [2,0,1], inverse(enc.order) = [1,2,0]. If you stare at it, + // you'll see that inverse(enc.order)[i] == j means that dimension i is the + // j'th most minor. Therefore we can safely permute *this* by trans.order. + // + // Thus we have + // + // outputEnc.order = inverse(inverse(inputEnc.order) * trans.order) + // = inverse(trans.order) * inputEnc.order. + // + LogicalResult inferTransOpEncoding(Attribute operandEncoding, + ArrayRef order, // trans order + Attribute &resultEncoding) const override { + llvm_unreachable( + "TODO[dyq]: Add triton::xpu::GlobalEncodingAttr Calculation Logic"); + return failure(); // unhandled encoding + } + + LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const override { + auto sliceEncoding = + mlir::dyn_cast(operandEncoding); + if (!sliceEncoding) + return emitOptionalError( + location, "ExpandDimsOp operand encoding must be SliceEncodingAttr"); + if (sliceEncoding.getDim() != axis) + return emitOptionalError( + location, "Incompatible slice dimension for ExpandDimsOp operand"); + resultEncoding = sliceEncoding.getParent(); + return success(); + } + + LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const override { + llvm_unreachable("TODO[dyq]: XPUSDNN-CHECK Add " + "triton::xpu::GlobalEncodingAttr Calculation Logic"); + return failure(); + } + + LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const override { + llvm_unreachable("TODO[dyq]: XPUSDNN-CHECK Add " + "triton::xpu::GlobalEncodingAttr Calculation Logic"); + return failure(); + } + + // Given a src shape + encoding and a dst shape, our goal is to compute a dst + // encoding that makes the reshape a "nop". That is, if GPU thread [x,y,z] + // contains elements [a,b,c,d] before the reshape, it contains those same + // elements after the reshape, they're just "renamed". + // + // A dst encoding that satisfies this property does not exist for all inputs. + // Here are some positive and negative examples. + // + // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so + // dim 1 is the fastest-changing in the dst, but the src has the opposite + // order. + // - OK: 2x2x32 order=[1,0,2] -> 4x32. We choose dst order [0,1]. + // What's important is that the 2x2 dimensions appear in major-to-minor + // order. + // - NOT OK: 32x32 sizePerThread=[2,2] -> 1024. Thread 0 in the src + // contains elements [(0,0), (0,1), (1,0), and (1,1)]. We cannot express + // this with an encoding based on the dst shape. + // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will + // contain the same elements as before. + // + // Users of this function require that it is symmetrical: if + // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => + // srcEnc. + LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { + llvm_unreachable( + "TODO[dyq]: Add triton::xpu::GlobalEncodingAttr Calculation Logic"); + return failure(); + } + + LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const override { + llvm_unreachable( + "TODO[dyq]: Add triton::xpu::GlobalEncodingAttr Calculation Logic"); + return failure(); + } + + LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const override { + llvm_unreachable( + "TODO[dyq]: Add triton::xpu::GlobalEncodingAttr Calculation Logic"); + return failure(); + } +}; + +//===----------------------------------------------------------------------===// +// Dialect Initialization +//===----------------------------------------------------------------------===// + +void mlir::triton::xpu::TritonXPUDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonXPU/IR/TritonXPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST // declare +#include "triton/Dialect/TritonXPU/IR/Ops.cpp.inc" + >(); + + addInterfaces(); + addInterfaces(); +} + +#define GET_OP_CLASSES // define +#include "triton/Dialect/TritonXPU/IR/Ops.cpp.inc" + +// verify TritonXPU ops +LogicalResult +triton::xpu::TritonXPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/third_party/xpu/lib/Dialect/TritonXPU/IR/Ops.cpp b/third_party/xpu/lib/Dialect/TritonXPU/IR/Ops.cpp new file mode 100644 index 000000000..b0cfb2125 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/IR/Ops.cpp @@ -0,0 +1,292 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace triton { +namespace xpu { +//-- MakeRangeOp -- +// OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { +// // make_range(start, start + 1) -> constant(start) +// if (adaptor.getStart() + 1 == adaptor.getEnd()) { +// auto shapedType = cast(getType()); +// return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); +// } +// return {}; +// } + +LogicalResult MakeRangeOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start > end) { + return this->emitOpError() << "start must be less than or equal to end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start < ty.getShape()[0]) { // loopIdx change the verify logic + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must not exceed the size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32)) { + return this->emitOpError() << "returned tensor must have i32 elements"; + } + return success(); +} + +// //-- InterleaveOp -- +// OpFoldResult InterleaveOp::fold(FoldAdaptor adaptor) { +// // make_range(start, start + 1) -> constant(start) +// if (adaptor.getStart() + 1 == adaptor.getEnd()) { +// auto shapedType = cast(getType()); +// return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); +// } +// return {}; +// } + +LogicalResult InterleaveOp::verify() { + int64_t start = getStartAttr().getInt(); + int64_t end = getEndAttr().getInt(); + if (start > end) { + return this->emitOpError() << "start must be less than or equal to end"; + } + auto ty = getType(); + if (ty.getShape().size() != 1) { + return this->emitOpError() << "return type must be a 1D tensor"; + } + if (end - start < ty.getShape()[0]) { // loopIdx change the verify logic + return this->emitOpError() + << "number of elements in returned tensor, " << ty.getShape()[0] + << ", must not exceed the size of range [" << start << ", " << end + << "), which has " << end - start << " elements"; + } + if (!ty.getElementType().isInteger(32) && + !ty.getElementType().isInteger(64)) { + return this->emitOpError() << "returned tensor must have i32/i64 elements"; + } + return success(); +} + +//-- ReduceOp -- +static LogicalResult +inferReduceReturnShape(const RankedTensorType &argTy, const Type &retEltTy, + int axis, SmallVectorImpl &inferredReturnTypes) { + auto retShape = argTy.getShape().vec(); + retShape.erase(retShape.begin() + axis); + if (retShape.empty()) { + // 0d-tensor -> scalar + inferredReturnTypes.push_back(retEltTy); + } else { + // nd-tensor where n >= 1 + // infer encoding + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = + dyn_cast(&dialect); + if (inferLayoutInterface + ->inferReduceOpEncoding(argEncoding, axis, retEncoding) + .failed()) { + llvm::report_fatal_error("failed to infer layout for ReduceOp"); + return failure(); + } + } + // create type + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); + } + return success(); +} + +// Helpers for Reductions and Scans +template LogicalResult verifyReduceScan(Op &op) { + if (op.getOperands().empty()) { + return op.emitOpError() << "must have at least 1 operand"; + } + if ((op.getNumOperands() - 1) != op.getNumResults()) { // -1 for loopIndex + return op.emitOpError() << "must have the same number of inputs as outputs"; + } + + auto getElementType = [](Type ty) { + if (auto tensorType = dyn_cast(ty)) { + return tensorType.getElementType(); + } + return ty; + }; + + for (auto [opElemTy, resTy] : + llvm::zip(op.getElementTypes(), op.getResultTypes())) { + auto _opElemTy = getElementTypeOrSelf(opElemTy); + if (_opElemTy != getElementType(resTy)) { + return op.emitOpError() << "operand types and result types must agree"; + } + } + return success(); +} + +template +static LogicalResult verifyRegionsImpl(Op &op) { + auto argElementTypes = op.getElementTypes(); + const auto &operands = op.getOperands(); + const auto numArgs = 2 * (operands.size() - 1); // -1 for loopIndex + auto &block = *op.getBody(); + if (block.getNumArguments() != numArgs) { + return op.emitOpError() << "nested block must take " << numArgs + << " arguments, but given block with " + << block.getNumArguments() << " arguments"; + } + unsigned i = 0; + const auto &blockArgTypes = block.getArgumentTypes(); + for (unsigned i = 0; i < numArgs; ++i) { + const auto &blockArgTy = blockArgTypes[i]; + const auto &argElemTy = + argElementTypes[i % (operands.size() - 1)]; // -1 for loopIndex + if (blockArgTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << blockArgTy; + } + } + + auto terminator = dyn_cast(block.getTerminator()); + if (!terminator) { + return op.emitOpError() + << "combine operation must be terminated " + << "with a ReduceReturnOp but got " << block.getTerminator(); + } + const auto &combineResults = terminator->getOperands(); + if (combineResults.size() != (operands.size() - 1)) { // -1 for loopIndex + return op.emitOpError() + << "expected combine operation to return " << operands.size() + << " values but got " << combineResults.size(); + } + for (unsigned i = 0; i < combineResults.size(); ++i) { + const auto &resultTy = combineResults[i].getType(); + const auto &argElemTy = argElementTypes[i]; + if (resultTy != argElemTy) { + return op.emitOpError() + << "type mismatch on combine operation. Expected argument " << i + << " to have type " << argElemTy << " but got " << resultTy; + } + } + return success(); +} + +static llvm::SmallVector +getInputTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcTys; + srcTys.reserve(operands.size()); + for (const auto &[i, ty] : llvm::enumerate(operands.getTypes())) { + if (i == (operands.size() - 1)) + continue; // skip loopIndex + srcTys.push_back(cast(ty)); + } + return srcTys; +} + +static llvm::SmallVector +getElementTypesImpl(const Operation::operand_range &operands) { + llvm::SmallVector srcElemTys; + srcElemTys.reserve(operands.size()); + for (const auto &[i, op] : llvm::enumerate(operands)) { + if (i == (operands.size() - 1)) + continue; // skip loopIndex + srcElemTys.push_back(cast(op.getType()).getElementType()); + } + return srcElemTys; +} + +LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } + +LogicalResult ReduceOp::verifyRegions() { + return verifyRegionsImpl(*this); +} + +llvm::SmallVector ReduceOp::getInputTypes() { + return getInputTypesImpl(this->getOperands()); +} + +llvm::SmallVector ReduceOp::getElementTypes() { + return getElementTypesImpl(this->getOperands()); +} + +unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } + +LogicalResult ReduceOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Properties *prop = properties.as(); + int axis = prop->axis.getInt(); + for (auto [i, arg] : llvm::enumerate(operands)) { + if (i == (operands.size() - 1)) + continue; // skip loopIndex + auto argTy = cast(arg.getType()); + auto retEltTy = getElementTypeOrSelf(argTy.getElementType()); + if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) + .failed()) { + return failure(); + } + } + return success(); +} + +//-- BroadcastOp -- +template +LogicalResult canonicalizeViewOrBroadcast(OpType op, + PatternRewriter &rewriter) { + auto definingOp = op.getSrc().getDefiningOp(); + if (!definingOp) { + return failure(); + } + + // view(view) -> view + if (auto parentView = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, TypeRange({op.getType()}), + parentView->getOperands(), + parentView->getAttrs()); + return success(); + } + + // view(splat) -> splat + if (auto splat = dyn_cast(definingOp)) { + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); + return success(); + } + + return failure(); +} + +LogicalResult BroadcastOp::canonicalize(BroadcastOp op, + PatternRewriter &rewriter) { + return canonicalizeViewOrBroadcast(op, rewriter); +} + +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + if (getType() == getSrc().getType()) { + // no-op + return getSrc(); + } + + auto value = adaptor.getSrc(); + if (!value) + return {}; + + if (auto denseElemsAttr = dyn_cast(value)) { + auto shapedType = cast(getType()); + return denseElemsAttr.resizeSplat(shapedType); + } + return {}; +} + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Alloca.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Alloca.cpp new file mode 100644 index 000000000..459eb0f7e --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Alloca.cpp @@ -0,0 +1,114 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO: Pass Description +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUALLOCA +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPUAllocaPass + : public impl::TritonXPUAllocaBase { + +public: + using impl::TritonXPUAllocaBase::TritonXPUAllocaBase; + + void runOnOperation() override { + mlir::ModuleOp m = getOperation(); + + m.walk([&](triton::xpu::LoadOp loadOp) { + auto loc = loadOp.getLoc(); + OpBuilder builder(loadOp); + auto resType = loadOp.getResult().getType(); + auto gmPtrType = loadOp.getPtr().getType(); + auto lmPtrType = addrspaceCast(gmPtrType, 0); + auto size = + mlir::isa(gmPtrType) + ? product(mlir::cast(gmPtrType).getShape()) + : 1; + if (auto gm2lmOp = + dyn_cast(loadOp->getPrevNode())) { + auto allocaOp = + builder.create(loc, lmPtrType, size); + + auto operandSegmentSizesAttr = + gm2lmOp->getAttrOfType("operandSegmentSizes"); + SmallVector operandSegmentSizes( + operandSegmentSizesAttr.asArrayRef()); + ++operandSegmentSizes[2]; // 0: ptr, 1: len, 2: bufPtr + gm2lmOp->setAttr("operandSegmentSizes", + builder.getDenseI32ArrayAttr(operandSegmentSizes)); + + gm2lmOp->insertOperands(gm2lmOp->getNumOperands(), {allocaOp}); + + allocaOp->moveBefore(gm2lmOp); + } else { + llvm_unreachable("Only support GM2LM as previous node of load"); + } + }); + + m.walk([&](triton::xpu::StoreOp storeOp) { + auto loc = storeOp.getLoc(); + OpBuilder builder(storeOp); + auto resType = storeOp.getValue().getType(); + auto gmPtrType = storeOp.getPtr().getType(); + auto lmPtrType = addrspaceCast(gmPtrType, 0); + auto size = + mlir::isa(gmPtrType) + ? product(mlir::cast(gmPtrType).getShape()) + : 1; + if (auto lm2gmOp = + dyn_cast(storeOp->getNextNode())) { + auto allocaOp = + builder.create(loc, lmPtrType, size); + + auto operandSegmentSizesAttr = + lm2gmOp->getAttrOfType("operandSegmentSizes"); + SmallVector operandSegmentSizes( + operandSegmentSizesAttr.asArrayRef()); + ++operandSegmentSizes[3]; // 0: ptr, 1: value, 2: len, 3: bufPtr + lm2gmOp->setAttr("operandSegmentSizes", + builder.getDenseI32ArrayAttr(operandSegmentSizes)); + lm2gmOp->insertOperands(lm2gmOp->getNumOperands(), {allocaOp}); + // remove value from lm2gm + --operandSegmentSizes[1]; + lm2gmOp->setAttr("operandSegmentSizes", + builder.getDenseI32ArrayAttr(operandSegmentSizes)); + lm2gmOp->eraseOperands(1); + + allocaOp->moveBefore(storeOp); + storeOp->setOperand(0, allocaOp); + } else { + llvm_unreachable("Only support LM2GM as next node of store"); + } + }); + + // Move Alloca in the Front of FuncOp Body + m.walk([&](triton::xpu::AllocaOp allocaOp) { + // 1.Find FuncOp + Operation *ancestorOp = allocaOp; + while (!isa(ancestorOp)) { + Block *block = ancestorOp->getBlock(); + ancestorOp = block->getParentOp(); + } + // 2. Move alloca in the Front of the First Op in the FuncOp Body + Operation *firstOp = + &(*(cast(ancestorOp).getBody().front().begin())); + allocaOp->moveBefore(firstOp); + }); + } +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/CMakeLists.txt b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000..7dad7ab0a --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(TritonXPUTransforms + Alloca.cpp + CoreTiling.cpp + CreateGM2LM.cpp + DtypeConvert.cpp + Legalize.cpp + LoopGrid.cpp + Mask.cpp + OffsetAnalysis.cpp + Vectorize.cpp + MemoryAsync.cpp + UnrollControl.cpp + Interleave.cpp + StoreControl.cpp + OtherSim.cpp + + DEPENDS + TritonXPUTransformsIncGen + + LINK_LIBS PUBLIC + TritonAnalysis +) diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/CoreTiling.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/CoreTiling.cpp new file mode 100644 index 000000000..cc1a7b502 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/CoreTiling.cpp @@ -0,0 +1,706 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO: Pass Description +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +#define DEBUG_TYPE "tritonxpu-core-tiling" + +namespace mlir { + +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUCORETILING +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPUCoreTilingPass + : public impl::TritonXPUCoreTilingBase { + + using impl::TritonXPUCoreTilingBase< + TritonXPUCoreTilingPass>::TritonXPUCoreTilingBase; + + TritonXPUCoreTilingPass() = default; + TritonXPUCoreTilingPass(bool dumpFlag, unsigned bufferSize) { + this->dumpFlag = dumpFlag; + this->bufferSize = bufferSize; + } + + inline bool isAxisNone(triton::ReduceOp &reduceOp) { + ReduceOpHelper helper(reduceOp); + auto reduceOpTensorShape = helper.getSrcShape(); + for (auto src : reduceOp.getSrcs()) { + if (auto defOp = src.getDefiningOp()) { + if (auto reshapeOp = dyn_cast(defOp)) { + if (auto reshapeResTy = + dyn_cast(reshapeOp.getResult().getType())) { + if (reshapeResTy.getShape().size() == 1) { + assert(reduceOp.getAxis() == 0); + return true; + } + } + } + } + } + return false; + } + + inline bool ifInInnerChains(SmallVector> innerChains, + Operation *iOp) { + for (auto innerChain : innerChains) { + return std::find(innerChain.begin(), innerChain.end(), iOp) != + innerChain.end(); + } + return false; + } + + Attribute + getOptimizedGEncoding(MLIRContext *context, RankedTensorType type, + SmallVector> innerChains, + Operation *op, unsigned ngroup, unsigned groupsize) { + Attribute newEncoding; + auto shape = type.getShape(); + unsigned rank = shape.size(); + if (auto globalEncoding = + dyn_cast(type.getEncoding())) { + std::vector newSizePerCore; + std::vector newCoresPerGroup; + std::vector newGroupsPerCluster; + std::vector order; + unsigned isReduceOpt = 1; + + if (rank == 1) { + order = {0}; + if (ifInInnerChains(innerChains, op)) { + newSizePerCore = {ceil(shape[0], groupsize)}; + newCoresPerGroup = {groupsize}; + newGroupsPerCluster = {1}; + } else { + newSizePerCore = {ceil(shape[0], ngroup)}; + newCoresPerGroup = {1}; + newGroupsPerCluster = {ngroup}; + } + } else if (rank == 2) { + newCoresPerGroup = {1, groupsize}; + newGroupsPerCluster = {ngroup, 1}; + order = {0, 1}; + if (rowsPerCore > 1 && shape[0] != 1) { + newSizePerCore = {rowsPerCore, ceil(shape[1], groupsize)}; + } else { + newSizePerCore = {1, ceil(shape[1], groupsize)}; + } + if (rowsPerCore > 1) { + if (auto broadcastOp = dyn_cast(op)) { + // BroadcastOp specifies the element size directly using the + // sizePerThread value. + order = {1, 0}; + } + } + } else { + llvm_unreachable("Reduce Optimization With Rank > 2 Unsupported"); + } + newEncoding = triton::xpu::ClusterLayoutAttr::get( + context, newSizePerCore, newCoresPerGroup, newGroupsPerCluster, order, + isReduceOpt); + } + + return newEncoding; + } + + bool getTensorColSize(ModuleOp &mod) { + mod.walk([&](arith::CmpIOp cmpiOp) { + auto lhs = cmpiOp.getLhs(); + auto rhs = cmpiOp.getRhs(); + + if (auto lhsTensorTy = dyn_cast(lhs.getType())) { + auto lhsShape = lhsTensorTy.getShape(); + if (cmpiOp.getPredicate() == arith::CmpIPredicate::slt && + lhsShape.size() == 2 && lhsShape[0] == 1) { // inner Cmp Calculation + if (auto rhsOp = rhs.getDefiningOp()) { + if (auto denseAttr = + mlir::dyn_cast(rhsOp.getValue())) { + auto values = denseAttr.getValues(); + if (!values.empty()) { + rawColSize = values[0].getZExtValue(); + } + } + } + } + } + }); + + return rawColSize ? true : false; + } + + int roundupPow2(int n) { + int ret = 1; + while (n > ret) { + ret *= 2; + } + return ret; + }; + + // Get ReduceOps' Shape To Check If Can Be Optimized + bool canBeOptimized(ModuleOp &mod) { + bool canBeOpt = false; + int colSize = 0; + + auto checkGroupInfo = [&](RankedTensorType &tensorType) { + if (auto globalEncoding = dyn_cast( + tensorType.getEncoding())) { + auto shape = tensorType.getShape(); + size_t _ngroup = product(globalEncoding.getGroupsPerCluster()); + size_t _groupsize = product(globalEncoding.getCoresPerGroup()); + size_t _sizePerCore = product(globalEncoding.getSizePerCore()); + size_t ncore = _ngroup * _groupsize; + size_t m = shape.front(); + size_t n = shape.back(); + size_t newgroupsize = ceil(n, static_cast(bufferSize)); + newgroupsize = roundupPow2(newgroupsize); + // min is for not using the whole 64 cores case + size_t newngroup = std::min(ceil(ncore, newgroupsize), m); + newgroupsize = std::min(ceil(ncore, newngroup), n); + + if (newngroup == 1 && newgroupsize == 64) + canBeOpt = false; + } + }; + + mod.walk([&](triton::ReduceOp reduceOp) { + ReduceOpHelper helper(reduceOp); + auto reduceOpTensorShape = helper.getSrcShape(); + + if (reduceOpTensorShape.size() == 2) { + assert(reduceOp.getAxis() == 1); + colSize = std::max(colSize, static_cast(reduceOpTensorShape[1])); + canBeOpt = true; + + // rowsPerCore Upper = [128 / 16, 128 / 32, 128 / 64] + unsigned rowsPerCoreUpper = bufferSize / reduceOpTensorShape[1]; + unsigned rowsPerCoreLower = 1; + + unsigned rowsPerCoreCal; + for (rowsPerCoreCal = rowsPerCoreUpper; + rowsPerCoreCal > rowsPerCoreLower; rowsPerCoreCal /= 2) { + if (reduceOpTensorShape[0] % (rowsPerCoreCal * core_num) == 0) + break; + } + + rowsPerCore = std::min(rowsPerCoreUpper, rowsPerCoreCal); + rowsPerCore = std::max(rowsPerCore, rowsPerCoreLower); + if (!getTensorColSize(mod) || colSize < rawColSize) + rowsPerCore = 1; + + auto tensorType = cast(reduceOp.getOperandTypes()[0]); + checkGroupInfo(tensorType); + } else if (isAxisNone(reduceOp)) { + canBeOpt = true; + } else if (canBeOpt) { + llvm_unreachable("Not All Reduce Op can be Optimized"); + } + }); + + if (!getTensorColSize(mod) || colSize < rawColSize) + rowsPerCore = 1; + + return canBeOpt; + } + + mlir::Operation *findRootOp(mlir::Operation *op) { + mlir::Operation *rootOp = op; + while (rootOp->getParentOp()) { + rootOp = rootOp->getParentOp(); + if (rootOp->getParentOp() && isa(rootOp->getParentOp())) { + return rootOp; + } + } + return op; + } + + void + getChains(const llvm::SmallVector> &allOpTrees, + llvm::SmallVector> &innerChains, + llvm::SmallVector> &outerChains) { + for (auto allOpTree : allOpTrees) { + llvm::SetVector innerChain; + llvm::SetVector outerChain; + for (auto op : allOpTree) { + if (auto rangeOp = dyn_cast(op)) { + for (auto user : rangeOp.getResult().getUsers()) { + if (auto userOp = findUserOp(user)) { + auto expandDimOp = cast(userOp); + if (expandDimOp.getAxis() == 1) { + outerChain.insert(rangeOp); + } + } + } + } + if (auto expandDimOp = dyn_cast(op)) { + auto src = expandDimOp.getSrc(); + auto result = expandDimOp.getResult(); + if (auto srcTy = mlir::dyn_cast(src.getType())) { + if (auto resTy = + mlir::dyn_cast(result.getType())) { + if (expandDimOp.getAxis() == 0) { + getOpChainBwd(innerChain, expandDimOp); + innerChain.remove(expandDimOp); + } + } + } + } + if (auto broadcastOp = dyn_cast(op)) { + auto src = broadcastOp.getSrc(); + auto result = broadcastOp.getResult(); + if (auto srcTy = mlir::dyn_cast(src.getType())) { + if (auto resTy = + mlir::dyn_cast(result.getType())) { + auto srcShape = srcTy.getShape(); + auto resShape = resTy.getShape(); + if (srcShape[0] != resShape[0]) { // unequal dim 0 shape means + // in the inner axis op chain + getOpChainBwd(innerChain, broadcastOp); + innerChain.remove(broadcastOp); + } + } + } + } + if (auto reduceOp = dyn_cast(op)) { + if (reduceOp.getAxis() == 0) { + getOpChainFwd(innerChain, reduceOp); + } + } + } + outerChains.emplace_back(outerChain); + innerChains.emplace_back(innerChain); + } + } + + // The common mrOp will be shared while row_size = col_size. + // In this case, we need to create a new mrOp for innerChain. + // The two mrOp will be modified with different [inner/outer] encodings. + void recoverMakeRange(SmallVector> &innerChains, + SmallVector> &outerChains) { + for (int i = 0; i < innerChains.size(); ++i) { + llvm::SetVector innerChain = innerChains[i]; + llvm::SetVector outerChain = outerChains[i]; + + for (auto it = outerChain.begin(); it != outerChain.end(); ++it) { + Operation *outerOp = *it; + if (inOpChain(innerChain, outerOp)) { // Common MROp + if (auto rangeOp = dyn_cast(outerOp)) { + // Find MROp's Whose User is ExpandDimsOp(dim=0) + SmallVector recoverUses; + for (auto &use : rangeOp->getUses()) { + if (auto op = findUserOp(use.getOwner())) { + auto expandDimsOp = cast(op); + if (expandDimsOp.getAxis() == 0) { + recoverUses.push_back(&use); + } + } + } + + if (!recoverUses.empty()) { + // Recover MakeRangeOp + OpBuilder builder(rangeOp); + auto loc = builder.getUnknownLoc(); + auto newMakeRangeOp = builder.create( + loc, rangeOp.getType(), rangeOp.getStart(), rangeOp.getEnd()); + // Link To InnerChain + for (auto use : recoverUses) { + use->assign(newMakeRangeOp); + } + // Now the old common mrOp is only used by outerChain + innerChains[i].insert(newMakeRangeOp); + innerChains[i].remove(rangeOp); + } + } + } + } + } + } + + // Modify All Op Encoding + void + modifyOpEncoding(ModuleOp &mod, MLIRContext *context, + const SmallVector> &innerChains) { + size_t ngroup = 1; + size_t groupsize = 64; + bool isFirst = true; + + auto getGroupInfo = [&](RankedTensorType &tensorType) { + if (auto globalEncoding = dyn_cast( + tensorType.getEncoding())) { + auto shape = tensorType.getShape(); + size_t _ngroup = product(globalEncoding.getGroupsPerCluster()); + size_t _groupsize = product(globalEncoding.getCoresPerGroup()); + size_t _sizePerCore = product(globalEncoding.getSizePerCore()); + size_t ncore = _ngroup * _groupsize; + size_t m = shape.front(); + size_t n = shape.back(); + size_t newgroupsize = ceil(n, static_cast(bufferSize)); + newgroupsize = roundupPow2(newgroupsize); + // min is for not using the whole 64 cores case + size_t newngroup = std::min(ceil(ncore, newgroupsize), m); + newgroupsize = std::min(ceil(ncore, newngroup), n); + if (isFirst) { + ngroup = newngroup; + groupsize = newgroupsize; + } else { + assert(ngroup == newngroup && "reduce ngroup is not consistent"); + assert(groupsize == newgroupsize && + "reduce groupsize is not consistent"); + } + isFirst = false; + } + }; + + // Step 0. Get Group Info + mod.walk([&](mlir::Operation *op) { + if (auto reduceOp = dyn_cast(op)) { + if (auto tensorType = + dyn_cast(reduceOp.getOperandTypes()[0])) { + if (tensorType.getShape().size() == 2) { + getGroupInfo(tensorType); + } else if (isAxisNone(reduceOp)) { + auto defOp = reduceOp.getSrcs()[0].getDefiningOp(); + if (auto reshapeOp = dyn_cast(defOp)) { + if (auto reshapeResTy = dyn_cast( + reshapeOp.getResult().getType())) { + if (reshapeResTy.getShape().size() == 1) { + auto reshapeSrcTy = + cast(reshapeOp.getOperand().getType()); + getGroupInfo(reshapeSrcTy); + } + } + } + } + } + } + }); + LLVM_DEBUG(llvm::dbgs() << "[Reduction SoftGroup]: " + << "GroupNum = " << ngroup + << ", GroupSize = " << groupsize << "\n"); + // Step 1. Modify All Op Encoding + mod.walk([&](mlir::Operation *op) { + auto opResults = op->getResults(); + for (auto opResult : opResults) { + if (auto resTy = dyn_cast(opResult.getType())) { + auto shape = resTy.getShape(); + auto elemTy = resTy.getElementType(); + auto encoding = resTy.getEncoding(); + Attribute newEncoding; // newEncoding + + if (auto globalEncoding = + dyn_cast(encoding)) { + newEncoding = getOptimizedGEncoding(context, resTy, innerChains, op, + ngroup, groupsize); + } else if (auto sliceEncoding = + dyn_cast(encoding)) { + // must be globalEncoding + if (auto parentEncoding = dyn_cast( + sliceEncoding.getParent())) { + auto newParentEncoding = getOptimizedGEncoding( + context, resTy, innerChains, op, ngroup, groupsize); + newEncoding = triton::gpu::SliceEncodingAttr::get( + context, sliceEncoding.getDim(), newParentEncoding); + } else { + llvm_unreachable("Unsupported SliceEncoding's Parent Attribute"); + } + } else { + llvm_unreachable("Unsupported Encoding Attribute"); + } + + auto newResTy = RankedTensorType::get(shape, elemTy, newEncoding); + opResult.setType(newResTy); + } + } + }); + + // Step 2. Special Modification For [constOp, expandDimsOp, reduceOp, forOp] + // Step 2.1. ConstOp: value's encoding is not modified before this walk + mod.walk([&](arith::ConstantOp constOp) { + auto newValue = constOp.getValue(); + if (auto attr = dyn_cast(constOp.getValue())) { + newValue = DenseElementsAttr::getFromRawBuffer( + mlir::cast(constOp.getType()), attr.getRawData()); + } + OpBuilder builder(constOp); + auto loc = constOp.getLoc(); + auto newConstOp = builder.create( + loc, constOp.getType(), newValue); + + constOp.replaceAllUsesWith(newConstOp.getResult()); + constOp.erase(); + }); + + // Step 2.2. ExpandDimsOp: it expands the data dimension, so its prev + // cvtOp's correct encoding should be inferd by its operand. cvtOp is + // actually generated after expandDimsOp, so we need to modify the encoding + // of the previous cvtOp after determining the shape of expandDimsOp. + mod.walk([&](triton::ExpandDimsOp expandOp) { + auto expandOpType = cast(expandOp.getType()); + auto globalEncoding = + cast(expandOpType.getEncoding()); + + if (auto cvtOp = + expandOp.getSrc().getDefiningOp()) { + auto cvtOpType = cast(cvtOp.getType()); + auto sliceEncoding = + cast(cvtOpType.getEncoding()); + + auto newSliceEncoding = triton::gpu::SliceEncodingAttr::get( + context, sliceEncoding.getDim(), globalEncoding); + auto newResTy = RankedTensorType::get( + cvtOpType.getShape(), cvtOpType.getElementType(), newSliceEncoding); + + cvtOp->getResult(0).setType(newResTy); + } else { + llvm_unreachable("ExpandDimsOp With Error Operand"); + } + }); + + // Step 2.3. ForOp: we need to modify forOp's argTy, args can't be walked. + mod.walk([&](scf::ForOp forOp) { + auto forBody = forOp.getBody(); + // modify forOp's argTy + auto forArgs = forBody->getArguments(); + for (auto forArg : forArgs) { + if (auto argTy = dyn_cast(forArg.getType())) { + auto shape = argTy.getShape(); + auto elemTy = argTy.getElementType(); + auto argEncoding = + cast(argTy.getEncoding()); + + auto newArgEncoding = getOptimizedGEncoding( + context, argTy, innerChains, forOp, ngroup, groupsize); + auto newArgTy = RankedTensorType::get(shape, elemTy, newArgEncoding); + + forArg.setType(newArgTy); + } + } + + // modify forOp's resTy + auto forResults = forOp->getResults(); + for (auto forRes : forResults) { + if (auto argTy = dyn_cast(forRes.getType())) { + auto shape = argTy.getShape(); + auto elemTy = argTy.getElementType(); + auto argEncoding = + cast(argTy.getEncoding()); + + auto newArgEncoding = getOptimizedGEncoding( + context, argTy, innerChains, forOp, ngroup, groupsize); + auto newArgTy = RankedTensorType::get(shape, elemTy, newArgEncoding); + + forRes.setType(newArgTy); + } + } + }); + + // Step 2.4. ReduceOp: it reduces the data dimension, so its correct + // encoding should be inferd by its input type. + mod.walk([&](triton::ReduceOp redOp) { + assert(redOp->getNumResults() == redOp->getNumOperands()); + for (int i = 0; i < redOp->getNumResults(); ++i) { + if (auto resTy = + dyn_cast(redOp.getResult()[i].getType())) { + // auto resTy = + // cast(redOp.getResult()[i].getType()); + auto srcTy = cast(redOp.getOperandTypes()[i]); + + auto resSliceEncoding = + cast(resTy.getEncoding()); + auto srcGlobalEncoding = + cast(srcTy.getEncoding()); + + auto newEncoding = triton::gpu::SliceEncodingAttr::get( + context, resSliceEncoding.getDim(), srcGlobalEncoding); + auto newResTy = RankedTensorType::get( + resTy.getShape(), resTy.getElementType(), newEncoding); + + redOp->getResult(i).setType(newResTy); + } + } + }); + + // Step 2.5. ReshapeOp: it changes the data dimension, so its correct + // encoding should be inferd by its input type. + mod.walk([&](triton::ReshapeOp reshapeOp) { + if (auto reshapeResTy = + dyn_cast(reshapeOp.getResult().getType())) { + auto reshapeResShape = reshapeResTy.getShape(); + if (reshapeResShape.size() == 1) { + unsigned ncore = ngroup * groupsize; + std::vector newSizePerCore = { + ceil(reshapeResShape[0], ncore)}; + std::vector newCoresPerGroup = {ncore}; + std::vector newGroupsPerCluster = {1}; + std::vector order = {0}; + unsigned isReduceOpt = 1; + Attribute newReshapeResEncoding = triton::xpu::ClusterLayoutAttr::get( + context, newSizePerCore, newCoresPerGroup, newGroupsPerCluster, + order, isReduceOpt); + auto newReshapeResTy = RankedTensorType::get( + reshapeResShape, reshapeResTy.getElementType(), + newReshapeResEncoding); + reshapeOp.getResult().setType(newReshapeResTy); + } + } + }); + } + + // Add ConvertLayout For Braoadcast + void addCvtForBCOp(ModuleOp &mod, MLIRContext *context) { + mod.walk([&](triton::xpu::BroadcastOp bcOp) { + auto resTy = cast(bcOp.getResult().getType()); + auto resEncoding = + cast(resTy.getEncoding()); + auto finEncoding = triton::xpu::ClusterLayoutAttr::get( + context, resEncoding.getSizePerCore(), resEncoding.getCoresPerGroup(), + resEncoding.getGroupsPerCluster(), {0, 1}, 1); + auto finTy = RankedTensorType::get( + resTy.getShape(), getElementTypeOrSelf(resTy), finEncoding); + + OpBuilder builder(bcOp); + auto newBCOp = builder.create( + bcOp->getLoc(), resTy, bcOp.getSrc()); + auto cvt = builder.create(bcOp->getLoc(), + finTy, newBCOp); + + bcOp.replaceAllUsesWith(cvt.getResult()); + bcOp->erase(); + }); + } + + void addTensorColSizeForMemoryOp(ModuleOp &mod, MLIRContext *context) { + mod.walk([&](triton::xpu::GM2LMOp gm2lmOp) { + auto resTy = cast(gm2lmOp.getResult().getType()); + auto resShape = resTy.getShape(); + + if (resShape.size() == 2 && resShape[0] > core_num) { + OpBuilder builder(gm2lmOp); + gm2lmOp->setAttr("tensorColSize", + builder.getSI32IntegerAttr( + std::min((unsigned)resShape[1], rawColSize))); + } + }); + + mod.walk([&](triton::xpu::LM2GMOp lm2gmOp) { + auto resTy = cast(lm2gmOp.getValue().getType()); + auto resShape = resTy.getShape(); + + if (resShape.size() == 2 && resShape[0] > core_num) { + OpBuilder builder(lm2gmOp); + lm2gmOp->setAttr("tensorColSize", + builder.getSI32IntegerAttr( + std::min((unsigned)resShape[1], rawColSize))); + } + }); + + mod.walk([&](triton::xpu::LoadOp loadOp) { + auto resTy = cast(loadOp.getResult().getType()); + auto resShape = resTy.getShape(); + + if (resShape.size() == 2 && resShape[0] > core_num) { + OpBuilder builder(loadOp); + loadOp->setAttr("tensorColSize", + builder.getSI32IntegerAttr( + std::min((unsigned)resShape[1], rawColSize))); + } + }); + + mod.walk([&](triton::xpu::StoreOp storeOp) { + auto resTy = cast(storeOp.getValue().getType()); + auto resShape = resTy.getShape(); + + if (resShape.size() == 2 && resShape[0] > core_num) { + OpBuilder builder(storeOp); + storeOp->setAttr("tensorColSize", + builder.getSI32IntegerAttr( + std::min((unsigned)resShape[1], rawColSize))); + } + }); + } + + void runOnOperation() override { + mlir::MLIRContext *context = &getContext(); + mlir::ModuleOp mod = getOperation(); + + // Step 1. Check If Can Be Optimized + if (!canBeOptimized(mod)) + return; + + // Step 2. Collect allOpTrees && innerChains && outerChains + llvm::SmallVector> opTrees; + llvm::SetVector visitedOps; + mod.walk([&](triton::xpu::LM2GMOp currStoreOp) { + auto currStoreRootOp = findRootOp(currStoreOp); + auto currStoreRootBlock = currStoreRootOp->getBlock(); + llvm::SetVector opTree; + getOpTreeBwd(opTree, visitedOps, currStoreRootOp, currStoreRootBlock); + opTrees.emplace_back(opTree); + }); + + llvm::SmallVector> innerChains; + llvm::SmallVector> outerChains; + getChains(opTrees, innerChains, outerChains); + + // Step 3. Recover MakeRange If It's A Common Op + recoverMakeRange(innerChains, outerChains); + + // Step 4. Modify All Op Encoding With Optimization Rules + modifyOpEncoding(mod, context, innerChains); + + // Step 5. Add ConvertLayout For Braoadcast + // This step can be eliminated if we set sizePerBank with its shape + if (rowsPerCore > 1) { + addCvtForBCOp(mod, context); + } + + // Step 6. Add tensorColSize Attr for GM2LMOp + // This step can be eliminated if we set sizePerBank with its shape + if (rowsPerCore > 1) { + addTensorColSizeForMemoryOp(mod, context); + } + + if (dumpFlag) { + LLVM_DEBUG({ + llvm::dbgs() << "\n[InnerChain]:\n"; + for (auto innerChain : innerChains) { + for (auto op : innerChain) { + op->dump(); + } + llvm::dbgs() << "\n"; + } + + llvm::dbgs() << "\n[OuterChain]:\n"; + for (auto outerChain : outerChains) { + for (auto op : outerChain) { + op->dump(); + } + llvm::dbgs() << "\n"; + } + }); + } + if (rowsPerCore == 1) + LLVM_DEBUG(llvm::dbgs() << "Core Tiling M-ColSize Opt Hit!\n"); + else if (rowsPerCore > 1) + LLVM_DEBUG(llvm::dbgs() << "Core Tiling S-ColSize Opt Hit!\n"); + } + +private: + unsigned rowsPerCore = 1; + unsigned rawColSize = 0; + unsigned core_num = 64; +}; + +} // namespace xpu +} // namespace triton + +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/CreateGM2LM.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/CreateGM2LM.cpp new file mode 100644 index 000000000..27a24998b --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/CreateGM2LM.cpp @@ -0,0 +1,367 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO[dyq]: Pass Description +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUCREATEGM2LM +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +bool replaceAtomicOp(mlir::ModuleOp m) { + bool getAtomicRMWOp = false; + + m.walk([&](triton::AtomicRMWOp atomicRMWOp) { + getAtomicRMWOp = true; + OpBuilder builder(atomicRMWOp); + auto loc = atomicRMWOp.getLoc(); + Value ptr = atomicRMWOp.getPtr(); + Value val = atomicRMWOp.getVal(); + Value mask = atomicRMWOp.getMask(); + Value emptyBufPtr; + RMWOp atomic_rmw_op = atomicRMWOp.getAtomicRmwOp(); + auto dtype = val.getType(); + auto loadOp = builder.create( + loc, dtype, ptr, mask, Value(), Value(), 1, -1, false, false, false); + + Operation *arithOp; + switch (atomic_rmw_op) { + case RMWOp::AND: { + arithOp = builder.create(loc, loadOp.getResult(), val); + break; + } + case RMWOp::OR: { + arithOp = builder.create(loc, loadOp.getResult(), val); + break; + } + case RMWOp::XOR: { + arithOp = builder.create(loc, loadOp.getResult(), val); + break; + } + case RMWOp::ADD: { + arithOp = builder.create(loc, loadOp.getResult(), val); + break; + } + case RMWOp::FADD: { + arithOp = builder.create(loc, loadOp.getResult(), val); + break; + } + case RMWOp::MAX: { + arithOp = builder.create(loc, loadOp.getResult(), val); + break; + } + case RMWOp::MIN: { + arithOp = builder.create(loc, loadOp.getResult(), val); + break; + } + case RMWOp::UMAX: { + arithOp = builder.create(loc, loadOp.getResult(), val); + break; + } + case RMWOp::UMIN: { + arithOp = builder.create(loc, loadOp.getResult(), val); + break; + } + case RMWOp::XCHG: { + assert(0 && "The RMWOp::XCHG is not supported in RMWOp"); + break; + } + default: { + assert(0 && "The atomic_rmw_op only could be 1-10 in RMWOp"); + } + } + + auto storeOp = builder.create( + loc, ptr, arithOp->getResults()[0], mask, Value(), -1, false); + atomicRMWOp.erase(); + }); + + return getAtomicRMWOp; +} + +Attribute getOneCoreGEncoding(Operation *op, ArrayRef shape) { + Attribute newEncoding; + unsigned rank = shape.size(); + llvm::SmallVector sizePerBank; + llvm::SmallVector coresPerGroup; + llvm::SmallVector groupsPerCluster; + llvm::SmallVector order; + bool isReduceOpt = false; + + if (rank == 1) { + sizePerBank = {static_cast(shape[0])}; + coresPerGroup = {1}; + groupsPerCluster = {1}; + order = {0}; + } else if (rank == 2) { + sizePerBank = {1, static_cast(shape[1])}; + coresPerGroup = {1, 1}; + groupsPerCluster = {1, 1}; + order = {0, 1}; + } else { + llvm_unreachable("AtomicOp Simulation With Rank > 2 Unsupported"); + } + + newEncoding = triton::xpu::ClusterLayoutAttr::get( + op->getContext(), sizePerBank, coresPerGroup, groupsPerCluster, order, + isReduceOpt); + return newEncoding; +} + +bool atomicSimulation(mlir::ModuleOp m) { + + // Step 1. Replace AtomicRMWOp with GM2LMOp + Arith.xxx + LM2GMOp + if (replaceAtomicOp(m)) { + // Step 2. Modify All Op Encoding + m.walk([&](mlir::Operation *op) { + auto opResult = op->getResults(); + if (opResult.size() == 1) { // SSA Assert + // Only TensorType Has Encoding + if (auto resTy = + mlir::dyn_cast(opResult[0].getType())) { + auto shape = resTy.getShape(); + auto elemTy = resTy.getElementType(); + auto encoding = resTy.getEncoding(); + Attribute newEncoding; // newEncoding + + auto globalEncoding = + mlir::dyn_cast(encoding); + auto sliceEncoding = + mlir::dyn_cast(encoding); + + if (globalEncoding) { + newEncoding = getOneCoreGEncoding(op, shape); + } else if (sliceEncoding) { + // must be globalEncoding + auto parentEncoding = + mlir::dyn_cast( + sliceEncoding.getParent()); + + if (parentEncoding) { + auto newParentEncoding = getOneCoreGEncoding(op, shape); + newEncoding = triton::gpu::SliceEncodingAttr::get( + op->getContext(), sliceEncoding.getDim(), newParentEncoding); + } else { + llvm_unreachable("Unsupported SliceEncoding's Parent Attribute"); + } + } else { + llvm_unreachable("Unsupported Encoding Attribute"); + } + + auto newResTy = RankedTensorType::get(shape, elemTy, newEncoding); + opResult[0].setType(newResTy); + } + } + }); + + // Step 3. Special Modification For [constOp] + // Step 3.1. ConstOp: value's encoding is not modified before this walk + m.walk([&](arith::ConstantOp constOp) { + auto newValue = constOp.getValue(); + if (auto attr = + mlir::dyn_cast(constOp.getValue())) { + newValue = DenseElementsAttr::getFromRawBuffer( + mlir::cast(constOp.getType()), attr.getRawData()); + } + OpBuilder builder(constOp); + auto loc = constOp.getLoc(); + auto newConstOp = builder.create( + loc, constOp.getType(), newValue); + constOp.replaceAllUsesWith(newConstOp.getResult()); + constOp.erase(); + }); + + // Step 3.2. ExpandDimsOp: it expands the data dimension, so its prev + // cvtOp's correct encoding should be inferd by its operand. cvtOp is + // actually generated after expandDimsOp, so we need to modify the + // encoding of the previous cvtOp after determining the shape of + // expandDimsOp. + m.walk([&](triton::ExpandDimsOp expandOp) { + auto expandOpType = mlir::cast(expandOp.getType()); + auto globalEncoding = mlir::cast( + expandOpType.getEncoding()); + + if (auto cvtOp = + expandOp.getSrc().getDefiningOp()) { + auto cvtOpType = mlir::cast(cvtOp.getType()); + auto sliceEncoding = + mlir::cast(cvtOpType.getEncoding()); + + auto newSliceEncoding = triton::gpu::SliceEncodingAttr::get( + expandOp->getContext(), sliceEncoding.getDim(), globalEncoding); + auto newResTy = RankedTensorType::get( + cvtOpType.getShape(), cvtOpType.getElementType(), newSliceEncoding); + + cvtOp->getResult(0).setType(newResTy); + } else { + llvm_unreachable("ExpandDimsOp With Error Operand"); + } + }); + + // Step 3.3. ForOp: we need to modify forOp's argTy, args can't be + m.walk([&](scf::ForOp forOp) { + auto forBody = forOp.getBody(); + // modify forOp's argTy + auto forArgs = forBody->getArguments(); + for (auto forArg : forArgs) { + if (auto argTy = mlir::dyn_cast(forArg.getType())) { + auto shape = argTy.getShape(); + auto elemTy = argTy.getElementType(); + auto argEncoding = + mlir::cast(argTy.getEncoding()); + + auto newArgEncoding = getOneCoreGEncoding(forOp, shape); + auto newArgTy = RankedTensorType::get(shape, elemTy, newArgEncoding); + + forArg.setType(newArgTy); + } + } + + // modify forOp's resTy + auto forResults = forOp->getResults(); + for (auto forRes : forResults) { + if (auto argTy = mlir::dyn_cast(forRes.getType())) { + auto shape = argTy.getShape(); + auto elemTy = argTy.getElementType(); + auto argEncoding = + mlir::cast(argTy.getEncoding()); + + auto newArgEncoding = getOneCoreGEncoding(forOp, shape); + auto newArgTy = RankedTensorType::get(shape, elemTy, newArgEncoding); + + forRes.setType(newArgTy); + } + } + }); + + // Step 3.4. ReduceOp: it reduces the data dimension, so its correct + // encoding should be inferd by its input type. + m.walk([&](triton::ReduceOp redOp) { + llvm_unreachable("TODO[dyq]: new reduceOp has multi operands and " + "results, we need to modify all Tys"); + auto resTy = mlir::cast(redOp.getType(0)); + auto srcTy = mlir::cast(redOp.getOperand(0).getType()); + + auto resSliceEncoding = + mlir::cast(resTy.getEncoding()); + auto srcGlobalEncoding = + mlir::cast(srcTy.getEncoding()); + + auto newEncoding = triton::gpu::SliceEncodingAttr::get( + redOp.getContext(), resSliceEncoding.getDim(), srcGlobalEncoding); + auto newResTy = RankedTensorType::get( + resTy.getShape(), resTy.getElementType(), newEncoding); + + redOp->getResult(0).setType(newResTy); + }); + + return true; + } + + return false; +} + +struct TritonXPUCreateGM2LMPass + : public impl::TritonXPUCreateGM2LMBase { + + using impl::TritonXPUCreateGM2LMBase< + TritonXPUCreateGM2LMPass>::TritonXPUCreateGM2LMBase; + + void runOnOperation() override { + mlir::ModuleOp m = getOperation(); + bool hasAtomicSim = false; + + // Replace AtomicRMWOp with GM2LMOp + Arith.xxx + LM2GMOp(Embedding + // Backward) + if (atomicSim) + hasAtomicSim = atomicSimulation(m); + + llvm::SmallSetVector opToErase; + Value emptyBufPtr; + + // FIXME: Sometimes (test_core.py::test_bin_op_constexpr) + // `triton::LoadOp` and `triton::StoreOp` can not be replaced + // with `triton::xpu::LoadOp` and `triton::xpu::StoreOp` in + // TritonToTritonXPUPass, So we workaround to replace it here. + m.walk([&](triton::LoadOp loadOp) { + auto loc = loadOp.getLoc(); + OpBuilder builder(loadOp); + auto newLoadOp = builder.create( + loc, loadOp.getType(), loadOp.getPtr(), loadOp.getMask(), + loadOp.getOther(), Value(), 1, -1, false, false, false); + loadOp.replaceAllUsesWith(newLoadOp.getResult()); + opToErase.insert(loadOp); + }); + + m.walk([&](triton::StoreOp storeOp) { + auto loc = storeOp.getLoc(); + OpBuilder builder(storeOp); + auto newLoadOp = builder.create( + loc, storeOp.getPtr(), storeOp.getValue(), storeOp.getMask(), Value(), + -1, false); + opToErase.insert(storeOp); + }); + + m.walk([&](triton::xpu::LoadOp loadOp) { + OpBuilder builder(loadOp); + auto loc = loadOp.getLoc(); + auto lmPtrType = addrspaceCast(loadOp.getPtr().getType(), 0); + if (xpuArch == 2) { + if (loadOp.getResult().hasOneUse()) { + if (auto extFOp = dyn_cast(*(loadOp->user_begin()))) { + auto gm2lmOp = builder.create( + loc, lmPtrType, loadOp.getPtr(), loadOp.getMask(), emptyBufPtr, + static_cast(OffsetState::Unknown), -1, -1, -1, -1, -1, + false, false, hasAtomicSim); + loadOp.setOperand(0, gm2lmOp.getResult()); + loadOp.getResult().setType(extFOp.getType()); + extFOp.getResult().replaceAllUsesWith(loadOp.getResult()); + opToErase.insert(extFOp); + return; + } + } + } + auto gm2lmOp = builder.create( + loc, lmPtrType, loadOp.getPtr(), loadOp.getMask(), emptyBufPtr, + static_cast(OffsetState::Unknown), -1, -1, -1, -1, -1, false, + false, hasAtomicSim); + loadOp.setOperand(0, gm2lmOp.getResult()); + }); + + m.walk([&](triton::xpu::StoreOp storeOp) { + OpBuilder builder(storeOp); + auto loc = storeOp.getLoc(); + auto storeVal = storeOp.getValue(); + if (xpuArch == 2) { + if (storeVal.getDefiningOp()) { + if (auto truncFOp = + dyn_cast(storeVal.getDefiningOp())) { + storeVal = truncFOp.getIn(); + } + } + } + storeOp->setOperand(1, storeVal); + auto lm2gmOp = builder.create( + loc, storeOp.getPtr(), storeVal, storeOp.getMask(), emptyBufPtr, + static_cast(OffsetState::Unknown), -1, -1, -1, hasAtomicSim); + lm2gmOp->moveAfter(storeOp); + }); + + for (auto op : opToErase) { + op->erase(); + } + } +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/DtypeConvert.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/DtypeConvert.cpp new file mode 100644 index 000000000..ce61b0850 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/DtypeConvert.cpp @@ -0,0 +1,116 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// TODO[dyq]: Pass Description +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace xpu { +#define GEN_PASS_DEF_TRITONXPUDTYPECONVERT +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPUDtypeConvert + : public impl::TritonXPUDtypeConvertBase { + + using impl::TritonXPUDtypeConvertBase< + TritonXPUDtypeConvert>::TritonXPUDtypeConvertBase; + + void runOnOperation() override { + mlir::ModuleOp m = getOperation(); + llvm::SetVector visitedOps; + llvm::SmallVector> allOpTrees; + + m.walk([&](triton::xpu::StoreOp currStoreOp) { + if (!visitedOps.contains(currStoreOp)) { + // Get the opTree on the storeOp val path + llvm::SetVector allOpTree; + getOpTreeBwd(allOpTree, visitedOps, + currStoreOp.getValue().getDefiningOp()); + allOpTrees.emplace_back(allOpTree); + } + }); + // fp16tofp32/bf16tofp32 + for (auto allOpTree : allOpTrees) { + for (auto op : allOpTree) { + auto builder = mlir::OpBuilder(op); + auto loc = op->getLoc(); + if (auto constOp = dyn_cast(op)) { + auto constResTy = getElementTypeOrSelf(constOp.getType()); + if ((xpuArch == 2 && mlir::isa(constResTy)) || + (xpuArch == 3 && mlir::isa(constResTy))) { + SmallVector constUsers; + for (auto user : constOp.getResult().getUsers()) { + constUsers.emplace_back(user); + } + Operation *extfOp; + if (mlir::isa(constOp.getType())) { + auto tensorType = RankedTensorType::get( + mlir::cast(constOp.getType()).getShape(), + builder.getF32Type(), + mlir::cast(constOp.getType()) + .getEncoding()); + extfOp = builder.create(loc, tensorType, constOp); + } else { + extfOp = builder.create(loc, builder.getF32Type(), + constOp); + } + extfOp->moveAfter(constOp); + for (auto op : constUsers) { + for (int i = 0; i < op->getOperands().size(); ++i) { + if (op->getOperands()[i] == constOp.getResult()) { + op->setOperand(i, extfOp->getResult(0)); + } + } + } + } + } else { + for (auto res : op->getResults()) { + auto resElemTy = getElementTypeOrSelf(res.getType()); + if ((xpuArch == 2 && mlir::isa(resElemTy)) || + (xpuArch == 3 && mlir::isa(resElemTy))) { + if (mlir::isa(res.getType())) { + auto tensorType = RankedTensorType::get( + mlir::cast(res.getType()).getShape(), + builder.getF32Type(), + mlir::cast(res.getType()).getEncoding()); + res.setType(tensorType); + } else { + res.setType(builder.getF32Type()); + } + } + } + } + } + } + + m.walk([&](arith::ExtFOp extfOp) { + auto inTy = extfOp.getIn().getType(); + auto resTy = extfOp.getType(); + if (getElementTypeOrSelf(inTy) == getElementTypeOrSelf(resTy)) { + extfOp.getOut().replaceAllUsesWith(extfOp.getIn()); + extfOp.erase(); + } + }); + + m.walk([&](arith::TruncFOp truncfOp) { + auto inTy = truncfOp.getIn().getType(); + auto resTy = truncfOp.getType(); + if (getElementTypeOrSelf(inTy) == getElementTypeOrSelf(resTy)) { + truncfOp.getOut().replaceAllUsesWith(truncfOp.getIn()); + truncfOp.erase(); + } + }); + } +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Interleave.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Interleave.cpp new file mode 100644 index 000000000..55d1ed658 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Interleave.cpp @@ -0,0 +1,196 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +#define DEBUG_TYPE "tritonxpu-interleave" + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUINTERLEAVE +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPUInterleave + : public impl::TritonXPUInterleaveBase { + +public: + using impl::TritonXPUInterleaveBase< + TritonXPUInterleave>::TritonXPUInterleaveBase; + + bool isSameSize(Value mulValue, triton::xpu::MakeRangeOp makeRangeOp) { + auto mulValDefOp = findDefOpBwd(mulValue); + if (mulValDefOp) { + auto constOp = cast(mulValDefOp); + auto type = constOp.getResult().getType(); + int64_t constValue = 0; + if (auto tensorType = dyn_cast(type)) { + auto denseAttr = dyn_cast(constOp.getValue()); + auto elementType = tensorType.getElementType(); + if (elementType.isInteger(32)) { + constValue = *denseAttr.getValues().begin(); + } else if (elementType.isInteger(64)) { + constValue = *denseAttr.getValues().begin(); + } else { + llvm_unreachable( + "[Offset Analysis] Unsupported Element Type in ConstOp"); + } + } else { + constValue = + cast(constOp.getValue()).getValue().getZExtValue(); + } + int64_t rangeSize = makeRangeOp.getRealSize(); + if (constValue != rangeSize) { + return false; + } + } else { + return false; + } + return true; + } + + Operation *findInterleavePatternOp(Operation *lhs, Operation *rhs) { + llvm::SetVector visitedDownwards, visitedUpwards; + + int mulCnt = 0; + std::function findDownwards = + [&](Operation *op) -> Operation * { + if (mulCnt > 1 || (isa(op) || isa(op))) { + return nullptr; + } + if (auto muliOp = dyn_cast(op)) { + mulCnt += 1; + auto makeRangeOp = cast(rhs); + if (!isSameSize(muliOp.getLhs(), makeRangeOp) && + !isSameSize(muliOp.getRhs(), makeRangeOp)) { + return nullptr; + } + } + if (!visitedDownwards.insert(op)) { + return nullptr; + } + if (isa(op)) { + return op; + } + for (auto user : op->getUsers()) { + if (Operation *foundOp = findDownwards(user)) { + return foundOp; + } + } + return nullptr; + }; + + std::function findUpwards = [&](Operation *op) -> bool { + if (isa(op) || isa(op) || + isa(op) || isa(op)) { + return false; + } + if (op == rhs) { + return true; + } + if (!visitedUpwards.insert(op)) { + return false; + } + for (auto operand : op->getOperands()) { + if (auto *defOp = operand.getDefiningOp()) { + if (findUpwards(defOp)) { + return true; + } + } + } + return false; + }; + + Operation *targetOp = findDownwards(lhs); + Operation *upStartOp = nullptr; + if (targetOp) { + for (auto operand : targetOp->getOperands()) { + if (operand.getDefiningOp() && + !visitedDownwards.count(operand.getDefiningOp())) { + upStartOp = operand.getDefiningOp(); + } + } + } + if (upStartOp && findUpwards(upStartOp)) { + return targetOp; + } + + return nullptr; + } + + void runOnOperation() override { + mlir::ModuleOp m = getOperation(); + llvm::DenseMap addiRangeMap; + + // 1. Get the map of AddIOp and MakeRangeOp for create InterleaveOp + m.walk([&](triton::xpu::LoadOp loadOp) { + auto res = loadOp.getResult(); + if (auto tensorTy = dyn_cast(res.getType())) { + if (tensorTy.getShape().size() == 1) { + auto getProgramIdOp = + findDefOpBwd(loadOp.getPtr()); + auto makeRangeOp = + findDefOpBwd(loadOp.getPtr()); + auto getNumProgramsOp = + findDefOpBwd(loadOp.getPtr()); + if (getProgramIdOp && makeRangeOp && !getNumProgramsOp) { + if (auto addIOp = + findInterleavePatternOp(getProgramIdOp, makeRangeOp)) { + addiRangeMap[addIOp] = makeRangeOp; + } + } + } + } + }); + + m.walk([&](triton::xpu::StoreOp storeOp) { + auto val = storeOp.getValue(); + if (auto tensorTy = dyn_cast(val.getType())) { + if (tensorTy.getShape().size() == 1) { + auto getProgramIdOp = + findDefOpBwd(storeOp.getPtr()); + auto makeRangeOp = + findDefOpBwd(storeOp.getPtr()); + auto valueMakeRangeOp = + findDefOpBwd(storeOp.getValue()); + auto getNumProgramsOp = + findDefOpBwd(storeOp.getPtr()); + if (getProgramIdOp && makeRangeOp && !valueMakeRangeOp && + !getNumProgramsOp) { + if (auto addIOp = + findInterleavePatternOp(getProgramIdOp, makeRangeOp)) + addiRangeMap[addIOp] = makeRangeOp; + } + } + } + }); + + // 2. Remove GetProgramIdOp * BLOCK_SIZE and replace MakeRangeOp with + // InterleaveOp for PointWise + for (const auto &pair : addiRangeMap) { + auto addIOp = cast(pair.first); + auto makeRangeOp = cast(pair.second); + OpBuilder builder(makeRangeOp); + auto loc = builder.getUnknownLoc(); + auto start = makeRangeOp.getStart(); + auto end = makeRangeOp.getEnd(); + auto idx = makeRangeOp.getLoopIndex(); + auto reduceOp = findUserOp(makeRangeOp); + // Interleaving is not suitable for ReduceOp + if (!reduceOp && idx) { + LLVM_DEBUG(llvm::dbgs() << "[Interleave] Hit Interleave\n"); + auto interleaveOp = builder.create( + loc, addIOp.getType(), start, end, idx, Value()); + addIOp.getResult().replaceAllUsesWith(interleaveOp); + addIOp.erase(); + } + } + } +}; +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Legalize.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Legalize.cpp new file mode 100644 index 000000000..a11b6265d --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Legalize.cpp @@ -0,0 +1,999 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO[dyq]: Pass Description +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinOps.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/IRMapping.h" + +#define DEBUG_TYPE "tritonxpu-legalize" + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPULEGALIZE +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPULegalizePass + : public impl::TritonXPULegalizeBase { + + using impl::TritonXPULegalizeBase< + TritonXPULegalizePass>::TritonXPULegalizeBase; + + TritonXPULegalizePass() = default; + TritonXPULegalizePass(unsigned bufferSize, unsigned coreNum) { + this->bufferSize = bufferSize; + this->coreNum = coreNum; + } + + mlir::Operation *findRootOp(mlir::Operation *op) { + mlir::Operation *rootOp = op; + while (rootOp->getParentOp()) { + rootOp = rootOp->getParentOp(); + if (rootOp->getParentOp() && isa(rootOp->getParentOp())) { + return rootOp; + } + } + return op; + } + + void getGroupInfo(llvm::SetVector(opTree), size_t &ngroup, + size_t &groupsize, size_t &rowspercore) { + bool isFirst = true; + + auto _getGroupInfo = [&](triton::xpu::ClusterLayoutAttr &encoding) { + int _ngroup = product(encoding.getGroupsPerCluster()); + int _groupsize = product(encoding.getCoresPerGroup()); + int _rowspercore = encoding.getSizePerCore()[0]; + if (isFirst) { + ngroup = _ngroup; + groupsize = _groupsize; + rowspercore = _rowspercore; + } else { + assert(ngroup == _ngroup && "reduction ngroup is not consistent"); + assert(groupsize == _groupsize && + "reduction groupsize is not consistent"); + assert(rowspercore == _rowspercore && + "reduction rowspercore is not consistent"); + } + isFirst = false; + }; + for (auto op : opTree) { + op->walk([&](triton::ReduceOp reduceOp) { + auto defOp = reduceOp.getSrcs()[0].getDefiningOp(); + if (auto reshapeOp = dyn_cast(defOp)) { + if (auto reshapeResTy = + dyn_cast(reshapeOp.getResult().getType())) { + if (reshapeResTy.getShape().size() == 1) { + auto reshapeSrcTy = + cast(reshapeOp.getOperand().getType()); + if (auto globalEncoding = + mlir::dyn_cast( + reshapeSrcTy.getEncoding())) { + _getGroupInfo(globalEncoding); + } + } + } + } else { + if (auto tensorType = mlir::dyn_cast( + reduceOp.getOperandTypes()[0])) { + if (auto globalEncoding = + mlir::dyn_cast( + tensorType.getEncoding())) { + _getGroupInfo(globalEncoding); + } + } + } + }); + } + return; + } + + size_t previousPowerOf2(size_t n) { + size_t exp = std::log2(n); + return std::pow(2, exp); + } + + size_t getSizePerCluster(Type &type, bool unrollOpt) { + if (auto tensorType = mlir::dyn_cast(type)) { + if (auto global = mlir::dyn_cast( + tensorType.getEncoding())) { + size_t sizePerCluster = 1; + auto tensorShape = tensorType.getShape(); + auto groupsPerCluster = global.getCoresPerGroup(); + auto coresPerGroup = global.getGroupsPerCluster(); + auto sizePerCore = global.getSizePerCore(); + + auto rank = tensorShape.size(); + assert(rank == groupsPerCluster.size()); + for (auto i = 0; i < rank; ++i) { + sizePerCluster *= + unrollOpt ? groupsPerCluster[i] * coresPerGroup[i] + : std::min(sizePerCore[i], (unsigned)tensorShape[i]) * + groupsPerCluster[i] * coresPerGroup[i]; + } + return sizePerCluster; + } else { + llvm_unreachable("Only Support ClusterEncodingAttr"); + } + } + return 1; + } + + llvm::SmallVector getSlicedShape(const std::vector &shape, + size_t spaceSize, + bool isReduceMultiGroup) { + size_t dimSize = 1u; + size_t rank = shape.size(); + llvm::SmallVector slicedShape(rank, 1u); + // think about col mask, slice multi-dim tensor to 1-dim tensor + assert(rank <= 2 && "only 1-dim or 2-dim tensor is supported"); + + if (!isReduceMultiGroup) // no opt + spaceSize = std::min(static_cast(spaceSize), shape[rank - 1]); + + for (int i = rank - 1; i >= 0; --i) { + dimSize *= shape[i]; + const double sliceNum = static_cast(dimSize) / spaceSize; + if (sliceNum > 1) { + slicedShape[i] = std::ceil(shape[i] / sliceNum); + break; + } + slicedShape[i] = shape[i]; + } + return slicedShape; + } + + // get greatest common divisor + size_t gcd(size_t a, size_t b) { + if (b == 0) + return a; + else + return gcd(b, a % b); + } + + size_t getLCM(llvm::SmallVector &datas) { + for (int i = 1; i < datas.size(); ++i) { + datas[i] = datas[i - 1] / gcd(datas[i - 1], datas[i]) * datas[i]; + } + return datas.back(); + } + + llvm::SmallVector getIterationCount(llvm::SmallVector &types, + bool isReduceMultiGroup, + bool unrollOpt) { + // bytesPerCluster = bytesPerCore * coresPerGroup * groupsPerCluster + // 8KB local memory per core, reserve 2KB for parameters + const size_t bytesPerCluster = 4 * 16 * (6 << 10); + size_t typesBytes = 0; + size_t tensorDim = 1; + for (size_t i = 0; i < types.size(); ++i) { + Type type = types[i]; + Type valueElemType = getElementTypeOrSelf(type); + size_t valueElemBytes = + std::max(valueElemType.getIntOrFloatBitWidth(), 8) / 8u; + typesBytes += valueElemBytes; + if (auto tensorType = mlir::dyn_cast(type)) { + tensorDim = std::max(tensorDim, tensorType.getShape().size()); + } + } + llvm::SmallVector> iterCounts( + tensorDim, llvm::SmallVector(types.size(), 1u)); + for (size_t i = 0; i < types.size(); ++i) { + if (auto tensorType = mlir::dyn_cast(types[i])) { + size_t spaceSize = + std::min(previousPowerOf2(bytesPerCluster / typesBytes), + getSizePerCluster(types[i], unrollOpt)); + auto tensorShape = tensorType.getShape(); + + if (tensorShape.size() > 2) { + llvm_unreachable("3D Shape Unsupported."); + } else if (tensorShape.size() == 2 && + tensorShape[1] > coreNum * bufferSize) { + LLVM_DEBUG(llvm::dbgs() << "2D Shape[-1] = " << tensorShape[1] + << "); coreNum * bufferSize = " + << coreNum * bufferSize << "\n"); + llvm_unreachable("2D Shape[-1] <= core_num * buffer_size Limit. " + "Please Adjust COL_BLOCK_SIZE"); + } + + llvm::SmallVector slicedShape = + getSlicedShape(tensorShape, spaceSize, isReduceMultiGroup); + for (size_t j = 0; j < tensorShape.size(); ++j) { + iterCounts[j][i] = + std::ceil(static_cast(tensorShape[j]) / slicedShape[j]); + } + } + } + + llvm::SmallVector lcmIterCount(tensorDim, 1u); + for (size_t j = 0; j < iterCounts.size(); ++j) { + lcmIterCount[j] = getLCM(iterCounts[j]); + } + return lcmIterCount; + } + + Type getSlicedType(const Type &type, llvm::SmallVector iterCount, + bool isInner, bool needNewEncoding = false) { + // Slice tensor according iteration count + if (auto tensorType = mlir::dyn_cast(type)) { + auto tensorShape = tensorType.getShape(); + llvm::SmallVector slicedShape(tensorShape.size(), 1u); + if (tensorShape.size() == iterCount.size()) { + for (int i = 0; i < tensorShape.size(); ++i) { + slicedShape[i] = std::max(tensorShape[i] / iterCount[i], size_t(1)); + } + } else if (tensorShape.size() < iterCount.size()) { + assert(tensorShape.size() == 1 && iterCount.size() == 2); + size_t count = isInner ? iterCount[1] : iterCount[0]; + slicedShape[0] = std::max(tensorShape[0] / count, size_t(1)); + } else { + llvm_unreachable( + "tensorShape.size() is not more than iterCount.size()"); + } + ArrayRef sliceTensorShape(slicedShape); + Attribute encoding; + if (needNewEncoding) { + int rank = sliceTensorShape.size(); + llvm::SmallVector order(rank); + std::iota(order.begin(), order.end(), 0); + encoding = triton::xpu::ClusterLayoutAttr::get( + &getContext(), sliceTensorShape, order, 128, 64); + } else { + encoding = tensorType.getEncoding(); + } + return RankedTensorType::get(sliceTensorShape, + tensorType.getElementType(), encoding); + } + return type; + } + + void + getChains(const llvm::SmallVector> &allOpTrees, + llvm::SmallVector> &innerChains, + llvm::SmallVector> &outerChains) { + for (auto allOpTree : allOpTrees) { + llvm::SetVector innerChain; + llvm::SetVector outerChain; + for (auto op : allOpTree) { + if (auto rangeOp = dyn_cast(op)) { + for (auto user : rangeOp.getResult().getUsers()) { + if (auto userOp = findUserOp(user)) { + auto expandDimOp = cast(userOp); + if (expandDimOp.getAxis() == 1) { + outerChain.insert(rangeOp); + } + } + } + } + if (auto expandDimOp = dyn_cast(op)) { + auto src = expandDimOp.getSrc(); + auto result = expandDimOp.getResult(); + if (auto srcTy = mlir::dyn_cast(src.getType())) { + if (auto resTy = + mlir::dyn_cast(result.getType())) { + if (expandDimOp.getAxis() == 0) { + getOpChainBwd(innerChain, expandDimOp); + innerChain.remove(expandDimOp); + } + } + } + } + if (auto broadcastOp = dyn_cast(op)) { + auto src = broadcastOp.getSrc(); + auto result = broadcastOp.getResult(); + if (auto srcTy = mlir::dyn_cast(src.getType())) { + if (auto resTy = + mlir::dyn_cast(result.getType())) { + auto srcShape = srcTy.getShape(); + auto resShape = resTy.getShape(); + if (srcShape[0] != resShape[0]) { // unequal dim 0 shape means + // in the inner axis op chain + getOpChainBwd(innerChain, broadcastOp); + innerChain.remove(broadcastOp); + } + } + } + } + if (auto reduceOp = dyn_cast(op)) { + if (reduceOp.getAxis() == 0) { + getOpChainFwd(innerChain, reduceOp); + } + } + } + outerChains.emplace_back(outerChain); + innerChains.emplace_back(innerChain); + } + } + + void runOnOperation() override { + mlir::MLIRContext *context = &getContext(); + mlir::ModuleOp m = getOperation(); + llvm::SmallVector> sortedOpTrees; + llvm::SetVector visitedOps; + llvm::SmallVector> allOpTrees; + llvm::SetVector visitedAllOps; + llvm::SmallVector> iterCounts; + llvm::SmallVector> allTensorTypes; + llvm::SetVector storeOps; + SmallVector reduceNGroups; + SmallVector reduceGroupSizes; + SmallVector reduceRowsPerCores; + + // Find SM2GM ptr op chain + llvm::SetVector sm2gmPtrLenOpChain; + m.walk([&](triton::xpu::SM2GMOp sm2gmOp) { + sm2gmPtrLenOpChain.insert(sm2gmOp); + getOpChainBwd(sm2gmPtrLenOpChain, sm2gmOp.getPtr().getDefiningOp()); + if (sm2gmOp.getLen()) { + getOpChainBwd(sm2gmPtrLenOpChain, sm2gmOp.getLen().getDefiningOp()); + } + }); + + llvm::SetVector endOps; + m.walk([&](triton::xpu::LM2GMOp lm2gmOp) { endOps.insert(lm2gmOp); }); + m.walk([&](triton::xpu::SM2GMOp sm2gmOp) { endOps.insert(sm2gmOp); }); + + for (auto currStoreOp : endOps) { + if (!visitedOps.contains(currStoreOp) && + !inSameSCFIfBlock(storeOps, currStoreOp)) { + storeOps.insert(currStoreOp); + + // Get the opTree on the storeOp path + auto currStoreRootOp = findRootOp(currStoreOp); + auto currStoreRootBlock = currStoreRootOp->getBlock(); + llvm::SetVector opTree; + getOpTreeBwd(opTree, visitedOps, currStoreRootOp, currStoreRootBlock); + llvm::SetVector sortedOpTree = sortOpTreeBwd(opTree); + sortedOpTrees.emplace_back(sortedOpTree); + + llvm::SetVector allOpTree; + getOpTreeBwd(allOpTree, visitedAllOps, currStoreOp); + allOpTrees.emplace_back(allOpTree); + + // Get all tensors types of loadOp or storeOp + llvm::SmallVector tensorTypes; + for (auto op : allOpTree) { + if (auto loadOp = dyn_cast(op)) { + auto loadResType = loadOp.getResult().getType(); + tensorTypes.emplace_back(loadResType); + } + if (auto storeOp = dyn_cast(op)) { + auto storeValType = storeOp.getValue().getType(); + tensorTypes.emplace_back(storeValType); + } + } + // Get the iteration count + allTensorTypes.emplace_back(tensorTypes); + } + } + + assert(allTensorTypes.size() == sortedOpTrees.size() && + "iteration count != the number of opTrees"); + + // 0. Get reduceId/reduceNum for shared memory init + unsigned reduceId = 0; + unsigned reduceNum = 0; + for (auto sortedOpTree : sortedOpTrees) { + size_t ngroup = 1; + size_t groupsize = 64; + size_t rowspercore = 1; + getGroupInfo(sortedOpTree, ngroup, groupsize, rowspercore); + reduceNGroups.emplace_back(ngroup); + reduceGroupSizes.emplace_back(groupsize); + reduceRowsPerCores.emplace_back(rowspercore); + for (auto op : sortedOpTree) { + if (auto reduceOp = dyn_cast(op)) { + reduceNum++; + } + } + } + + for (int i = 0; i < allTensorTypes.size(); ++i) { + auto tensorTypes = allTensorTypes[i]; + auto opTree = allOpTrees[i]; + // unrollOpt only for 1D tensor, fixed stride and + // OffsetState::Unknown(resnet max pool infer) + bool unrollOpt = false; + // TODO[dyq]: choose [renge]for1 control rather than unrollOpt + // for (auto op : opTree) { + // if (auto gm2lmOp = dyn_cast(op)) { + // auto tensorType = gm2lmOp.getPtr().getType(); + // auto rank = + // mlir::isa(tensorType) + // ? + // mlir::cast(tensorType).getShape().size() + // : 1; + // auto cond_1 = rank == 1 ? true : false; + // auto cond_2 = gm2lmOp.getFixedStride() == -1 ? true : false; + // auto cond_3 = static_cast( + // gm2lmOp.getOffsetState()) == + // OffsetState::Unknown ? true : false; + // if (cond_1 && cond_2 && cond_3) { + // unrollOpt = true; + // } else { + // unrollOpt = false; + // break; + // } + // } + // } + + bool atomicSim = false; + size_t simIterCount; + for (auto tensorTy : tensorTypes) { + if (auto rankTensorTy = mlir::dyn_cast(tensorTy)) { + auto gEncoding = mlir::cast( + rankTensorTy.getEncoding()); + auto coresPerGroup = gEncoding.getCoresPerGroup(); + auto groupsPerCluster = gEncoding.getGroupsPerCluster(); + + auto oneCoreAct = + (llvm::find_if(coresPerGroup, + [](unsigned int num) { return num != 1; }) == + coresPerGroup.end()) && + (llvm::find_if(groupsPerCluster, [](unsigned int num) { + return num != 1; + }) == groupsPerCluster.end()); + + if (oneCoreAct) { + atomicSim = true; + auto shape = rankTensorTy.getShape(); + simIterCount = product(shape); + } else { + atomicSim = false; + break; + } + } + } + bool isReduceMultiGroup = reduceNGroups[i] > 1 ? true : false; + llvm::SmallVector iterCount = + getIterationCount(tensorTypes, isReduceMultiGroup, unrollOpt); + + m.walk([&](triton::xpu::GM2LMOp gm2lmOp) { + if (findUserOp(gm2lmOp) || + findUserOp(gm2lmOp)) { + atomicSim = false; + } + }); + + if (atomicSim) { + int32_t lrie = 1; + m.walk([&](triton::xpu::GM2LMOp gm2lmOp) { lrie = gm2lmOp.getLrie(); }); + iterCount.assign(iterCount.size(), + mlir::ceil(simIterCount, lrie)); + } + + iterCounts.emplace_back(iterCount); + } + + // For reduce2d, inner axis tensor is not sliced + llvm::SmallVector> innerChains; + llvm::SmallVector> outerChains; + getChains(allOpTrees, innerChains, outerChains); + + // Duplicate the MakeRangeOp to avoid conflict when innerChains and + // outerChains all include it(bilibli_mul_reducesum dyanmic 2x14x3x256) + for (int i = 0; i < innerChains.size(); ++i) { + llvm::SetVector innerChain = innerChains[i]; + llvm::SetVector outerChain = outerChains[i]; + + for (auto it = outerChain.begin(); it != outerChain.end(); ++it) { + Operation *outerOp = *it; + if (inOpChain(innerChain, outerOp)) { // Common MROp + if (auto rangeOp = dyn_cast(outerOp)) { + // Find MROp's Whose User is ExpandDimsOp(dim=0) + for (auto user : rangeOp->getUsers()) { + if (auto op = findUserOp(user)) { + auto expandDimsOp = cast(op); + if (expandDimsOp.getAxis() == 0) { + // Recover MakeRangeOp + OpBuilder builder(rangeOp); + auto loc = builder.getUnknownLoc(); + auto newMakeRangeOp = builder.create( + loc, rangeOp.getType(), rangeOp.getStart(), + rangeOp.getEnd()); + + // Link To InnerChain + auto operands = user->getOperands(); + for (auto _it = operands.begin(); _it != operands.end(); + ++_it) { + auto operand = *_it; + if (operand == rangeOp) { + user->setOperand(std::distance(operands.begin(), _it), + newMakeRangeOp); + } + } + + // Now the old common mrOp is only used by outerChain + innerChains[i].insert(newMakeRangeOp); + innerChains[i].remove(rangeOp); + sortedOpTrees[i].insert(newMakeRangeOp); + sortedOpTrees[i] = sortOpTreeBwd(sortedOpTrees[i]); + } + } + } + } + } + } + } + + // for (auto [i, opTree] : llvm::enumerate(allOpTrees)) { + // LLVM_DEBUG(llvm::dbgs() << "\nDump OpTree-" << i << ":\n"); + // for (auto op : opTree) { + // op->dump(); + // } + + // LLVM_DEBUG(llvm::dbgs() << "\nDump outerChain-" << i << ":\n"); + // for (auto op : outerChains[i]) { + // op->dump(); + // } + + // LLVM_DEBUG(llvm::dbgs() << "\nDump innerChain-" << i << ":\n"); + // for (auto op : innerChains[i]) { + // op->dump(); + // } + // } + + auto getInnerChainInfo = [&](Operation *op) -> std::string { + for (size_t i = 0; i < innerChains.size(); ++i) { + if (innerChains[i].count(op)) { + return "InnerChain"; + } + } + return ""; + }; + + auto printCSV = [&](mlir::ModuleOp &mod) { + LLVM_DEBUG(llvm::dbgs() << "{\n"); + LLVM_DEBUG(llvm::dbgs() << "Operation,Chain Info\n"); + + // 遍历 mod 中的所有操作 + mod.walk([&](mlir::Operation *op) { + if (dyn_cast(op) || dyn_cast(op)) + return; + // 获取操作的字符串表示,记得处理逗号和换行符 + std::string opStr; + llvm::raw_string_ostream os(opStr); + op->print(os); + // 替换逗号和换行符 + std::replace(opStr.begin(), opStr.end(), ',', ';'); + std::replace(opStr.begin(), opStr.end(), '\n', ' '); + + // 获取 InnerChain 信息 + std::string chainInfo = getInnerChainInfo(op); + + // 输出一行 + LLVM_DEBUG(llvm::dbgs() << opStr << "," << chainInfo << "\n"); + }); + LLVM_DEBUG(llvm::dbgs() << "}\n"); + }; + + // printCSV(m); + + // 1. Create loop for GM2LM/LM2GM + for (size_t i = 0; i < iterCounts.size(); ++i) { + llvm::SetVector sortedOpTree = sortedOpTrees[i]; + llvm::SmallVector iterCount = iterCounts[i]; + size_t outIterCount = iterCount[0]; + size_t reduceNGroup = reduceNGroups[i]; + size_t reduceGroupSize = reduceGroupSizes[i]; + size_t reduceRowsPerCore = reduceRowsPerCores[i]; + bool isReduceMultiGroup = reduceNGroup > 1 ? true : false; + llvm::SetVector innerChain = innerChains[i]; + llvm::SetVector outerChain = outerChains[i]; + auto endOp = sortedOpTree[0]; + OpBuilder builder(endOp); + auto loc = builder.getUnknownLoc(); + + // Set loop args and create for loop. + auto low = + builder.create(loc, builder.getIndexAttr(0)); + auto upper = builder.create( + loc, builder.getIndexAttr(outIterCount)); + auto step = + builder.create(loc, builder.getIndexAttr(1)); + // Control elem_size per Loop To Avoid the Mem Overflow + auto forLoopOp = builder.create(loc, low, upper, step); + builder.setInsertionPointToStart(forLoopOp.getBody()); + + // Create loop body + Value idx = builder.create( + loc, builder.getI32Type(), forLoopOp.getInductionVar()); + + Operation *yieldOp = forLoopOp.getBody()->getTerminator(); + + // LLVM_DEBUG(llvm::dbgs() << "\nBefore Loop Move:\n" << m << " \n"); + + for (auto op : llvm::reverse(sortedOpTree)) { + // op->dump(); + + bool isInner = inOpChain(innerChain, op); + // LLVM_DEBUG(llvm::dbgs() << "\nisInner: " << isInner); + + if (!isa(op)) { + op->moveBefore(yieldOp); + } + + auto setSlicedResTy = [&](Operation *op, bool isInner = false, + bool needNewEncoding = false) { + for (auto [i, resTy] : llvm::enumerate(op->getResultTypes())) { + // LLVM_DEBUG(llvm::dbgs() << "\nOrigin Type: " << resTy); + auto slicedResTy = + getSlicedType(resTy, iterCount, isInner, needNewEncoding); + // LLVM_DEBUG(llvm::dbgs() << "\nSlicedResType Type: " << + // slicedResTy); + op->getResult(i).setType(slicedResTy); + } + }; + + if (auto makeRangeOp = dyn_cast(op)) { + auto type = makeRangeOp.getType(); + if (outerChain.count(makeRangeOp)) { + auto slicedResTy = getSlicedType(type, iterCount, false); + auto newOutRangeOp = builder.create( + loc, slicedResTy, reduceGroupSize, reduceRowsPerCore, idx); + op->replaceAllUsesWith(newOutRangeOp->getResults()); + } else { + Value index = isInner ? Value() : idx; + uint32_t start = makeRangeOp.getStart(); + uint32_t end = makeRangeOp.getEnd(); + uint32_t realSize = end - start; + auto newMakeRangeOp = builder.create( + loc, type, builder.getI32IntegerAttr(start), + builder.getI32IntegerAttr(end), + builder.getI32IntegerAttr(realSize), index, Value()); + setSlicedResTy(newMakeRangeOp, isInner); + uint32_t newEnd = + start + product(newMakeRangeOp.getType().getShape()); + newMakeRangeOp.setEnd(newEnd); + op->replaceAllUsesWith(newMakeRangeOp->getResults()); + } + } else if (auto reduceOp = dyn_cast(op)) { + // LLVM_DEBUG(llvm::dbgs() << "\nbefore modify reduceOp\n" << m); + auto newReduceIdxOp = builder.create( + loc, reduceOp->getResultTypes(), reduceOp.getSrcs(), + reduceOp.getAxis(), idx); + auto &newCombineOp = newReduceIdxOp.getCombineOp(); + builder.cloneRegionBefore(reduceOp.getCombineOp(), newCombineOp, + newCombineOp.end()); + setSlicedResTy(newReduceIdxOp, isInner); + op->replaceAllUsesWith(newReduceIdxOp->getResults()); + // LLVM_DEBUG(llvm::dbgs() << "\nAfter modify reduceOp\n" << m); + for (auto &opInCombine : newCombineOp.getOps()) { + if (auto redReturnOp = + dyn_cast(opInCombine)) { + auto oldInsertionPoint = builder.saveInsertionPoint(); + builder.setInsertionPoint(redReturnOp); + auto newRedReturnOp = builder.create( + loc, redReturnOp.getOperands()); + builder.restoreInsertionPoint(oldInsertionPoint); + redReturnOp->replaceAllUsesWith(newRedReturnOp->getResults()); + opInCombine.erase(); // avoid the HasParent Trait + break; // stop the loop in combine region + } + } + op->erase(); + } else if (auto constOp = dyn_cast(op)) { + if (auto attr = + mlir::dyn_cast(constOp.getValue())) { + auto slicedResTy = + getSlicedType(constOp.getType(), iterCount, isInner); + ShapedType slicedShapedType = mlir::cast(slicedResTy); + auto newValue = DenseElementsAttr::getFromRawBuffer( + slicedShapedType, attr.getRawData()); + auto newConstOp = builder.create( + loc, slicedResTy, newValue); + op->replaceAllUsesWith(newConstOp->getResults()); + } + } else if (auto forOp = dyn_cast(op)) { + + if (iterCount.back() != 1) { + if (auto stepOp = dyn_cast( + forOp.getStep().getDefiningOp())) { + setSlicedResTy(stepOp, isInner); + } + } + + // Set forOp Result Type + setSlicedResTy(forOp, isInner); + + // Set forOp Arg Type + auto forBody = forOp.getBody(); + auto forArgs = forBody->getArguments(); + for (auto forArg : forArgs) { + bool isInnerArg = inOpChain(innerChain, forArg.getDefiningOp()); + auto slicedArgType = + getSlicedType(forArg.getType(), iterCount, isInnerArg); + forArg.setType(slicedArgType); + } + + // Set forOp's childOp Result Type + auto &forRegion = forOp.getRegion(); + auto &forBlock = forRegion.front(); + SetVector erasedOps; + for (auto &inBlockOp : forBlock) { + bool inBlockIsInner = inOpChain(innerChain, &inBlockOp); + if (auto reduceOpInFor = + mlir::dyn_cast(inBlockOp)) { + OpBuilder builderInFor(reduceOpInFor); + auto newReduceIdxOp = builderInFor.create( + reduceOpInFor->getLoc(), reduceOpInFor->getResultTypes(), + reduceOpInFor.getSrcs(), reduceOpInFor.getAxis(), idx); + auto &newCombineOp = newReduceIdxOp.getCombineOp(); + builderInFor.cloneRegionBefore(reduceOpInFor.getCombineOp(), + newCombineOp, newCombineOp.end()); + setSlicedResTy(newReduceIdxOp, inBlockIsInner); + reduceOpInFor.replaceAllUsesWith(newReduceIdxOp->getResults()); + erasedOps.insert(reduceOpInFor); + for (auto &opInCombine : newCombineOp.getOps()) { + if (auto redReturnOp = + dyn_cast(opInCombine)) { + auto oldInsertionPoint = builderInFor.saveInsertionPoint(); + builderInFor.setInsertionPoint(redReturnOp); + auto newRedReturnOp = + builderInFor.create( + redReturnOp.getLoc(), redReturnOp.getOperands()); + builderInFor.restoreInsertionPoint(oldInsertionPoint); + redReturnOp->replaceAllUsesWith(newRedReturnOp->getResults()); + erasedOps.insert(redReturnOp); + } + } + } else if (auto ifOp = mlir::dyn_cast(inBlockOp)) { + // Set IfOp's childOp Arg Type(Then) + auto &thenRegion = ifOp.getThenRegion(); + auto &thenBlock = thenRegion.front(); + for (auto &inBlockOp : thenBlock) { + bool isInnerOp = inOpChain(innerChain, &inBlockOp); + + auto newInBlockOp = thenBlock.begin(); + auto dist = std::distance(thenBlock.begin(), + Block::iterator(inBlockOp)); + std::advance(newInBlockOp, dist); + + for (auto newInBlockOpRes : newInBlockOp->getResults()) { + auto slicedOpType = getSlicedType(newInBlockOpRes.getType(), + iterCount, isInnerOp); + newInBlockOpRes.setType(slicedOpType); + } + } + + // Set IfOp's childOp Arg Type(Else) + auto &elseRegion = ifOp.getElseRegion(); + if (!elseRegion.empty()) { + auto &elseBlock = elseRegion.front(); + for (auto &inBlockOp0 : elseBlock) { + bool isInnerOp = inOpChain(innerChain, &inBlockOp0); + setSlicedResTy(&inBlockOp0, isInnerOp); + } + for (auto newInBlockOpRes : ifOp->getResults()) { + auto slicedOpType = + getSlicedType(newInBlockOpRes.getType(), iterCount, true); + newInBlockOpRes.setType(slicedOpType); + } + } + } else { + setSlicedResTy(&inBlockOp, inBlockIsInner); + } + } + for (auto op : erasedOps) { + if (op->use_empty()) + op->erase(); + } + } else if (auto ifOp = dyn_cast(op)) { + // iterCount.back() != 1 check? + // Set IfOp Result Type + setSlicedResTy(ifOp, isInner); + + // Set IfOp Arg Type + auto newIfBody = ifOp.getBody(); + auto newIfArgs = newIfBody->getArguments(); + for (auto newIfArg : newIfArgs) { + bool isInnerArg = inOpChain(innerChain, newIfArg.getDefiningOp()); + auto slicedArgType = + getSlicedType(newIfArg.getType(), iterCount, isInnerArg); + newIfArg.setType(slicedArgType); + } + + // Set IfOp's childOp Arg Type(Then) + auto &newIfThenRegion = ifOp.getThenRegion(); + auto &newIfThenBlock = newIfThenRegion.front(); + auto &oldIfThenRegion = ifOp.getThenRegion(); + auto &oldIfThenBlock = oldIfThenRegion.front(); + for (auto &inBlockOp : oldIfThenBlock) { + bool isInnerOp = inOpChain(innerChain, &inBlockOp); + + auto newInBlockOp = newIfThenBlock.begin(); + auto dist = std::distance(oldIfThenBlock.begin(), + Block::iterator(inBlockOp)); + std::advance(newInBlockOp, dist); + + for (auto newInBlockOpRes : newInBlockOp->getResults()) { + auto slicedOpType = getSlicedType(newInBlockOpRes.getType(), + iterCount, isInnerOp); + newInBlockOpRes.setType(slicedOpType); + } + } + + // Set IfOp's childOp Arg Type(Else) + auto &newIfElseRegion = ifOp.getElseRegion(); + if (!newIfElseRegion.empty()) { + auto &newIfElseBlock = newIfElseRegion.front(); + auto &oldIfElseRegion = ifOp.getElseRegion(); + auto &oldIfElseBlock = oldIfElseRegion.front(); + for (auto &inBlockOp : oldIfElseBlock) { + bool isInnerOp = inOpChain(innerChain, &inBlockOp); + + auto newInBlockOp = newIfElseBlock.begin(); + auto dist = std::distance(oldIfElseBlock.begin(), + Block::iterator(inBlockOp)); + std::advance(newInBlockOp, dist); + + for (auto newInBlockOpRes : newInBlockOp->getResults()) { + auto slicedOpType = getSlicedType(newInBlockOpRes.getType(), + iterCount, isInnerOp); + newInBlockOpRes.setType(slicedOpType); + } + } + } + + } else if (auto reshapeOp = dyn_cast(op)) { + if (auto reshapeResTy = + dyn_cast(reshapeOp.getResult().getType())) { + auto reshapeResShape = reshapeResTy.getShape(); + if (reshapeResShape.size() == 1) { + auto reshapeSrcTy = + cast(reshapeOp.getOperand().getType()); + auto reshapeSrcShape = reshapeSrcTy.getShape(); + size_t reshapeSrcSize = product(reshapeSrcShape); + llvm::SmallVector slicedShape(1, 1u); + + slicedShape[0] = + std::ceil(static_cast(reshapeSrcSize / iterCount[0])); + auto slicedReshapeSrcTy = RankedTensorType::get( + slicedShape, reshapeSrcTy.getElementType(), + reshapeResTy.getEncoding()); + reshapeOp.getResult().setType(slicedReshapeSrcTy); + } + } + } else { + setSlicedResTy(op, isInner); + } + + // LLVM_DEBUG(llvm::dbgs() << "After Deal:\n" << m << "\n"); + } + } + + // Create sm2gmPtrLenOpChain before func.returnOp + if (!sm2gmPtrLenOpChain.empty()) { + SmallVector funcRetures; + m.walk([&](func::ReturnOp funcReture) { + funcRetures.push_back(funcReture); + }); + assert(funcRetures.size() == 1 && + "Only one func.return is expected in the module"); + auto sortedSm2gmPtrLenOpChain = sortOpTreeBwd(sm2gmPtrLenOpChain); + for (int j = sortedSm2gmPtrLenOpChain.size() - 1; j >= 0; --j) { + auto op = sortedSm2gmPtrLenOpChain[j]; + OpBuilder builder(op); + op->moveBefore(funcRetures[0]); + // Set encoding, only core0 sm2gm + for (auto res : op->getResults()) { + auto resTy = res.getType(); + + if (auto resTensorTy = mlir::dyn_cast(resTy)) { + auto resShape = resTensorTy.getShape(); + auto elemTy = resTensorTy.getElementType(); + + if (auto resEncoding = + mlir::dyn_cast( + resTensorTy.getEncoding())) { + auto sizePerCore = resEncoding.getSizePerCore(); + auto coresPerGroup = resEncoding.getCoresPerGroup(); + auto groupsPerCluster = resEncoding.getGroupsPerCluster(); + auto order = resEncoding.getOrder(); + auto isReduceOpt = resEncoding.getIsReduceOpt(); + + SmallVector newCoresPerGroup(coresPerGroup.size(), 1); + SmallVector newGroupsPerCluster(groupsPerCluster.size(), + 1); + SmallVector newSizePerCore(resShape.begin(), + resShape.end()); + + auto newEncoding = triton::xpu::ClusterLayoutAttr::get( + context, newSizePerCore, newCoresPerGroup, + newGroupsPerCluster, order, isReduceOpt); + + auto newResTy = + RankedTensorType::get(resShape, elemTy, newEncoding); + res.setType(newResTy); + if (auto constOp = dyn_cast(op)) { + if (auto attr = mlir::dyn_cast( + constOp.getValue())) { + auto newValue = DenseElementsAttr::getFromRawBuffer( + newResTy, attr.getRawData()); + constOp.setValueAttr(newValue); + } + } + } else if (auto resEncoding = + mlir::dyn_cast( + resTensorTy.getEncoding())) { + auto resGlobalEncoding = + mlir::cast( + resEncoding.getParent()); + + auto dim = resEncoding.getDim(); + auto sizePerCore = resGlobalEncoding.getSizePerCore(); + auto coresPerGroup = resGlobalEncoding.getCoresPerGroup(); + auto groupsPerCluster = resGlobalEncoding.getGroupsPerCluster(); + auto order = resGlobalEncoding.getOrder(); + auto isReduceOpt = resGlobalEncoding.getIsReduceOpt(); + + SmallVector newCoresPerGroup(coresPerGroup.size(), 1); + SmallVector newGroupsPerCluster(groupsPerCluster.size(), + 1); + SmallVector newSizePerCore(sizePerCore.size(), 1); + assert(sizePerCore.size() < 3 && resShape.size() < 3 && + resShape.size() <= sizePerCore.size()); + if (sizePerCore.size() == 2 && resShape.size() == 1) { + newSizePerCore[1 - dim] = resShape[0]; + } else { + for (int i = 0; i < resShape.size(); ++i) { + newSizePerCore[i] = resShape[i]; + } + } + + auto newGlobalEncoding = triton::xpu::ClusterLayoutAttr::get( + context, newSizePerCore, newCoresPerGroup, + newGroupsPerCluster, order, isReduceOpt); + auto newEncoding = triton::gpu::SliceEncodingAttr::get( + context, resEncoding.getDim(), newGlobalEncoding); + + auto newResTy = + RankedTensorType::get(resShape, elemTy, newEncoding); + res.setType(newResTy); + } else { + assert(0 && "Unexpected tensor encoding in SM Optimization"); + } + } + } + } + } + + // Set ReduceOpHelper + m.walk([&](triton::xpu::ReduceOp redOp) { + ReduceOpHelper helper(redOp); + helper.setReduceId(reduceId); + helper.setReduceNum(reduceNum); + reduceId++; + }); + + // MakeRange Replace Protection + m.walk([&](triton::MakeRangeOp mrOp) { + OpBuilder builder(mrOp); + auto loc = mrOp->getLoc(); + uint32_t start = mrOp.getStart(); + uint32_t end = mrOp.getEnd(); + uint32_t realSize = end - start; + auto newMakeRangeOp = builder.create( + loc, mrOp.getType(), builder.getI32IntegerAttr(start), + builder.getI32IntegerAttr(end), builder.getI32IntegerAttr(realSize), + Value(), Value()); + mrOp->replaceAllUsesWith(newMakeRangeOp->getResults()); + }); + } +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/LoopGrid.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/LoopGrid.cpp new file mode 100644 index 000000000..0ee155701 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/LoopGrid.cpp @@ -0,0 +1,107 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO[dyq]: Pass Description +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPULOOPGRID +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPULoopGrid + : public impl::TritonXPULoopGridBase { + + using impl::TritonXPULoopGridBase::TritonXPULoopGridBase; + + static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = 3; + + Value ceilDiv(OpBuilder &builder, Location loc, Value lhs, Value rhs) { + auto c1 = builder.create(loc, 1, lhs.getType()); + auto sub = builder.create(loc, rhs, c1); + auto add = builder.create(loc, lhs, sub); + auto div = builder.create(loc, add, rhs); + return div.getResult(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + m.walk([&](triton::FuncOp func) { + OpBuilder b(func); + + auto i32Ty = b.getI32Type(); + auto origFuncType = func.getFunctionType(); + auto origInputTypes = origFuncType.getInputs(); + SmallVector newInputTypes(origInputTypes.begin(), + origInputTypes.end()); + newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, i32Ty); + + auto newFuncType = + b.getFunctionType(newInputTypes, origFuncType.getResults()); + + func.setType(newFuncType); + + // Add the corresponding arguments to function body + auto &body = func.getBody().front(); + for (unsigned int i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { + body.addArgument(i32Ty, func.getLoc()); + } + + // collect op that will be move to the loopGridFor + SmallVector operations; + for (auto &op : body.getOperations()) { + if (&op != body.getTerminator()) + operations.push_back(&op); + } + + b.setInsertionPoint(&body, body.begin()); + auto loc = b.getUnknownLoc(); + auto idxTy = b.getIndexType(); + auto argIdx = func.getNumArguments() - TRITON_PROGRAM_INFO_ARG_COUNT; + auto idxCluster = b.create(loc, i32Ty); + auto numCluster = b.create(loc, i32Ty); + auto gridX = func.getArgument(argIdx + 0); + auto gridY = func.getArgument(argIdx + 1); + auto gridZ = func.getArgument(argIdx + 2); + auto gridXY = b.create(loc, gridX, gridY); + auto gridXYZ = b.create(loc, gridXY, gridZ); + auto numProgramsPerCluster = ceilDiv(b, loc, gridXYZ, numCluster); + auto lower = b.create(loc, idxTy, idxCluster); + auto upper = b.create(loc, idxTy, gridXYZ); + auto step = b.create(loc, idxTy, numCluster); + auto loopGrid = b.create(loc, lower, upper, step); + for (auto op : operations) { + op->moveBefore(loopGrid.getBody()->getTerminator()); + } + b.setInsertionPointToStart(loopGrid.getBody()); + Value index = + b.create(loc, i32Ty, loopGrid.getInductionVar()); + auto pidZ = b.create(loc, index, gridZ); + index = b.create(loc, index, gridZ); + auto pidY = b.create(loc, index, gridY); + auto pidX = b.create(loc, index, gridY); + + SmallVector programId{pidX, pidY, pidZ}; + func.walk([&](triton::GetProgramIdOp op) { + op.replaceAllUsesWith(programId[op.getAxisAsInt()]); + }); + func.walk([&](triton::GetNumProgramsOp op) { + op.replaceAllUsesWith(func.getArgument(argIdx + op.getAxisAsInt())); + }); + }); + } +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Mask.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Mask.cpp new file mode 100644 index 000000000..de9eccf7b --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Mask.cpp @@ -0,0 +1,707 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO: Pass Description +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +#include "mlir/IR/IRMapping.h" + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUMASK +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPUMaskPass : public impl::TritonXPUMaskBase { + +public: + using impl::TritonXPUMaskBase::TritonXPUMaskBase; + + void getOpChain(llvm::SetVector &opChain, Operation *op) { + if (!op || opChain.contains(op)) + return; + + opChain.insert(op); + + if (op->use_empty()) + return; + + for (Operation *user : op->getUsers()) { + getOpChain(opChain, user); + } + } + + bool isFindUserOpImpl(Operation *startingOp, Operation *targetOp, + llvm::SetVector &visitedOps) { + if (startingOp == targetOp) { + return true; + } + + if (!startingOp || visitedOps.contains(startingOp)) { + return false; + } + + visitedOps.insert(startingOp); + + for (auto userOp : startingOp->getUsers()) { + if (isFindUserOpImpl(userOp, targetOp, visitedOps)) { + return true; + } + } + + return false; + } + + bool isFindUserOp(Operation *startingOp, Operation *targetOp) { + llvm::SetVector visitedOps; + return isFindUserOpImpl(startingOp, targetOp, visitedOps); + } + + // Block Access Optimization Degradation + // If the loadOp/storeOp .len() is not computed by subIOp, the block + // read/write optimizations cannot be applied. + void blockAccessOptDeGrad(mlir::ModuleOp m) { + m.walk([&](triton::xpu::GM2LMOp gm2lmOp) { + auto len = gm2lmOp.getLen(); + if (len) { + auto subIOp = len.getDefiningOp(); + if (!subIOp) { + OpBuilder builder(gm2lmOp); + auto loc = gm2lmOp->getLoc(); + gm2lmOp->setAttr("offsetState", + builder.getSI32IntegerAttr( + static_cast(OffsetState::Unknown))); + } + } + }); + + m.walk([&](triton::xpu::LM2GMOp lm2gmOp) { + auto len = lm2gmOp.getLen(); + if (len) { + auto subIOp = len.getDefiningOp(); + if (!subIOp) { + OpBuilder builder(lm2gmOp); + auto loc = lm2gmOp->getLoc(); + lm2gmOp->setAttr("offsetState", + builder.getSI32IntegerAttr( + static_cast(OffsetState::Unknown))); + } + } + }); + + m.walk([&](triton::xpu::SM2GMOp sm2gmOp) { + auto len = sm2gmOp.getLen(); + if (len) { + auto subIOp = len.getDefiningOp(); + if (!subIOp) { + OpBuilder builder(sm2gmOp); + auto loc = sm2gmOp->getLoc(); + sm2gmOp->setAttr("offsetState", + builder.getSI32IntegerAttr( + static_cast(OffsetState::Unknown))); + } + } + }); + } + + void addThreadIdMask(mlir::scf::IfOp ifOp, triton::xpu::LoadOp loadOp) { + + auto resTensorType = mlir::cast(loadOp.getType()); + + OpBuilder builder(ifOp); + auto loc = ifOp->getLoc(); + + // Step 1. get val + auto rcvVal = builder.create( + loc, builder.getI64Type(), builder.getI32IntegerAttr(0), loadOp); + auto rcvValI32 = builder.create( + loc, builder.getI32Type(), rcvVal); + + // Step 2. get ThreadNum + auto clusterNum = builder.create( + loc, builder.getIndexType(), mlir::gpu::Dimension::x); + auto clusterNum_cast = builder.create( + loc, builder.getI32Type(), clusterNum); + auto coreNum = builder.create( + loc, builder.getIndexType(), mlir::gpu::Dimension::x); + auto coreNum_cast = builder.create( + loc, builder.getI32Type(), coreNum); + auto threadNum = builder.create( + loc, builder.getI32Type(), clusterNum_cast, coreNum_cast); + + // Step 3. val % ThreadNum + auto remFOp = + builder.create(loc, rcvValI32, threadNum); + + // Step 4. get ThreadId + auto threadIdOp = builder.create( + loc, builder.getI32Type(), builder.getSI32IntegerAttr(1)); + + // Step 5. threadId == val % ThreadNum + auto threadCond = builder.create( + loc, builder.getI1Type(), mlir::arith::CmpIPredicate::eq, remFOp, + threadIdOp); + + auto originCondOp = ifOp.getCondition(); + + auto newCondOp = builder.create( + loc, builder.getI1Type(), originCondOp, threadCond); + + ifOp->setOperand(0, newCondOp); + } + + void atoNaiveMask(mlir::ModuleOp m) { + m.walk([&](scf::IfOp ifOp) { + OpBuilder builder(ifOp); + auto loc = ifOp->getLoc(); + + auto threadIdOp = builder.create( + loc, builder.getI32Type()); + + auto core0Op = builder.create(loc, 0, 32); + + auto atomicCondOp = builder.create( + loc, builder.getI1Type(), mlir::arith::CmpIPredicate::eq, core0Op, + threadIdOp); + + auto originCondOp = ifOp.getCondition(); + + auto newCondOp = builder.create( + loc, builder.getI1Type(), originCondOp, atomicCondOp); + + ifOp->setOperand(0, newCondOp); + }); + } + + void atoOptimizationMask(mlir::ModuleOp m) { + /*********************** Atomic Mask Opt ************************** + * %cst_0 = arith.constant dense<-1> : tensor<2048xi64> + * %cst_1 = arith.constant dense<0.000000e+00> : tensor<2048xf32> + * + * Before Optimization: + * %1 = tt.load: tensor<2048xi64> + * %2 = tt.load: tensor<2048xf32> + * %3 = arith.cmpi eq, %1, %cst_0 -> tensor<2048xi1> + * %4 = arith.select %3, %cst_1, %2 -> tensor<2048xf32>, + * %5 = tt.atomic_add %4 -> tensor<2048xf32> + + * After Optimization: + * %1 = tt.load: tensor<2048xi64> + * %3 = arith.cmpi eq, %1, %cst_0 -> tensor<2048xi1> + * %4 = arith.neg %3 ------> negativeCond + * scf.if %4 { + * %2 = tt.load: tensor<2048xf32> + * %5 = tt.atomic_add %4 -> tensor<2048xf32> + * } + * + *****************************************************************/ + arith::CmpIOp rcvCmpIOp; + triton::xpu::LoadOp rcvLoadOp; + arith::SelectOp rcvSelectOp; + AtomicMaskCond atoMaskCond = AtomicMaskCond::NonActivate; + bool atoMaskOpt = false; + + m.walk([&](arith::CmpIOp cmpIOp) { + if (cmpIOp.getPredicate() == arith::CmpIPredicate::eq) { + auto lhs = cmpIOp.getLhs(); + auto rhs = cmpIOp.getRhs(); + auto res = cmpIOp.getResult(); + + // Assert Only Have One SelectOp User + auto cmpiop_user_begin = cmpIOp->user_begin(); + auto cmpiop_user_end = cmpIOp->user_end(); + if (std::distance(cmpiop_user_begin, cmpiop_user_end) != 1) + return; + + if (auto selectOp = + dyn_cast(*cmpiop_user_begin)) { + + // Assert Only Have One AddFOp User + auto selectop_user_begin = selectOp->user_begin(); + auto selectop_user_end = selectOp->user_end(); + if (std::distance(selectop_user_begin, selectop_user_end) != 1) + return; + + if (auto addFOp = dyn_cast(*selectop_user_begin)) { + auto trueVal = selectOp.getTrueValue(); + auto falseVal = selectOp.getFalseValue(); + + auto trueLoadOp = + dyn_cast(trueVal.getDefiningOp()); + auto falseConstOp = + dyn_cast(falseVal.getDefiningOp()); + + auto trueConstOp = + dyn_cast(trueVal.getDefiningOp()); + auto falseLoadOp = + dyn_cast(falseVal.getDefiningOp()); + + if (trueLoadOp && falseConstOp) { + if (auto denseAttr = mlir::dyn_cast( + falseConstOp.getValue())) { + float constValue = *denseAttr.getValues().begin(); + if (constValue == 0.0f) + atoMaskCond = AtomicMaskCond::PostiveCond; + } + } else if (trueConstOp && falseLoadOp) { + if (auto denseAttr = mlir::dyn_cast( + trueConstOp.getValue())) { + float constValue = *denseAttr.getValues().begin(); + if (constValue == 0.0f) + atoMaskCond = AtomicMaskCond::NegativeCond; + } + } + + auto lhsLoadOp = dyn_cast(lhs.getDefiningOp()); + auto rhsConstOp = dyn_cast(rhs.getDefiningOp()); + + auto lhsConstOp = dyn_cast(lhs.getDefiningOp()); + auto rhsLoadOp = dyn_cast(rhs.getDefiningOp()); + + if ((lhsLoadOp && rhsConstOp) || (lhsConstOp && rhsLoadOp)) { + rcvCmpIOp = cmpIOp; + rcvLoadOp = lhsLoadOp ? lhsLoadOp : rhsLoadOp; + rcvSelectOp = selectOp; + atoMaskOpt = true; + } + } + } + } + }); + + // Step 1. Create If + if (rcvCmpIOp && rcvLoadOp && atoMaskOpt && + atoMaskCond != AtomicMaskCond::NonActivate) { + Operation *nextOp = rcvLoadOp->getNextNode(); + // Move the cmpi operation right after the load operation + OpBuilder builder(rcvLoadOp); + auto loc = rcvLoadOp->getLoc(); + builder.setInsertionPointAfter(rcvLoadOp); + + auto newCmpiOp = builder.create( + loc, rcvCmpIOp.getType(), mlir::arith::CmpIPredicate::eq, + rcvCmpIOp.getLhs(), rcvCmpIOp.getRhs()); + rcvCmpIOp.replaceAllUsesWith(newCmpiOp.getResult()); + rcvCmpIOp.erase(); + + auto trueValue = builder.create( + loc, builder.getI1Type(), + builder.getIntegerAttr(builder.getI1Type(), 1)); + + auto posCond = builder.create( + loc, builder.getI1Type(), builder.getI32IntegerAttr(0), newCmpiOp); + + auto negCond = + builder.create(loc, trueValue, posCond); + + // Create the SCF if operation + scf::IfOp scfIfOp; + if (atoMaskCond == AtomicMaskCond::PostiveCond) { + scfIfOp = builder.create(loc, posCond, + /*withElseRegion=*/false); + } else if (atoMaskCond == AtomicMaskCond::NegativeCond) { + scfIfOp = builder.create(loc, negCond, + /*withElseRegion=*/false); + } + + addThreadIdMask(scfIfOp, rcvLoadOp); + + // Move subsequent operations inside the ifOp's then block + builder.setInsertionPointToStart(&scfIfOp.getThenRegion().front()); + Operation *yieldOp = scfIfOp.getThenRegion().front().getTerminator(); + while (nextOp) { + Operation *currentOp = nextOp; + nextOp = nextOp->getNextNode(); + if (!isa(currentOp)) { + currentOp->moveBefore(yieldOp); + } + } + } + + // Step 2. Eliminate Unvalid SelectOp + if (rcvCmpIOp && rcvLoadOp && atoMaskOpt && + atoMaskCond != AtomicMaskCond::NonActivate) { + auto true_value = rcvSelectOp.getTrueValue(); + auto false_value = rcvSelectOp.getFalseValue(); + + if (atoMaskCond == AtomicMaskCond::PostiveCond) { + rcvSelectOp.replaceAllUsesWith(true_value); + } else if (atoMaskCond == AtomicMaskCond::NegativeCond) { + rcvSelectOp.replaceAllUsesWith(false_value); + } + rcvSelectOp->erase(); + } + } + + bool isOptMask(mlir::ModuleOp m) { + bool isOpt = false; + m.walk([&](triton::xpu::LM2GMOp lm2gmOp) { + auto op = findDefOpBwd(lm2gmOp.getValue()); + if (op) { + auto selectOp = cast(op); + if (auto cmpIOp = dyn_cast( + selectOp.getCondition().getDefiningOp())) { + if (cmpIOp.getPredicate() == arith::CmpIPredicate::eq) { + isOpt = true; + } + } + } + }); + return isOpt; + } + + // Add AtomicOp Simulation Condition + // We must use the specifiy core to avoid access race + void addAtomicSimulationCond(mlir::ModuleOp m) { + bool atomicSim = false; + m->walk([&](triton::xpu::GM2LMOp gm2lmOp) { + auto tensorTy = gm2lmOp.getResult().getType(); + if (auto rankTensorTy = mlir::dyn_cast(tensorTy)) { + auto gEncoding = mlir::cast( + rankTensorTy.getEncoding()); + auto coresPerGroup = gEncoding.getCoresPerGroup(); + auto groupsPerCluster = gEncoding.getGroupsPerCluster(); + + atomicSim = (llvm::find_if(coresPerGroup, + [](unsigned int num) { return num != 1; }) == + coresPerGroup.end()) && + (llvm::find_if(groupsPerCluster, [](unsigned int num) { + return num != 1; + }) == groupsPerCluster.end()); + + if (findUserOp(gm2lmOp) || + findUserOp(gm2lmOp)) { + atomicSim = false; + } + } + }); + + if (!atomicSim) + return; + + AtomicMaskType atoMaskTy = AtomicMaskType::NaiveMask; + if (maskValue != -1 && isOptMask(m)) { + atoMaskTy = AtomicMaskType::OptimizationMask; + } + + switch (atoMaskTy) { + case AtomicMaskType::NaiveMask: { + atoNaiveMask(m); + break; + } + case AtomicMaskType::OptimizationMask: { + atoOptimizationMask(m); + break; + } + default: + llvm_unreachable("Unknown Atomic Mask Type"); + } + } + + void runOnOperation() override { + mlir::ModuleOp m = getOperation(); + + // Check Core Tiling Optimization + // TODO[dyq]: Open core tiling pass + // m.walk([&](triton::ReduceOp redOp) { + // isReduceOpt = isReduceOptimized(redOp.operand().getType()); + + // if (auto reduceSrcTy = + // redOp.operand().getType().dyn_cast()) { + // auto sizePerCore = reduceSrcTy.getEncoding() + // .cast() + // .getSizePerCore(); + // rowsPerCore = sizePerCore[0]; + // } + // }); + + // Step 0. Convert CmpOp+SplatOp to SplatOp+CmpOp + m.walk([&](triton::SplatOp splatOp) { + if (auto splatDefOp = splatOp.getSrc().getDefiningOp()) { + if (auto cmpiOp = dyn_cast(splatDefOp)) { + OpBuilder builder(splatOp); + auto loc = cmpiOp->getLoc(); + auto lhs = cmpiOp.getLhs(); + auto rhs = cmpiOp.getRhs(); + auto lhsElemTy = lhs.getType(); + if (auto lhsTensorTy = dyn_cast(lhs.getType())) { + lhsElemTy = lhsTensorTy.getElementType(); + } + auto resTy = cast(splatOp.getType()); + auto newSplatTy = RankedTensorType::get(resTy.getShape(), lhsElemTy, + resTy.getEncoding()); + auto lhsSplatOp = + builder.create(loc, newSplatTy, lhs); + auto rhsSplatOp = + builder.create(loc, newSplatTy, rhs); + auto newCmpiOp = builder.create( + loc, resTy, cmpiOp.getPredicate(), lhsSplatOp, rhsSplatOp); + splatOp.replaceAllUsesWith(newCmpiOp.getResult()); + } + } + }); + + // Step 1. Replace CmiOp with SubOp(replace mask with len) + llvm::DenseMap> maskUsersMap; + llvm::SetVector cmpiOps; + + m.walk([&](mlir::arith::CmpIOp cmpiOp) { + llvm::SetVector maskUsers; + llvm::SetVector cmpiOpUsers; + if (cmpiOp.getPredicate() == arith::CmpIPredicate::slt || + cmpiOp.getPredicate() == arith::CmpIPredicate::ult) { + getOpChain(cmpiOpUsers, cmpiOp); + for (auto user : cmpiOpUsers) { + Value mask; + if (auto loadOp = dyn_cast(user)) { + mask = loadOp.getLen(); + } + if (auto storeOp = dyn_cast(user)) { + mask = storeOp.getLen(); + } + if (auto storeOp = dyn_cast(user)) { + mask = storeOp.getLen(); + } + if (mask) { + if (mask == cmpiOp.getResult()) { + cmpiOps.insert(cmpiOp); + maskUsers.insert(user); + } + if (auto andIOp = + dyn_cast(mask.getDefiningOp())) { + if (isFindUserOp(cmpiOp, andIOp.getLhs().getDefiningOp())) { + if (auto cmpiOpDef = + findDefOpBwd(andIOp.getLhs())) { + cmpiOps.insert(cmpiOp); + auto expandDimOp = cast(cmpiOpDef); + // expand axis = 0 means inner dim + if (expandDimOp.getAxis() == 0) { + maskUsers.insert(user); + } + } + } else if (isFindUserOp(cmpiOp, + andIOp.getRhs().getDefiningOp())) { + if (auto cmpiOpDef = + findDefOpBwd(andIOp.getRhs())) { + cmpiOps.insert(cmpiOp); + auto expandDimOp = cast(cmpiOpDef); + // expand axis = 0 means inner dim + if (expandDimOp.getAxis() == 0) { + maskUsers.insert(user); + } + } + } + } + } + } + if (!maskUsers.empty()) { + maskUsersMap[cmpiOp] = maskUsers; + } + } + }); + + // Step 2 Replace the mask of LoadOp/StoreOp + for (const auto &pair : maskUsersMap) { + Value rhs; + Value lhs; + auto op = pair.first; + auto users = pair.second; + + if (auto cmpiOp = dyn_cast(op)) { + rhs = cmpiOp.getRhs(); + lhs = cmpiOp.getLhs(); + } else { + llvm_unreachable( + "cmpiOp only is mlir::arith::CmpIOp/triton::gpu::CmpIOp"); + } + + OpBuilder builder(op); + auto loc = op->getLoc(); + // TODO[dyq]: Open coretiling pass + // if (isReduceOpt && rowsPerCore != 1) { + // auto resTensorType = rhs.getType().cast(); + // SmallVector values( + // resTensorType.getNumElements(), + // builder.getI32IntegerAttr(rowsPerCore)); + // auto denseValues = DenseElementsAttr::get(resTensorType, values); + // auto rowsPerCoreValue = + // builder.create(loc, denseValues); + // rhs = builder.create(loc, rhs, + // rowsPerCoreValue); + // } + + auto elemLenOp = + builder.create(loc, rhs.getType(), rhs, lhs); + + // get maskValue from subIOp.rhs() + if (rhs.getDefiningOp()) { + if (auto rMaskConstOp = + dyn_cast(rhs.getDefiningOp())) { + auto resTy = rMaskConstOp.getResult().getType(); + if (auto tensorType = mlir::dyn_cast(resTy)) { + if (auto denseAttr = mlir::dyn_cast( + rMaskConstOp.getValue())) { + auto values = denseAttr.getValues(); + if (!values.empty()) { + maskValue = values[0].getZExtValue(); + } + } + } + } + } + + for (auto user : users) { + if (auto loadOp = dyn_cast(user)) { + loadOp.setOperand(1, elemLenOp.getResult()); + } else if (auto storeOp = dyn_cast(user)) { + storeOp.setOperand(2, elemLenOp.getResult()); + } else if (auto storeOp = dyn_cast(user)) { + storeOp.setOperand(1, elemLenOp.getResult()); + } + } + } + + // Step 3. Create scf::IfOp to deal with tailing + cmpiOps = multiRootTopologicalSort(cmpiOps); + llvm::SmallVector sortedCmpiOps; + for (auto it = cmpiOps.begin(); it != cmpiOps.end(); ++it) { + sortedCmpiOps.emplace_back(*it); + } + llvm::SmallVector> ifBlockTrees; + for (int i = 0; i < sortedCmpiOps.size(); ++i) { + Operation *op = sortedCmpiOps[i]; + OpBuilder builder(op); + mlir::Block *block = op->getBlock(); + auto loc = builder.getUnknownLoc(); + // ops to be moved into ifblock and later earsed + llvm::SetVector opsToMoveAndErase; + // Get the ops that from current to the end of the block + mlir::Operation *terminator = block->getTerminator(); + auto it = op->getIterator(); + ++it; + for (; &*it != terminator; ++it) { + opsToMoveAndErase.insert(&*it); + } + if (auto yieldOp = dyn_cast(terminator)) { + if (yieldOp.getResults().size() > 0) { + opsToMoveAndErase.insert(terminator); + } + } + auto sortedOpsToMoveAndErase = sortOpTreeBwd(opsToMoveAndErase); + ifBlockTrees.emplace_back(sortedOpsToMoveAndErase); + // Create scf::IfOp + builder.setInsertionPoint(terminator); + mlir::scf::IfOp newIfOp; + if (auto cmpiOp = dyn_cast(op)) { + auto cond = builder.create( + loc, builder.getI1Type(), builder.getI32IntegerAttr(0), + cmpiOp.getResult()); + if (auto yieldOp = dyn_cast(terminator)) { + if (yieldOp.getResults().size() > 0) { + newIfOp = builder.create( + loc, yieldOp.getResults().getType(), cond, + /*withElseRegion=*/true); + builder.setInsertionPointToStart(newIfOp.thenBlock()); + } else { + newIfOp = builder.create(loc, cond, + /*withElseRegion=*/false); + builder.setInsertionPointToStart(newIfOp.thenBlock()); + } + } else if (auto funcRetureOp = dyn_cast(terminator)) { + newIfOp = builder.create(loc, cond, + /*withElseRegion=*/false); + builder.setInsertionPointToStart(newIfOp.thenBlock()); + } + } else { + llvm_unreachable("cmpiOp only is mlir::arith::CmpIOp"); + } + + mlir::IRMapping mapping; + for (int j = sortedOpsToMoveAndErase.size() - 1; j >= 0; --j) { + auto bodyOp = sortedOpsToMoveAndErase[j]; + auto newBodyOp = builder.clone(*bodyOp, mapping); // Clone bodyOps + // TODO[dyq]: Open core tiling pass + // if (auto reduceOp = dyn_cast(bodyOp)) { + // // for shared memory init + // auto newReduceOp = cast(newBodyOp); + // ReduceOpHelper helper(reduceOp); + // ReduceOpHelper newHelper(newReduceOp); + // newHelper.setReduceId(helper.getReduceId()); + // newHelper.setOriginResShape(helper.getOriginResShape()); + // } else + if (auto storeOp = dyn_cast(bodyOp)) { + SMHelper helper(storeOp); + SMHelper newHelper(newBodyOp); + newHelper.setOffset(helper.getOffset()); + } else if (auto sm2gmOp = dyn_cast(bodyOp)) { + SMHelper helper(sm2gmOp); + SMHelper newHelper(newBodyOp); + newHelper.setOffset(helper.getOffset()); + } + } + + if (auto yieldOp = dyn_cast(terminator)) { + if (yieldOp.getResults().size() > 0) { + builder.setInsertionPointToStart(newIfOp.elseBlock()); + auto block = yieldOp->getBlock(); + auto *region = block->getParent(); + auto *parentOp = region->getParentOp(); + if (auto parentForOp = dyn_cast(parentOp)) { + // Create YiledOp For newIfOp + if (parentForOp.getRegionIterArgs().size() > 0) { + builder.create( + loc, parentForOp.getRegionIterArgs()); + } + // Create YiledOp For YieldOp's ParentOp + builder.setInsertionPointToEnd(parentForOp.getBody()); + auto ifResults = newIfOp.getResults(); + builder.create(loc, ifResults); + } else if (auto parentIfOp = dyn_cast(parentOp)) { + // Create YiledOp For newIfOp + auto elseYieldOp = parentIfOp.elseYield(); + auto elseResults = elseYieldOp.getResults(); + builder.create(loc, elseResults); + // Create YiledOp For YieldOp's ParentOp + builder.setInsertionPointToEnd(parentIfOp.getBody()); + builder.create(loc, newIfOp.getResults()); + } else { + llvm_unreachable("Unknown Mask YiledOp Pattern"); + } + } + } + + // update next iteration of sortedCmpiOps + for (int k = i + 1; k < sortedCmpiOps.size(); ++k) { + auto mappedMaskVal = + mapping.lookupOrDefault(sortedCmpiOps[k]->getResult(0)); + sortedCmpiOps[k] = mappedMaskVal.getDefiningOp(); + } + // Erase Old Ops + for (auto op : sortedOpsToMoveAndErase) { + if (op->getParentOp() != nullptr) { + op->erase(); + } + } + } + + // Step 3. Block Read/Write Optimizations Degradation + blockAccessOptDeGrad(m); + + // Step 4. Add AtomicOp Simulation Conditon + addAtomicSimulationCond(m); + } + +private: + int32_t maskValue = -1; +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/MemoryAsync.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/MemoryAsync.cpp new file mode 100644 index 000000000..4eb224fdd --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/MemoryAsync.cpp @@ -0,0 +1,111 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +#define DEBUG_TYPE "tritonxpu-memory-async" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUMEMORYASYNC +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +class TritonXPUMemoryAsyncPass + : public impl::TritonXPUMemoryAsyncBase { +public: + using impl::TritonXPUMemoryAsyncBase< + TritonXPUMemoryAsyncPass>::TritonXPUMemoryAsyncBase; + + TritonXPUMemoryAsyncPass() = default; + TritonXPUMemoryAsyncPass(bool dumpFlag) { this->dumpFlag = dumpFlag; } + + void loadOpAsyncCheck(triton::xpu::LoadOp loadOp_1, + triton::xpu::LoadOp loadOp_2) { + OpBuilder builder(loadOp_1); + + if (loadOp_1->getBlock() == loadOp_2->getBlock() && + loadOp_1->isBeforeInBlock(loadOp_2)) { + auto gm2lmOp_1 = cast(loadOp_1->getPrevNode()); + gm2lmOp_1->setAttr("async", builder.getBoolAttr(true)); + loadOp_1->moveAfter(loadOp_2); + if (dumpFlag) + LLVM_DEBUG(llvm::dbgs() << "Memory Async Optimization Hit!\n"); + } + } + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + mod->walk([&](triton::xpu::GM2LMOp gm2lmOp_1) { + // Pruning + if (gm2lmOp_1.getAsync()) + return; + + auto loadOp_1 = cast(gm2lmOp_1->getNextNode()); + + auto loadop_user_begin = loadOp_1->user_begin(); + auto loadop_user_end = loadOp_1->user_end(); + + if (!loadOp_1->hasOneUse()) + return; + + llvm::TypeSwitch(*loadop_user_begin) + .Case([&](auto vBinOp) { + auto lhsOp = + vBinOp.getLhs().template getDefiningOp(); + auto rhsOp = + vBinOp.getRhs().template getDefiningOp(); + + if (lhsOp && rhsOp) { + triton::xpu::LoadOp loadOp_2 = lhsOp == loadOp_1 ? rhsOp : lhsOp; + loadOpAsyncCheck(loadOp_1, loadOp_2); + } + }) + .Case([&](auto binOp) { + auto lhsOp = + binOp.getLhs().template getDefiningOp(); + auto rhsOp = + binOp.getRhs().template getDefiningOp(); + + if (lhsOp && rhsOp) { + triton::xpu::LoadOp loadOp_2 = lhsOp == loadOp_1 ? rhsOp : lhsOp; + loadOpAsyncCheck(loadOp_1, loadOp_2); + } + }) + .Case( + [&](auto selectOp) { + auto tv = selectOp.getTrueValue() + .template getDefiningOp(); + auto fv = selectOp.getFalseValue() + .template getDefiningOp(); + + if (tv && fv) { + triton::xpu::LoadOp loadOp_2 = tv == loadOp_1 ? fv : tv; + loadOpAsyncCheck(loadOp_1, loadOp_2); + } + }) + .Default([&](auto &op) { + if (dumpFlag) { + LLVM_DEBUG({ + op->dump(); + llvm::dbgs() << "Unsupport Op For Memory Async Optimization\n"; + }); + } + }); + }); + } +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/OffsetAnalysis.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/OffsetAnalysis.cpp new file mode 100644 index 000000000..95f3dfa19 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/OffsetAnalysis.cpp @@ -0,0 +1,1290 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO: Pass Description +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "triton/Analysis/UtilityXPU.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +#define DEBUG_TYPE "tritonxpu-offset-analysis" + +namespace mlir { + +using OffsetStateTransitionTable = + std::map>; + +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUOFFSETANALYSIS +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPUOffsetAnalysisPass + : public impl::TritonXPUOffsetAnalysisBase { + +public: + using impl::TritonXPUOffsetAnalysisBase< + TritonXPUOffsetAnalysisPass>::TritonXPUOffsetAnalysisBase; + + struct MockData { + Operation *mockOp; + int mockVal; + SmallVector mockVals; + + MockData(Operation *op, int val, SmallVector &vals) { + mockOp = op; + mockVal = val; + mockVals = vals; + } + + std::string getToken() { + std::string prefix = ""; + TypeSwitch(mockOp) + .Case( + [&](auto indexCastOp) { prefix = "index_"; }) + .Case( + [&](auto threadIdOp) { prefix = "coreId_"; }) + .Case( + [&](auto getProgramIdOp) { prefix = "clusterId_"; }) + .Case([&](auto gm2lmOp) { prefix = "gm2lm_"; }); + return prefix + std::to_string(mockVal); + } + }; + + template ::value, bool> = true> + void legalizeOffset(T memoryOp) { + Value ptr = memoryOp.getPtr(); + Operation *ptrOp = findDefOpBwd(ptr); + if (ptrOp) { + if (auto memAddPtrOp = dyn_cast(ptrOp)) { + SetVector ptrOpChain; + getPtrChainBwd(ptrOpChain, memAddPtrOp, memAddPtrOp); + if (ptrOpChain.size() < 2) { + return; + } + Value addIRes = memAddPtrOp.getOffset(); + OpBuilder builder(memAddPtrOp); + auto loc = memAddPtrOp.getLoc(); + for (auto op : ptrOpChain) { + if (auto addPtrOp = dyn_cast(op)) { + auto addPtrDefOp = addPtrOp.getPtr().getDefiningOp(); + if (addPtrDefOp) { + if (auto preAddPtrOp = dyn_cast(addPtrDefOp)) { + if (getElementTypeOrSelf(addIRes).getIntOrFloatBitWidth() == + 32 && + getElementTypeOrSelf(preAddPtrOp.getOffset()) + .getIntOrFloatBitWidth() == 64) { + auto extIntOp = builder.create( + loc, preAddPtrOp.getOffset().getType(), addIRes); + auto addIOp = builder.create( + loc, preAddPtrOp.getOffset().getType(), extIntOp, + preAddPtrOp.getOffset()); + addIRes = addIOp.getResult(); + } else if (getElementTypeOrSelf(addIRes) + .getIntOrFloatBitWidth() == 64 && + getElementTypeOrSelf(preAddPtrOp.getOffset()) + .getIntOrFloatBitWidth() == 32) { + auto extIntOp = builder.create( + loc, addIRes.getType(), preAddPtrOp.getOffset()); + auto addIOp = builder.create( + loc, addIRes.getType(), addIRes, extIntOp); + addIRes = addIOp.getResult(); + } else { + auto addIOp = builder.create( + loc, addPtrOp.getOffset().getType(), addIRes, + preAddPtrOp.getOffset()); + addIRes = addIOp.getResult(); + } + } else { + memAddPtrOp.setOperand(0, addPtrOp.getPtr()); + } + } else { + memAddPtrOp.setOperand(0, addPtrOp.getPtr()); + } + } + } + memAddPtrOp.setOperand(1, addIRes); + } + } + } + + void getPtrChainBwd(llvm::SetVector &opChain, + triton::AddPtrOp startOp, Operation *op) { + if (!op) { + return; + } + opChain.insert(op); + + int noDefCnt = 0; + for (auto operand : op->getOperands()) { + if (!operand.getDefiningOp()) { + noDefCnt++; + } + } + + if (isa(op) || noDefCnt == op->getNumOperands()) { + return; + } + + if (auto addptrOp = dyn_cast(op)) { + if (startOp.getResult().getType() == addptrOp.getResult().getType()) { + getPtrChainBwd(opChain, startOp, addptrOp.getPtr().getDefiningOp()); + } + } + return; + } + + void getOpChainBwdBFS(llvm::SetVector &opChain, + Operation *startOp) { + if (!startOp) { + return; + } + std::queue opQueue; + opQueue.push(startOp); + + while (!opQueue.empty()) { + Operation *currentOp = opQueue.front(); + opQueue.pop(); + + if (!opChain.insert(currentOp)) { + continue; + } + + if (isa(currentOp) || + isa(currentOp) || + isa(currentOp) || + (isa(currentOp) && currentOp != startOp)) { + continue; + } + + if (isa(currentOp)) { + opQueue.push(currentOp->getOperands()[0].getDefiningOp()); + continue; + } + + for (auto operand : currentOp->getOperands()) { + if (Operation *operandOp = operand.getDefiningOp()) { + opQueue.push(operandOp); + } + } + } + } + + llvm::SetVector + sortByLine(llvm::SetVector &opChain) { + auto compareOpsByLine = [&](mlir::Operation *op1, mlir::Operation *op2) { + return op2Line[op1] > op2Line[op2]; + }; + + llvm::SmallVector opChainVec; + for (auto op : opChain) { + opChainVec.emplace_back(op); + } + llvm::sort(opChainVec, compareOpsByLine); + llvm::SetVector sortedOpChain; + for (auto op : opChainVec) { + sortedOpChain.insert(op); + } + return sortedOpChain; + } + + void dumpOpChain(SetVector &opChain) { + for (auto it = opChain.rbegin(), eit = opChain.rend(); it != eit; ++it) { + Operation *op = *it; + if (op) + op->dump(); + } + } + + SmallVector getMockDataItems(SetVector opChain) { + auto getProgramIdMockVals = []() { + SmallVector mockVals(12); + std::iota(mockVals.begin(), mockVals.end(), 0); + return mockVals; + }; + + auto getThreadIdMockVals = []() { + SmallVector mockVals(/*core_num*/ 64); + std::iota(mockVals.begin(), mockVals.end(), 0); + return mockVals; + }; + + auto getIndexMockVals = [](Operation *indexCastOp) { + // Get forOp + Value arg = indexCastOp->getOperand(0); // only have one operand + BlockArgument blockArg = mlir::dyn_cast(arg); + Block *block = blockArg.getOwner(); + auto forOp = cast(block->getParentOp()); + + // Get Induction Vars + auto lowerBoundOp = + forOp.getLowerBound().getDefiningOp(); + auto stepOp = forOp.getStep().getDefiningOp(); + auto upperBoundOp = + forOp.getUpperBound().getDefiningOp(); + + // Get mockVals + SmallVector mockVals; + if (lowerBoundOp && stepOp && upperBoundOp) { + auto lowerBoundVal = mlir::cast(lowerBoundOp.getValue()) + .getValue() + .getZExtValue(); + auto stepVal = mlir::cast(stepOp.getValue()) + .getValue() + .getZExtValue(); + auto upperBoundVal = mlir::cast(upperBoundOp.getValue()) + .getValue() + .getZExtValue(); + + for (auto i = lowerBoundVal; i < upperBoundVal; i += stepVal) + mockVals.emplace_back(i); + } + + return mockVals; + }; + + auto getGM2LMOpMockVals = []() { + SmallVector mockVals(1, 1); + return mockVals; + }; + + SmallVector mockDataItems; + for (auto it = opChain.rbegin(), eit = opChain.rend(); it != eit; ++it) { + TypeSwitch(*it) + .Case([&](auto indexCastOp) { + SmallVector mockVals = getIndexMockVals(indexCastOp); + mockDataItems.emplace_back(MockData(indexCastOp, 0, mockVals)); + }) + .Case([&](auto threadIdOp) { + SmallVector mockVals = getThreadIdMockVals(); + mockDataItems.emplace_back(MockData(threadIdOp, 0, mockVals)); + }) + .Case([&](auto getProgramIdOp) { + SmallVector mockVals = getProgramIdMockVals(); + mockDataItems.emplace_back(MockData(getProgramIdOp, 0, mockVals)); + }) + .Case([&](auto gm2lmOp) { + SmallVector mockVals = getGM2LMOpMockVals(); + mockDataItems.emplace_back(MockData(gm2lmOp, 0, mockVals)); + }); + } + return mockDataItems; + }; + + SmallVector + constOpCalFunc(arith::ConstantOp op, + DenseMap> &op2OffsetVal) { + auto type = op.getResult().getType(); + int intValue; + + if (auto tensorType = mlir::dyn_cast(type)) { // tensor + auto shape = tensorType.getShape(); + unsigned rank = shape.size(); + unsigned numElems = shape[rank - 1]; + + auto denseAttr = mlir::dyn_cast(op.getValue()); + + auto elementType = tensorType.getElementType(); + if (elementType.isF32()) { + intValue = *denseAttr.getValues().begin(); + } else if (elementType.isInteger(32)) { + intValue = *denseAttr.getValues().begin(); + } else if (elementType.isInteger(64)) { + intValue = *denseAttr.getValues().begin(); + } else if (elementType.isInteger(1)) { + intValue = *denseAttr.getValues().begin(); + } else { + llvm_unreachable( + "[Offset Analysis] Unsupported Element Type in ConstOp"); + } + + return SmallVector(numElems, intValue); + } + + if (type.isF32()) { + auto doubleVal = mlir::cast(op.getValue()).getValueAsDouble(); + intValue = static_cast(doubleVal); + } else { + intValue = + mlir::cast(op.getValue()).getValue().getZExtValue(); + } + + return SmallVector(1, intValue); + } + + SmallVector + indexCastOpCalFunc(arith::IndexCastOp op, + DenseMap> &op2OffsetVal, + int mockVal = 0) { + return SmallVector(1, mockVal); + } + + SmallVector + threadIdOpCalFunc(mlir::gpu::ThreadIdOp op, + DenseMap> &op2OffsetVal, + int mockVal = 0) { + return SmallVector(1, mockVal); + } + + SmallVector + getProgramIdOpCalFunc(triton::GetProgramIdOp op, + DenseMap> &op2OffsetVal, + int mockVal = 0) { + return SmallVector(1, mockVal); + } + + SmallVector + xpuGm2lmOpCalFunc(triton::xpu::GM2LMOp op, + DenseMap> &op2OffsetVal, + int mockVal = 0) { + + auto offsetState = static_cast(op.getOffsetState()); + int32_t op_lrie = op.getLrie(); + assert((offsetState == OffsetState::DiscreteSame || op_lrie > 1) && + "Mocked GM2LMOp must be DiscreteSame"); + + auto result = op.getResult(); + auto resTy = mlir::dyn_cast(result.getType()); + auto resShape = resTy.getShape(); + unsigned numElems = resShape[resShape.size() - 1]; + + return SmallVector(numElems, mockVal); + } + + SmallVector + xpuLoadOpCalFunc(triton::xpu::LoadOp op, + DenseMap> &op2OffsetVal, + int mockVal = 0) { + auto ptr = op.getPtr(); + auto ptrOp = ptr.getDefiningOp(); + return op2OffsetVal[ptrOp]; + } + + SmallVector + makeRangeOpCalFunc(triton::MakeRangeOp op, + DenseMap> &op2OffsetVal) { + + auto start = op.getStart(); + auto end = op.getEnd(); + + SmallVector res; + for (size_t i = start; i < end; ++i) { + res.push_back(i); + } + return res; + } + + SmallVector + splatOpCalFunc(triton::SplatOp op, + DenseMap> &op2OffsetVal) { + + auto operand = op.getOperand(); + Operation *operandOp; + SmallVector prevValue; + if (mlir::isa(operand)) { + prevValue.push_back(0); + } else { + operandOp = operand.getDefiningOp(); + assert(operandOp && op2OffsetVal.find(operandOp) != op2OffsetVal.end() && + "Operands must be present in op2OffsetVal map"); + prevValue = op2OffsetVal[operandOp]; + } + + assert(prevValue.size() == 1 && "[splatOpCalFunc] Only support 1->N splat"); + + auto src = op.getSrc(); + auto res = op.getResult(); + auto srcTy = mlir::dyn_cast(src.getType()); + auto resTy = mlir::dyn_cast(res.getType()); + auto resShape = resTy.getShape(); + + unsigned rank = resShape.size(); + unsigned numElems = resShape[rank - 1]; + + return SmallVector(numElems, prevValue[0]); + } + + SmallVector + expandDimsOpCalFunc(triton::ExpandDimsOp op, + DenseMap> &op2OffsetVal) { + + auto operand = op.getOperand(); + auto operandOp = operand.getDefiningOp(); + assert(operandOp && op2OffsetVal.find(operandOp) != op2OffsetVal.end() && + "Operands must be present in op2OffsetVal map"); + + auto src = op.getSrc(); + auto res = op.getResult(); + auto srcTy = mlir::dyn_cast(src.getType()); + auto resTy = mlir::dyn_cast(res.getType()); + auto srcShape = srcTy.getShape(); + auto resShape = resTy.getShape(); + + if (srcShape.size() == 1 && resShape.size() == 2 && + resShape[resShape.size() - 1] == 1) { // xmask make_range + return SmallVector(1, 0); + } + + return op2OffsetVal[operandOp]; + } + + SmallVector + broadcastOpCalFunc(triton::xpu::BroadcastOp op, + DenseMap> &op2OffsetVal) { + + auto operand = op.getOperand(); + auto operandOp = operand.getDefiningOp(); + assert(operandOp && op2OffsetVal.find(operandOp) != op2OffsetVal.end() && + "Operands must be present in op2OffsetVal map"); + + auto src = op.getSrc(); + auto res = op.getResult(); + auto srcTy = mlir::dyn_cast(src.getType()); + auto resTy = mlir::dyn_cast(res.getType()); + auto srcShape = srcTy.getShape(); + auto resShape = resTy.getShape(); + + if (!srcTy && resTy) { // [f32 -> 1xf32] + unsigned numElems = resShape[resShape.size() - 1]; + assert(op2OffsetVal[operandOp].size() == 1 && + "[broadcastOpCalFunc] Error Input Shape [f32 -> " + "1xf32]"); + return SmallVector(numElems, op2OffsetVal[operandOp][0]); + } + + if (srcShape.size() != resShape.size()) { + return op2OffsetVal[operandOp]; + } + + int broadNum = 0; + for (size_t i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] != resShape[i]) { + if (++broadNum > 1) { // [1x1xf32 -> NxNxf32] + llvm_unreachable("[broadcastOpCalFunc] Unsupported broadcast 2 dims"); + } + } + } + + for (size_t i = 0; i < srcShape.size(); ++i) { + if (srcShape[i] != resShape[i]) { + if (srcShape[i] == 1) { // [1x1xf32 -> 1xNxf32] + unsigned numElems = resShape[resShape.size() - 1]; + if (i == srcShape[srcShape.size() - 1]) + return SmallVector(numElems, op2OffsetVal[operandOp][0]); + else + return op2OffsetVal[operandOp]; + } else { // [1x2xf32 -> 1xNxf32] + llvm_unreachable("[broadcastOpCalFunc] Only support broadcast 1->N"); + } + } + } + + return op2OffsetVal[operandOp]; + } + + SmallVector xpuConvertLayoutOpCalFunc( + triton::xpu::ConvertLayoutOp op, + DenseMap> &op2OffsetVal) { + + auto operand = op.getOperand(); + auto operandOp = operand.getDefiningOp(); + assert(operandOp && op2OffsetVal.find(operandOp) != op2OffsetVal.end() && + "Operands must be present in op2OffsetVal map"); + + return op2OffsetVal[operandOp]; + } + + template + SmallVector + unaryOpCalFunc(T op, DenseMap> &op2OffsetVal) { + auto operand = op.getOperand(); + + auto unaryOp = operand.getDefiningOp(); + if (!unaryOp) { + LLVM_DEBUG(llvm::dbgs() << "Operand Must Be Defined by Operations\n"); + findUnsupportedOp = true; + return {}; + } + assert(op2OffsetVal.find(unaryOp) != op2OffsetVal.end() && + "Operand Must Be Present in op2OffsetVal map"); + + SmallVector operandVal = op2OffsetVal[unaryOp]; + + SmallVector res; + if constexpr (std::is_same_v) { + for (size_t i = 0; i < operandVal.size(); ++i) { + res.push_back(operandVal[i]); + } + } else { + llvm_unreachable("Unknown binOpCalFunc Type"); + } + return res; + } + + template + SmallVector + binOpCalFunc(T op, DenseMap> &op2OffsetVal) { + auto operands = op.getOperands(); + assert(operands.size() == 2 && + "Expected binary operation with two operands"); + + auto lhsOp = operands[0].getDefiningOp(); + auto rhsOp = operands[1].getDefiningOp(); + if (!lhsOp || !rhsOp) { + findUnsupportedOp = true; + return {}; + } + assert(op2OffsetVal.find(lhsOp) != op2OffsetVal.end() && + op2OffsetVal.find(rhsOp) != op2OffsetVal.end() && + "Operands must be present in op2OffsetVal map"); + + SmallVector lhs = op2OffsetVal[lhsOp]; + SmallVector rhs = op2OffsetVal[rhsOp]; + if (lhs.size() != rhs.size()) { + LLVM_DEBUG(llvm::dbgs() << "lhs.size(): " << lhs.size() + << "/ rhs.size(): " << rhs.size() << "\n"); + } + assert(lhs.size() == rhs.size() && + "Two operands size must be equal"); // TODO: splat for binOp + + SmallVector res; + if constexpr (std::is_same_v) { + for (size_t i = 0; i < lhs.size(); ++i) { + res.push_back(lhs[i] + rhs[i]); + } + } else if constexpr (std::is_same_v) { + for (size_t i = 0; i < lhs.size(); ++i) { + res.push_back(lhs[i] - rhs[i]); + } + } else if constexpr (std::is_same_v) { + for (size_t i = 0; i < lhs.size(); ++i) { + if (rhs[i] == 0) { + LLVM_DEBUG(llvm::dbgs() + << "Div 0, Return OffsetState::Unknown for Protection\n"); + findUnsupportedOp = true; + return {}; + } + res.push_back(lhs[i] / rhs[i]); + } + } else if constexpr (std::is_same_v) { + for (size_t i = 0; i < lhs.size(); ++i) { + res.push_back(lhs[i] * rhs[i]); + } + } else if constexpr (std::is_same_v) { + for (size_t i = 0; i < lhs.size(); ++i) { + if (rhs[i] == 0) { + LLVM_DEBUG(llvm::dbgs() + << "Rem 0, Return OffsetState::Unknown for Protection\n"); + findUnsupportedOp = true; + return {}; + } + res.push_back(lhs[i] % rhs[i]); + } + } else { + llvm_unreachable("Unknown binOpCalFunc Type"); + } + return res; + } + + SmallVector getOffset(const SetVector &opChain, + Operation *offsetDefineOp, + const SmallVector &mockDataItems) { + auto hasDynamicInput = [](Operation *op) -> bool { + for (auto operand : op->getOperands()) { + if (mlir::isa(operand)) { + continue; + } + auto operandOp = operand.getDefiningOp(); + if (!operandOp) { + return true; + } + } + return false; + }; + + bool findGM2LMOp = + llvm::find_if(opChain, [](Operation *op) { + if (auto gm2lmOp = dyn_cast(op)) { + auto offsetState = + static_cast(gm2lmOp.getOffsetState()); + // We can mock DiscreteSame OffsetState + return offsetState != OffsetState::DiscreteSame; + } + return false; + }) != opChain.end(); + + // Step 1. Pruning + if (opChain.empty() || findGM2LMOp) + return {}; + + // Step 2. collect mock value + DenseMap op2MockVal; + for (auto mockData : mockDataItems) + op2MockVal[mockData.mockOp] = mockData.mockVal; + + // Step 3. get offset result + DenseMap> op2OffsetVal; + for (auto it = opChain.rbegin(), eit = opChain.rend(); it != eit; ++it) { + if (op2OffsetVal.find(*it) == op2OffsetVal.end()) { + TypeSwitch(*it) + .Case([&](auto indexCastOp) { + auto mockVal = op2MockVal[indexCastOp]; + op2OffsetVal[indexCastOp] = + indexCastOpCalFunc(indexCastOp, op2OffsetVal, mockVal); + }) + .Case([&](auto threadIdOp) { + auto mockVal = op2MockVal[threadIdOp]; + op2OffsetVal[threadIdOp] = + threadIdOpCalFunc(threadIdOp, op2OffsetVal, mockVal); + }) + .Case([&](auto getProgramIdOp) { + auto mockVal = op2MockVal[getProgramIdOp]; + op2OffsetVal[getProgramIdOp] = + getProgramIdOpCalFunc(getProgramIdOp, op2OffsetVal, mockVal); + }) + .Case([&](auto xpuGm2lmOp) { + auto mockVal = op2MockVal[xpuGm2lmOp]; + op2OffsetVal[xpuGm2lmOp] = + xpuGm2lmOpCalFunc(xpuGm2lmOp, op2OffsetVal, mockVal); + }) + .Case([&](auto xpuLoadOp) { + auto mockVal = op2MockVal[xpuLoadOp]; + op2OffsetVal[xpuLoadOp] = + xpuLoadOpCalFunc(xpuLoadOp, op2OffsetVal, mockVal); + }) + .Case([&](auto constOp) { + op2OffsetVal[constOp] = constOpCalFunc(constOp, op2OffsetVal); + }) + .Case([&](auto unaryOp) { + op2OffsetVal[unaryOp] = unaryOpCalFunc(unaryOp, op2OffsetVal); + }) + .Case([&](auto binOp) { + op2OffsetVal[binOp] = binOpCalFunc(binOp, op2OffsetVal); + }) + .Case([&](auto makeRangeOp) { + op2OffsetVal[makeRangeOp] = + makeRangeOpCalFunc(makeRangeOp, op2OffsetVal); + }) + .Case([&](auto splatOp) { + if (hasDynamicInput(splatOp)) { + findUnsupportedOp = true; + return; + } + op2OffsetVal[splatOp] = splatOpCalFunc(splatOp, op2OffsetVal); + }) + .Case([&](auto expandDimsOp) { + op2OffsetVal[expandDimsOp] = + expandDimsOpCalFunc(expandDimsOp, op2OffsetVal); + }) + .Case([&](auto broadcastOp) { + if (hasDynamicInput(broadcastOp)) { + findUnsupportedOp = true; + return; + } + op2OffsetVal[broadcastOp] = + broadcastOpCalFunc(broadcastOp, op2OffsetVal); + }) + .Case([&](auto convertLayoutOp) { + op2OffsetVal[convertLayoutOp] = + xpuConvertLayoutOpCalFunc(convertLayoutOp, op2OffsetVal); + }) + .Default([&](auto &op) { + findUnsupportedOp = true; + LLVM_DEBUG(llvm::dbgs() + << "[OffsetState]: Unsupported Operation Type: " + << op->getName().getStringRef() + << ". Return OffsetState::Unknown for Protection.\n"); + }); + if (findUnsupportedOp) { + return {}; + } + } + } + + assert(op2OffsetVal.find(offsetDefineOp) != op2OffsetVal.end() && + "Operands Must Be Present in op2OffsetVal Map\n"); + + SmallVector res = op2OffsetVal[offsetDefineOp]; + return res; + } + + void getAllOffset( + std::unordered_map> &allOffsetResults, + SmallVector &sortedTokens, + SmallVector &mockDataItems, size_t curOpIndex, + std::string token, Operation *offsetDefineOp, + const SetVector &opChain) { + + if (curOpIndex == mockDataItems.size()) { + sortedTokens.emplace_back(token); + allOffsetResults[token] = + getOffset(opChain, offsetDefineOp, mockDataItems); + return; + } + + for (int val : mockDataItems[curOpIndex].mockVals) { + mockDataItems[curOpIndex].mockVal = val; + getAllOffset(allOffsetResults, sortedTokens, mockDataItems, + curOpIndex + 1, + token + mockDataItems[curOpIndex].getToken() + "-", + offsetDefineOp, opChain); + } + } + + OffsetState memoryStateTransfer(const OffsetState &state1, + const OffsetState &state2) { + auto state1_it = offsetstate_transition_table.find(state1); + if (state1_it != offsetstate_transition_table.end()) { + auto state2_it = state1_it->second.find(state2); + if (state2_it != state1_it->second.end()) { + return state2_it->second; + } + } + llvm_unreachable("Invalid OffsetState Transition"); + return OffsetState::Unknown; + } + + OffsetState checkOffset(const SmallVector &res, + const unsigned &numElems) { + + auto isEquallyStride = [](const SmallVector &res, + size_t numElems) -> bool { + int step = res[1] - res[0]; + for (size_t start = 0; start < res.size(); start += numElems) { + for (size_t i = 2; i < numElems; ++i) { + if (start + i < res.size()) + if ((res[start + i] - res[start + i - 1]) != step) { + return false; + } + } + } + return true; + }; + + if (numElems == 1 || res.empty()) + return OffsetState::Unknown; + + SmallVector offsets(res.size()); + bool multiBank = false; + for (size_t start = 0; start < res.size(); start += numElems) { + for (size_t i = 0; i < numElems; ++i) { + if (start + i < res.size()) { + offsets[start + i] = res[start + i] - res[start]; + // check online + if (offsets[start + i] >= numElems) { + multiBank = true; + } + if (offsets[start + i] < 0) { + LLVM_DEBUG(llvm::dbgs() + << "[OffsetState]: The 0th Address Is Not the Beginning " + "of the Bank.\n"); + fixedStride = -1; + return OffsetState::Unknown; + } + } + } + } + + if (multiBank) { + if (isEquallyStride(res, numElems)) { + fixedStride = res[1] - res[0]; + LLVM_DEBUG(llvm::dbgs() + << "[OffsetState]: Addresses Are Not in the Same Bank, " + "but Have Fixed Stride-" + << fixedStride << ".\n"); + return OffsetState::Unknown; + } else { + LLVM_DEBUG(llvm::dbgs() + << "[OffsetState]: Addresses Are Not in the Same Bank.\n"); + fixedStride = -1; + return OffsetState::Unknown; + } + } + + bool hasDiscreateSame = false; + bool hasDiscrete = false; + for (size_t start = 0; start < res.size(); start += numElems) { + SmallVector coreOffset; + for (size_t i = 0; i < numElems; ++i) { + if (start + i < res.size()) { + if (offsets[start + i] != 0 && offsets[start + i] != i) + return OffsetState::Discrete; + + coreOffset.push_back(offsets[start + i]); + } + } + auto allZeros = std::all_of(coreOffset.begin(), coreOffset.end(), + [](int num) { return num == 0; }); + if (allZeros) { + hasDiscreateSame = true; + } else { + hasDiscrete = true; + } + } + + if (hasDiscreateSame && hasDiscrete) + return OffsetState::Discrete; + + fixedStride = res[1] - res[0]; + return offsets[1] == 1 ? OffsetState::Continuous + : OffsetState::DiscreteSame; + } + + void lrieDiscreteSameAnalysis(const SmallVector &res) { + analysisFlag = true; + + if (res.size() == 0) + return; + + SmallVector seq_lens; + + auto getSeqLens = [&]() { + int cur_val = res[0]; + int count = 1; + + for (size_t i = 1; i < res.size(); ++i) { + if (res[i] == cur_val) { + ++count; + } else { + seq_lens.push_back(count); + cur_val = res[i]; + count = 1; + } + } + seq_lens.push_back(count); + }; + // Step 1. Get All Same Sequence Lengths + getSeqLens(); + + // for (int i = 0; i < seq_lens.size(); ++i) { + // LLVM_DEBUG(llvm::dbgs() << "seq_lens[" << i << "]: " << seq_lens[i] << + // "\n"); + // } + + // Step 2. Get Sequence Lengths' GCD + auto gcd = [](int a, int b) { + while (b != 0) { + int t = b; + b = a % b; + a = t; + } + return a; + }; + + lrie = seq_lens[0]; + for (size_t i = 1; i < seq_lens.size(); i++) { + lrie = gcd(lrie, seq_lens[i]); + if (lrie == 1) + break; + } + + if (dumpFlag && lrie > 1) { + LLVM_DEBUG(llvm::dbgs() << "lrie: " << lrie << "\n"); + } + + if (lrie > 256 && lrie % 256 == 0) + lrie = 256; // Control LM BufferSize to Avoid Memory Exceed + else + assert("lrie is not the multiple of 256"); + + return; + } + + OffsetState locallyContinity( + std::unordered_map> &allOffsetResults, + SmallVector &sortedTokens, const int64_t numElems) { + + auto isLocallyContinuous = [](const SmallVector &res, + const int64_t numElems, int64_t &rowLen, + int64_t &rowStride) -> bool { + if (res.size() < 2) { + return false; + } + int64_t step = res[1] - res[0]; + if (step != 1) { + return false; + } + int64_t currRowLen = 2; + int64_t currRowStride = 1; + bool isFirst = true; + for (int64_t i = 2; i < res.size(); i++) { + if (res[i] - res[i - 1] == 1) { + currRowLen++; + } else { + currRowStride = res[i] - res[i - 1] + currRowLen - 1; + if (currRowStride < 0) { + return false; + } + + if (isFirst) { + rowLen = currRowLen; + rowStride = currRowStride; + isFirst = false; + } else { + if (rowStride == -1 && currRowStride % rowLen != 0) { + rowLen = -1; + return false; + } + if (currRowStride != rowStride) { + rowStride = -1; + } + + auto gcd = [](int64_t a, int64_t b) { + while (b != 0) { + int t = b; + b = a % b; + a = t; + } + return a; + }; + rowLen = gcd(rowLen, currRowLen); + } + currRowLen = 1; + currRowStride = 1; + } + } + bool _isLocallyContinuous = false; + if (rowLen >= 2 && rowLen % numElems != 0) { + _isLocallyContinuous = true; + } + return _isLocallyContinuous; + }; + + SmallVector _allOffsets; + for (auto token : sortedTokens) { + for (auto offset : allOffsetResults[token]) { + _allOffsets.emplace_back(offset); + } + } + if (isLocallyContinuous(_allOffsets, numElems, rowLen, rowStride)) { + LLVM_DEBUG( + llvm::dbgs() + << "[OffsetState]: The Address is Locally Continuous, rowLen is " + << rowLen << ", rowStride is " << rowStride << "\n"); + return OffsetState::LocallyContinuous; + } + return OffsetState::Unknown; + } + + // -1 for Unknown + // 0 for DiscreteSame + // 1 for Continuous + // 2 for Discrete + // 3 for LocallyContinuous + template ::value, bool> = true> + OffsetState getOffsetState(T memoryOp) { + Value ptr = memoryOp.getPtr(); + Operation *ptrOp = ptr.getDefiningOp(); + if (!ptrOp) { + // Case 1. inptr -> gm2lm + return OffsetState::Continuous; + } else if (isa(ptrOp) || isa(ptrOp)) { + // Case 2. inptr -> cal -> addptr -> bitcast -> gm2lm + // Case 3. inptr -> cal -> addptr -> splat -> gm2lm + Value prevVal = ptrOp->getOperand(0); + ptrOp = prevVal.getDefiningOp(); + if (!ptrOp) + return OffsetState::Continuous; + } else if (!isa(ptrOp)) { + // Case 4. inptr -> unknown -> gm2lm + LLVM_DEBUG( + llvm::dbgs() + << "[OffsetAnalysis]: Unsupported Offset Calculation Pattern\n"); + return OffsetState::Unknown; + } + + if (!isa(ptrOp)) { + // Case 5. inptr -> bitcast/splat -> unknown -> gm2lm + LLVM_DEBUG(llvm::dbgs() + << "[OffsetAnalysis]: Unsupported Offset Calculation " + "Pattern with Bitcast/Splat\n"); + return OffsetState::Unknown; + } + + // Case Normal. inptr -> cal -> addptr -> gm2lm + Value offset = ptrOp->getOperand(1); + Operation *offsetDefineOp = offset.getDefiningOp(); + + // Step 1. Get Offset opChain + SetVector opChain; + getOpChainBwdBFS(opChain, offsetDefineOp); + opChain = sortByLine(opChain); + + if (dumpFlag) + LLVM_DEBUG(dumpOpChain(opChain)); + + // Step 2. Cal offsetMock + // Step 2.1. Get All Mockdata Items + // [arith::IndexCastOp, mlir::gpu::ThreadIdOp, triton::GetProgramIdOp] + SmallVector mockDataItems = getMockDataItems(opChain); + + // Step 2.2. Check && Pruning Before getAllOffset + auto checkMockDataList = [](SmallVector &mockDataItems) { + for (auto mockData : mockDataItems) { + if (mockData.mockVals.empty()) + return false; + } + return true; + }; + if (!checkMockDataList(mockDataItems)) + return OffsetState::Unknown; + + // Step 2.3. Get allOffsetResults With Own Token + // token(key)-offsetResult(value) + std::unordered_map> allOffsetResults; + SmallVector sortedTokens; + getAllOffset(allOffsetResults, sortedTokens, mockDataItems, 0, "", + offsetDefineOp, opChain); + if (findUnsupportedOp) { + LLVM_DEBUG(llvm::dbgs() + << "Operands Must Be Defined By operations. " + "Check If Operand Is an Input Argument.\n" + "Return OffsetState::Unknown for Protection.\n"); + } + + // Step 3. Get OffsetState by Offset + // Step 3.1. calculate numElems + bool atomicSim = memoryOp.getAtomicSim(); + unsigned numElems = 1; + if (auto offsetTy = mlir::dyn_cast(offset.getType())) { + auto offsetShape = offsetTy.getShape(); + unsigned rank = offsetShape.size(); + auto gEncoding = + mlir::cast(offsetTy.getEncoding()); + auto sizePerCore = gEncoding.getSizePerCore(); + auto coresPerGroup = gEncoding.getCoresPerGroup(); + auto groupsPerCluster = gEncoding.getGroupsPerCluster(); + + if (atomicSim) { + numElems = lrie > 1 ? lrie : 1; + } else { + numElems = product(sizePerCore); + } + + // We Can Only Check 1st Row Ptr Offset While Small ColSize Opt Hit + if (memoryOp.getTensorColSize() != -1) + numElems = std::min(numElems, (unsigned)offsetShape[rank - 1]); + } + + // Step 3.2. Calculate all the offsetStates for each token, and then + // combine these offsetStates through memoryStateTransfer. + OffsetState memoryState = + checkOffset(allOffsetResults[sortedTokens[0]], numElems); + std::unordered_map allOffsetStateResult; + for (auto token : sortedTokens) { + auto offsetMock = allOffsetResults[token]; + allOffsetStateResult[token] = checkOffset(offsetMock, numElems); + if (dumpFlag) + LLVM_DEBUG(llvm::dbgs() << "\n" + << token << ":" << allOffsetStateResult[token] + << ", offsetMock.size() = " << offsetMock.size() + << ", numElems = " << numElems << "\n"); + memoryState = + memoryStateTransfer(memoryState, allOffsetStateResult[token]); + } + + if (atomicSim && !analysisFlag) { + // analysis only once + lrieDiscreteSameAnalysis(allOffsetResults[sortedTokens[0]]); + if (lrie > 1) { + fixedStride = 0; + return OffsetState::DiscreteSame; + } + } + + // Step 4. Optimize to OffsetState::LocallyContinuous + if (memoryState == OffsetState::Unknown && + memoryOp.getTensorColSize() == -1) { + return locallyContinity(allOffsetResults, sortedTokens, numElems); + } + + return memoryState; + } + + void runOnOperation() override { + mlir::ModuleOp m = getOperation(); + mlir::ModuleOp mod = getOperation(); + + mod.walk([&](mlir::Operation *op) { + TypeSwitch(op).Case( + [&](auto memoryOp) { legalizeOffset(memoryOp); }); + }); + + mod.walk([&](mlir::Operation *op) { op2Line[op] = line++; }); + + mod.walk([&](triton::xpu::GM2LMOp gm2lmOp) { + if (dumpFlag) + LLVM_DEBUG(llvm::dbgs() + << "\n=======================================\n"); + OffsetState offsetState = getOffsetState(gm2lmOp); + if (dumpFlag) { + LLVM_DEBUG(llvm::dbgs() + << "\n" + << gm2lmOp << "\n[OffsetState]: " << offsetState + << "\n=======================================\n"); + } + // In case `fixedStride` being modified by cluster(s) whose + // OffsetState is Continuous. + if (offsetState == OffsetState::Discrete) { + fixedStride = -1; + } else if (offsetState == OffsetState::Unknown && + (fixedStride == 1 | fixedStride == 0)) { + // Multi Memory State Like (Unknown & Continuous) + fixedStride = -1; + } + + OpBuilder builder(gm2lmOp); + int32_t offsetStateInt = static_cast(offsetState); + gm2lmOp->setAttr("offsetState", + builder.getSI32IntegerAttr(offsetStateInt)); + gm2lmOp->setAttr("fixedStride", builder.getSI32IntegerAttr(fixedStride)); + gm2lmOp->setAttr("rowLen", builder.getIntegerAttr( + builder.getIntegerType(64, true), rowLen)); + gm2lmOp->setAttr( + "rowStride", + builder.getIntegerAttr(builder.getIntegerType(64, true), rowStride)); + gm2lmOp->setAttr("lrie", builder.getSI32IntegerAttr(lrie)); + auto loadOp = cast(gm2lmOp->getNextNode()); + loadOp->setOperand(0, gm2lmOp); + loadOp->setAttr("stride", builder.getSI32IntegerAttr(fixedStride)); + loadOp->setAttr("isDiscrete", builder.getBoolAttr(offsetState == + OffsetState::Discrete)); + fixedStride = -1; // reset + rowLen = -1; + rowStride = -1; + findUnsupportedOp = false; + }); + + mod.walk([&](triton::xpu::LM2GMOp lm2gmOp) { + if (dumpFlag) + LLVM_DEBUG(llvm::dbgs() + << "\n=======================================\n"); + OffsetState offsetState = getOffsetState(lm2gmOp); + // Only able to handle continuous and unknown cases. + if (offsetState != OffsetState::Continuous && + offsetState != OffsetState::LocallyContinuous) { + offsetState = OffsetState::Unknown; + } + if (dumpFlag) { + LLVM_DEBUG(llvm::dbgs() + << "\n" + << lm2gmOp << "\n[OffsetState]: " << offsetState + << "\n=======================================\n"); + } + OpBuilder builder(lm2gmOp); + int32_t offsetStateInt = static_cast(offsetState); + lm2gmOp->setAttr("offsetState", + builder.getSI32IntegerAttr(offsetStateInt)); + lm2gmOp->setAttr("rowLen", builder.getIntegerAttr( + builder.getIntegerType(64, true), rowLen)); + lm2gmOp->setAttr( + "rowStride", + builder.getIntegerAttr(builder.getIntegerType(64, true), rowStride)); + findUnsupportedOp = false; // reset + rowLen = -1; + rowStride = -1; + }); + + mod.walk([&](triton::xpu::SM2GMOp sm2gmOp) { + // TODO: Deal with other offset states in SM2GM, especially 2D case in + // reduction + OffsetState offsetState = OffsetState::Continuous; + if (dumpFlag) { + LLVM_DEBUG(llvm::dbgs() + << "\n=======================================\n" + << sm2gmOp << "\n[OffsetState]: " << offsetState + << "\n=======================================\n"); + } + + OpBuilder builder(sm2gmOp); + int32_t offsetStateInt = static_cast(offsetState); + sm2gmOp->setAttr("offsetState", + builder.getSI32IntegerAttr(offsetStateInt)); + findUnsupportedOp = false; // reset + }); + } + +private: + llvm::DenseMap op2Line; + int32_t line = 0; + int32_t fixedStride = -1; + int64_t rowLen = -1; + int64_t rowStride = -1; + int32_t lrie = -1; // Longest Run Of Identical Elements + bool analysisFlag = false; + bool findUnsupportedOp = false; + + /* OffsetState Transition Table + * + * +---------------+----------+----------+--------------+-----------+ + * | | Unknown | Discrete | Discrete Same| Continuous| + * +---------------+----------+----------+--------------+-----------+ + * | Unknown | Unknown | Unknown | Unknown | Unknown | + * +---------------+----------+----------+--------------+-----------+ + * | Discrete | Unknown | Discrete | Discrete | Discrete | + * +---------------+----------+----------+--------------+-----------+ + * | Discrete Same | Unknown | Discrete | Discrete Same| Discrete | + * +---------------+----------+----------+--------------+-----------+ + * | Continuous | Unknown | Discrete | Discrete | Continuous| + * +---------------+----------+----------+--------------+-----------+ + * + * Unknown Case: + * 1. The Offset Mock Result's Size <= 1 + * 2. The 0th Address Is Not the Beginning of the Bank. + * 3. Addresses Are Not in the Same Bank. + * 4. MockDataList Check Failed (Empty MockData) + */ + + OffsetStateTransitionTable offsetstate_transition_table = { + {OffsetState::Unknown, + { + {OffsetState::Unknown, OffsetState::Unknown}, + {OffsetState::Discrete, OffsetState::Unknown}, + {OffsetState::DiscreteSame, OffsetState::Unknown}, + {OffsetState::Continuous, OffsetState::Unknown}, + {OffsetState::LocallyContinuous, OffsetState::Unknown}, + }}, + {OffsetState::Discrete, + { + {OffsetState::Unknown, OffsetState::Unknown}, + {OffsetState::Discrete, OffsetState::Discrete}, + {OffsetState::DiscreteSame, OffsetState::Discrete}, + {OffsetState::Continuous, OffsetState::Discrete}, + {OffsetState::LocallyContinuous, OffsetState::Discrete}, + }}, + {OffsetState::DiscreteSame, + { + {OffsetState::Unknown, OffsetState::Unknown}, + {OffsetState::Discrete, OffsetState::Discrete}, + {OffsetState::DiscreteSame, OffsetState::DiscreteSame}, + {OffsetState::Continuous, OffsetState::Discrete}, + {OffsetState::LocallyContinuous, OffsetState::Discrete}, + }}, + {OffsetState::Continuous, + { + {OffsetState::Unknown, OffsetState::Unknown}, + {OffsetState::Discrete, OffsetState::Discrete}, + {OffsetState::DiscreteSame, OffsetState::Discrete}, + {OffsetState::Continuous, OffsetState::Continuous}, + {OffsetState::LocallyContinuous, OffsetState::LocallyContinuous}, + }}, + {OffsetState::LocallyContinuous, + { + {OffsetState::Unknown, OffsetState::Unknown}, + {OffsetState::Discrete, OffsetState::Discrete}, + {OffsetState::DiscreteSame, OffsetState::Discrete}, + {OffsetState::Continuous, OffsetState::LocallyContinuous}, + {OffsetState::LocallyContinuous, OffsetState::LocallyContinuous}, + }}}; +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/OtherSim.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/OtherSim.cpp new file mode 100644 index 000000000..49a5b5c5c --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/OtherSim.cpp @@ -0,0 +1,100 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO: Pass Description +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUOTHERSIM +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPUOtherSim + : public impl::TritonXPUOtherSimBase { + +public: + using impl::TritonXPUOtherSimBase::TritonXPUOtherSimBase; + + void runOnOperation() override { + mlir::ModuleOp m = getOperation(); + bool skip = true; + m.walk([&](triton::xpu::ReduceOp reduceOp) { skip = false; }); + if (skip) { + return; + } + + m.walk([&](triton::xpu::LoadOp loadOp) { + auto loc = loadOp.getLoc(); + OpBuilder builder(loadOp); + if (auto other = loadOp.getOther()) { + unsigned numElems = getTotalElemsPerThread(other.getType()); + Type elemTy = getElementTypeOrSelf(other.getType()); + unsigned vecSize = 1u; + if (auto vecType = dyn_cast(elemTy)) { + vecSize = vecType.getNumElements(); + } + int64_t _bufLen = numElems * vecSize; + Block *block = loadOp->getBlock(); + auto gm2lmOp = loadOp.getPtr().getDefiningOp(); + auto allocaOp = + gm2lmOp.getBufPtr().getDefiningOp(); + // Create If(len < bufLen) + auto len = gm2lmOp.getLen(); + auto lenElemTy = getElementTypeOrSelf(len); + auto extractLen = builder.create( + loc, lenElemTy, builder.getI32IntegerAttr(0), len); + auto bufLen = builder.create( + loc, _bufLen, lenElemTy.getIntOrFloatBitWidth()); + auto sltBufLen = builder.create( + loc, builder.getI1Type(), arith::CmpIPredicate::slt, extractLen, + bufLen); + auto ifOp = builder.create(loc, sltBufLen, + /*withElseRegion=*/false); + ifOp->moveBefore(gm2lmOp); + extractLen->moveBefore(ifOp); + bufLen->moveBefore(ifOp); + sltBufLen->moveBefore(ifOp); + // Create Constant/Store + if (auto otherDef = other.getDefiningOp()) { + if (auto constOp = dyn_cast(otherDef)) { + auto newConstOp = builder.create( + loc, constOp.getType(), constOp.getValue()); + auto storeOp = builder.create( + loc, allocaOp, newConstOp, Value(), Value(), -1, false); + newConstOp->moveBefore(ifOp.thenBlock()->getTerminator()); + storeOp->moveBefore(ifOp.thenBlock()->getTerminator()); + } else if (auto vconstOp = + dyn_cast(otherDef)) { + auto newVConstOp = builder.create( + loc, vconstOp.getType(), vconstOp.getValue()); + auto storeOp = builder.create( + loc, allocaOp, newVConstOp, Value(), Value(), -1, false); + newVConstOp->moveBefore(ifOp.thenBlock()->getTerminator()); + storeOp->moveBefore(ifOp.thenBlock()->getTerminator()); + } else { + auto storeOp = builder.create( + loc, allocaOp, otherDef->getResults()[0], Value(), Value(), -1, + false); + storeOp->moveBefore(ifOp.thenBlock()->getTerminator()); + } + } else { + auto storeOp = builder.create( + loc, allocaOp, other, Value(), Value(), -1, false); + storeOp->moveBefore(ifOp.thenBlock()->getTerminator()); + } + } + }); + } +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/StoreControl.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/StoreControl.cpp new file mode 100644 index 000000000..ea53e8a99 --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/StoreControl.cpp @@ -0,0 +1,169 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO[dyq]: Pass Description +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +#define DEBUG_TYPE "tritonxpu-store-control" + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUSTORECONTROL +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPUStoreControl + : public impl::TritonXPUStoreControlBase { + + using impl::TritonXPUStoreControlBase< + TritonXPUStoreControl>::TritonXPUStoreControlBase; + + void getGroupInfo(triton::xpu::ReduceOp &reduceOp, int64_t &groupSize, + int64_t &groupNum) { + auto types = reduceOp.getOperandTypes(); + assert(types.size() > 1); + for (int i = 0; i < types.size() - 1; ++i) { + if (auto tensorType = dyn_cast(types[i])) { + auto clusterEncoding = + cast(tensorType.getEncoding()); + if (i == 0) { + groupSize = product(clusterEncoding.getCoresPerGroup()); + groupNum = product(clusterEncoding.getGroupsPerCluster()); + } else { + assert(groupSize == product(clusterEncoding.getCoresPerGroup())); + assert(groupNum == product(clusterEncoding.getGroupsPerCluster())); + } + } + } + } + + bool findDefChain(Operation *startOp, Operation *endOp, + SetVector &chain, + SetVector &visitedOps) { + if (!startOp) { + return false; + } + chain.insert(startOp); + if (startOp == endOp) { + return true; + } + for (auto operand : startOp->getOperands()) { + auto defOp = operand.getDefiningOp(); + if (!visitedOps.count(defOp)) { + if (findDefChain(defOp, endOp, chain, visitedOps)) { + return true; + } + } + } + chain.pop_back(); + return false; + } + + bool hasBroadcast(Operation *startOp, Operation *endOp) { + SetVector chain; + SetVector visitedOps; + findDefChain(startOp, endOp, chain, visitedOps); + for (auto op : chain) { + if (isa(op)) { + return true; + } + } + return false; + } + + bool isSameSize(triton::xpu::ReduceOp &reduceOp, + triton::xpu::StoreOp storeOp) { + llvm::ArrayRef redResShape = {1}; + auto redRes = reduceOp.getResult()[0]; + if (auto redResTy = dyn_cast(redRes.getType())) { + auto sliceEncoding = + cast(redResTy.getEncoding()); + redResShape = redResTy.getShape(); + } + llvm::ArrayRef storeValShape = {1}; + auto storeVal = storeOp.getValue(); + if (auto storeValTy = dyn_cast(storeVal.getType())) { + storeValShape = storeValTy.getShape(); + } + if (product(redResShape) == product(storeValShape)) { + return true; + } + return false; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + DenseMap> ifBodyMap; + m.walk([&](triton::xpu::StoreOp storeOp) { + OpBuilder builder(storeOp); + auto loc = storeOp->getLoc(); + SetVector ifBodyOps; + if (auto op = findDefOpBwd(storeOp.getValue())) { + auto reduceOp = cast(op); + ReduceOpHelper help(reduceOp); + auto srcShape = help.getSrcShape(); + if (srcShape.size() > 1 && reduceOp.getAxis() != srcShape.size() - 1) { + return; + } + if (hasBroadcast(storeOp, op) || !isSameSize(reduceOp, storeOp)) { + return; + } + auto allocaOp = storeOp.getPtr().getDefiningOp(); + for (Operation *user : allocaOp->getUsers()) { + if (auto lm2gmOp = dyn_cast(user)) { + ifBodyOps.insert(allocaOp); + ifBodyOps.insert(storeOp); + ifBodyOps.insert(lm2gmOp); + } + } + if (ifBodyOps.empty()) { + return; + } + int64_t _groupSize = 64, _groupNum = 1; + getGroupInfo(reduceOp, _groupSize, _groupNum); + auto coreId = + builder.create(loc, builder.getI32Type()); + auto groupSize = + builder.create(loc, _groupSize, 32); + auto coreIdInsideGroup = builder.create( + loc, builder.getI32Type(), coreId, groupSize); + auto zero = builder.create(loc, 0, 32); + auto isCoreId0InsideGroup = builder.create( + loc, builder.getI1Type(), arith::CmpIPredicate::eq, + coreIdInsideGroup, zero); + int64_t _usedCoreNum = _groupSize * _groupNum; + auto usedCoreNum = + builder.create(loc, _usedCoreNum, 32); + auto sltUsedCoreNum = builder.create( + loc, builder.getI1Type(), arith::CmpIPredicate::slt, coreId, + usedCoreNum); + auto cond = builder.create( + loc, builder.getI1Type(), isCoreId0InsideGroup, sltUsedCoreNum); + auto ifOp = builder.create(loc, cond, + /*withElseRegion=*/false); + ifBodyMap[ifOp] = ifBodyOps; + LLVM_DEBUG(llvm::dbgs() << "[StoreControl] GroupSize: " << _groupSize + << ", usedCoreNum: " << _usedCoreNum << "\n"); + } + }); + for (auto &pair : ifBodyMap) { + auto ifOp = cast(pair.first); + auto ifBodyOps = pair.second; + for (auto op : ifBodyOps) { + op->moveBefore(ifOp.thenBlock()->getTerminator()); + } + } + } +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/UnrollControl.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/UnrollControl.cpp new file mode 100644 index 000000000..1e64687ab --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/UnrollControl.cpp @@ -0,0 +1,1101 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include "mlir/IR/IRMapping.h" +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" + +#define DEBUG_TYPE "tritonxpu-unroll-control" + +namespace mlir { +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUUNROLLCONTROL +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +#define COMBINE_OP \ + arith::AddFOp, arith::MulFOp, arith::MaxNumFOp, arith::MinNumFOp, \ + arith::OrIOp, arith::XOrIOp, arith::AndIOp + +template struct COMOp; + +#define COMOP(SrcType, DstType) \ + template <> struct COMOp { \ + typedef DstType type; \ + }; + +COMOP(arith::AddFOp, triton::xpu::VvaddFOp); +COMOP(arith::MulFOp, triton::xpu::VvmulFOp); +COMOP(arith::MaxNumFOp, triton::xpu::VvmaxNumFOp); +COMOP(arith::MinNumFOp, triton::xpu::VvminNumFOp); +COMOP(arith::OrIOp, triton::xpu::VvorIOp); +COMOP(arith::XOrIOp, triton::xpu::VvxorIOp); +COMOP(arith::AndIOp, triton::xpu::VvandIOp); + +struct TritonXPUUnrollControl + : public impl::TritonXPUUnrollControlBase { + +public: + using impl::TritonXPUUnrollControlBase< + TritonXPUUnrollControl>::TritonXPUUnrollControlBase; + + template static decltype(auto) createCombineVectorizedOp(T op) { + OpBuilder builder(op); + return builder.create::type>( + op.getLoc(), op.getResult().getType(), op.getLhs(), op.getRhs()); + } + + void processOpVecTy(ModuleOp &m) { + m.walk([&](Operation *op) { + TypeSwitch(op).Case([&](auto combineOp) { + if (auto tensorTy = + dyn_cast(combineOp.getResult().getType())) { + if (isa(getElementTypeOrSelf(tensorTy))) { + auto vecOp = createCombineVectorizedOp(combineOp); + combineOp.replaceAllUsesWith(vecOp.getResult()); + combineOp.erase(); + } + } + }); + }); + } + + bool isAncestorOf(Operation *op1, Operation *op2) { + Block *block1 = op1->getBlock(); + for (Block *block2 = op2->getBlock(); block2 != nullptr;) { + if (block1 == block2) { + return true; + } + Operation *parentOp = block2->getParentOp(); + if (parentOp == nullptr) { + break; + } + block2 = parentOp->getBlock(); + } + return false; + } + + void getUnrollTree(Operation *op, SetVector &opTree, + SetVector &visitedOps, + SetVector &excludeChainOps, Operation *rootOp, + bool isTop2Bottom = true) { + if (!op || visitedOps.count(op) || + isa(op)) { + return; + } + + visitedOps.insert(op); + if (isAncestorOf(op, rootOp) || op->getBlock() == rootOp->getBlock()) { + opTree.insert(op); + } + + // Search definedOp of childOp + if (auto ifOp = dyn_cast(op)) { + // Then + auto &ifThenBlock = ifOp.getThenRegion().front(); + for (auto &inBlockOp : ifThenBlock) { + getUnrollTree(&inBlockOp, opTree, visitedOps, excludeChainOps, rootOp, + isTop2Bottom); + } + // Else + auto &ifElseRegion = ifOp.getElseRegion(); + if (!ifElseRegion.empty()) { + auto &ifElseBlock = ifElseRegion.front(); + for (auto &inBlockOp : ifElseBlock) { + getUnrollTree(&inBlockOp, opTree, visitedOps, excludeChainOps, rootOp, + isTop2Bottom); + } + } + } + + // from bottom to top + if (isa( + op)) { + } else if (auto storeOp = dyn_cast(op)) { + auto defOp = storeOp.getValue().getDefiningOp(); + getUnrollTree(defOp, opTree, visitedOps, excludeChainOps, rootOp, + isTop2Bottom); + } else { + for (auto operand : op->getOperands()) { + auto defOp = operand.getDefiningOp(); + getUnrollTree(defOp, opTree, visitedOps, excludeChainOps, rootOp, + isTop2Bottom); + } + } + + if (isTop2Bottom) { + // from top to bottom + if (excludeChainOps.count(op) || + isa(op)) { + } else { + for (auto userOp : op->getUsers()) { + getUnrollTree(userOp, opTree, visitedOps, excludeChainOps, rootOp, + isTop2Bottom); + } + } + } + return; + } + + int64_t getNumCol(Type type) { + if (auto tensorTy = dyn_cast(type)) + return tensorTy.getShape().back(); + else + return 1; + } + + int64_t getNumInVector(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getNumElements(); + else + return 1; + } + + int64_t getnumUnroll(Type type) { + int64_t numUnroll = numUnrollPerCore * 64; + if (auto tensorTy = dyn_cast(type)) { + auto clusterEncoding = + cast(tensorTy.getEncoding()); + numUnroll = numUnrollPerCore * clusterEncoding.getCoresPerGroup().back(); + } + return numUnroll; + } + + Type createPointerType(Type type, int64_t vecSize) { + if (auto tensorType = dyn_cast(type)) { + Type elemType = getElementTypeOrSelf(tensorType); + Type elemScalarType = getElementTypeOrSelf(elemType); + Type pointerType = triton::PointerType::get(elemScalarType, 0); + auto shape = tensorType.getShape().vec(); + shape[shape.size() - 1] = shape.back() * vecSize; + return RankedTensorType::get(shape, pointerType, + tensorType.getEncoding()); + } else { + return triton::PointerType::get(type, 0); + } + } + + triton::xpu::ClusterLayoutAttr + createEncoding(MLIRContext *context, triton::xpu::ClusterLayoutAttr &encoding, + int64_t iterNum) const { + auto sizePerCore = encoding.getSizePerCore().vec(); + sizePerCore[sizePerCore.size() - 1] = + ceil(sizePerCore.back(), iterNum); + auto newEncoding = triton::xpu::ClusterLayoutAttr::get( + context, sizePerCore, encoding.getCoresPerGroup(), + encoding.getGroupsPerCluster(), encoding.getOrder(), + encoding.getIsReduceOpt()); + return newEncoding; + } + + void setTensorType(MLIRContext *context, Operation *op, int64_t iterNum, + bool isOuter, bool sliceShape = true) const { + for (auto [i, resTy] : llvm::enumerate(op->getResultTypes())) { + if (isa(resTy) && !isOuter) { + auto tensorTy = cast(resTy); + auto shape = tensorTy.getShape().vec(); + if (sliceShape) { + shape[shape.size() - 1] = ceil(shape.back(), iterNum); + } + RankedTensorType controledTensorTy; + if (auto sliceEncoding = dyn_cast( + tensorTy.getEncoding())) { + auto clusterEncoding = + cast(sliceEncoding.getParent()); + auto newClusterEncoding = + createEncoding(context, clusterEncoding, iterNum); + auto newEncoding = triton::gpu::SliceEncodingAttr::get( + context, sliceEncoding.getDim(), newClusterEncoding); + controledTensorTy = RankedTensorType::get( + shape, tensorTy.getElementType(), newEncoding); + } else { + auto clusterEncoding = + cast(tensorTy.getEncoding()); + auto newClusterEncoding = + createEncoding(context, clusterEncoding, iterNum); + controledTensorTy = RankedTensorType::get( + shape, tensorTy.getElementType(), newClusterEncoding); + } + op->getResult(i).setType(controledTensorTy); + } + } + } + + triton::xpu::ExtractSliceOp + getExtractedOperand(MLIRContext *context, OpBuilder &builder, Location &loc, + mlir::Operation *op, unsigned operandIndex, + int64_t iterNum) const { + auto resTy = op->getOperand(operandIndex).getType(); + RankedTensorType tensorTy; + if (isa(resTy)) { + tensorTy = cast(resTy); + } + auto shape = tensorTy.getShape().vec(); + shape[shape.size() - 1] = ceil(shape.back(), iterNum); + auto clusterEncoding = + cast(tensorTy.getEncoding()); + auto newClusterEncoding = createEncoding(context, clusterEncoding, iterNum); + + RankedTensorType controledTensorTy = RankedTensorType::get( + shape, tensorTy.getElementType(), newClusterEncoding); + triton::xpu::ExtractSliceOp extractSliceOp = + builder.create( + loc, controledTensorTy, op->getOperand(operandIndex)); + return extractSliceOp; + } + + // Determine whether the operand has been hoisted + bool isOperandOperationInSameForBlock(mlir::Operation *op, + unsigned operandIndex) { + auto *parentOp = op->getParentOp(); + while (parentOp && !llvm::isa(parentOp)) { + parentOp = parentOp->getParentOp(); + } + if (!parentOp) + return false; + + auto forOp = llvm::cast(parentOp); + mlir::Value operand = op->getOperand(operandIndex); + if (auto blockArg = mlir::dyn_cast(operand)) { + mlir::Block *block = forOp.getBody()->front().getBlock(); + return blockArg.getOwner() == block; + } else { + mlir::Operation *definingOp = operand.getDefiningOp(); + if (definingOp) { + return definingOp->getBlock()->getParentOp() == forOp.getOperation(); + } + } + return false; + } + + void insertIndex(Operation *op, Value idxVar) { + OpBuilder builder(op); + auto operandSegmentSizesAttr = + op->getAttrOfType("operandSegmentSizes"); + SmallVector operandSegmentSizes( + operandSegmentSizesAttr.asArrayRef()); + // LoadOp: 0: ptr, 1: mask, 2: other, 3: index + // StoreOp: 0: ptr, 1: value, 2: mask, 3: index + // MakeRangeOp: 0: loopIndex, 1: unrollIndex + // InterleaveOp: 0: loopIndex, 1: unrollIndex + ++operandSegmentSizes[operandSegmentSizes.size() - 1]; + op->setAttr("operandSegmentSizes", + builder.getDenseI32ArrayAttr(operandSegmentSizes)); + op->insertOperands(op->getNumOperands(), {idxVar}); + } + + void getOuterChain(llvm::SetVector &allOpTree, + llvm::SetVector &outerChain) { + for (auto op : allOpTree) { + if (auto expandDimOp = dyn_cast(op)) { + auto src = expandDimOp.getSrc(); + auto result = expandDimOp.getResult(); + if (auto srcTy = dyn_cast(src.getType())) { + if (auto resTy = dyn_cast(result.getType())) { + if (expandDimOp.getAxis() == 1) { + getOpChainBwd(outerChain, expandDimOp); + outerChain.remove(expandDimOp); + } + } + } + } + if (auto broadcastOp = dyn_cast(op)) { + auto src = broadcastOp.getSrc(); + auto result = broadcastOp.getResult(); + if (auto srcTy = dyn_cast(src.getType())) { + if (auto resTy = dyn_cast(result.getType())) { + int64_t srcElemNum = 1; + if (auto vecTy = + dyn_cast(getElementTypeOrSelf(srcTy))) { + srcElemNum = vecTy.getNumElements(); + } + int64_t resElemNum = 1; + if (auto vecTy = + dyn_cast(getElementTypeOrSelf(resTy))) { + resElemNum = vecTy.getNumElements(); + } + auto srcShape = srcTy.getShape(); + auto resShape = resTy.getShape(); + int64_t srcInnerNum = srcElemNum * srcShape.back(); + int64_t resInnerNum = resElemNum * resShape.back(); + if (srcInnerNum != resInnerNum) { // unequal dim 1 shape means in + // the inner axis op chain + getOpChainBwd(outerChain, broadcastOp); + outerChain.remove(broadcastOp); + } + } + } + } + } + } + + void + getOuterChains(const SmallVector> &allOpTrees, + SmallVector> &outerChains) { + for (auto allOpTree : allOpTrees) { + SetVector outerChain; + getOuterChain(allOpTree, outerChain); + outerChains.emplace_back(outerChain); + } + } + + void getDAG(Operation *op, SetVector &visitedOps, + SmallVector> &unrollOpTrees, + SetVector &excludeChainOps, + bool isTop2Bottom = true) { + SetVector opTree; + getUnrollTree(op, opTree, visitedOps, excludeChainOps, op, isTop2Bottom); + if (!opTree.empty()) { + SetVector sortedOpTree = sortOpTree(opTree); + unrollOpTrees.push_back(sortedOpTree); + } + } + + void createFor(OpBuilder &builder, Location &loc, int64_t start, + int64_t iterNum, scf::ForOp &forOp, arith::IndexCastOp &idxVar, + ValueRange &iterArgs) { + auto lower = builder.create(loc, start); + auto upper = builder.create(loc, iterNum); + auto step = builder.create(loc, 1); + forOp = builder.create(loc, lower, upper, step, iterArgs); + builder.setInsertionPointToStart(forOp.getBody()); + idxVar = builder.create(loc, builder.getI32Type(), + forOp.getInductionVar()); + } + + void createLoopBody(MLIRContext *context, OpBuilder &builder, Location &loc, + int64_t iterNum, SetVector &unrollOpTree, + SetVector &outerChain, + arith::IndexCastOp &idxVar, IRMapping &mapping) { + for (auto op : unrollOpTree) { + bool isOuter = inOpChain(outerChain, op); + auto newOp = builder.clone(*op, mapping); + setTensorType(context, newOp, iterNum, isOuter); + TypeSwitch(newOp) + .Case([&](auto loadOp) { + if (auto tensorTy = + dyn_cast(loadOp.getPtr().getType())) { + if (!loadOp.getSVOpt() && !loadOp.getIsDiscrete()) { + insertIndex(newOp, idxVar); + } + } + }) + .Case([&](auto storeOp) { + if (auto tensorTy = + dyn_cast(storeOp.getPtr().getType())) { + insertIndex(newOp, idxVar); + } + }) + .Case([&](auto makeRangeOp) { + if (auto tensorTy = + dyn_cast(op->getResults()[0].getType())) { + insertIndex(newOp, idxVar); + } + }) + .Case([&](auto interleaveOp) { + if (auto tensorTy = + dyn_cast(op->getResults()[0].getType())) { + insertIndex(newOp, idxVar); + } + }) + .Case([&](auto addPtrOp) { + auto ptr = addPtrOp.getPtr(); + auto offset = addPtrOp.getOffset(); + if (ptr.getType() != offset.getType()) { + auto extractOp = builder.create( + loc, getElementTypeOrSelf(ptr), builder.getI32IntegerAttr(0), + ptr); + auto splatOp = builder.create(loc, ptr.getType(), + extractOp); + setTensorType(context, splatOp, iterNum, isOuter); + addPtrOp.setOperand(0, splatOp); + addPtrOp->moveAfter(splatOp); + } + }) + .Case([&](auto constantOp) { + auto value = constantOp.getValue(); + if (auto attr = dyn_cast(value)) { + value = DenseElementsAttr::getFromRawBuffer( + cast(constantOp.getType()), attr.getRawData()); + } + constantOp.setValueAttr(value); + }) + .Case([&](auto ifOp) { + // Set IfOp's childOp Type(Then) + auto &ifThenBlock = ifOp.getThenRegion().front(); + for (auto &inBlockOp : ifThenBlock) { + setTensorType(context, &inBlockOp, iterNum, isOuter); + } + // Set IfOp's childOp Type(Else) + auto &ifElseRegion = ifOp.getElseRegion(); + if (!ifElseRegion.empty()) { + auto &ifElseBlock = ifElseRegion.front(); + for (auto &inBlockOp : ifElseBlock) { + setTensorType(context, &inBlockOp, iterNum, isOuter); + } + } + }); + if (scf::IfOp ifOp = dyn_cast(newOp)) { + auto &ifThenBlock = ifOp.getThenRegion().front(); + for (auto &inBlockOp : ifThenBlock) { + unsigned numifOpResults = ifOp.getNumResults(); + if (auto yieldOp = llvm::dyn_cast(&inBlockOp)) { + // 1. needExtract denotes Extraction is required if the operand of + // YieldOp does not match the type expected by the result of IfOp + // 2. isSame denotes whether the operand of YieldOp is in the same + // ForBlock as IfOp. + unsigned numyieldOpOperands = yieldOp.getNumOperands(); + assert( + (numifOpResults == numyieldOpOperands) && + "The number of IfOp results and YieldOp operands must match."); + for (unsigned i = 0; i < numyieldOpOperands; ++i) { + bool needExtract = false; + bool isSame = true; + Value result = ifOp.getResult(i); + Type resultType = result.getType(); + Value operand = yieldOp.getOperand(i); + Type operandType = operand.getType(); + if (resultType != operandType) { + needExtract = true; + } + isSame = isOperandOperationInSameForBlock(&inBlockOp, i); + if (!isSame && needExtract) { + assert(isa( + inBlockOp.getOperand(i).getDefiningOp()) && + "Unable to extract the non-constant operand."); + auto extractSliceOp = getExtractedOperand(context, builder, loc, + yieldOp, i, iterNum); + extractSliceOp->moveBefore(ifOp); + inBlockOp.setOperand(i, extractSliceOp->getResult(0)); + } + } + } + } + } + } + } + + void eraseDAG(SetVector &unrollOpTree) { + SetVector eraseOpTree(unrollOpTree.rbegin(), + unrollOpTree.rend()); + for (auto op : eraseOpTree) { + SetVector users; + for (auto user : op->getUsers()) { + if (isa(user)) { + users.insert(user); + } + } + for (auto user : users) { + user->erase(); + } + if (op->use_empty()) { + op->erase(); + } + } + } + + void moveAllocaAndGM2LM(scf::ForOp forOp) { + ModuleOp m = getOperation(); + + SmallVector gm2lmOps; + m.walk([&](triton::xpu::GM2LMOp gm2lmOp) { gm2lmOps.push_back(gm2lmOp); }); + + for (auto gm2lmOp : gm2lmOps) { + if (gm2lmOp->getBlock() != forOp->getBlock()) + continue; + + if (gm2lmOp->isBeforeInBlock(forOp)) + continue; + + auto allocaOp = gm2lmOp.getBufPtr().getDefiningOp(); + + allocaOp->moveBefore(forOp); + gm2lmOp->moveBefore(forOp); + } + } + + void unrollControl(MLIRContext *context, + SmallVector> &unrollOpTrees) { + // Get outerChains + SmallVector> outerChains; + getOuterChains(unrollOpTrees, outerChains); + + for (int i = 0; i < unrollOpTrees.size(); ++i) { + auto outerChain = outerChains[i]; + auto unrollOpTree = unrollOpTrees[i]; + // 1. Prepare for unroll control + int64_t numCol = 1; + int64_t numUnroll = 1; + triton::xpu::StoreOp insertPt; + SmallVector allStoreOps; + for (auto op : unrollOpTree) { + // 1.1 Get insertPt and tensor num + if (auto storeOp = dyn_cast(op)) { + auto type = storeOp.getValue().getType(); + numUnroll = numUnroll == 1 ? getnumUnroll(type) + : std::min(numUnroll, getNumCol(type)); + numCol = + numCol == 1 ? getNumCol(type) : std::min(numCol, getNumCol(type)); + allStoreOps.emplace_back(storeOp); + //[TODO] To deal with the case that storeOps are in more than one + // block + if (insertPt && insertPt->getBlock() != storeOp->getBlock()) { + return; + } + if (!insertPt || storeOp->isBeforeInBlock(insertPt)) { + insertPt = storeOp; + } + } + } + if (insertPt) { + auto loc = insertPt.getLoc(); + int64_t iterNum = ceil(numCol, numUnroll); + if (iterNum <= 1) + return; + LLVM_DEBUG(llvm::dbgs() + << "[Unroll Control] Hit Unroll Control Pointwise\n"); + // 2. Unroll control + // 2.1 Create forOp + OpBuilder builder(insertPt); + scf::ForOp forOp; + arith::IndexCastOp idxVar; + ValueRange iterArgs; + createFor(builder, loc, 0, iterNum, forOp, idxVar, iterArgs); + // 2.2 Set Tensor Type + IRMapping mapping; + createLoopBody(context, builder, loc, iterNum, unrollOpTree, outerChain, + idxVar, mapping); + + // 3. Erase old DAG + eraseDAG(unrollOpTree); + + // 4. Move Alloca & GM2LM Op before ForOp + moveAllocaAndGM2LM(forOp); + } + } + } + + void unrollControlReduce(MLIRContext *context, + SetVector &unrollOpTree, + Operation *insertPt, ValueRange &iterArgs, + ValueRange &returnOperands) { + SetVector outerChain; + getOuterChain(unrollOpTree, outerChain); + if (auto reduceOp = dyn_cast(insertPt)) { + int64_t numCol = 1, numUnroll = 1; + getUnrollInfoReduce(reduceOp, numCol, numUnroll); + int64_t iterNum = ceil(numCol, numUnroll); + if (iterNum <= 1) + return; + OpBuilder builder(reduceOp); + auto loc = reduceOp.getLoc(); + // 1. Prepare for unroll control + // Insert ExtractSliceOp for TensorType + SmallVector newIterArgs(iterArgs.size()); + for (int i = 0; i < iterArgs.size(); ++i) { + auto iterArgDefOp = iterArgs[i].getDefiningOp(); + bool isOuter = inOpChain(outerChain, iterArgDefOp); + auto extractSliceOp = builder.create( + loc, iterArgs[i].getType(), iterArgs[i]); + setTensorType(context, extractSliceOp, iterNum, isOuter); + auto inUnrollOpTree = [&](OpOperand &operand) { + return unrollOpTree.count(operand.getOwner()); + }; + iterArgs[i].replaceUsesWithIf(extractSliceOp.getResult(), + inUnrollOpTree); + newIterArgs[i] = extractSliceOp.getResult(); + } + // 2. Unroll control + // 2.1 Create forOp + scf::ForOp forOp; + arith::IndexCastOp idxVar; + ValueRange newIterArgsRange(newIterArgs); + createFor(builder, loc, 1, iterNum, forOp, idxVar, newIterArgsRange); + // 2.2 Set Tensor Type + IRMapping mapping; + createLoopBody(context, builder, loc, iterNum, unrollOpTree, outerChain, + idxVar, mapping); + bool isOuterReduce = inOpChain(outerChain, reduceOp); + setTensorType(context, reduceOp, iterNum, isOuterReduce, false); + // 2.3 Modify users and defs + // replace initArgs with iterArgs + auto inForOp = [&](OpOperand &operand) { + return forOp == operand.getOwner()->getBlock()->getParentOp(); + }; + auto forBody = forOp.getBody(); + auto forArgs = forBody->getArguments(); + for (int i = 0; i < forOp.getInitArgs().size(); ++i) { + forOp.getInitArgs()[i].replaceUsesWithIf(forArgs[i + 1], inForOp); + } + SmallVector mapRes; + for (int i = 0; i < returnOperands.size(); ++i) { + mapRes.emplace_back(mapping.lookup(returnOperands[i])); + } + builder.create(loc, mapRes); + auto isReduceOp = [&](OpOperand &operand) { + return reduceOp == operand.getOwner(); + }; + for (int i = 0; i < forOp.getResults().size(); ++i) { + reduceOp.getOperands()[i].replaceUsesWithIf(forOp.getResults()[i], + isReduceOp); + } + // 3. Erase old DAG + eraseDAG(unrollOpTree); + } + } + + void getExcludeChainOps(ModuleOp &m, + SetVector &excludeChainOps) { + m.walk([&](Operation *op) { + TypeSwitch(op) + .Case([&](auto memoryOp) { + getOpChainBwd(excludeChainOps, memoryOp.getPtr().getDefiningOp()); + if (memoryOp.getLen()) { + getOpChainBwd(excludeChainOps, memoryOp.getLen().getDefiningOp()); + } + }) + .Case([&](auto acessOp) { + if (acessOp.getMask()) { + getOpChainBwd(excludeChainOps, acessOp.getMask().getDefiningOp()); + } + }); + }); + } + + void findDiscretePtrChain(SetVector &unrollOpTree, + SetVector &newUnrollOpTree) { + for (auto op : unrollOpTree) { + if (auto loadOp = dyn_cast(op)) { + bool isDiscrete = loadOp.getIsDiscrete(); + if (isDiscrete) { + OpBuilder builder(loadOp); + auto loc = loadOp.getLoc(); + auto resType = loadOp.getResult().getType(); + int64_t numCol = getNumCol(resType); + int64_t numUnroll = getnumUnroll(resType); + if (numCol > numUnroll && numCol % numUnroll == 0) { + auto lmPtr = loadOp.getPtr(); + auto gm2lmOp = cast( + findDefOpBwd(lmPtr)); + auto gmPtrOp = cast( + findDefOpBwd(gm2lmOp.getPtr())); + auto offset = gmPtrOp.getOffset(); + auto newLmPtr = builder.create( + loc, lmPtr.getType(), lmPtr, offset); + SetVector ptrVisitedOps; + SetVector ptrExcludeChainOps; + getUnrollTree(newLmPtr, newUnrollOpTree, ptrVisitedOps, + ptrExcludeChainOps, newLmPtr, false); + if (!newUnrollOpTree.empty()) { + newUnrollOpTree = sortOpTree(newUnrollOpTree); + } + gm2lmOp->setAttr("offsetState", + builder.getSI32IntegerAttr(static_cast( + OffsetState::Continuous))); + loadOp.setOperand(0, newLmPtr); + } + } + } + } + } + + void + findDiscretePtrChains(SmallVector> &unrollOpTrees, + SmallVector> &newUnrollOpTrees) { + for (auto [i, unrollOpTree] : llvm::enumerate(unrollOpTrees)) { + findDiscretePtrChain(unrollOpTree, newUnrollOpTrees[i]); + } + } + + void createDiscreteOffset(ModuleOp &m) { + m.walk([&](triton::xpu::LoadOp loadOp) { + bool isDiscrete = loadOp.getIsDiscrete(); + if (isDiscrete) { + OpBuilder builder(loadOp); + auto loc = builder.getUnknownLoc(); + auto lmPtr = loadOp.getPtr(); + auto lmAddPtr = + cast(findDefOpBwd(lmPtr)); + auto lmOffset = lmAddPtr.getOffset(); + auto gm2lmOp = cast( + findDefOpBwd(lmPtr)); + auto gmPtrOp = cast( + findDefOpBwd(gm2lmOp.getPtr())); + auto gmOffset = gmPtrOp.getOffset(); + auto extractOp = builder.create( + loc, getElementTypeOrSelf(gmOffset), builder.getI32IntegerAttr(0), + gmOffset); + auto splatOp = + builder.create(loc, lmOffset.getType(), extractOp); + auto offset = builder.create(loc, lmOffset.getType(), + lmOffset, splatOp); + lmAddPtr.setOperand(1, offset); + lmAddPtr->moveAfter(offset); + if (gm2lmOp->getOperand(0) == lmAddPtr.getResult()) + gm2lmOp->moveAfter(lmAddPtr); + } + }); + } + + void pointwiseUnrollControl(ModuleOp &m, MLIRContext *context) { + // 1. Data-flow Analysis: get load -> store DAG + // (op in ptrChain/lenChain/maskChain will not walk from top to down) + // 1.1 Get excludeChainOps + SetVector excludeChainOps; + getExcludeChainOps(m, excludeChainOps); + // 1.2 Get load -> store DAG + SetVector visitedOps; + SmallVector> unrollOpTrees; + m.walk([&](triton::xpu::StoreOp storeOp) { + auto valType = storeOp.getValue().getType(); + int64_t numCol = getNumCol(valType); + int64_t numUnroll = getnumUnroll(valType); + if (numCol > numUnroll && numCol % numUnroll == 0) { + getDAG(storeOp, visitedOps, unrollOpTrees, excludeChainOps); + } + for (auto visitedOp : visitedOps) { + if (isa(visitedOp)) { + visitedOps.remove(visitedOp); + } + } + }); + if (unrollOpTrees.size() == 0) + return; + + // 1.3 Find ptr chain of discrete for moving to loop body + SmallVector> newUnrollOpTrees(unrollOpTrees); + findDiscretePtrChains(unrollOpTrees, newUnrollOpTrees); + + // 2. Deal with unroll opTrees + unrollControl(context, newUnrollOpTrees); + + // 3. Calculate discrete offset in the runtime + createDiscreteOffset(m); + } + + void createLoadStore(scf::ForOp &forOp, scf::YieldOp &yieldOp, Value &yield, + int i, Block &block, + SmallVector &storeOps) { + OpBuilder builder(yieldOp); + auto loc = yieldOp->getLoc(); + Type yieldType = yield.getType(); + Type yieldElemType = getElementTypeOrSelf(yieldType); + int64_t vecSize = getNumInVector(yieldElemType); + Type ptrTy = createPointerType(yieldType, vecSize); + int64_t tensorSize = getNumCol(yieldType); + if (!forOp.getResults()[i].use_empty()) { + // Create Alloca Store for Init Args + auto initForArg = forOp.getInitArgs()[i]; + auto newAllocaOp = builder.create( + loc, ptrTy, tensorSize * vecSize); + auto initStoreOp = builder.create( + loc, newAllocaOp, initForArg, Value(), Value(), -1, false); + newAllocaOp->moveBefore(forOp); + initStoreOp->moveBefore(forOp); + // Create Load for Input + auto inputLoadOp = builder.create( + loc, yieldType, newAllocaOp, Value(), Value(), Value(), 1, -1, false, + false, false); + auto notUsedForYield = [&](OpOperand &operand) { + return !isa(operand.getOwner()); + }; + auto forArg = forOp.getRegionIterArgs()[i]; + forArg.replaceUsesWithIf(inputLoadOp, notUsedForYield); + inputLoadOp->moveBefore(&block.front()); + // Create Store for Output + auto outputStoreOp = builder.create( + loc, newAllocaOp, yield, Value(), Value(), -1, false); + outputStoreOp->moveBefore(yieldOp); + storeOps.emplace_back(outputStoreOp); + // Create Load for Reduce + auto reduceLoadOp = builder.create( + loc, yieldType, newAllocaOp, Value(), Value(), Value(), 1, -1, false, + false, false); + // Move Load closed to For user + reduceLoadOp->moveAfter(forOp); + Operation *insertPt = nullptr; + for (auto user : forOp.getResults()[i].getUsers()) { + if (!insertPt) { + insertPt = user; + } else { + if (insertPt->getBlock() == user->getBlock()) { + if (user->isBeforeInBlock(insertPt)) { + insertPt = user; + } + } + } + } + if (insertPt) { + reduceLoadOp->moveBefore(insertPt); + } + // Replace For Result with Load + auto notReduceLoadOp = [&](OpOperand &operand) { + return reduceLoadOp != operand.getOwner(); + }; + forOp.getResults()[i].replaceUsesWithIf(reduceLoadOp, notReduceLoadOp); + + // Discard Yield by setting initForArg to operand + yieldOp->setOperand(i, initForArg); + } + } + + void getUnrollInfoReduce(triton::xpu::ReduceOp &reduceOp, int64_t &numCol, + int64_t &numUnroll) { + auto types = reduceOp.getOperandTypes(); + assert(types.size() > 1); + for (int i = 0; i < types.size() - 1; ++i) { + if (i == 0) { + numCol = getNumCol(types[i]); + numUnroll = getnumUnroll(types[i]); + } else { + assert(numCol == getNumCol(types[i])); + assert(numUnroll == getnumUnroll(types[i])); + } + } + } + + void forUnrollControl(ModuleOp &m, MLIRContext *context) { + SetVector excludeChainOps; + getExcludeChainOps(m, excludeChainOps); + SetVector vistedForOps; + // 1. Create Store Load + m.walk([&](triton::xpu::ReduceOp reduceOp) { + int64_t numCol = 1, numUnroll = 1; + getUnrollInfoReduce(reduceOp, numCol, numUnroll); + if (numCol > numUnroll && numCol % numUnroll == 0) { + LLVM_DEBUG(llvm::dbgs() << "[Unroll Control] Hit Unroll Control For\n"); + for (auto operand : reduceOp.getOperands()) { + if (auto forOp = dyn_cast(operand.getDefiningOp())) { + if (!vistedForOps.count(forOp)) { + vistedForOps.insert(forOp); + auto &forBlock = forOp.getRegion().front(); + bool hasIf = false; + SetVector visitedOps; + for (auto &inForBlockOp : forBlock) { + if (auto ifOp = dyn_cast(inForBlockOp)) { + SmallVector storeOps; + auto &ifBlock = ifOp.getThenRegion().front(); + auto yieldOp = cast(ifBlock.getTerminator()); + for (auto [i, yield] : + llvm::enumerate(yieldOp.getOperands())) { + createLoadStore(forOp, yieldOp, yield, i, ifBlock, + storeOps); + } + // Unroll control + for (auto storeOp : storeOps) { + SmallVector> unrollOpTrees; + getDAG(storeOp, visitedOps, unrollOpTrees, excludeChainOps); + // Find ptr chain of discrete for moving to loop body + SmallVector> newUnrollOpTrees( + unrollOpTrees); + findDiscretePtrChains(unrollOpTrees, newUnrollOpTrees); + unrollControl(context, newUnrollOpTrees); + } + hasIf = true; + } + } + if (!hasIf) { + SmallVector storeOps; + auto yieldOp = cast(forBlock.getTerminator()); + for (auto [i, yield] : llvm::enumerate(yieldOp.getOperands())) { + createLoadStore(forOp, yieldOp, yield, i, forBlock, storeOps); + } + // Unroll control + for (auto storeOp : storeOps) { + SmallVector> unrollOpTrees; + getDAG(storeOp, visitedOps, unrollOpTrees, excludeChainOps); + // Find ptr chain of discrete for moving to loop body + SmallVector> newUnrollOpTrees( + unrollOpTrees); + findDiscretePtrChains(unrollOpTrees, newUnrollOpTrees); + unrollControl(context, newUnrollOpTrees); + } + } + } + } + } + } + }); + } + + void getInlineInfo(SetVector &inlineOps, Operation *startOp, + ValueRange &returnOperands) { + Operation *op = startOp; + while (!isa(op)) { + inlineOps.insert(op); + op = op->getNextNode(); + } + returnOperands = op->getOperands(); + } + + void createReduceWithinCore(ModuleOp &m, MLIRContext *context) { + SetVector excludeChainOps; + getExcludeChainOps(m, excludeChainOps); + m.walk([&](triton::xpu::ReduceOp reduceOp) { + ReduceOpHelper helper(reduceOp); + OpBuilder builder(reduceOp); + auto loc = reduceOp->getLoc(); + SetVector visitedOps; + auto reduceOperandNum = reduceOp.getNumOperands() - 1; + SmallVector> copyOpTrees; + SetVector unrollOpTree; + int64_t numCol = 1, numUnroll = 1; + getUnrollInfoReduce(reduceOp, numCol, numUnroll); + if (numCol > numUnroll && numCol % numUnroll == 0) { + LLVM_DEBUG(llvm::dbgs() + << "[Unroll Control] Hit Unroll Control Reduction\n"); + for (int i = 0; i < reduceOperandNum; ++i) { + if (auto reduceDefOp = reduceOp.getOperands()[i].getDefiningOp()) { + getDAG(reduceDefOp, visitedOps, copyOpTrees, excludeChainOps, + false); + } + } + // 1. Copy Defined Op Chain of Reduce Operand for InitArgs + IRMapping mapping; + for (auto ©OpTree : copyOpTrees) { + for (auto ©Op : copyOpTree) { + auto newOp = builder.clone(*copyOp, mapping); + unrollOpTree.insert(newOp); + } + } + // 2. Inline Combine Op of Reduce + // Clone Region + IRRewriter rewriter(builder); + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + rewriter.cloneRegionBefore(reduceOp.getCombineOp(), &parent.front()); + auto &newReduce = parent.front(); + // Set Type for Cloned Ops + auto tensorTy = reduceOp.getInputTypes()[0]; + auto shape = tensorTy.getShape(); + for (auto &op : newReduce) { + if (auto cmpfOp = dyn_cast(op)) { + auto tensorTy0 = cmpfOp.getODSOperands(0)[0].getType(); + auto tensorTy1 = cmpfOp.getODSOperands(1)[0].getType(); + int operandIndexNeedModify; + mlir::Type operandNeedReserved; + if (tensorTy0 != tensorTy1) { + if ((mlir::isa(tensorTy0) || + mlir::isa(tensorTy0)) && + mlir::isa(tensorTy1)) { + operandIndexNeedModify = 0; + operandNeedReserved = tensorTy1; + } else if ((mlir::isa(tensorTy1) || + mlir::isa(tensorTy1)) && + mlir::isa(tensorTy0)) { + operandIndexNeedModify = 1; + operandNeedReserved = tensorTy0; + } + assert(isa( + cmpfOp.getOperand(operandIndexNeedModify) + .getDefiningOp()) && + "Unable to extract the non-constant operand."); + auto splatOp = builder.create( + loc, operandNeedReserved, + cmpfOp.getOperand(operandIndexNeedModify)); + splatOp->moveBefore(&op); + cmpfOp.setOperand(operandIndexNeedModify, splatOp.getResult()); + } + } else if (auto selOp = dyn_cast(op)) { + auto tensorTy1 = selOp.getODSOperands(1)[0].getType(); + auto tensorTy2 = selOp.getODSOperands(2)[0].getType(); + int operandIndexNeedModify; + mlir::Type operandNeedReserved; + if (tensorTy1 != tensorTy2) { + if ((mlir::isa(tensorTy1) || + mlir::isa(tensorTy1)) && + mlir::isa(tensorTy2)) { + operandIndexNeedModify = 1; + operandNeedReserved = tensorTy2; + } else if ((mlir::isa(tensorTy2) || + mlir::isa(tensorTy2)) && + mlir::isa(tensorTy1)) { + operandIndexNeedModify = 2; + operandNeedReserved = tensorTy1; + } + assert(isa( + selOp.getOperand(operandIndexNeedModify) + .getDefiningOp()) && + "Unable to extract the non-constant operand."); + + auto splatOp = builder.create( + loc, operandNeedReserved, + selOp.getOperand(operandIndexNeedModify)); + splatOp->moveBefore(&op); + selOp.setOperand(operandIndexNeedModify, splatOp.getResult()); + } + } + for (auto [i, resTy] : llvm::enumerate(op.getResultTypes())) { + auto inlineTensorTy = + RankedTensorType::get(shape, resTy, tensorTy.getEncoding()); + op.getResult(i).setType(inlineTensorTy); + } + } + // Inline Ops + llvm::SmallVector combineArgs(2 * reduceOperandNum); + for (unsigned i = 0; i < reduceOperandNum; ++i) { + combineArgs[i] = reduceOp.getOperands()[i]; + combineArgs[reduceOperandNum + i] = + mapping.lookup(reduceOp.getOperands()[i]); + } + auto currOp = &*rewriter.getInsertionPoint(); + auto insertOp = currOp->getPrevNode(); + rewriter.inlineBlockBefore(&newReduce, currOp, combineArgs); + ValueRange returnOperands; + getInlineInfo(unrollOpTree, insertOp, returnOperands); + + auto isReduceOp = [&](OpOperand &operand) { + return reduceOp == operand.getOwner(); + }; + llvm::SmallVector iterArgs(reduceOperandNum); + for (auto [i, returnOperand] : llvm::enumerate(returnOperands)) { + iterArgs[i] = reduceOp.getOperands()[i]; + reduceOp.getOperands()[i].replaceUsesWithIf(returnOperand, + isReduceOp); + } + // Find ptr chain of discrete for moving to loop body + SetVector newUnrollOpTree(unrollOpTree); + findDiscretePtrChain(unrollOpTree, newUnrollOpTree); + // 3. Create Loop for ReduceWithinCore + ValueRange iterArgsRange(iterArgs); + unrollControlReduce(context, newUnrollOpTree, reduceOp, iterArgsRange, + returnOperands); + // 4. For Vectorize: triton.addf->triton_xpu.vvaddf + processOpVecTy(m); + } + }); + } + + void reductionUnrollControl(ModuleOp &m, MLIRContext *context) { + // 1. Unroll Control for Reduce For + forUnrollControl(m, context); + // 2. Create For for ReduceWithinCore + createReduceWithinCore(m, context); + // 3. Calculate discrete offset in the runtime + createDiscreteOffset(m); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + bool isReduce = false; + m.walk([&](triton::xpu::ReduceOp redOp) { isReduce = true; }); + if (isReduce) { + reductionUnrollControl(m, context); + } else { + pointwiseUnrollControl(m, context); + } + } + +private: + int64_t numUnrollPerCore = 2; +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Vectorize.cpp b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Vectorize.cpp new file mode 100644 index 000000000..4a4a16abd --- /dev/null +++ b/third_party/xpu/lib/Dialect/TritonXPU/Transforms/Vectorize.cpp @@ -0,0 +1,1279 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// TODO: Pass Description +//===----------------------------------------------------------------------===// + +// clang-format off +#include "triton/Dialect/TritonXPU/IR/Dialect.h" +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" +// clang-format on + +#define DEBUG_TYPE "tritonxpu-vectorize" + +namespace mlir { +namespace triton { +namespace xpu { + +enum class ElemState { + SS = 0, /*00*/ + SV = 1, /*01*/ + VS = 2, /*10*/ + VV = 3 /*11*/ +}; + +using OperationTree = llvm::SetVector; + +#define ARITH_BINARY_FLOAT_OP \ + arith::AddFOp, arith::SubFOp, arith::MulFOp, arith::DivFOp, \ + arith::MaximumFOp, arith::MinimumFOp + +#define ARITH_BINARY_INT_OP \ + arith::SubIOp, arith::AndIOp, arith::OrIOp, arith::MulIOp, arith::AddIOp, \ + arith::XOrIOp + +#define MATH_UNARY_OP \ + math::ExpOp, math::SqrtOp, math::SinOp, math::CosOp, arith::ExtFOp, \ + arith::TruncFOp, math::AbsFOp + +// TODO: VMin when LLVM can select +#define REDUCE_COMBINE_OP \ + arith::AddFOp, arith::MulFOp, arith::MaxNumFOp, arith::MinNumFOp, \ + arith::OrIOp, arith::XOrIOp, arith::AndIOp, triton::xpu::ReduceReturnOp + +template struct VOp; + +#define VOP(SrcType, DstType) \ + template <> struct VOp { \ + typedef DstType type; \ + }; + +VOP(arith::AddFOp, triton::xpu::VvaddFOp) +VOP(arith::SubFOp, triton::xpu::VvsubFOp) +VOP(arith::MulFOp, triton::xpu::VvmulFOp) +VOP(arith::DivFOp, triton::xpu::VvdivFOp) +VOP(arith::MaximumFOp, triton::xpu::VvmaxFOp) +VOP(arith::MinimumFOp, triton::xpu::VvminFOp) + +VOP(arith::AddIOp, triton::xpu::VvaddIOp) +VOP(arith::SubIOp, triton::xpu::VvsubIOp) +VOP(arith::MulIOp, triton::xpu::VvmulIOp) +VOP(arith::AndIOp, triton::xpu::VvandIOp) +VOP(arith::XOrIOp, triton::xpu::VvxorIOp) +VOP(arith::OrIOp, triton::xpu::VvorIOp) + +VOP(math::ExpOp, triton::xpu::VExpFOp) +VOP(math::AbsFOp, triton::xpu::VAbsFOp) +VOP(math::LogOp, triton::xpu::VLogFOp) +VOP(math::SqrtOp, triton::xpu::VSqrtFOp) +VOP(math::SinOp, triton::xpu::VSinFOp) +VOP(math::CosOp, triton::xpu::VCosFOp) +VOP(arith::ExtFOp, triton::xpu::VExtFOp) +VOP(arith::TruncFOp, triton::xpu::VTruncFOp) +VOP(arith::SIToFPOp, triton::xpu::VSIToFPOp) + +template struct VV2SVOp; + +#define VV2SVOp(SrcType, DstType) \ + template <> struct VV2SVOp { \ + typedef DstType type; \ + }; + +VV2SVOp(triton::xpu::VvaddFOp, triton::xpu::SvaddFOp); +VV2SVOp(triton::xpu::VvmulFOp, triton::xpu::SvmulFOp); +VV2SVOp(triton::xpu::VvsubFOp, triton::xpu::SvsubFOp); +VV2SVOp(triton::xpu::VvmaxFOp, triton::xpu::SvmaxFOp); + +} // namespace xpu +} // namespace triton +} // namespace mlir + +namespace mlir { + +namespace triton { +namespace xpu { + +#define GEN_PASS_DEF_TRITONXPUVECTORIZE +#include "triton/Dialect/TritonXPU/Transforms/Passes.h.inc" + +struct TritonXPUVectorizePass + : public impl::TritonXPUVectorizeBase { + + using impl::TritonXPUVectorizeBase< + TritonXPUVectorizePass>::TritonXPUVectorizeBase; + + template + static decltype(auto) createBinVectorizedOp(T op, Type vectorizedTensorTy) { + OpBuilder builder(op); + return builder.create::type>( + op.getLoc(), vectorizedTensorTy, op.getLhs(), op.getRhs()); + } + + template + static decltype(auto) createUnaryVectorizedOp(T op, Type vectorizedTensorTy) { + OpBuilder builder(op); + return builder.create::type>( + op.getLoc(), vectorizedTensorTy, op.getOperand()); + } + + static decltype(auto) createLibdeviceOp(triton::ExternElementwiseOp &op, + const llvm::StringRef &symbol, + Type vectorizedTensorTy) { + OpBuilder builder(op); + return builder.create( + op.getLoc(), vectorizedTensorTy, op.getOperands(), op.getLibname(), + op.getLibpath(), symbol, op.getPure()); + } + + // TODO[dyq]: open isMultipleOfBank + // bool isMultipleOfBank(ModuleOp &mod) { + // bool res = false; + // mod.walk([&](arith::CmpIOp cmpiOp) { + // auto lhs = cmpiOp.getLhs(); + // auto rhs = cmpiOp.getRhs(); + + // if (cmpiOp.predicate() == arith::CmpIPredicate::slt) { + // auto lhsShape = lhs.getType().cast().getShape(); + + // if (lhsShape.size() == 2 && lhsShape[0] == 1) { // inner Cmp + // Calculation + // if (auto rhsOp = + // rhs.getDefiningOp()) { // Static Rnumel + // auto denseAttr = rhsOp.getValue().dyn_cast(); + // auto elemPerCore = + // *denseAttr.getValues().begin(); // get rnumel int + // res = (elemPerCore & (bufferSize - 1)) == 0; // check multiple? + // } + // } + // } + // }); + // return res; + // } + + Operation *getBlockArgumentOp(Value arg) { + BlockArgument blockArg = mlir::dyn_cast(arg); + Block *block = blockArg.getOwner(); + unsigned argIndex = blockArg.getArgNumber(); + + if (auto forOp = dyn_cast(block->getParentOp())) { + // TODO[dyq]: check getIterOperands -> getInitArgs + Value initValue = + forOp.getInitArgs()[argIndex - forOp.getNumInductionVars()]; + return initValue.getDefiningOp(); + } + llvm_unreachable( + "[Vectorization]: Operand is Not a BlockArgument of scf::for."); + return nullptr; + } + + bool binLikeOpVectorize(Value lhs, Value rhs, OperationTree &visited, + OperationTree &vectorizedOps) { + bool isFP32Ty = getElementTypeOrSelf(lhs.getType()).isF32() && + getElementTypeOrSelf(rhs.getType()).isF32(); + bool isFP16Ty = getElementTypeOrSelf(lhs.getType()).isF16() && + getElementTypeOrSelf(rhs.getType()).isF16(); + bool isINT32Ty = getElementTypeOrSelf(lhs.getType()).isInteger(32) && + getElementTypeOrSelf(rhs.getType()).isInteger(32); + if (!isFP32Ty && !isFP16Ty && !isINT32Ty) { + return false; + } + + bool isVectorized = false; + + Operation *lhsOp = lhs.getDefiningOp(); + Operation *rhsOp = rhs.getDefiningOp(); + + Operation *lhsLoopInitOp = nullptr; + Operation *rhsLoopInitOp = nullptr; + + if (mlir::isa(lhs)) { + lhsLoopInitOp = getBlockArgumentOp(lhs); + } + + if (mlir::isa(rhs)) { + rhsLoopInitOp = getBlockArgumentOp(rhs); + } + + bool lhsVectorized = lhsOp + ? vectorize(lhsOp, visited, vectorizedOps) + : vectorize(lhsLoopInitOp, visited, vectorizedOps); + bool rhsVectorized = rhsOp + ? vectorize(rhsOp, visited, vectorizedOps) + : vectorize(rhsLoopInitOp, visited, vectorizedOps); + + isVectorized = lhsVectorized && rhsVectorized; + return isVectorized; + } + + bool vectorize(Operation *op, OperationTree &visited, + OperationTree &vectorizedOps) { + assert(op && "[Vectorization]: Empty Operation pointer"); + visited.insert(op); + + if (vectorizedOps.contains(op)) + return true; + + bool isVectorized = false; + TypeSwitch(op) + .Case([&](auto loadOp) { isVectorized = true; }) + .Case([&](auto loadOp) { isVectorized = true; }) + .Case( + [&](auto coreIdOp) { isVectorized = true; }) + .Case( + [&](auto programIdOp) { isVectorized = true; }) + .Case([&](auto constOp) { isVectorized = true; }) + .Case([&](auto unaryOp) { isVectorized = true; }) + .Case([&](auto loadOp) { + unsigned numElems = getTotalElemsPerThread(loadOp.getType()); + Type elemTy = getElementTypeOrSelf(loadOp.getType()); + auto elemWidth = elemTy.getIntOrFloatBitWidth(); + auto vectorWidth = 512 / elemWidth; + isVectorized = numElems % vectorWidth == 0 && numElems != 0; + }) + .Case([&](auto storeOp) { + isVectorized = vectorize(storeOp.getValue().getDefiningOp(), visited, + vectorizedOps); + }) + .Case([&](auto reduceOp) { + if (ReduceVec) { + isVectorized = true; + for (Block &block : reduceOp.getCombineOp().getBlocks()) { + for (auto &op : block) { + if (!isa(op)) { + isVectorized = false; + break; + } + } + } + } else { + isVectorized = false; + } + }) + .Case([&](auto extractOp) { + isVectorized = vectorize(extractOp.getTensor().getDefiningOp(), + visited, vectorizedOps); + }) + .Case([&](auto splatOp) { + auto defineOp = splatOp.getSrc().getDefiningOp(); + if (!defineOp) { // some splatOp deal in_ptr + isVectorized = true; + } else { // some splatOp deal tensor + auto srcTy = splatOp.getSrc().getType(); + isVectorized = getTotalElemsPerThread(srcTy) == 1; + } + }) + .Case([&](auto broadCastOp) { + // Some BroadcastOp From ReduceOp + auto srcTy = + mlir::dyn_cast(broadCastOp.getSrc().getType()); + auto resTy = mlir::dyn_cast( + broadCastOp.getResult().getType()); + + auto srcShape = srcTy.getShape(); + auto resShape = resTy.getShape(); + + auto rank = srcTy.getRank(); + unsigned resNumElems = getTotalElemsPerThread(resTy); + + if (rank == 2 && resNumElems >= 16) { + // srcShape[0] > 32: Scalar Calculations Perform Better than Vector + // Calculations When The Data Size is Small. + if ((srcShape[0] > 32 && srcShape[1] == 1) || + (srcShape[0] == 1 && srcShape[1] == resShape[1])) { + isVectorized = true; + } + } + }) + .Case([&](auto expandDimsOp) { + isVectorized = vectorize(expandDimsOp.getOperand().getDefiningOp(), + visited, vectorizedOps); + }) + .Case([&](auto addPtrOp) { + isVectorized = vectorize(addPtrOp.getPtr().getDefiningOp(), visited, + vectorizedOps) && + vectorize(addPtrOp.getOffset().getDefiningOp(), + visited, vectorizedOps); + }) + .Case([&](auto cvtOp) { + isVectorized = vectorize(cvtOp.getOperand().getDefiningOp(), visited, + vectorizedOps); + }) + .Case([&](auto selectOp) { + auto tv = selectOp.getTrueValue(); + auto fv = selectOp.getFalseValue(); + isVectorized = binLikeOpVectorize(tv, fv, visited, vectorizedOps); + }) + .Case([&](auto cmpIOp) { + isVectorized = false; + // TODO: Add vCmpIOp Support + // auto lhs = cmpIOp.getLhs(); + // auto rhs = cmpIOp.getRhs(); + // isVectorized = binLikeOpVectorize(lhs, rhs, visited, + // vectorizedOps); + }) + .Case([&](auto cmpFOp) { + auto lhs = cmpFOp.getLhs(); + auto rhs = cmpFOp.getRhs(); + isVectorized = binLikeOpVectorize(lhs, rhs, visited, vectorizedOps); + }) + .Case([&](auto ifOp) { + // For then Region + Region &thenRegion = ifOp.getThenRegion(); + Block &thenBlock = thenRegion.front(); + Operation *thenTerminator = thenBlock.getTerminator(); + + if (auto yieldOp = dyn_cast(thenTerminator)) { + if (auto prevOp = yieldOp.getOperands().front().getDefiningOp()) { + isVectorized = vectorize(prevOp, visited, vectorizedOps); + } + } + + // For Else Region + if (!ifOp.getElseRegion().empty()) { + Region &elseRegion = ifOp.getElseRegion(); + Block &elseBlock = elseRegion.front(); + Operation *elseTerminator = elseBlock.getTerminator(); + if (auto yieldOp = dyn_cast(elseTerminator)) { + if (auto prevOp = yieldOp.getOperands().front().getDefiningOp()) { + isVectorized &= vectorize(prevOp, visited, vectorizedOps); + } + } + } + }) + .Case([&](auto forOp) { + // TODO[dyq]: check getIterOperands -> getInitArgs + auto iterArgsInitValues = forOp.getInitArgs(); + Region ®ion = forOp.getRegion(); + Block &block = region.front(); + Operation *terminator = block.getTerminator(); + + if (auto yieldOp = dyn_cast(terminator)) { + if (auto prevOp = yieldOp.getOperands().front().getDefiningOp()) { + isVectorized = vectorize(prevOp, visited, vectorizedOps) && + iterArgsInitValues.size() == 1; + } + } + }) + .Case([&](auto yieldOp) { + if (auto prevOp = yieldOp.getOperands().front().getDefiningOp()) { + isVectorized = vectorize(prevOp, visited, vectorizedOps); + } + }) + .Case([&](auto extElemwiseOp) { + auto symbol = extElemwiseOp.getSymbol(); + auto prevOp = extElemwiseOp.getOperands().front().getDefiningOp(); + assert(extElemwiseOp.getOperands().size() > 0 && + "Unexcepted ExternElementwiseOp Operand"); + if (symbol == "_ZN3xpu5tanhfEf") { + isVectorized = false; + // isVectorized = true; + // for (auto operand : extElemwiseOp.getOperands()) { + // isVectorized = + // isVectorized && vectorize(prevOp, visited, vectorizedOps); + // } + } else if (symbol == "_ZN3xpu3erfEf") { + isVectorized = true; + for (auto operand : extElemwiseOp.getOperands()) { + isVectorized = + isVectorized && vectorize(prevOp, visited, vectorizedOps); + } + } else if (symbol == "_ZN3xpu5isinfEf") { + isVectorized = false; + // TODO: check visinf logic + // isVectorized = true; + // for (auto operand : extElemwiseOp.getOperands()) { + // isVectorized = + // isVectorized && vectorize(prevOp, visited, vectorizedOps); + // } + } else { + isVectorized = false; + LLVM_DEBUG(llvm::dbgs() + << "[Vectorization]: Unsupported LibDeviceOp Symbol" + << symbol << "\n"); + } + }) + .Case([&](arith::SIToFPOp unaryOp) { + auto inType = getElementTypeOrSelf(unaryOp.getIn().getType()); + isVectorized = inType.isInteger(32) && + vectorize(unaryOp.getOperand().getDefiningOp(), + visited, vectorizedOps); + }) + .Case([&](auto binOp) { + auto lhs = binOp.getLhs(); + auto rhs = binOp.getRhs(); + isVectorized = binLikeOpVectorize(lhs, rhs, visited, vectorizedOps); + }) + .Case([&](auto binOp) { + auto lhs = binOp.getLhs(); + auto rhs = binOp.getRhs(); + isVectorized = binLikeOpVectorize(lhs, rhs, visited, vectorizedOps); + }) + .Case([&](auto unaryOp) { + isVectorized = vectorize(unaryOp.getOperand().getDefiningOp(), + visited, vectorizedOps); + }); + + if (!isVectorized) { + if (dumpFlag) { + LLVM_DEBUG({ + op->dump(); + llvm_unreachable("[Vectorization]: Unsupported Operation"); + }); + } + return false; + } + + // Dont Need To Vectorize ReduceOp's Result + if (auto reduceOp = dyn_cast(op)) + return true; + + for (Operation *user : op->getUsers()) { + if (visited.contains(user)) + continue; + + // FIXME: We've omitted the `other` value of LoadOp when create GM2LMOp in + // the past. However, `other` value comes back as we are about to separate + // GM2LMOp and LoadOp, and it will lead to a user LoadOp be in the + // vectorization path. Actions should be taken to handle this case. Here + // we workaround to skip LoadOp's `other` value. + if (auto loadOp = dyn_cast(user)) { + if (op == loadOp.getOther().getDefiningOp()) { + continue; + } + } + + if (!vectorize(user, visited, vectorizedOps)) + return false; + } + + vectorizedOps.insert(op); + return true; + } + + RankedTensorType getVectorType(Type tensorType, unsigned _elemWidth = 0) { + unsigned numElems = getTotalElemsPerThread(tensorType); + Type elemTy = getElementTypeOrSelf(tensorType); + auto elemWidth = + _elemWidth == 0 ? elemTy.getIntOrFloatBitWidth() : _elemWidth; + auto vectorWidth = 512 / elemWidth; + + RankedTensorType newTensorTy; + + if (numElems % vectorWidth == 0 && + numElems != 0) { // normal vector<16xf32>/vector<32xf16> + // Step 1. getVectorType + VectorType newVectorType = mlir::VectorType::get(vectorWidth, elemTy); + + // Step 2. getShape + RankedTensorType oriTensorTy = mlir::cast(tensorType); + auto oriShape = oriTensorTy.getShape(); + llvm::SmallVector newShape(oriShape.begin(), oriShape.end()); + auto rank = oriShape.size(); + newShape[rank - 1] /= vectorWidth; + + // Step 3. getEncoding + auto oriEncoding = + mlir::cast(oriTensorTy.getEncoding()); + auto sizePerCore = oriEncoding.getSizePerCore().vec(); + auto corePerGroup = oriEncoding.getCoresPerGroup().vec(); + auto groupsPerCluster = oriEncoding.getGroupsPerCluster().vec(); + auto order = oriEncoding.getOrder().vec(); + auto isReduceOpt = oriEncoding.getIsReduceOpt(); + + sizePerCore[rank - 1] = + std::max(1, int(sizePerCore[rank - 1] / vectorWidth)); + + auto newEncoding = triton::xpu::ClusterLayoutAttr::get( + tensorType.getContext(), sizePerCore, corePerGroup, groupsPerCluster, + order, isReduceOpt); + + // Step 4. create RankedTensorType + newTensorTy = RankedTensorType::get(newShape, newVectorType, newEncoding); + } else if (numElems == 1) { // special vector<1xf32> + // Step 1. getVectorType + VectorType newVectorType = mlir::VectorType::get(1, elemTy); + // Step 2. getEncoding + auto newEncoding = triton::xpu::ClusterLayoutAttr::get( + tensorType.getContext(), {1}, {4}, {16}, {0}, false); + // Step 3. create RankedTensorType + newTensorTy = RankedTensorType::get(1, newVectorType, newEncoding); + } else { + llvm_unreachable( + "Only Supported vector<32xTy> or vector<16xTy> or vector<1xTy>"); + } + return newTensorTy; + } + + void processOpVecTy(OperationTree &vectorizedOps, ModuleOp &mod) { + for (auto *op : vectorizedOps) { + TypeSwitch(op) + .Case([&](auto loadOp) { + auto newVectorizedTensorTy = + getVectorType(loadOp.getResult().getType()); + loadOp.getResult().setType(newVectorizedTensorTy); + }) + .Case([&](auto lm2gmOp) { (void)lm2gmOp; }) + .Case([&](auto storeOp) { (void)storeOp; }) + .Case([&](auto binOp) { + auto newVectorizedTensorTy = + getVectorType(binOp.getResult().getType()); + auto newBinOp = createBinVectorizedOp(binOp, newVectorizedTensorTy); + binOp.replaceAllUsesWith(newBinOp.getResult()); + binOp.erase(); + }) + .Case([&](auto binOp) { + auto newVectorizedTensorTy = + getVectorType(binOp.getResult().getType()); + auto newBinOp = createBinVectorizedOp(binOp, newVectorizedTensorTy); + binOp.replaceAllUsesWith(newBinOp.getResult()); + binOp.erase(); + }) + .Case([&](auto unaryOp) { + auto newVectorizedTensorTy = + getVectorType(unaryOp.getResult().getType()); + auto newUnaryOp = + createUnaryVectorizedOp(unaryOp, newVectorizedTensorTy); + unaryOp.replaceAllUsesWith(newUnaryOp.getResult()); + unaryOp.erase(); + }) + .Case([&](auto unaryOp) { + auto newVectorizedTensorTy = + getVectorType(unaryOp.getResult().getType()); + auto newUnaryOp = + createUnaryVectorizedOp(unaryOp, newVectorizedTensorTy); + unaryOp.replaceAllUsesWith(newUnaryOp.getResult()); + unaryOp.erase(); + }) + .Case([&](auto constOp) { + auto newVectorizedTensorTy = + getVectorType(constOp.getResult().getType()); + OpBuilder builder(constOp); + auto newConstOp = builder.create( + constOp.getLoc(), newVectorizedTensorTy, constOp.getValue()); + constOp.replaceAllUsesWith(newConstOp.getResult()); + constOp.erase(); + }) + .Case([&](auto splatOp) { + auto newVectorizedTensorTy = + getVectorType(splatOp.getResult().getType()); + OpBuilder builder(splatOp); + auto newSplatOp = builder.create( + splatOp.getLoc(), newVectorizedTensorTy, splatOp.getOperand()); + splatOp.replaceAllUsesWith(newSplatOp.getResult()); + splatOp.erase(); + }) + .Case([&](auto forOp) { + auto forBody = forOp.getBody(); + auto forArgs = forBody->getArguments(); + // TODO[dyq]: check getIterOperands -> getInitArgs + auto iterArgsInitValues = forOp.getInitArgs(); + assert(iterArgsInitValues.size() == 1 && + "[Vectorization]: Only Support ForOp with One Iter Args"); + Value iterArgInitValue = iterArgsInitValues.front(); + auto newVectorizedTensorTy = iterArgInitValue.getType(); + + // 1. Change Input Iter Args Type + forArgs[1].setType(newVectorizedTensorTy); + + // 2. Change Output Type + forOp.getResult(0).setType(newVectorizedTensorTy); + }) + .Case([&](auto ifOp) { + // 1. Get Terminator Type + Region &thenRegion = ifOp.getThenRegion(); + Block &thenBlock = thenRegion.front(); + Operation *thenTerminator = thenBlock.getTerminator(); + + Type resType; + if (auto yieldOp = dyn_cast(thenTerminator)) { + if (auto prevOp = yieldOp.getOperands().front().getDefiningOp()) { + resType = prevOp->getResult(0).getType(); + } + } else { + resType = thenTerminator->getResult(0).getType(); + } + + // 2. Change Output Type + ifOp.getResult(0).setType(resType); + }) + .Case([&](auto yieldOp) { (void)yieldOp; }) + .Case([&](auto cvtOp) { + auto newVectorizedTensorTy = + getVectorType(cvtOp.getResult().getType()); + cvtOp.getResult().setType(newVectorizedTensorTy); + }) + .Case([&](auto selectOp) { + auto newVectorizedTensorTy = + getVectorType(selectOp.getResult().getType()); + OpBuilder builder(selectOp); + auto newSelectOp = builder.create( + selectOp.getLoc(), newVectorizedTensorTy, + selectOp.getCondition(), selectOp.getTrueValue(), + selectOp.getFalseValue()); + selectOp.replaceAllUsesWith(newSelectOp.getResult()); + selectOp.erase(); + }) + .Case([&](auto cmpFOp) { + auto rhsTy = cmpFOp.getRhs().getType(); + Type elemTy = getElementTypeOrSelf(getElementTypeOrSelf(rhsTy)); + auto newVectorizedTensorTy = getVectorType( + cmpFOp.getResult().getType(), elemTy.getIntOrFloatBitWidth()); + OpBuilder builder(cmpFOp); + auto newCmpFOp = builder.create( + cmpFOp.getLoc(), newVectorizedTensorTy, cmpFOp.getPredicate(), + cmpFOp.getLhs(), cmpFOp.getRhs()); + cmpFOp.replaceAllUsesWith(newCmpFOp.getResult()); + cmpFOp.erase(); + }) + .Case([&](auto broadCastOp) { + auto newVectorizedTensorTy = + getVectorType(broadCastOp.getResult().getType()); + broadCastOp.getResult().setType(newVectorizedTensorTy); + }) + .Case([&](auto expandOp) { + auto newVectorizedTensorTy = + getVectorType(expandOp.getResult().getType()); + expandOp.getResult().setType(newVectorizedTensorTy); + }) + .Case([&](auto extElemwiseOp) { + auto symbol = extElemwiseOp.getSymbol(); + OpBuilder builder(extElemwiseOp); + auto newVectorizedTensorTy = + getVectorType(extElemwiseOp.getResult().getType()); + if (symbol == "_ZN3xpu5tanhfEf") { + auto newExtElemwiseOp = + createLibdeviceOp(extElemwiseOp, "_ZN3xpu6vtanhfEDv16_f", + newVectorizedTensorTy); + extElemwiseOp.replaceAllUsesWith(newExtElemwiseOp.getResult()); + extElemwiseOp.erase(); + } else if (symbol == "_ZN3xpu3erfEf") { + auto newExtElemwiseOp = createLibdeviceOp( + extElemwiseOp, "_ZN3xpu4verfEDv16_f", newVectorizedTensorTy); + extElemwiseOp.replaceAllUsesWith(newExtElemwiseOp.getResult()); + extElemwiseOp.erase(); + } else if (symbol == "_ZN3xpu5isinfEf") { + auto newExtElemwiseOp = + createLibdeviceOp(extElemwiseOp, "_ZN3xpu6visinfEDv16_f", + newVectorizedTensorTy); + extElemwiseOp.replaceAllUsesWith(newExtElemwiseOp.getResult()); + extElemwiseOp.erase(); + } else { + LLVM_DEBUG(llvm::dbgs() + << "[Vectorization]: Can not Convert Symbol " << symbol + << " to Vfunc\n"); + } + }) + .Default([&](auto &op) { + LLVM_DEBUG(op->dump()); + llvm_unreachable( + "[Vectorization]: Unsupported Operation Type To VecType"); + }); + } + } + + bool inline vectorizedTyValid(Type elemTy) { + if (elemTy.isF16() || elemTy.isF32() || elemTy.isBF16() || + elemTy.isInteger(16) || elemTy.isInteger(32)) + return true; + return false; + } + + void vectorizeAndProcessOpVecTy(ModuleOp &mod, Operation *rootOp, + Type rootOpTy, std::string logMessage) { + auto rowsPerCore = 1; + if (auto rootOpTensorTy = mlir::dyn_cast(rootOpTy)) { + auto rank = rootOpTensorTy.getShape().size(); + if (rank > 1) { + rowsPerCore = mlir::cast( + rootOpTensorTy.getEncoding()) + .getSizePerCore()[0]; + } + } + + unsigned numElems = getTotalElemsPerThread(rootOpTy) / rowsPerCore; + Type vecTy = getElementTypeOrSelf(rootOpTy); + Type elemTy = getElementTypeOrSelf(vecTy); + auto elemWidth = elemTy.getIntOrFloatBitWidth(); + auto vectorWidth = 512 / elemWidth; + if (numElems < vectorWidth || numElems % vectorWidth > 0 || + !vectorizedTyValid(elemTy)) + return; + + OperationTree visited; + OperationTree vectorizedOps; + + if (!vectorize(rootOp, visited, vectorizedOps)) + return; + + LLVM_DEBUG({ + llvm::errs() << logMessage << "\n"; + if (dumpFlag) { + for (auto vecOp : vectorizedOps) + vecOp->dump(); + } + }); + + auto encoding = mlir::cast(rootOpTy).getEncoding(); + + processOpVecTy(vectorizedOps, mod); + } + + void maximumFusion(arith::SelectOp selectOp) { + if (auto orIOp = selectOp.getCondition().getDefiningOp()) { + if (orIOp.getResult().hasOneUse()) { + auto lhs = orIOp.getLhs().getDefiningOp(); + auto rhs = orIOp.getRhs().getDefiningOp(); + + bool isMax = (lhs.getPredicate() == arith::CmpFPredicate::OGT && + rhs.getPredicate() == arith::CmpFPredicate::UNE) || + (lhs.getPredicate() == arith::CmpFPredicate::UNE && + rhs.getPredicate() == arith::CmpFPredicate::OGT); + bool isMin = (lhs.getPredicate() == arith::CmpFPredicate::OLT && + rhs.getPredicate() == arith::CmpFPredicate::UNE) || + (lhs.getPredicate() == arith::CmpFPredicate::UNE && + rhs.getPredicate() == arith::CmpFPredicate::OLT); + + if (lhs && rhs && lhs.getResult().hasOneUse() && + rhs.getResult().hasOneUse()) { + OpBuilder builder(selectOp); + if (isMax) { + auto newMaxFOp = builder.create( + selectOp.getLoc(), selectOp.getType(), selectOp.getTrueValue(), + selectOp.getFalseValue()); + selectOp->replaceAllUsesWith(newMaxFOp); + selectOp->erase(); + orIOp->erase(); + lhs->erase(); + rhs->erase(); + LLVM_DEBUG(llvm::dbgs() + << "[Vectorization]: Apply Maximum Fusion Optimization " + "For VVMax.\n"); + } else if (isMin) { + auto newMinFOp = builder.create( + selectOp.getLoc(), selectOp.getType(), selectOp.getTrueValue(), + selectOp.getFalseValue()); + selectOp->replaceAllUsesWith(newMinFOp); + selectOp->erase(); + orIOp->erase(); + lhs->erase(); + rhs->erase(); + LLVM_DEBUG(llvm::dbgs() + << "[Vectorization]: Apply Minimum Fusion Optimization " + "For VVMin.\n"); + } + } + } + } + } + + bool isLoadVectorized(triton::xpu::LoadOp loadOp) { + Type resTy = loadOp.getType(); + Type resElemTy = getElementTypeOrSelf(resTy); + return mlir::isa(resElemTy); + } + + bool SVOptimization_Cond(Operation *op) { + bool canSVOpt = false; + // TODO: Check block Argument + if (!op) + return canSVOpt; + + TypeSwitch(op) + .Case([&](auto loadOp) { + auto gm2lmOp = cast(loadOp->getPrevNode()); + OffsetState offsetState = + static_cast(gm2lmOp.getOffsetState()); + if (offsetState == OffsetState::DiscreteSame && + isLoadVectorized(loadOp)) + canSVOpt = true; + }) + .Case([&](auto bcOp) { + auto src = bcOp.getSrc(); + if (auto srcTy = mlir::dyn_cast(src.getType())) { + auto srcShape = srcTy.getShape(); + if (srcShape.size() == 2 && srcShape[0] == 64 && srcShape[1] == 1) { + canSVOpt = true; + } + } + }) + .Case([&](auto vConstOp) { canSVOpt = true; }) + .Default([&](auto &op) { canSVOpt = false; }); + + return canSVOpt; + } + + bool collectVUser(Operation *op, DenseMap &vBinOps) { + // To check if the collection was successful. + bool canSVOpt = true; + for (auto user : op->getUsers()) { + TypeSwitch(user) + .Case([&](auto vBinOp) { + auto lDefineOp = + vBinOp.getLhs().getDefiningOp(); // getLhs define op + auto rDefineOp = vBinOp.getRhs().getDefiningOp(); + + bool lCond = SVOptimization_Cond(lDefineOp); + bool rCond = SVOptimization_Cond(rDefineOp); + + bool opIsLhs = lDefineOp == op; + + if ((opIsLhs ? lCond : rCond) && (lCond != rCond)) { + vBinOps[vBinOp] = opIsLhs ? ElemState::SV : ElemState::VS; + } else { + canSVOpt = false; + } + }) + .Default([&](auto &user) { canSVOpt = false; }); + + if (!canSVOpt) + break; + } + return canSVOpt; + } + + void SVOptimization_Modify(triton::xpu::LoadOp loadOp) { + // Get Information + Type tensorType = loadOp.getType(); + Type vecElemTy = getElementTypeOrSelf(tensorType); + + // vecNums / numElems (all vector<16xTy> use one same Ty) + unsigned vecNums = + mlir::cast(tensorType).getNumElements(); + unsigned numElems = getTotalElemsPerThread(tensorType); + + // elem type + Type elemTy = getElementTypeOrSelf(vecElemTy); + + // encoding + auto encoding = mlir::cast( + mlir::cast(tensorType).getEncoding()); + + std::vector sizePerCore = {1}; // 1 for scalar + Attribute newEncoding = triton::xpu::ClusterLayoutAttr::get( + encoding.getContext(), sizePerCore, encoding.getCoresPerGroup(), + encoding.getGroupsPerCluster(), encoding.getOrder(), + encoding.getIsReduceOpt()); + + Type newTensorType = RankedTensorType::get( + ceil(vecNums, numElems), elemTy, newEncoding); + + // Replace Origin Op + OpBuilder builder(loadOp); + loadOp->setAttr("SVOpt", builder.getBoolAttr(true)); + loadOp->getResult(0).setType(newTensorType); + } + + // To check if the SVOptimization(Own) was successful. + void SVOptimization_Modify(triton::xpu::BroadcastOp vBCOp) { + auto src = vBCOp.getSrc(); + vBCOp.replaceAllUsesWith(src); + vBCOp.erase(); + } + + // To check if the SVOptimization(Own) was successful. + void SVOptimization_Modify(triton::xpu::VConstOp vConstOp) { + auto res = vConstOp.getResult(); + auto resTy = mlir::cast(res.getType()); + triton::xpu::ClusterLayoutAttr vConstOpEncoding = + mlir::cast(resTy.getEncoding()); + unsigned rank = resTy.getRank(); + + auto elemTy = getElementTypeOrSelf(vConstOp.getType()); + auto _elemTy = getElementTypeOrSelf(elemTy); + RankedTensorType newSrcTy; + if (rank == 1) { + newSrcTy = + RankedTensorType::get({/*core_num=*/64}, _elemTy, vConstOpEncoding); + } else if (rank == 2) { + newSrcTy = RankedTensorType::get({/*core_num=*/64, 1}, _elemTy, + vConstOpEncoding); + } else { + llvm_unreachable("Got Unsupport Rank"); + } + + // TODO[dyq]: dyn_cast -> cast + auto oriDenseAttr = + mlir::dyn_cast(vConstOp.getValue()); + auto initValue = DenseElementsAttr::getFromRawBuffer( + newSrcTy, oriDenseAttr.getRawData()); + + OpBuilder builder(vConstOp); + auto newConstOp = builder.create(vConstOp.getLoc(), + newSrcTy, initValue); + vConstOp.replaceAllUsesWith(newConstOp.getResult()); + vConstOp.erase(); + } + + template void createSVBinOp(T vBinOp, ElemState elemStateInt) { + if (elemStateInt == ElemState::VS) { + // SVSUB Has A Strict Order Of Operations. + // V-S -> -S+V + if constexpr (std::is_same_v) { + OpBuilder builder(vBinOp); + auto negFOp = + builder.create(vBinOp.getLoc(), vBinOp.getRhs()); + auto svBinFOp = builder.create( + vBinOp.getLoc(), vBinOp.getType(), vBinOp.getLhs(), negFOp, + static_cast(elemStateInt)); + vBinOp.replaceAllUsesWith(svBinFOp.getResult()); + vBinOp.erase(); + LLVM_DEBUG(llvm::dbgs() + << "[Vectorization]: Apply VSSUB -> SVADD Optimization.\n"); + return; + } + } + + OpBuilder builder(vBinOp); + auto svBinFOp = builder.create::type>( + vBinOp.getLoc(), vBinOp.getType(), vBinOp.getLhs(), vBinOp.getRhs(), + static_cast(elemStateInt)); + vBinOp.replaceAllUsesWith(svBinFOp.getResult()); + vBinOp.erase(); + } + + void VvOpToSvOp(DenseMap &vBinOps, + std::string logMessage) { + for (auto &pair : vBinOps) { + auto op = pair.first; + auto elemStateInt = pair.second; + TypeSwitch(op) + .Case( + [&](auto vBinOp) { createSVBinOp(vBinOp, elemStateInt); }) + .Default([&](auto &op) { + llvm_unreachable( + "[Vectorization]: Got An Unexpected SV Operation Type"); + }); + } + LLVM_DEBUG(llvm::dbgs() << logMessage); + } + + template void SVOptimization(T op, std::string logMessage) { + // Step 1. collect all vUser + DenseMap vBinOps; + if (!collectVUser(op, vBinOps)) + return; + + // Step 2. Deal Input Op Own Modification + SVOptimization_Modify(op); + + // Step 3. Deal Input Op's User Modification + VvOpToSvOp(vBinOps, logMessage); + } + + // Simpify Mod Graph + // TODO[dyq]: use canonicalizer + void cvtOpclean(triton::gpu::ConvertLayoutOp cvtOp) { + auto src = cvtOp.getSrc(); + auto res = cvtOp.getResult(); + + if (src.getType() != res.getType()) + return; + + cvtOp.replaceAllUsesWith(src); + cvtOp.erase(); + } + + void VvdivToVvmul(triton::xpu::VvdivFOp vvdivOp) { + // Only can be optimized to vvmul when the denominator is a scalar, it can + // be further optimized to svmul + if (auto bcOp = + vvdivOp.getRhs().getDefiningOp()) { + auto src = bcOp.getSrc(); + auto res = bcOp.getResult(); + + // Check 1. Src Shape Must Be 64x1xf32 + if (auto srcTy = mlir::dyn_cast(src.getType())) { + auto srcShape = srcTy.getShape(); + if (srcShape.size() != 2 || !(srcShape[0] == 64 && srcShape[1] == 1)) { + return; + } else { + // Step 2. Create DivOp For Rhs + OpBuilder builder(bcOp); + SmallVector intValues(srcShape[1], + builder.getF32FloatAttr(1)); + DenseElementsAttr denseAttr = + DenseFPElementsAttr::get(srcTy, intValues); + auto ones = + builder.create(bcOp.getLoc(), denseAttr); + + auto oneDivByRhs = + builder.create(bcOp.getLoc(), srcTy, ones, src); + + bcOp->setOperand(0, oneDivByRhs); + + // Step 3. Change vvdiv by vvmul + OpBuilder builder_tmp(vvdivOp); + auto vvmulOp = builder_tmp.create( + vvdivOp.getLoc(), vvdivOp.getType(), vvdivOp.getLhs(), + vvdivOp.getRhs()); + vvdivOp.replaceAllUsesWith(vvmulOp->getResult(0)); + vvdivOp.erase(); + LLVM_DEBUG( + llvm::dbgs() + << "[Vectorization]: Apply VVDIV -> VVMUL Optimization.\n"); + } + } else { + return; + } + } + } + + void VVMacOpFusion(triton::xpu::VvmulFOp mulOp) { + for (auto nextOp : mulOp->getUsers()) { + if (auto addOp = dyn_cast(nextOp)) { + auto lDefineOp = addOp.getLhs().getDefiningOp(); // getLhs define op + OpBuilder builder(addOp); + auto newMacOp = builder.create( + mulOp.getLoc(), mulOp.getType(), mulOp.getLhs(), mulOp.getRhs(), + lDefineOp == mulOp ? addOp.getRhs() : addOp.getLhs()); + + addOp->replaceAllUsesWith(newMacOp); + addOp->erase(); + LLVM_DEBUG(llvm::dbgs() + << "[Vectorization]: Apply VVMacOp Fusion Optimization.\n"); + } + } + } + + void BF16ToFP32VecOptimize(ModuleOp &mod) { + // bf16Tofp32Unordered could only used in order-independent cases + bool bf16Tofp32Unordered = true; + int load_cnt = 0; + mod.walk([&](triton::xpu::LoadOp loadOp) { + load_cnt++; + Type ptrTy = loadOp.getPtr().getType(); + Type ptrElemTy = getElementTypeOrSelf(ptrTy); + Type ptrDataTy = mlir::cast(ptrElemTy).getPointeeType(); + Type resTy = loadOp.getResult().getType(); + Type resElemTy = getElementTypeOrSelf(resTy); + Type resScalarTy = getElementTypeOrSelf(resElemTy); + + if (resScalarTy.isF32() && ptrDataTy.isBF16()) { + auto stride = loadOp.getStride(); + auto tensorColSize = loadOp.getTensorColSize(); + bool isVector = mlir::isa(resElemTy); + bool isSvOpt = loadOp.getSVOpt(); + bool isDiscreteSame = stride == 0; + bool isContiguous = stride == 1; + bool notCoreDealMultiRows = tensorColSize == -1; + bf16Tofp32Unordered &= + (isVector && isContiguous && notCoreDealMultiRows) || isSvOpt || + isDiscreteSame; + } else { + bf16Tofp32Unordered &= false; + } + }); + + bf16Tofp32Unordered = load_cnt == 0 ? false : bf16Tofp32Unordered; + + mod.walk([&](triton::xpu::StoreOp storeOp) { + Value val = storeOp.getValue(); + Type valTy = val.getType(); + Type valElemTy = getElementTypeOrSelf(valTy); + if (bf16Tofp32Unordered && mlir::isa(valElemTy)) { + if (findDefOpBwd(val)) { + bf16Tofp32Unordered &= false; + } + } + }); + + mod.walk([&](triton::xpu::LoadOp loadOp) { + OpBuilder builder(loadOp); + loadOp->setAttr("bf16Tofp32Unordered", + builder.getBoolAttr(bf16Tofp32Unordered)); + }); + + mod.walk([&](triton::xpu::StoreOp storeOp) { + OpBuilder builder(storeOp); + storeOp->setAttr("bf16Tofp32Unordered", + builder.getBoolAttr(bf16Tofp32Unordered)); + }); + + if (bf16Tofp32Unordered) { + LLVM_DEBUG( + llvm::dbgs() + << "[Vectorization]: Apply BF16ToFP32VecUnordered Optimization.\n"); + } + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + // Maximum Fusion Online + // [cmpf, cmpf, ori, select] -> [fmax] + if (Maximum_Fusion) { + mod.walk([&](arith::SelectOp selectOp) { maximumFusion(selectOp); }); + } + + // Eliminate SelectOp For bufferSize X Col Size + // TODO[dyq]: open isMultipleOfBank + // if (isMultipleOfBank(mod)) { + // mod.walk([&](arith::SelectOp selectOp) { + // // Have Only One User(ReduceOp) + // if (selectOp.getResult().hasOneUse()) { + // auto userOp = *selectOp->user_begin(); + // if (auto redOp = dyn_cast(userOp)) { + // auto trueVal = selectOp.getTrueValue(); + // auto trueValOp = trueVal.getDefiningOp(); + + // selectOp->replaceAllUsesWith(trueValOp->getResults()); + // selectOp->erase(); + // LLVM_DEBUG(llvm::dbgs() << "[Vectorization]: Eliminate SelectOp + // For " + // "bufferSize X Col Size.\n"); + // } + // } + // }); + // } + + if (ReduceVec) { + // For [Load -> Reduce] || [Broadcast -> Reduce] + mod.walk([&](triton::xpu::ReduceOp redOp) { + for (int i = 0; i < redOp.getOperands().size() - 1; ++i) { + auto reduceOperand = redOp.getOperands()[i]; + auto reduceOperandOp = reduceOperand.getDefiningOp(); + auto reduceOperandTy = reduceOperand.getType(); + vectorizeAndProcessOpVecTy(mod, reduceOperandOp, reduceOperandTy, + "[Vectorization]: [Load -> " + "Reduce] || [Broadcast -> Reduce] Hit."); + } + + ReduceOpHelper help(redOp); + if (help.isVectorized()) { + // reduceop's correct encoding should be inferd by its input type. + auto srcLayout = help.getSrcLayout(); + for (Value redRes : redOp.getResults()) { + if (auto resTy = dyn_cast(redRes.getType())) { + auto resSliceEncoding = + cast(resTy.getEncoding()); + auto srcClusterEncoding = + cast(srcLayout); + auto newEncoding = triton::gpu::SliceEncodingAttr::get( + redOp.getContext(), resSliceEncoding.getDim(), + srcClusterEncoding); + auto newResTy = RankedTensorType::get( + resTy.getShape(), resTy.getElementType(), newEncoding); + redRes.setType(newResTy); + } + } + + for (Block &block : redOp.getCombineOp().getBlocks()) { + // Set Arg's Type to VecType + auto inputTypes = redOp.getInputTypes(); + auto inputSize = inputTypes.size(); + int vecSize = 16; + for (int i = 0; i < inputSize; ++i) { + auto vecTy = getElementTypeOrSelf(inputTypes[i]); + vecSize = cast(vecTy).getNumElements(); + auto arg1 = block.getArguments()[i]; + auto arg2 = block.getArguments()[inputSize + i]; + arg1.setType(vecTy); + arg2.setType(vecTy); + } + // Set CombineOp's Type to VecType + for (auto &op : block) { + TypeSwitch(&op) + .Case([&](auto redComOp) { + for (auto res : redComOp->getResults()) { + auto elemTy = res.getType(); + VectorType vecType = VectorType::get(vecSize, elemTy); + res.setType(vecType); + } + }) + .Default([&](auto defaultOp) { + LLVM_DEBUG(defaultOp->dump()); + llvm_unreachable( + "[Vectorization]: Unsupported Operation Type " + "To VecType in Reduce"); + }); + } + } + } + }); + } + + // For [Broadcast -> Store] + mod.walk([&](triton::xpu::StoreOp storeOp) { + auto storeOpValueTy = storeOp.getValue().getType(); + vectorizeAndProcessOpVecTy(mod, storeOp, storeOpValueTy, + "[Vectorization]: [Broadcast -> Store] Hit."); + }); + + // Eliminate CvtOp in VVOp Path + if (cvtOp_clean) { + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) { cvtOpclean(cvtOp); }); + } + + // Div -> Mul + if (Div2Mul) { + mod.walk([&](triton::xpu::VvdivFOp vvdivFOp) { VvdivToVvmul(vvdivFOp); }); + } + + // SV Optimization offline + if (SV_Fusion) { + // SVOptimization For LoadOp + mod.walk([&](triton::xpu::LoadOp vLoadOp) { + vectorizedLoadOps.insert(vLoadOp); + }); + for (auto vLoadOp : vectorizedLoadOps) { + SVOptimization(vLoadOp, + "[Vectorization]: Apply SV Optimization For LoadOp.\n"); + } + + // SVOptimization For BroadcastOp + mod.walk([&](triton::xpu::BroadcastOp vBCOp) { + vectorizedBcOps.insert(vBCOp); + }); + for (auto vBCOp : vectorizedBcOps) { + SVOptimization( + vBCOp, "[Vectorization]: Apply SV Optimization For BroadcastOp.\n"); + } + + // SVOptimization For ConstOp + mod.walk([&](triton::xpu::VConstOp vConstOp) { + vectorizedConstOps.insert(vConstOp); + }); + for (auto vConstOp : vectorizedConstOps) { + SVOptimization( + vConstOp, "[Vectorization]: Apply SV Optimization For VConstOp.\n"); + } + } + + // MAC Optimization offline + if (VMAC_Fusion) { + mod.walk([&](triton::xpu::VvmulFOp vvmulFOp) { // must walk after svOpt + vvmulFOps.insert(vvmulFOp); + }); + + for (auto vvmulFOp : vvmulFOps) { + VVMacOpFusion(vvmulFOp); + } + } + + // bfloat16 -> float32 Vector Optimization + if (BF16ToFP32VecOpt) { + BF16ToFP32VecOptimize(mod); + } + } + +private: + llvm::SetVector vvmulFOps; + llvm::SetVector vectorizedBcOps; + llvm::SetVector vectorizedLoadOps; + llvm::SetVector vectorizedConstOps; + bool Maximum_Fusion = true; + bool SV_Fusion = true; + bool VMAC_Fusion = true; + bool cvtOp_clean = true; + bool Div2Mul = true; + bool ReduceVec = true; + bool BF16ToFP32VecOpt = true; +}; + +} // namespace xpu +} // namespace triton +} // namespace mlir diff --git a/third_party/xpu/lib/Target/CMakeLists.txt b/third_party/xpu/lib/Target/CMakeLists.txt new file mode 100644 index 000000000..85a4cf843 --- /dev/null +++ b/third_party/xpu/lib/Target/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(LLVMIR) +add_subdirectory(LLVMXPU) diff --git a/third_party/xpu/lib/Target/LLVMIR/CMakeLists.txt b/third_party/xpu/lib/Target/LLVMIR/CMakeLists.txt new file mode 100644 index 000000000..f2f9adf8f --- /dev/null +++ b/third_party/xpu/lib/Target/LLVMIR/CMakeLists.txt @@ -0,0 +1,28 @@ +add_triton_library(TritonLLVMIR + LLVMDIScope.cpp + LLVMIRBreakPhiStruct.cpp + + DEPENDS + LLVMIRIncGen + + LINK_LIBS + ${CMAKE_DL_LIBS} + PUBLIC + MLIRArithToLLVM + MLIRBuiltinToLLVMIRTranslation + MLIRIndexToLLVM + MLIRIR + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRNVVMToLLVMIRTranslation + MLIRROCDLToLLVMIRTranslation + MLIRSCFToControlFlow + MLIRSupport + MLIRTargetLLVMIRExport + TritonGPUToLLVM + ) + +set_source_files_properties( + LLVMIRTranslation.cpp + PROPERTIES + COMPILE_FLAGS "-D__BUILD_DIR__=\\\"${CMAKE_BINARY_DIR}\\\"") diff --git a/third_party/xpu/lib/Target/LLVMIR/LLVMDIScope.cpp b/third_party/xpu/lib/Target/LLVMIR/LLVMDIScope.cpp new file mode 100644 index 000000000..af7079060 --- /dev/null +++ b/third_party/xpu/lib/Target/LLVMIR/LLVMDIScope.cpp @@ -0,0 +1,161 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "triton/Target/LLVMIR/Passes.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Path.h" + +//===----------------------------------------------------------------------===// +// This file implements a pass to add debug info scope to LLVM operations, and +// is inspired by the DIScopeForLLVMFuncOpPass in LLVM/MLIR. Different from the +// DIScopeForLLVMFuncOpPass, this pass also handles inlined functions. +//===----------------------------------------------------------------------===// + +using namespace mlir; + +#define GEN_PASS_CLASSES +#include "triton/Target/LLVMIR/Passes.h.inc" + +namespace { + +/// Attempt to extract a filename for the given loc. +FileLineColLoc extractFileLoc(Location loc) { + if (auto fileLoc = dyn_cast(loc)) + return fileLoc; + if (auto nameLoc = dyn_cast(loc)) + return extractFileLoc(nameLoc.getChildLoc()); + if (auto opaqueLoc = dyn_cast(loc)) + return extractFileLoc(opaqueLoc.getFallbackLocation()); + if (auto fusedLoc = dyn_cast(loc)) + return extractFileLoc(fusedLoc.getLocations().front()); + if (auto callerLoc = dyn_cast(loc)) + return extractFileLoc(callerLoc.getCaller()); + StringAttr unknownFile = mlir::StringAttr::get(loc.getContext(), ""); + return mlir::FileLineColLoc::get(unknownFile, 0, 0); +} + +/// Add a debug info scope to LLVMFuncOp that are missing it. +struct LLVMDIScopePass : public LLVMDIScopeBase { + LLVMDIScopePass() = default; + + void setSubprogramAttr(LLVM::LLVMFuncOp funcOp) { + Location loc = funcOp.getLoc(); + if (loc->findInstanceOf>()) + return; + + MLIRContext *context = &getContext(); + + // To find a DICompileUnitAttr attached to a parent (the module for + // example), otherwise create a default one. + LLVM::DICompileUnitAttr compileUnitAttr; + if (ModuleOp module = funcOp->getParentOfType()) { + auto fusedCompileUnitAttr = + module->getLoc() + ->findInstanceOf>(); + if (fusedCompileUnitAttr) + compileUnitAttr = fusedCompileUnitAttr.getMetadata(); + } + + // Filename, line and colmun to associate to the function. + LLVM::DIFileAttr fileAttr; + int64_t line = 1, col = 1; + FileLineColLoc fileLoc = extractFileLoc(loc); + if (!fileLoc && compileUnitAttr) { + fileAttr = compileUnitAttr.getFile(); + } else if (!fileLoc) { + fileAttr = LLVM::DIFileAttr::get(context, "", ""); + } else { + line = fileLoc.getLine(); + col = fileLoc.getColumn(); + StringRef inputFilePath = fileLoc.getFilename().getValue(); + fileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(inputFilePath), + llvm::sys::path::parent_path(inputFilePath)); + } + auto subroutineTypeAttr = + LLVM::DISubroutineTypeAttr::get(context, llvm::dwarf::DW_CC_normal, {}); + + // Figure out debug information (`subprogramFlags` and `compileUnitAttr`) to + // attach to the function definition / declaration. External functions are + // declarations only, and are defined in a different compile unit, so mark + // them appropriately in `subprogramFlags`, and set an empty + // `compileUnitAttr`. + DistinctAttr distinctId; + auto subprogramFlags = LLVM::DISubprogramFlags::Optimized; + if (!funcOp.isExternal()) { + distinctId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context)); + if (!compileUnitAttr) { + compileUnitAttr = LLVM::DICompileUnitAttr::get( + distinctId, llvm::dwarf::DW_LANG_C, fileAttr, + StringAttr::get(context, "triton"), + /*isOptimized=*/true, LLVM::DIEmissionKind::LineTablesOnly); + } + subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition; + } else { + compileUnitAttr = {}; + } + + StringAttr funcNameAttr = funcOp.getNameAttr(); + // Note that scopeline is set differently from LLVM's + // DIScopeForLLVMFuncOpPass. I don't find reasons why scopeline should be + // the column offset + auto subprogramAttr = LLVM::DISubprogramAttr::get( + context, distinctId, compileUnitAttr, fileAttr, funcNameAttr, + funcNameAttr, fileAttr, + /*line=*/line, + /*scopeline=*/line, subprogramFlags, subroutineTypeAttr); + funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); + } + + // Get a nested loc for inlined functions + Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr, + Location calleeLoc) { + auto calleeFileName = extractFileLoc(calleeLoc).getFilename(); + auto context = op->getContext(); + LLVM::DIFileAttr calleeFileAttr = LLVM::DIFileAttr::get( + context, llvm::sys::path::filename(calleeFileName), + llvm::sys::path::parent_path(calleeFileName)); + auto lexicalBlockFileAttr = LLVM::DILexicalBlockFileAttr::get( + context, scopeAttr, calleeFileAttr, /*discriminator=*/0); + Location loc = calleeLoc; + if (mlir::isa(calleeLoc)) { + auto nestedLoc = mlir::cast(calleeLoc).getCallee(); + loc = getNestedLoc(op, lexicalBlockFileAttr, nestedLoc); + } + return FusedLoc::get(context, {loc}, lexicalBlockFileAttr); + } + + void setLexicalBlockFileAttr(Operation *op) { + auto opLoc = op->getLoc(); + if (auto callSiteLoc = dyn_cast(opLoc)) { + auto callerLoc = callSiteLoc.getCaller(); + auto calleeLoc = callSiteLoc.getCallee(); + LLVM::DIScopeAttr scopeAttr; + // We assemble the full inline stack so the parent of this loc must be a + // function + auto funcOp = op->getParentOfType(); + auto funcOpLoc = mlir::cast(funcOp.getLoc()); + scopeAttr = mlir::cast(funcOpLoc.getMetadata()); + auto loc = + CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc); + op->setLoc(loc); + } + } + + void runOnOperation() override { + getOperation()->walk([&](Operation *op) -> void { + if (isa(op)) + setSubprogramAttr(cast(op)); + else + setLexicalBlockFileAttr(op); + }); + } +}; + +} // end anonymous namespace + +std::unique_ptr mlir::createLLVMDIScopePass() { + return std::make_unique(); +} diff --git a/third_party/xpu/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp b/third_party/xpu/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp new file mode 100644 index 000000000..44afcfd21 --- /dev/null +++ b/third_party/xpu/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +/// Implements a trivial pass breaking up 1 level deep structure in phi nodes. +/// This handles the common case generated by Triton and allow better +/// optimizations down the compiler pipeline. +//===----------------------------------------------------------------------===// +#include "LLVMPasses.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiStruct(PHINode *phiNode) { + StructType *STy = dyn_cast(phiNode->getType()); + if (!STy) + return false; + IRBuilder<> builder(phiNode); + unsigned numOperands = phiNode->getNumIncomingValues(); + unsigned numScalarEl = STy->getNumElements(); + Value *newStruct = UndefValue::get(STy); + builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHI()); + llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP(); + for (unsigned i = 0; i < numScalarEl; i++) { + builder.SetInsertPoint(phiNode); + PHINode *newPhiNode = + builder.CreatePHI(STy->getElementType(i), numOperands); + for (unsigned j = 0; j < numOperands; ++j) { + Value *operand = phiNode->getIncomingValue(j); + builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator()); + newPhiNode->addIncoming(builder.CreateExtractValue(operand, i), + phiNode->getIncomingBlock(j)); + } + builder.restoreIP(insertInsertPt); + newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i); + insertInsertPt = builder.saveIP(); + } + phiNode->replaceAllUsesWith(newStruct); + return true; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + SmallVector PhiNodes; + for (BasicBlock &BB : F) { + for (Instruction &inst : BB) { + if (PHINode *phiNode = dyn_cast(&inst)) { + Changed |= processPhiStruct(phiNode); + continue; + } + break; + } + } + return Changed; +} + +PreservedAnalyses BreakStructPhiNodesPass::run(Function &F, + FunctionAnalysisManager &AM) { + + bool b = runOnFunction(F); + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/third_party/xpu/lib/Target/LLVMIR/LLVMPasses.h b/third_party/xpu/lib/Target/LLVMIR/LLVMPasses.h new file mode 100644 index 000000000..1dcdb2992 --- /dev/null +++ b/third_party/xpu/lib/Target/LLVMIR/LLVMPasses.h @@ -0,0 +1,16 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CodeGen.h" + +namespace llvm { + +// Pass to pre-process LLVM IR before optimization and break up phi of struct. +// Breaking up those phis into elementary types allows better optimizations +// downstream. +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; + +} // namespace llvm diff --git a/third_party/xpu/lib/Target/LLVMXPU/CMakeLists.txt b/third_party/xpu/lib/Target/LLVMXPU/CMakeLists.txt new file mode 100644 index 000000000..349210924 --- /dev/null +++ b/third_party/xpu/lib/Target/LLVMXPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(MLIRLLVMXPUToLLVMIRTranslation + LLVMXPUToLLVMIRTranslation.cpp + +# DEPENDS +# MLIRXPUConversionsIncGen + +# LINK_COMPONENTS +# Core + +# LINK_LIBS PUBLIC +# MLIRLLVMXPUDialect +# MLIRIR +# MLIRLLVMDialect +# MLIRSupport +# MLIRTargetLLVMIRExport +) diff --git a/third_party/xpu/lib/Target/LLVMXPU/LLVMXPUToLLVMIRTranslation.cpp b/third_party/xpu/lib/Target/LLVMXPU/LLVMXPUToLLVMIRTranslation.cpp new file mode 100644 index 000000000..dd2a1f000 --- /dev/null +++ b/third_party/xpu/lib/Target/LLVMXPU/LLVMXPUToLLVMIRTranslation.cpp @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +//===- LLVMXPUToLLVMIRTranslation.cpp - Translate LLVMXPU to LLVM IR +//------------===// +// +// This file implements a translation between the MLIR LLVMXPU dialect and +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +// clang-format off +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsXPU.h" //llvm::Intrinsic +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "triton/Dialect/LLVMXPU/IR/Dialect.h" +#include "triton/Target/LLVMXPU/LLVMXPUToLLVMIRTranslation.h" +// clang-format on + +using namespace mlir; +using namespace mlir::LLVM; +using mlir::LLVM::detail::createIntrinsicCall; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the LLVMXPU dialect to LLVM IR. +class LLVMXPUDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "triton/Dialect/LLVMXPU/IR/LLVMXPUConversions.inc" + return failure(); + } + + /// Attaches module-level metadata for functions marked as kernels. + LogicalResult + amendOperation(Operation *op, ArrayRef instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + auto func = dyn_cast(op); + if (!func) + return failure(); + llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); + llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName()); + + auto generateMetadata = [&](int dim, StringRef name) { + llvm::Metadata *llvmMetadata[] = { + llvm::ValueAsMetadata::get(llvmFunc), + llvm::MDString::get(llvmContext, name), + llvm::ValueAsMetadata::get(llvm::ConstantInt::get( + llvm::Type::getInt32Ty(llvmContext), dim))}; + llvm::MDNode *llvmMetadataNode = + llvm::MDNode::get(llvmContext, llvmMetadata); + moduleTranslation.getOrInsertNamedModuleMetadata("xpu.annotations") + ->addOperand(llvmMetadataNode); + }; + + return success(); + } +}; +} // namespace + +void mlir::registerLLVMXPUDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addExtension(+[](MLIRContext *ctx, XPU::LLVMXPUDialect *dialect) { + dialect->addInterfaces(); + }); +} + +void mlir::registerLLVMXPUDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerLLVMXPUDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/third_party/xpu/lib/Tools/CMakeLists.txt b/third_party/xpu/lib/Tools/CMakeLists.txt new file mode 100644 index 000000000..4b021da33 --- /dev/null +++ b/third_party/xpu/lib/Tools/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(TritonTools + LinearLayout.cpp + + DEPENDS + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + f2reduce +) diff --git a/third_party/xpu/lib/Tools/LinearLayout.cpp b/third_party/xpu/lib/Tools/LinearLayout.cpp new file mode 100644 index 000000000..75e530db5 --- /dev/null +++ b/third_party/xpu/lib/Tools/LinearLayout.cpp @@ -0,0 +1,427 @@ +#include "triton/Tools/LinearLayout.h" + +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "third_party/f2reduce/f2reduce.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir::triton { + +namespace { +using BasesT = LinearLayout::BasesT; +using llvm::Twine; + +BasesT makeBasesMap( + ArrayRef>>> bases) { + BasesT ret; + for (const auto &[inDim, inDimBases] : bases) { + ret[inDim] = inDimBases; + } + return ret; +} + +std::string stringifyBases(const BasesT &bases, + ArrayRef outDimNames) { + std::string ret; + + if (bases.empty()) + return "(empty layout)\n"; + + // TODO: Add spaces for alignment. + for (const auto &[inDim, inDimBases] : bases) { + if (inDimBases.empty()) { + ret += " - " + inDim.str() + " is a size 1 dimension\n"; + continue; + } + + ret += " - " + + join(llvm::seq(inDimBases.size()), "\n ", + [&, &inDim = inDim, &inDimBases = inDimBases](int i) { + return inDim.str() + "=" + std::to_string(1 << i) + " -> (" + + join(inDimBases[i], ", ") + ")"; + }) + + "\n"; + } + ret += "where out dims are: [" + + join(outDimNames, ", ", [](StringAttr s) { return s.str(); }) + "]\n"; + return ret; +} + +BasesT validateBases(BasesT bases, ArrayRef outDimNames) { + if (bases.empty()) + return bases; + + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (llvm::any_of(basis, [](int32_t b) { return b < 0; })) { + llvm::report_fatal_error( + "Invalid bases passed to LinearLayout. Expected all basis " + "values to be non-negative, but found a negative value for " + "in dimension '" + + Twine(inDim) + "'. Full list of bases:\n" + + stringifyBases(bases, outDimNames)); + } + } + } + + // Check that the bases all have length equal to outDimNames.size(). + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + if (basis.size() != outDimNames.size()) { + llvm::report_fatal_error( + "Invalid bases passed to LinearLayout. Expect all bases to have " + "the same size, equal to outDimNames.size() (" + + Twine(outDimNames.size()) + + "). But this failed for in dimension '" + Twine(inDim) + + "'. Full list of bases:\n" + stringifyBases(bases, outDimNames)); + } + } + } + + return bases; +} + +// Compute the rank of the matrix formed by taking the bases for the given +// outDim as columns. In other words, finds the number of linearly-independent +// bases for this output dimension. +int getMatrixRank(const LinearLayout &layout, StringAttr outDim) { + // Suppose we have a layout specified by the following key values. + // + // L(0,1) = 0b01 + // L(0,2) = 0b10 + // L(1,0) = 0b10 + // L(2,0) = 0b11 + // + // We will create one column per key value. The max bit width of these values + // is 2, so our matrix will have 2 rows. The final matrix will be + // + // | ↑ ↑ ↑ ↑ | | 0b0111 | + // | L(0,1) L(0,2) L(1,0) L(2,0) | = | 0b1001 | + // | ↓ ↓ ↓ ↓ | + int numRows = layout.getOutDimSizeLog2(outDim); + + int numCols = 0; + for (StringAttr inDim : layout.getInDimNames()) { + numCols += layout.getInDimSizeLog2(inDim); + } + + if (numCols == 0 || numRows == 0) + return 0; + + // Don't handle giant LLs. This makes some things easier; for example, each + // row can be a single uint64_t. + assert(numCols <= 64 && "LinearLayout too large"); + assert(numRows <= 64 && "LinearLayout too large"); + + // Note that `new int[n]()` is zero-initialized, whereas `new int[n]` is not. + std::unique_ptr m(new uint64_t[numRows]()); + + // Fill in the matrix. + int c = 0; + for (StringAttr inDim : layout.getInDimNames()) { + for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) { + uint64_t basis = layout.getBasis(inDim, i, outDim); + for (int j = 0; j < numRows; j++) { + m[j] |= ((basis >> j) & 1) << c; + } + c++; + } + } + + // stride is specified in number of 64-bit words per row. + f2reduce::inplace_rref_strided(m.get(), numRows, numCols, /*stride=*/1); + + // The rank of the reduced matrix is simply the number of nonzero rows. + int rank = 0; + for (int i = 0; i < numRows; i++) { + if (m[i] != 0) + rank++; + } + return rank; +} + +// Check that the given layout is surjective, i.e. that every `out` coordinate +// can be reached by some `in` coordinate. +// +// It's sufficient to check each output dimension indepedently. Still, +// it's prohibitively slow to calculate this naively. +// +// Thankfully, this is equivalent to checking that the number of +// linearly-independent bases for outDim d is equal to getOutDimSizeLog2(d). +// This can be computed by finding the rank of the matrix whose columns are +// those bases. We can compute the rank of our matrix using Gaussian +// elimination, which runs in O(n^3) for an n x n matrix. Our matrix size is +// log(product(inDimSize)) x log(outDimSize), and we do this numOutDims times, +// so this should be plenty fast overall. +void validateSurjectivity(const LinearLayout &layout) { + for (const auto &outDim : layout.getOutDimNames()) { + unsigned rank = getMatrixRank(layout, outDim); + unsigned expectedRank = layout.getOutDimSizeLog2(outDim); + if (rank != expectedRank) { + llvm::report_fatal_error( + "Invalid bases passed to LinearLayout. Expected bases to be " + "surjective, i.e. all possible output coordinates can be reached " + "by some input coordinates. But this failed for output dimension " + + Twine(outDim) + ", where we got rank " + Twine(rank) + + " instead of expected rank " + Twine(expectedRank) + + ". Full list of bases:\n" + + Twine(stringifyBases(layout.getBases(), layout.getOutDimNames()))); + } + } +} + +template +void assertDimsEqualIgnoringOrder(T &&a, U &&b) { + llvm::DenseSet as(a.begin(), a.end()); + llvm::DenseSet bs(b.begin(), b.end()); + if (as != bs) { + llvm::report_fatal_error("Dimensions must match, ignoring order, but they " + "don't. Got dims: [" + + Twine(triton::join(a, ", ")) + "] and [" + + triton::join(b, ", ") + "]"); + } +} + +} // anonymous namespace + +LinearLayout::LinearLayout(BasesT bases, ArrayRef outDimNames) + : bases(validateBases(std::move(bases), outDimNames)), + outDimNames(outDimNames.begin(), outDimNames.end()) { + validateSurjectivity(*this); +} + +LinearLayout::LinearLayout( + ArrayRef>>> bases, + ArrayRef outDimNames) + : LinearLayout(makeBasesMap(bases), outDimNames) {} + +/*static*/ LinearLayout LinearLayout::identity1D(int32_t size, + StringAttr inDimName, + StringAttr outDimName) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> powersOf2; + for (int32_t i = 1; i < size; i *= 2) { + powersOf2.emplace_back().push_back(i); + } + return LinearLayout({{inDimName, std::move(powersOf2)}}, {outDimName}); +} + +/*static*/ LinearLayout LinearLayout::zeros1D(int32_t size, + StringAttr inDimName, + StringAttr outDimName) { + if (size == 0) + return LinearLayout::empty(); + + assert(llvm::isPowerOf2_32(size)); + std::vector> zeros; + for (int i = 0; i < llvm::Log2_32(size); i++) { + zeros.emplace_back().push_back(0); + } + return LinearLayout({{inDimName, zeros}}, {outDimName}); +} + +int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const { + // Sadly SetVector doesn't provide an O(1) way to do this. + for (int i = 0; i < outDimNames.size(); ++i) { + if (outDimNames[i] == outDim) { + return i; + } + } + llvm::report_fatal_error("outDim " + Twine(outDim) + " is not in layout\n" + + toString()); +} + +int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const { + auto it = bases.find(inDim); + assert(it != bases.end()); + return it->second.size(); +} + +int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const { + // TODO(jlebar): Cache this? + int32_t outDimIdx = getOutDimIndex(outDim); + int32_t max = 0; + for (const auto &[inDim, inDimBases] : bases) { + for (const auto &basis : inDimBases) { + max = std::max(max, basis[outDimIdx]); + } + } + return max == 0 ? 0 : llvm::Log2_32(max) + 1; +} + +LinearLayout LinearLayout::transposeIns(ArrayRef newInDims) const { + assertDimsEqualIgnoringOrder(newInDims, getInDimNames()); + + BasesT newBases; + for (const auto &inDim : newInDims) { + newBases[inDim] = bases.find(inDim)->second; + } + return LinearLayout(std::move(newBases), outDimNames.getArrayRef()); +} + +LinearLayout +LinearLayout::transposeOuts(ArrayRef newOutDims) const { + assertDimsEqualIgnoringOrder(newOutDims, getOutDimNames()); + + std::vector permutation; + for (const auto &outDim : newOutDims) { + permutation.push_back(getOutDimIndex(outDim)); + } + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + std::vector newBasis; + for (int32_t i : permutation) { + newBasis.push_back(basis[i]); + } + newInDimBases.push_back(std::move(newBasis)); + } + } + return LinearLayout(std::move(newBases), newOutDims); +} + +LinearLayout operator*(LinearLayout inner, LinearLayout outer) { + // Check that elements common to both outerDimsRange and innerDimsRange appear + // in the same relative order. + auto checkCommonDims = [&](auto outerDimsRange, auto innerDimsRange) { + llvm::DenseSet outerDims(outerDimsRange.begin(), + outerDimsRange.end()); + llvm::DenseSet innerDims(innerDimsRange.begin(), + innerDimsRange.end()); + + std::vector outerCommonDims; + for (StringAttr dim : outerDimsRange) { + if (innerDims.contains(dim)) { + outerCommonDims.push_back(dim); + } + } + + std::vector innerCommonDims; + for (StringAttr dim : innerDimsRange) { + if (outerDims.contains(dim)) { + innerCommonDims.push_back(dim); + } + } + + if (outerCommonDims != innerCommonDims) { + llvm::report_fatal_error( + "Cannot multiply layouts. All in/out dimensions common to both " + "layouts must appear in the same relative order, but they " + "don't.\nOuter:\n" + + Twine(outer.toString()) + "\nInner:\n" + inner.toString()); + } + }; + + // Check that dims common to outer and inner have the same relative order. + checkCommonDims(outer.getInDimNames(), inner.getInDimNames()); + checkCommonDims(outer.getOutDimNames(), inner.getOutDimNames()); + + // Get the sizeLog2 of all input and output dimensions we're going to + // consider, in order. `inner` is more minor, so its dimensions come first. + llvm::MapVector inDimSizes; + llvm::SetVector outDimNames; + for (const auto &layout : {inner, outer}) { + for (StringAttr inDim : layout.getInDimNames()) { + inDimSizes[inDim] += layout.getInDimSizeLog2(inDim); + } + for (StringAttr outDim : layout.getOutDimNames()) { + outDimNames.insert(outDim); + } + } + BasesT allBases; + for (auto [inDimName, inDimSize] : inDimSizes) { + std::vector> &inDimBases = allBases[inDimName]; + + // Fill with zeros. + inDimBases = std::vector>( + inDimSize, std::vector(outDimNames.size(), 0)); + + for (auto [outDimIdx, outDimName] : llvm::enumerate(outDimNames)) { + if (inner.hasInDim(inDimName) && inner.hasOutDim(outDimName)) { + for (int i = 0; i < inner.getInDimSizeLog2(inDimName); i++) { + inDimBases[i][outDimIdx] = inner.getBasis(inDimName, i, outDimName); + } + } + if (outer.hasInDim(inDimName) && outer.hasOutDim(outDimName)) { + int offset = + inner.hasInDim(inDimName) ? inner.getInDimSizeLog2(inDimName) : 0; + int shift = inner.hasOutDim(outDimName) + ? inner.getOutDimSizeLog2(outDimName) + : 0; + for (int i = 0; i < outer.getInDimSizeLog2(inDimName); i++) { + inDimBases[offset + i][outDimIdx] = + outer.getBasis(inDimName, i, outDimName) << shift; + } + } + } + } + + return LinearLayout(std::move(allBases), outDimNames.getArrayRef()); +} + +SmallVector> +LinearLayout::apply(ArrayRef> ins) const { + assertDimsEqualIgnoringOrder(llvm::make_first_range(ins), getInDimNames()); + + SmallVector> ret; + for (StringAttr outDim : getOutDimNames()) { + int32_t outVal = 0; + for (auto &[inDim, val] : ins) { + for (int i = 0; i < getInDimSizeLog2(inDim); i++) { + if (val & (1 << i)) + outVal ^= getBasis(inDim, i, outDim); + } + } + ret.push_back({outDim, outVal}); + } + return ret; +} + +LinearLayout LinearLayout::compose(const LinearLayout &outer) const { + assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getInDimNames()); + + BasesT newBases; + for (const auto &[inDim, inDimBases] : bases) { + auto &newInDimBases = newBases[inDim]; + for (const auto &basis : inDimBases) { + SmallVector> bases; + for (auto [outDim, b] : llvm::zip(getOutDimNames(), basis)) { + bases.push_back({outDim, b}); + } + auto newBases = outer.apply(bases); + auto newBasesRange = llvm::make_second_range(newBases); + newInDimBases.push_back( + std::vector(newBasesRange.begin(), newBasesRange.end())); + } + } + return LinearLayout(std::move(newBases), outer.getOutDimNames()); +} + +bool operator==(LinearLayout lhs, LinearLayout rhs) { + // llvm::MapVector doesn't have an operator== :(. + if (lhs.getOutDimNames() != rhs.getOutDimNames()) + return false; + if (lhs.bases.size() != rhs.bases.size()) + return false; + for (auto it1 = lhs.bases.begin(), it2 = rhs.bases.begin(); + it1 != lhs.bases.end(); ++it1, ++it2) { + if (*it1 != *it2) + return false; + } + return true; +} + +std::string LinearLayout::toString() const { + return stringifyBases(bases, getOutDimNames()); +} + +} // namespace mlir::triton diff --git a/third_party/xpu/python/src/interpreter.cc b/third_party/xpu/python/src/interpreter.cc new file mode 100644 index 000000000..6ab7c6c75 --- /dev/null +++ b/third_party/xpu/python/src/interpreter.cc @@ -0,0 +1,435 @@ +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace { + +enum class MemSemantic { ACQUIRE_RELEASE, ACQUIRE, RELEASE, RELAXED }; + +enum class RMWOp { ADD, FADD, AND, OR, XOR, XCHG, MAX, MIN, UMIN, UMAX }; + +std::map mem_semantic_map = { + {MemSemantic::ACQUIRE_RELEASE, __ATOMIC_ACQ_REL}, + {MemSemantic::ACQUIRE, __ATOMIC_ACQUIRE}, + {MemSemantic::RELEASE, __ATOMIC_RELEASE}, + {MemSemantic::RELAXED, __ATOMIC_RELAXED}, +}; + +// Use compiler builtin atomics instead of std::atomic which requires +// each variable to be declared as atomic. +// Currently work for clang and gcc. +template T atomic_cmp(T *ptr, T val, int order) { + auto cmp = [](T old, T val) { + if constexpr (is_min) { + return old > val; + } else { + return old < val; + } + }; + // First load + T old_val = __atomic_load_n(ptr, order); + while (cmp(old_val, val)) { + if (__atomic_compare_exchange(ptr, &old_val, &val, false, order, order)) { + break; + } + } + return old_val; +} + +template T atomic_fadd(T *ptr, T val, int order) { + T old_val; + T new_val; + // First load + // Load ptr as if uint32_t or uint64_t and then memcpy to T + if constexpr (sizeof(T) == 4) { + uint32_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); + std::memcpy(&old_val, &tmp, sizeof(T)); + } else if constexpr (sizeof(T) == 8) { + uint64_t tmp = __atomic_load_n(reinterpret_cast(ptr), order); + std::memcpy(&old_val, &tmp, sizeof(T)); + } else { + throw std::invalid_argument("Unsupported data type"); + } + while (true) { + new_val = old_val + val; + if (__atomic_compare_exchange(ptr, &old_val, &new_val, false, order, + order)) { + break; + } + } + return old_val; +} + +class AtomicOp { +public: + AtomicOp(const uint64_t *ptr, size_t numel, int order) + : ptr(ptr), numel(numel), order(order) {} + + void apply() { + for (size_t i = 0; i < numel; ++i) { + applyAt(reinterpret_cast(ptr[i]), i); + } + } + + virtual ~AtomicOp() = default; + +protected: + virtual void applyAt(void *, size_t i) = 0; + + const uint64_t *ptr; + size_t numel; + int order; +}; + +template class AtomicRMWOpBase : public AtomicOp { +public: + AtomicRMWOpBase(const uint64_t *ptr, const void *val, void *ret, + const bool *mask, size_t numel, int order) + : AtomicOp(ptr, numel, order), val(val), ret(ret), mask(mask) {} + +protected: + void applyAt(void *loc, size_t i) override final { + if (mask[i]) { + *(static_cast(ret) + i) = + applyAtMasked(static_cast(loc), + *(static_cast(val) + i), order); + } + } + + virtual DType applyAtMasked(DType *loc, const DType value, int order) = 0; + + const void *val; + void *ret; + const bool *mask; +}; + +template +class AtomicRMWOp : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_add(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_fadd(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_and(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_or(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_fetch_xor(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return atomic_cmp(loc, value, order); + } +}; + +template +class AtomicRMWOp> + : public AtomicRMWOpBase { +public: + using AtomicRMWOpBase::AtomicRMWOpBase; + +protected: + DType applyAtMasked(DType *loc, const DType value, int order) override { + return __atomic_exchange_n(loc, value, order); + } +}; + +class AtomicCASOp : public AtomicOp { +public: + AtomicCASOp(const uint64_t *ptr, void *expected, const void *desired, + size_t itemsize, size_t numel, int order) + : AtomicOp(ptr, numel, order), expected(expected), desired(desired), + itemsize(itemsize) {} + +protected: + void applyAt(void *loc, size_t i) override { + // Atomic operations perform bitwise comparison, so it's safe to + // use number of bytes (itemsize) to determine the type of pointers + if (itemsize == 1) { + uint8_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 2) { + uint16_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 4) { + uint32_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else if (itemsize == 8) { + uint64_t desired_val = *(static_cast(desired) + i); + __atomic_compare_exchange_n(static_cast(loc), + static_cast(expected) + i, + desired_val, false, order, order); + } else { + // The ‘__atomic’ builtins can be used with any integral scalar or pointer + // type that is 1, 2, 4, or 8 bytes in length. 16-byte integral types are + // also allowed if ‘__int128’ (see 128-bit Integers) is supported by the + // architecture. + // https://gcc.gnu.org/onlinedocs/gcc/_005f_005fatomic-Builtins.html + throw std::invalid_argument("Invalid byte size"); + } + } + +private: + void *expected; + const void *desired; + size_t itemsize; +}; + +// This is a workaround because explicit template parameter list for lambdas is +// a C++20 extension: +// auto try_make_op = [&]() { +// if (dtype.is(pybind11::dtype::of())) { +// atomic_op = std::make_unique>(ptr, val, ret, mask, +// numel, order); +// } +// }; +template struct OpCreator { + pybind11::dtype dtype; + const uint64_t *ptr; + const void *val; + void *ret; + const bool *mask; + size_t numel; + int order; + std::unique_ptr &atomic_op; + + template void create() { + if (!atomic_op && dtype.is(pybind11::dtype::of())) { + atomic_op = std::make_unique>(ptr, val, ret, mask, + numel, order); + } + } +}; + +template +std::unique_ptr +makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, + void *ret, const bool *mask, size_t numel, int order) { + // Iterate over all supported data types, make one that matches, and return + std::unique_ptr atomic_op; + OpCreator try_make_op{dtype, ptr, val, ret, + mask, numel, order, atomic_op}; + + (try_make_op.template create(), ...); + if (!atomic_op) { + throw std::invalid_argument("Unsupported data type"); + } + // Make it a unique_ptr + return atomic_op; +} + +} // namespace + +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "RMW_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX) + .export_values(); + + m.def("load", + [](py::array_t ptr, py::array_t mask, py::array other, + py::dtype ret_dtype) -> py::array { + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptr.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } + return ret.reshape(shape); + }); + + m.def("store", + [](py::array_t ptr, py::array value, py::array_t mask) { + int numel = ptr.size(); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_value = value.reshape({numel}); + for (size_t i = 0; i < ptr.size(); ++i) { + if (reshaped_mask.at(i)) { + memcpy(reinterpret_cast(reshaped_ptr.mutable_at(i)), + reshaped_value.data(i), value.dtype().itemsize()); + } + } + }); + + m.def("atomic_rmw", + [](RMWOp rmw_op, py::array_t ptr, py::array val, + py::array_t mask, MemSemantic sem) -> py::array { + int order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = val.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array_t reshaped_mask = mask.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto *ptr_data = reshaped_ptr.data(); + auto *mask_data = reshaped_mask.data(); + auto *val_data = static_cast(reshaped_val.data()); + auto *ret_data = static_cast(ret.mutable_data()); + + std::unique_ptr atomic_op; + +#define MAKE_ATOMIC_RMW_OP(OP_NAME, ...) \ + case OP_NAME: \ + atomic_op = makeAtomicRMWOp( \ + ret_dtype, ptr_data, val_data, ret_data, mask_data, numel, order); \ + break; + + switch (rmw_op) { + MAKE_ATOMIC_RMW_OP(RMWOp::ADD, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::FADD, float, double) + MAKE_ATOMIC_RMW_OP(RMWOp::AND, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::OR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XOR, int32_t, uint32_t, int64_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MAX, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMAX, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::MIN, int32_t, int64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::UMIN, uint32_t, uint64_t) + MAKE_ATOMIC_RMW_OP(RMWOp::XCHG, int32_t, uint32_t, int64_t, + uint64_t) + default: + throw std::invalid_argument("Unsupported RMW operation"); + } + +#undef MAKE_ATOMIC_RMW_OP + + atomic_op->apply(); + return ret.reshape(shape); + }); + + m.def("atomic_cas", + [](py::array_t ptr, py::array &cmp, py::array &val, + MemSemantic sem) -> py::array { + int order = mem_semantic_map[sem]; + int numel = ptr.size(); + auto shape = + std::vector(ptr.shape(), ptr.shape() + ptr.ndim()); + auto ret_dtype = cmp.dtype(); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptr = ptr.reshape({numel}); + py::array reshaped_cmp = cmp.reshape({numel}); + py::array reshaped_val = val.reshape({numel}); + auto itemsize = cmp.itemsize(); + memcpy(static_cast(ret.mutable_data()), + static_cast(reshaped_cmp.data()), + itemsize * numel); + AtomicCASOp(reshaped_ptr.data(), ret.mutable_data(), + static_cast(reshaped_val.data()), itemsize, + numel, order) + .apply(); + return ret.reshape(shape); + }); +} diff --git a/third_party/xpu/python/src/ir.cc b/third_party/xpu/python/src/ir.cc new file mode 100644 index 000000000..2b93ee0e0 --- /dev/null +++ b/third_party/xpu/python/src/ir.cc @@ -0,0 +1,1676 @@ +#include +#include +#include + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/LocationSnapshot.h" +#include "mlir/Transforms/Passes.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace { + +namespace py = pybind11; +using namespace mlir; +using namespace triton; + +#if !defined(TRITON_CONCEAL_IR) || (TRITON_CONCEAL_IR == 0) +#define DUMP() self.dump() +#define PRINT(...) self.print(__VA_ARGS__) +#else +#define DUMP() +#define PRINT(...) +#endif + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + OpBuilder &getBuilder() { return *builder; } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(FileLineColLoc::get(context, fileName, line, column)); + } + + Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +}; + +std::string locationToString(Location loc) { + std::string str; + llvm::raw_string_ostream os(str); + loc.print(os); + os.flush(); // Make sure all the content is dumped into the 'str' string + return str; +} + +void outputWarning(Location loc, const std::string &msg) { + std::string locStr = locationToString(loc); + + PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(), + /*stack_level=*/2); +} + +} // anonymous namespace + +/*****************************************************************************/ +/* Python bindings for ir */ +/*****************************************************************************/ + +void init_triton_ir(py::module &&m) { + using ret = py::return_value_policy; + using namespace pybind11::literals; + + py::enum_(m, "PADDING_OPTION", py::module_local()) + .value("PAD_ZERO", PaddingOption::PAD_ZERO) + .value("PAD_NAN", PaddingOption::PAD_NAN) + .export_values(); + + py::enum_(m, "CACHE_MODIFIER", py::module_local()) + .value("NONE", CacheModifier::NONE) + .value("CA", CacheModifier::CA) + .value("CG", CacheModifier::CG) + .value("WB", CacheModifier::WB) + .value("CS", CacheModifier::CS) + .value("WT", CacheModifier::WT) + .export_values(); + + py::enum_(m, "MEM_SEMANTIC", py::module_local()) + .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) + .value("ACQUIRE", MemSemantic::ACQUIRE) + .value("RELEASE", MemSemantic::RELEASE) + .value("RELAXED", MemSemantic::RELAXED) + .export_values(); + + py::enum_(m, "MEM_SYNC_SCOPE", py::module_local()) + .value("GPU", MemSyncScope::GPU) + .value("CTA", MemSyncScope::CTA) + .value("SYSTEM", MemSyncScope::SYSTEM) + .export_values(); + + py::enum_(m, "EVICTION_POLICY", py::module_local()) + .value("NORMAL", EvictionPolicy::NORMAL) + .value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST) + .value("EVICT_LAST", EvictionPolicy::EVICT_LAST) + .export_values(); + + py::enum_(m, "ATOMIC_OP", py::module_local()) + .value("ADD", RMWOp::ADD) + .value("FADD", RMWOp::FADD) + .value("AND", RMWOp::AND) + .value("OR", RMWOp::OR) + .value("XOR", RMWOp::XOR) + .value("XCHG", RMWOp::XCHG) + .value("MAX", RMWOp::MAX) + .value("MIN", RMWOp::MIN) + .value("UMIN", RMWOp::UMIN) + .value("UMAX", RMWOp::UMAX); + + py::enum_(m, "ROUNDING_MODE", py::module_local()) + .value("RTZ", RoundingMode::RTZ) + .value("RTNE", RoundingMode::RTNE); + + py::enum_(m, "PROPAGATE_NAN", py::module_local()) + .value("NONE", PropagateNan::NONE) + .value("ALL", PropagateNan::ALL); + + py::enum_(m, "INPUT_PRECISION", py::module_local()) + .value("TF32", InputPrecision::TF32) + .value("TF32x3", InputPrecision::TF32x3) + .value("IEEE", InputPrecision::IEEE) + .export_values(); + + py::class_(m, "context", py::module_local()).def(py::init<>()); + + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "type", py::module_local()) + .def("is_integer", + [](Type &self, unsigned width) { return self.isInteger(width); }) + .def("is_fp16", &Type::isF16) + .def("__str__", [](Type &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "function_type", py::module_local()) + .def("param_types", [](FunctionType &self) { + return std::vector(self.getInputs().begin(), + self.getInputs().end()); + }); + + py::class_(m, "location", py::module_local()) + .def("__str__", [](Location &self) { + std::string str; + llvm::raw_string_ostream os(str); + self.print(os); + return os.str(); + }); + + py::class_(m, "value", py::module_local()) + .def("set_attr", + [](Value &self, std::string &name, Attribute &attr) -> void { + if (Operation *definingOp = self.getDefiningOp()) + definingOp->setAttr(name, attr); + else { + auto arg = mlir::cast(self); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } + } + }) + .def("get_context", &Value::getContext) + .def("replace_all_uses_with", + [](Value &self, Value &newValue) { + self.replaceAllUsesWith(newValue); + }) + .def("get_type", &Value::getType) + .def("id", [](Value &self) { + // The Value is identified by and compared with + // other Values via the underlying ValueImpl + return (uint64_t)self.getImpl(); + }); + + py::class_(m, "op_result", py::module_local()); + + py::class_(m, "block_argument", py::module_local()); + + py::class_(m, "region", py::module_local()) + .def("get_parent_region", &Region::getParentRegion, ret::reference) + .def("size", [](Region &self) { return self.getBlocks().size(); }) + .def("empty", &Region::empty) + .def("id", [](Region &self) { return (uint64_t)&self; }); + + py::class_(m, "block", py::module_local()) + .def("arg", + [](Block &self, int index) -> BlockArgument { + if (index >= self.getNumArguments()) + throw pybind11::index_error("Block argument index out of range"); + return self.getArgument(index); + }) + .def("add_argument", + [](Block &self, Type ty) { + auto loc = UnknownLoc::get(ty.getContext()); + self.addArgument(ty, loc); + }) + .def("get_num_arguments", &Block::getNumArguments) + .def("get_argument", &Block::getArgument) + .def("dump", [](Block &self) { DUMP(); }) + .def("move_before", + [](Block &self, Block &dst) { self.moveBefore(&dst); }) + .def("insert_before", &Block::insertBefore) + .def("get_parent", &Block::getParent, ret::reference) + .def("merge_block_before", + [](Block &self, Block &dst) { + // ref: RewriterBase::mergeBlocks() + if (self.getNumArguments() != 0) + throw std::runtime_error( + "This block has arguments, don't merge"); + dst.getOperations().splice(dst.begin(), self.getOperations()); + self.dropAllUses(); + self.erase(); + }) + .def("replace_use_in_block_with", + [](Block &self, Value &v, Value &newVal) { + v.replaceUsesWithIf(newVal, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + Block *currentBlock = user->getBlock(); + while (currentBlock) { + if (currentBlock == &self) + return true; + // Move up one level + currentBlock = + currentBlock->getParent()->getParentOp()->getBlock(); + } + return false; + }); + }) + .def("__str__", + [](Block &self) { + std::string str; + llvm::raw_string_ostream os(str); + PRINT(os); + return str; + }) + .def("has_terminator", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("has_return", + [](Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) + .def("erase", [](Block &self) { self.erase(); }) + .def("id", [](Block &self) { return (uint64_t)&self; }); + + py::class_(m, "attribute", py::module_local()); + py::class_(m, "integer_attr", py::module_local()); + py::class_(m, "bool_attr", py::module_local()); + + // Ops + py::class_(m, "OpState", py::module_local()) + .def("set_attr", + [](OpState &self, std::string &name, Attribute &attr) -> void { + self->setAttr(name, attr); + }) + .def("get_num_results", + [](OpState &self) -> unsigned { return self->getNumResults(); }) + .def("get_result", + [](OpState &self, unsigned idx) -> Value { + if (idx >= self->getNumResults()) + throw pybind11::index_error("Op result index out of range"); + return self->getResult(idx); + }) + .def( + "get_region", + [](OpState &self, unsigned idx) -> Region & { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self->getRegion(idx); + }, + ret::reference) + .def( + "get_body", + [](scf::ForOp &self, unsigned idx) -> Block * { + if (idx >= self->getNumRegions()) + throw pybind11::index_error("Op region index out of range"); + return self.getBody(idx); + }, + ret::reference) + .def("dump", [](OpState &self) { DUMP(); }) + .def("__str__", + [](OpState &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + PRINT(os, printingFlags); + return str; + }) + .def("append_operand", + [](OpState &self, Value &val) { + self->insertOperands(self->getNumOperands(), val); + }) + .def("verify", [](OpState &self) -> bool { + return succeeded(verify(self.getOperation())); + }); + // scf Ops + py::class_(m, "ForOp", py::module_local()) + .def("get_induction_var", &scf::ForOp::getInductionVar); + + py::class_(m, "IfOp", py::module_local()) + .def("get_then_block", &scf::IfOp::thenBlock, ret::reference) + .def("get_else_block", &scf::IfOp::elseBlock, ret::reference) + .def("get_then_yield", &scf::IfOp::thenYield) + .def("get_else_yield", &scf::IfOp::elseYield); + py::class_(m, "YieldOp", py::module_local()); + py::class_(m, "WhileOp", py::module_local()) + .def("get_before", &scf::WhileOp::getBefore, ret::reference) + .def("get_after", &scf::WhileOp::getAfter, ret::reference); + py::class_(m, "ConditionOp", py::module_local()); + + py::class_>( + m, "operation", py::module_local()) + .def("get_name", + [](Operation &self) { + llvm::StringRef opName = self.getName().getStringRef(); + return opName.str(); + }) + .def("get_num_operands", &Operation::getNumOperands) + .def("get_operand", &Operation::getOperand) + .def("get_num_results", &Operation::getNumResults) + .def("get_result", &Operation::getResult) + .def("get_num_regions", &Operation::getNumRegions) + .def("get_region", &Operation::getRegion, ret::reference) + .def("get_block", &Operation::getBlock, ret::reference) + .def("get_str_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }) + .def("get_flat_symbol_ref_attr", + [](Operation &self, const std::string &name) -> py::object { + auto ret = self.getAttrOfType(name); + if (!ret) + return py::none(); + return py::str(ret.getValue().str()); + }); + + // dynamic_attr is used to transfer ownership of the MLIR context to the + // module + py::class_(m, "module", py::module_local(), + py::dynamic_attr()) + .def("dump", [](ModuleOp &self) { DUMP(); }) + .def("str", + [](ModuleOp &self) -> std::string { + std::string str; + llvm::raw_string_ostream os(str); + auto printingFlags = OpPrintingFlags(); + printingFlags.enableDebugInfo(); + PRINT(os, printingFlags); + return str; + }) + .def("push_back", + [](ModuleOp &self, FuncOp &funcOp) -> void { + self.push_back(funcOp); + }) + .def("has_function", + [](ModuleOp &self, std::string &funcName) -> bool { + if (self.lookupSymbol(funcName)) + return true; + return false; + }) + .def("get_function", + [](ModuleOp &self, std::string &funcName) -> FuncOp { + return self.lookupSymbol(funcName); + }) + .def("get_int_attr", + [](ModuleOp &self, std::string name) -> py::object { + auto ret = self->getAttrOfType(name); + if (!ret) + return py::none(); + return py::int_(ret.getInt()); + }) + .def("create_location_snapshot", + [](ModuleOp &self, const std::string &fileName) -> void { + generateLocationsFromIR(/*raw_ostream=*/llvm::nulls(), + /*fileName=*/fileName, + /*op=*/self, /*flags=*/{}); + }) + .def("walk", + [](ModuleOp &self, const std::function &fn) { + self.walk(fn); + }); + + m.def("make_attr", [](const std::vector &values, MLIRContext &context) { + return mlir::cast(DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + IntegerType::get(&context, 32)), + values)); + }); + + m.def( + "parse_mlir_module", + [](const std::string &inputFilename, MLIRContext &context) { + // parse module + OwningOpRef module = + parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); + return module->clone(); + }, + ret::take_ownership); + + py::class_(m, "function", py::module_local()) + // .def_property_readonly("attrs", &ir::function::attrs) + // .def("add_attr", &ir::function::add_attr); + .def("args", + [](FuncOp &self, unsigned idx) -> BlockArgument { + if (idx >= self.getNumArguments()) + throw pybind11::index_error( + "Function argument index out of range"); + return self.getArgument(idx); + }) + .def( + "add_entry_block", + [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, + ret::reference) + .def( + "set_arg_attr", + [](FuncOp &self, int arg_no, const std::string &name, int val) { + // set arg attributes "name" to value "val" + auto attrTy = IntegerType::get(self.getContext(), 32); + self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val)); + }, + ret::reference) + // .def("has_attr", &::FuncOp::hasAttr) + .def("finalize", + [](FuncOp &self) -> void { + // Remove dead code + // 1. Unreachable code after return + self.walk([&](Block *block) { + Operation *retOp = nullptr; + // It's better to not use walk here because we only want to + // check operations in the current block + for (auto &op : block->getOperations()) { + if (isa(op)) + if (retOp == nullptr) { + retOp = &op; + break; + } + } + if (retOp && retOp != &block->back()) { + auto pos = retOp->getIterator(); + pos++; + auto *newBlock = block->splitBlock(pos); + newBlock->erase(); + } + }); + // 2. Check if the result of tl.advance is used + self.walk([&](Operation *op) { + if (isa(op) && op->getResult(0).use_empty()) + outputWarning(op->getLoc(), "The result of tl.advance is not " + "being used. Note that tl.advance " + "does not have any side effects. " + "To move the block pointer, you " + "need to assign the result of " + "tl.advance to a variable."); + }); + }) + .def_property_readonly("type", &FuncOp::getFunctionType) + .def("reset_type", &FuncOp::setType); + + py::class_(m, "InsertPoint", py::module_local()); + + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()) + .def(py::init()) + // getters + .def("create_module", + [](TritonOpBuilder &self) -> ModuleOp { + return self.create(); + }) + // insertion block/point + .def("set_insertion_point_to_start", + [](TritonOpBuilder &self, Block &block) -> void { + self.setInsertionPointToStart(block); + }) + .def("set_insertion_point_to_end", + [](TritonOpBuilder &self, Block &block) { + self.setInsertionPointToEnd(block); + }) + .def("set_insertion_point_after", + [](TritonOpBuilder &self, Operation &op) { + self.setInsertionPointAfter(op); + }) + .def( + "get_insertion_block", + [](TritonOpBuilder &self) -> Block * { + return self.getBuilder().getInsertionBlock(); + }, + ret::reference) + .def("get_insertion_point", + [](TritonOpBuilder &self) { + return self.getBuilder().saveInsertionPoint(); + }) + .def("restore_insertion_point", + [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { + self.restoreInsertionPoint(pt); + }) + // Attr + .def("get_bool_attr", + [](TritonOpBuilder &self, bool value) { + return self.getBuilder().getBoolAttr(value); + }) + .def("get_int32_attr", + [](TritonOpBuilder &self, int32_t value) { + return self.getBuilder().getI32IntegerAttr(value); + }) + // Use arith.ConstantOp to create constants + // Constants + .def("get_int1", + [](TritonOpBuilder &self, bool v) -> Value { + return Value(self.create( + v, self.getBuilder().getI1Type())); + }) + .def("get_int8", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_int16", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_int32", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_int64", + [](TritonOpBuilder &self, int64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_uint8", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI8Type())); + }) + .def("get_uint16", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI16Type())); + }) + .def("get_uint32", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI32Type())); + }) + .def("get_uint64", + [](TritonOpBuilder &self, uint64_t v) -> Value { + return Value(self.create( + v, self.getBuilder().getI64Type())); + }) + .def("get_bf16", + [](TritonOpBuilder &self, float v) -> Value { + auto type = self.getBuilder().getBF16Type(); + return self.create( + APFloat(type.getFloatSemantics(), std::to_string(v)), type); + }) + .def("get_fp16", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF16FloatAttr(v)); + }) + .def("get_fp32", + [](TritonOpBuilder &self, float v) -> Value { + return self.create( + self.getBuilder().getF32FloatAttr(v)); + }) + .def("get_fp64", + [](TritonOpBuilder &self, double v) -> Value { + return self.create( + self.getBuilder().getF64FloatAttr(v)); + }) + .def("get_null_value", + [](TritonOpBuilder &self, Type type) -> Value { + if (auto floatTy = dyn_cast(type)) + return self.create( + APFloat(floatTy.getFloatSemantics(), 0), floatTy); + else if (auto intTy = dyn_cast(type)) + return self.create(0, intTy); + else + throw std::runtime_error("Not implemented"); + }) + .def("get_all_ones_value", + [](TritonOpBuilder &self, Type type) -> Value { + uint64_t val = 0xFFFFFFFFFFFFFFFF; + if (auto intTy = dyn_cast(type)) + return self.create(val, intTy); + else + throw std::runtime_error("Not implemented"); + }) + + // Types + .def("get_void_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getNoneType(); + }) + .def("get_int1_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI1Type(); + }) // or ret::copy? + .def("get_int8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_int16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(16); + }) + .def("get_int32_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI32Type(); + }) + .def("get_int64_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI64Type(); + }) + .def("get_fp8e4nv_ty", + // TODO: fp8e4nv is using Float8E4M3FNUZType, which + // does not seem right. It should use FloatE4M3FNType + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b8_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e4b15_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getI8Type(); + }) + .def("get_fp8e5_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_fp8e5b16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getType(); + }) + .def("get_half_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF16Type(); + }) + .def("get_bf16_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getBF16Type(); + }) + .def("get_float_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF32Type(); + }) + .def("get_double_ty", + [](TritonOpBuilder &self) -> Type { + return self.getBuilder().getF64Type(); + }) + .def("get_ptr_ty", + [](TritonOpBuilder &self, Type &type, int addrSpace) -> Type { + return PointerType::get(type, addrSpace); + }) + .def("get_block_ty", + [](TritonOpBuilder &self, Type &elementType, + std::vector &shape) -> Type { + return RankedTensorType::get(shape, elementType); + }) + .def("get_function_ty", + [](TritonOpBuilder &self, std::vector inTypes, + std::vector outTypes) -> Type { + return self.getBuilder().getFunctionType(inTypes, outTypes); + }) + // locs + .def("set_loc", + [](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); }) + .def("set_loc", + [](TritonOpBuilder &self, const std::string &fileName, int line, + int column) { self.setLastLoc(fileName, line, column); }) + .def("get_loc", + [](TritonOpBuilder &self) -> Location { return self.getLastLoc(); }) + + // Ops + .def("get_or_insert_function", + [](TritonOpBuilder &self, ModuleOp &module, std::string &funcName, + Type &funcType, std::string &visibility, + bool noinline) -> FuncOp { + if (Operation *funcOperation = module.lookupSymbol(funcName)) + return llvm::dyn_cast(funcOperation); + if (auto funcTy = dyn_cast(funcType)) { + llvm::SmallVector attrs = { + NamedAttribute( + self.getBuilder().getStringAttr("sym_visibility"), + self.getBuilder().getStringAttr(visibility)), + NamedAttribute(self.getBuilder().getStringAttr("noinline"), + self.getBuilder().getBoolAttr(noinline))}; + return self.create(funcName, funcTy, attrs); + } + throw std::invalid_argument("invalid function type"); + }) + .def( + "create_block", + [](TritonOpBuilder &self) -> Block * { + Region *parent = self.getBuilder().getBlock()->getParent(); + return self.getBuilder().createBlock(parent); + }, + ret::reference) + .def( + "create_block_with_parent", + [](TritonOpBuilder &self, Region &parent, + std::vector &argTypes) -> Block * { + // TODO: update arg loc + auto loc = self.getBuilder().getUnknownLoc(); + llvm::SmallVector argLocs(argTypes.size(), loc); + return self.getBuilder().createBlock(&parent, {}, argTypes, + argLocs); + }, + ret::reference) + .def( + "new_block", + [](TritonOpBuilder &self) -> Block * { return new Block(); }, + ret::reference) + // Function + .def("ret", + [](TritonOpBuilder &self, std::vector &vals) -> OpState { + return self.create(vals); + }) + .def("call", + [](TritonOpBuilder &self, FuncOp &func, std::vector &args) + -> OpState { return self.create(func, args); }) + // Unstructured control flow + .def("create_cond_branch", + [](TritonOpBuilder &self, Value condition, Block *trueDest, + Block *falseDest) -> OpState { + return self.create(condition, trueDest, + falseDest); + }) + .def("create_branch", + [](TritonOpBuilder &self, Block *dest, std::vector &args) + -> OpState { return self.create(dest, args); }) + // Structured control flow + .def("create_for_op", + [](TritonOpBuilder &self, Value &lb, Value &ub, Value &step, + std::vector &initArgs) -> scf::ForOp { + return self.create(lb, ub, step, initArgs); + }) + .def("create_if_op", + [](TritonOpBuilder &self, std::vector &retTypes, + Value &condition, bool withElse) -> scf::IfOp { + return self.create(retTypes, condition, withElse); + }) + .def("create_yield_op", + [](TritonOpBuilder &self, std::vector &yields) + -> scf::YieldOp { return self.create(yields); }) + .def("create_while_op", + [](TritonOpBuilder &self, std::vector &retTypes, + std::vector &initArgs) -> scf::WhileOp { + return self.create(retTypes, initArgs); + }) + .def("create_condition_op", + [](TritonOpBuilder &self, Value &cond, + std::vector &args) -> scf::ConditionOp { + return self.create(cond, args); + }) + + // miscellaneous + .def("create_make_range", + [](TritonOpBuilder &self, int start, int end) -> Value { + auto retType = RankedTensorType::get( + {end - start}, self.getBuilder().getI32Type()); + return self.create(retType, start, end); + }) + + // Cast instructions + // Conversions for custom FP types (FP8 and non-standard rounding modes) + .def("create_fp_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType, + std::optional roundingMode) -> Value { + if (roundingMode.has_value()) + return self.create( + dstType, src, + RoundingModeAttr::get(self.getBuilder().getContext(), + roundingMode.value())); + else + return self.create(dstType, src); + }) + // Conversions for standard LLVM builtin types + .def("create_bitcast", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_si_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_ui_to_fp", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_si", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_to_ui", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_ext", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_fp_trunc", + [](TritonOpBuilder &self, Value &src, Type &dstType) -> Value { + return self.create(dstType, src); + }) + .def("create_int_cast", + [](TritonOpBuilder &self, Value &src, Type &dstType, + bool isSigned) -> Value { + // get element type if necessary + Type srcType = src.getType(); + auto srcTensorType = dyn_cast(srcType); + auto dstTensorType = dyn_cast(dstType); + Type srcEltType = srcType; + Type dstEltType = dstType; + if (dstTensorType && srcTensorType) { + dstEltType = dstTensorType.getElementType(); + srcEltType = srcTensorType.getElementType(); + } + unsigned srcWidth = srcEltType.getIntOrFloatBitWidth(); + unsigned dstWidth = dstEltType.getIntOrFloatBitWidth(); + if (srcWidth == dstWidth) + return self.create(dstType, src); + else if (srcWidth > dstWidth) + return self.create(dstType, src); + else if (isSigned) + return self.create(dstType, src); + else + return self.create(dstType, src); + }) + .def("create_to_index", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getIndexType(), input); + }) + .def("create_index_to_si", + [](TritonOpBuilder &self, Value &input) -> Value { + return self.create( + self.getBuilder().getI64Type(), input); + }) + .def("create_fmul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_frem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fadd", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_fsub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_umulhi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sdiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_udiv", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_srem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_urem", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_fma", + [](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value { + return Value(self.create(a, b, c)); + }) + .def("create_shl", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_lshr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_ashr", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_minui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minimumf follows the torch.minimum convention and returns NaN if either + // operand is NaN + .def("create_minimumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // minnumf follows the torch.fmin convention and returns the non-NaN + // operand + .def("create_minnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxsi", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_maxui", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maximumf follows the torch.maximum convention and returns NaN if either + // operand is NaN + .def("create_maximumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // maxnumf follows the torch.fmax convention and returns the non-NaN + // operand + .def("create_maxnumf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + .def("create_clampf", + [](TritonOpBuilder &self, Value &input, Value &min, Value &max, + PropagateNan propagateNan) -> Value { + return Value(self.create(input, min, max, propagateNan)); + }) + .def("create_precise_sqrt", + [](TritonOpBuilder &self, Value &input) -> Value { + return Value(self.create(input)); + }) + .def("create_precise_divf", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return Value(self.create(lhs, rhs)); + }) + // AddPtr (similar to GEP) + .def("create_addptr", + [](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value { + return self.create(ptr.getType(), ptr, offset); + }) + // Comparison (int) + .def("create_icmpSLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sle, lhs, + rhs); + }) + .def("create_icmpSLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::slt, lhs, + rhs); + }) + .def("create_icmpSGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sge, lhs, + rhs); + }) + .def("create_icmpSGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::sgt, lhs, + rhs); + }) + .def("create_icmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ule, lhs, + rhs); + }) + .def("create_icmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ult, lhs, + rhs); + }) + .def("create_icmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::uge, lhs, + rhs); + }) + .def("create_icmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ugt, lhs, + rhs); + }) + .def("create_icmpEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::eq, lhs, + rhs); + }) + .def("create_icmpNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpIPredicate::ne, lhs, + rhs); + }) + // Comparison (float) + .def("create_fcmpOLT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLT, lhs, + rhs); + }) + .def("create_fcmpOGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGT, lhs, + rhs); + }) + .def("create_fcmpOLE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OLE, lhs, + rhs); + }) + .def("create_fcmpOGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OGE, lhs, + rhs); + }) + .def("create_fcmpOEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::OEQ, lhs, + rhs); + }) + .def("create_fcmpONE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ONE, lhs, + rhs); + }) + .def("create_fcmpULT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULT, lhs, + rhs); + }) + .def("create_fcmpUGT", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGT, lhs, + rhs); + }) + .def("create_fcmpULE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::ULE, lhs, + rhs); + }) + .def("create_fcmpUGE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UGE, lhs, + rhs); + }) + .def("create_fcmpUEQ", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UEQ, lhs, + rhs); + }) + .def("create_fcmpUNE", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(arith::CmpFPredicate::UNE, lhs, + rhs); + }) + // // Logical + .def("create_and", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_xor", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + .def("create_or", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + return self.create(lhs, rhs); + }) + // Input/Output + .def("create_load", + [](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_store", + [](TritonOpBuilder &self, Value &ptrs, Value &value, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, value, cacheModifier, evictionPolicy); + }) + .def("create_tensor_pointer_load", + [](TritonOpBuilder &self, Value &ptr, + std::vector &boundaryCheck, + std::optional paddingOption, + CacheModifier cacheModifier, EvictionPolicy evictionPolicy, + bool isVolatile) -> Value { + return self.create(ptr, boundaryCheck, paddingOption, + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_tensor_pointer_store", + [](TritonOpBuilder &self, Value &ptr, Value &val, + std::vector &boundaryCheck, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptr, val, boundaryCheck, cacheModifier, + evictionPolicy); + }) + .def("create_masked_load", + [](TritonOpBuilder &self, Value &ptrs, Value &mask, + std::optional &other, CacheModifier cacheModifier, + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), + cacheModifier, evictionPolicy, + isVolatile); + }) + .def("create_masked_store", + [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> void { + self.create(ptrs, val, mask, cacheModifier, + evictionPolicy); + }) + .def("create_descriptor_load", + [](TritonOpBuilder &self, Value &desc_ptr, + std::vector &indices, Type type, + CacheModifier cacheModifier, + EvictionPolicy evictionPolicy) -> Value { + return self.create( + type, desc_ptr, indices, cacheModifier, evictionPolicy); + }) + .def("create_descriptor_store", + [](TritonOpBuilder &self, Value &desc_ptr, Value value, + std::vector &indices) -> void { + self.create(desc_ptr, value, + indices); + }) + .def("create_reshape", + [](TritonOpBuilder &self, Value &arg, std::vector &shape, + bool allowReorder) -> Value { + auto argType = + cast(arg.getType()).getElementType(); + return self.create( + RankedTensorType::get(shape, argType), arg, allowReorder); + }) + .def("create_expand_dims", + [](TritonOpBuilder &self, Value &arg, int axis) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + std::vector retShape = argType.getShape(); + retShape.insert(retShape.begin() + axis, 1); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, axis); + }) + .def("create_cat", + [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + if (!(lhsType.getShape().size() == 1 && + rhsType.getShape().size() == 1)) + throw std::invalid_argument( + "shape not supported by cat. Expecting rank-1 inputs"); + std::vector shape{lhsType.getShape()[0] + + rhsType.getShape()[0]}; + return self.create( + RankedTensorType::get(shape, lhsType.getElementType()), lhs, + rhs); + }) + .def("create_join", + [](TritonOpBuilder &self, Value &a, Value &b) -> Value { + return self.create(a, b); + }) + .def("create_split", + [](TritonOpBuilder &self, Value &a) -> std::vector { + auto op = self.create(a); + return std::vector(op->result_begin(), op->result_end()); + }) + // Implements tl.trans and tl.permute. + .def("create_trans", + [](TritonOpBuilder &self, Value &arg, + std::vector &order) -> Value { + auto argType = dyn_cast(arg.getType()); + auto argEltType = argType.getElementType(); + auto retShape = applyPermutation(argType.getShape(), order); + return self.create( + RankedTensorType::get(retShape, argEltType), arg, order); + }) + .def("create_broadcast", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + if (auto argType = dyn_cast(arg.getType())) + return self.createOrFold( + RankedTensorType::get(shape, argType.getElementType()), arg); + throw std::invalid_argument( + "arg is not of RankedTensorType, use create_splat"); + }) + .def("create_splat", + [](TritonOpBuilder &self, Value &arg, + std::vector &shape) -> Value { + auto argType = arg.getType(); + auto ret = self.createOrFold( + RankedTensorType::get(shape, argType), arg); + return ret; + }) + // // atomic + .def("create_atomic_cas", + [](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val, + MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, ptr, cmp, val, sem, + scope); + }) + .def("create_atomic_rmw", + [](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val, + Value &mask, MemSemantic sem, MemSyncScope scope) -> Value { + Type dstType; + if (auto srcTensorType = + dyn_cast(ptr.getType())) { + Type dstElemType = + cast(srcTensorType.getElementType()) + .getPointeeType(); + dstType = + RankedTensorType::get(srcTensorType.getShape(), dstElemType); + } else { + auto ptrType = cast(getElementTypeOrSelf(ptr)); + dstType = ptrType.getPointeeType(); + } + return self.create(dstType, rmwOp, ptr, val, mask, + sem, scope); + }) + // External + .def("create_extern_elementwise", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, Type retType, bool isPure) -> Value { + return self.create(retType, argList, libName, + libPath, symbol, isPure); + }) + // Built-in instruction + .def("create_get_program_id", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create( + self.getBuilder().getI32Type(), + ProgramIDDimAttr::get(self.getBuilder().getContext(), + ProgramIDDim(axis))); + }) + .def("create_get_num_programs", + [](TritonOpBuilder &self, int axis) -> Value { + if (axis < 0 || axis > 3) + throw pybind11::index_error("program_id must be in [0,3]"); + return self.create( + self.getBuilder().getI32Type(), + ProgramIDDimAttr::get(self.getBuilder().getContext(), + ProgramIDDim(axis))); + }) + .def("create_dot", + [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, + mlir::Value &c, InputPrecision inputPrecision, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create(c.getType(), a, b, c, inputPrecision, + maxNumImpreciseAcc); + }) + .def("create_floor", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_ceil", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_exp2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_cos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_log2", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_erf", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_sqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_rsqrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_fabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_iabs", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_reduce", + [](TritonOpBuilder &self, std::vector operands, int axis) + -> OpState { return self.create(operands, axis); }) + .def("create_reduce_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_scan", + [](TritonOpBuilder &self, std::vector operands, int axis, + bool reverse) -> OpState { + return self.create(operands, axis, reverse); + }) + .def("create_scan_ret", + [](TritonOpBuilder &self, py::args args) -> OpState { + llvm::SmallVector return_values; + for (const auto &arg : args) { + return_values.push_back(py::cast(arg)); + } + return self.create(return_values); + }) + .def("create_ptr_to_int", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_int_to_ptr", + [](TritonOpBuilder &self, Value &val, Type &type) -> Value { + return self.create(type, val); + }) + .def("create_select", + [](TritonOpBuilder &self, Value &condition, Value &trueValue, + Value &falseValue) -> Value { + return self.create(condition, trueValue, + falseValue); + }) + .def("create_inline_asm", + [](TritonOpBuilder &self, const std::string &inlineAsm, + const std::string &constraints, const std::vector &values, + const std::vector &types, bool isPure, + int pack) -> OpState { + return self.create( + types, inlineAsm, constraints, isPure, pack, values); + }) + .def("create_print", + [](TritonOpBuilder &self, const std::string &prefix, bool hex, + const std::vector &values) -> void { + self.create( + StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(prefix)), + hex, values); + }) + .def("create_assert", + [](TritonOpBuilder &self, Value &condition, + const std::string &message, const std::string &fileName, + const std::string &funcName, unsigned lineNo) -> void { + auto messageAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(message)); + auto fileNameAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(fileName)); + auto funcNameAttr = StringAttr::get(self.getBuilder().getContext(), + llvm::StringRef(funcName)); + auto lineNoAttr = self.getBuilder().getI32IntegerAttr(lineNo); + self.create(condition, messageAttr, fileNameAttr, + funcNameAttr, lineNoAttr); + }) + // Undef + .def("create_undef", + [](TritonOpBuilder &self, Type &type) -> Value { + return self.create(type); + }) + .def("create_histogram", + [](TritonOpBuilder &self, Value operand, int numBins) -> Value { + return self.create( + RankedTensorType::get( + {static_cast(numBins)}, + IntegerType::get(operand.getContext(), 32)), + operand); + }) + // Force GPU barrier + .def("create_barrier", + [](TritonOpBuilder &self) { self.create(); }) + // Make a block pointer (tensor pointer in Triton IR) + .def("create_make_block_ptr", + [](TritonOpBuilder &self, Value &base, std::vector &shape, + std::vector &strides, std::vector &offsets, + std::vector &tensorShape, + std::vector &order) -> Value { + return self.create(base, shape, strides, offsets, + tensorShape, order); + }) + // Advance a block pointer + .def("create_advance", + [](TritonOpBuilder &self, Value &ptr, + std::vector &offsets) -> Value { + return self.create(ptr.getType(), ptr, offsets); + }); + + py::class_(m, "pass_manager", py::module_local()) + .def(py::init()) + .def("enable_debug", + [](PassManager &self) { + auto *context = self.getContext(); +#if !defined(TRITON_CONCEAL_IR) || (TRITON_CONCEAL_IR == 0) + bool haveDiagnostics = + ::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); + bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); +#else + bool haveDiagnostics = false; + bool haveDump = false; +#endif + if (haveDiagnostics || haveDump) { + context->disableMultithreading(); + } + if (haveDiagnostics) { + context->printOpOnDiagnostic(true); + context->printStackTraceOnDiagnostic(true); + context->getDiagEngine().registerHandler([](Diagnostic &diag) { + llvm::outs() << diag << "\n"; + return success(); + }); + } + if (haveDump) { + auto printingFlags = OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + printingFlags.enableDebugInfo(); + auto printAlways = [](Pass *, Operation *) { return true; }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/nullptr, + /*shouldPrintAfterPass=*/printAlways, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/true, + /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), + printingFlags); + } + }) + .def("run", [](PassManager &self, ModuleOp &mod) { +#if !defined(TRITON_CONCEAL_IR) || (TRITON_CONCEAL_IR == 0) + // TODO: maybe dump module to file and print error for better + // diagnostics + auto reproducerPath = + triton::tools::getStrEnv("TRITON_REPRODUCER_PATH"); + if (!reproducerPath.empty()) { + auto anchorName = self.getOpAnchorName(); + auto passes = self.getPasses(); + Operation *op = mod.getOperation(); + makeReproducer(anchorName, passes, op, reproducerPath); + } + + if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) { + ::llvm::DebugFlag = true; + } + + if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY"); + !debugOnly.empty()) { + bool enableTritonLogging = false; + llvm::SmallVector split; + llvm::SmallVector storage; + llvm::SmallVector debugTypes; + + StringRef(debugOnly.c_str()).split(split, ','); + llvm::transform(split, std::back_inserter(debugTypes), + [&storage, &enableTritonLogging](StringRef str) { + if (str == "triton") + enableTritonLogging = true; + // StringRefs are not always null-terminated. + // The purpose for this storage pattern is to + // produce a collection of C-strings that are. + storage.push_back(str.str()); + return storage.back().c_str(); + }); + if (enableTritonLogging) { + debugTypes.insert(debugTypes.end(), { + "ttgpu_to_llvm", + "tritonxpu-core-tiling", + "tritonxpu-interleave", + "tritonxpu-legalize", + "tritonxpu-memory-async", + "tritonxpu-offset-analysis", + "tritonxpu-store-control", + "tritonxpu-unroll-control", + "tritonxpu-vectorize", + }); + } + + ::llvm::DebugFlag = true; + ::llvm::setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); + } +#endif + + bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING"); + if (haveTiming) { + self.enableTiming(); + } + + if (failed(self.run(mod.getOperation()))) + throw std::runtime_error("PassManager::run failed"); + }); +} + +void init_triton_env_vars(py::module &m) { + m.def("get_cache_invalidating_env_vars", + []() -> std::map { + std::map ret; + for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) { + auto strVal = triton::tools::getStrEnv(envVar); + if (strVal.empty()) + continue; + auto boolV = triton::tools::isEnvValueBool(strVal); + if (boolV.has_value()) + ret[envVar] = boolV.value() ? "true" : "false"; + else + ret[envVar] = strVal; + } + return ret; + }); +} diff --git a/third_party/xpu/python/src/llvm.cc b/third_party/xpu/python/src/llvm.cc new file mode 100644 index 000000000..decc9f312 --- /dev/null +++ b/third_party/xpu/python/src/llvm.cc @@ -0,0 +1,559 @@ +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/Object/SymbolSize.h" +#include "llvm/Object/SymbolicFile.h" +#include "llvm/Pass.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Passes/StandardInstrumentations.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include +#include +#include + +#define DEFAULTLOCALLIMIT 8000 + +namespace py = pybind11; + +namespace llvm { +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; +} // namespace llvm + +using namespace llvm; + +Expected> +initObjectFileByMemory(unsigned char *ObjArray, uint64_t ObjLen) { + if (!ObjArray) { + return errorCodeToError(object::object_error::section_stripped); + } + + ErrorOr> FileOrErr = MemoryBuffer::getMemBuffer( + StringRef((char *)ObjArray, ObjLen), ".debug", false); + if (!FileOrErr) { + llvm::report_fatal_error(errorCodeToError(FileOrErr.getError())); + } + + std::unique_ptr Buffer = std::move(FileOrErr.get()); + Expected> ObjOrErr = + object::ObjectFile::createObjectFile(Buffer->getMemBufferRef()); + if (!ObjOrErr) { + llvm::report_fatal_error(ObjOrErr.takeError()); + } + + return ObjOrErr; +} + +bool isElfStackSizeOOB(std::string ElfObject) { + uint32_t StackSizeLimit = DEFAULTLOCALLIMIT; + std::string StackSizeLimitStr = + mlir::triton::tools::getStrEnv("TRITON_TUNE_BUFFER_LM_SIZE"); + if (!StackSizeLimitStr.empty()) { + llvm::StringRef StackSizeLimitSRf = StackSizeLimitStr; + if (StackSizeLimitSRf.getAsInteger(10, StackSizeLimit)) { + llvm::report_fatal_error( + "Invalid value for TRITON_TUNE_BUFFER_LM_SIZE: " + StackSizeLimitSRf); + } + } + + Expected> ObjFile = + initObjectFileByMemory( + (unsigned char *)const_cast(ElfObject.data()), + ElfObject.size()); + if (!ObjFile) { + llvm::report_fatal_error(ObjFile.takeError()); + } + + uint32_t StackSize = 0; + + std::vector> SymbolSizes = + object::computeSymbolSizes(*ObjFile.get()); + for (std::pair &SymbolSize : SymbolSizes) { + const object::SymbolRef &Symbol = SymbolSize.first; + Expected NameOrErr = Symbol.getName(); + if (!NameOrErr) { + errorToErrorCode(NameOrErr.takeError()); + continue; + } + StringRef Name = *NameOrErr; + if (!Name.contains("KERNEL_STACK_SIZE")) { + continue; + } + + Expected SymSectionOrErr = + Symbol.getSection(); + if (!SymSectionOrErr) { + llvm::report_fatal_error(SymSectionOrErr.takeError()); + } + + Expected ContentsOrErr = SymSectionOrErr.get()->getContents(); + if (!ContentsOrErr) { + llvm::report_fatal_error(ContentsOrErr.takeError()); + } + + if (SymbolSize.second != 4) { + llvm::report_fatal_error("Symbol Size Error"); + } + if (SymSectionOrErr.get()->getSize() != 8) { + llvm::report_fatal_error("Section Size Error"); + } + + StackSize = *((uint32_t *)(ContentsOrErr->data())); + break; + } + return StackSize > StackSizeLimit; +} + +std::string translateLLVMIRToASM(llvm::Module &module, + const std::string &triple, + const std::string &proc, + const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, bool isObject) { + using namespace mlir; + // options + auto options = llvm::cl::getRegisteredOptions(); + for (std::string flag : flags) { + auto *shortPtr = static_cast *>(options[flag]); + assert(shortPtr); + shortPtr->setValue(true); + } + +#if !defined(TRITON_CONCEAL_IR) || (TRITON_CONCEAL_IR == 0) + if (triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optIt = options.find("print-after-all"); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } +#endif + + { + uint32_t LMSizeLimit = -1; + std::string LMSizeLimitStr = + mlir::triton::tools::getStrEnv("LLVM_ERROR_LM_SIZE"); + if (!LMSizeLimitStr.empty()) { + llvm::StringRef LMSizeLimitSRf = LMSizeLimitStr; + if (LMSizeLimitSRf.getAsInteger(10, LMSizeLimit)) { + llvm::report_fatal_error("Invalid value for LLVM_ERROR_LM_SIZE: " + + LMSizeLimitSRf); + } + } + llvm::StringMap optMap = + llvm::cl::getRegisteredOptions(); + auto optIt = optMap.find("xpu-error-lm-size"); + if (optIt != optMap.end()) { + llvm::cl::opt *optPtr = + static_cast *>(optIt->second); + *optPtr = LMSizeLimit; + } + } + + bool disableLLVMOpt = triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); + if (!disableLLVMOpt) { + // Check to see if we are passing a list of flags to disable optimizations. + auto flagList = triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + } + + // inline everything + for (llvm::Function &f : module.functions()) + if (!f.hasFnAttribute(llvm::Attribute::NoInline)) + f.addFnAttr(llvm::Attribute::AlwaysInline); + // verify and store llvm + llvm::legacy::PassManager pm; + pm.add(llvm::createAlwaysInlinerLegacyPass()); + pm.add(llvm::createVerifierPass()); + + const bool enabledTiming = triton::tools::getBoolEnv("LLVM_ENABLE_TIMING"); + if (enabledTiming) { + llvm::TimePassesIsEnabled = true; + llvm::TimePassesPerRun = true; + } + + pm.run(module); + + SmallString<0> timePassesStr; + raw_svector_ostream reportStream(timePassesStr); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + // module->print(llvm::outs(), nullptr); + + // create machine + module.setTargetTriple(triple); + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error); + llvm::TargetOptions opt; + if (enable_fp_fusion) + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + std::unique_ptr machine{target->createTargetMachine( + module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, + std::nullopt, + disableLLVMOpt ? llvm::CodeGenOptLevel::None + : llvm::CodeGenOptLevel::Aggressive)}; + // set data layout + module.setDataLayout(machine->createDataLayout()); + // emit machine code + std::string result; + { + llvm::raw_string_ostream stream(result); + llvm::buffer_ostream pstream(stream); + for (llvm::Function &f : module.functions()) + f.addFnAttr(llvm::Attribute::AlwaysInline); + llvm::legacy::PassManager pass; + // emit + auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile + : llvm::CodeGenFileType::AssemblyFile; + machine->addPassesToEmitFile(pass, pstream, nullptr, fileType); + pass.run(module); + + if (enabledTiming) { + reportAndResetTimings(&reportStream); + llvm::dbgs() << reportStream.str(); + timePassesStr.clear(); + } + } + return result; +} + +using ret = py::return_value_policy; + +void init_triton_llvm(py::module &&m) { + + py::class_(m, "context", py::module_local()) + .def(py::init<>()); + + py::class_(m, "function_list") + .def( + "__iter__", + [](llvm::Module::FunctionListType &s) { + return py::make_iterator(s.begin(), s.end()); + }, + py::keep_alive<0, 1>()); + + // Module Flag behavior. See + // https://llvm.org/doxygen/classllvm_1_1Module.html#a0a5c55e12c97b80021330fe82b642293 + // for details. + py::class_(m, "module_flag_behavior", + py::module_local()); + m.attr("MODULE_FLAG_BEHAVIOR_ERROR") = llvm::Module::Error; + m.attr("MODULE_FLAG_BEHAVIOR_WARNING") = llvm::Module::Warning; + m.attr("MODULE_FLAG_BEHAVIOR_REQUIRE") = llvm::Module::Require; + m.attr("MODULE_FLAG_BEHAVIOR_OVERRIDE") = llvm::Module::Override; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND") = llvm::Module::Append; + m.attr("MODULE_FLAG_BEHAVIOR_APPEND_UNIQUE") = llvm::Module::AppendUnique; + m.attr("MODULE_FLAG_BEHAVIOR_MAX") = llvm::Module::Max; + m.attr("MODULE_FLAG_BEHAVIOR_MIN") = llvm::Module::Min; + + py::class_(m, "module", py::module_local()) + .def( + "__str__", + [](llvm::Module *self) { + std::string str; + llvm::raw_string_ostream os(str); +#if !defined(TRITON_CONCEAL_IR) || (TRITON_CONCEAL_IR == 0) + os << *self; +#endif + return os.str(); + }, + ret::take_ownership) + .def( + "get_functions", + [](llvm::Module *mod) -> llvm::Module::FunctionListType & { + // Note: Backends assume that we are compiling exactly one kernel + // (i.e. one function that's that's called by the CPU) and that it's + // the first function in this list. + return mod->getFunctionList(); + }, + ret::reference_internal) + .def("add_flag", + [](llvm::Module *mod, llvm::Module::ModFlagBehavior behavior, + std::string &key, uint32_t value) { + return mod->addModuleFlag(behavior, key, value); + }); + + py::class_(m, "function", py::module_local()) + .def_property_readonly( + "name", [](llvm::Function *fn) { return fn->getName().str(); }) + .def("set_calling_conv", &llvm::Function::setCallingConv) + .def("add_fn_attr", [](llvm::Function *fn, std::string &name, + std::string &val) { fn->addFnAttr(name, val); }) + + // Sets the nvvm.maxreg property on the given function. + .def("set_nvvm_maxnreg", + [](llvm::Function *fn, int maxnreg) { + auto op = MDNode::get( + fn->getContext(), + { + ValueAsMetadata::get(fn), + MDString::get(fn->getContext(), "maxnreg"), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(fn->getContext()), maxnreg)), + }); + fn->getParent() + ->getOrInsertNamedMetadata("nvvm.annotations") + ->addOperand(op); + }) + // External functions that are definitions (i.e. not declarations) are + // kernel functions. + .def("is_declaration", &llvm::Function::isDeclaration) + .def("is_external_linkage", [](llvm::Function *fn) { + return fn->getLinkage() == llvm::GlobalValue::ExternalLinkage; + }); + + // optimization levels + py::class_(m, "optimization_level", + py::module_local()); + m.attr("OPTIMIZE_O0") = llvm::OptimizationLevel::O0; + m.attr("OPTIMIZE_O1") = llvm::OptimizationLevel::O1; + m.attr("OPTIMIZE_O2") = llvm::OptimizationLevel::O2; + m.attr("OPTIMIZE_O3") = llvm::OptimizationLevel::O3; + m.attr("OPTIMIZE_Os") = llvm::OptimizationLevel::Os; + m.attr("OPTIMIZE_Oz") = llvm::OptimizationLevel::Oz; + + m.def( + "to_module", + [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) { + return mlir::translateModuleToLLVMIR(mod, ctx); + }, + py::keep_alive<0, 2>()); + + m.def( + "optimize_module", + [](llvm::Module *mod, const llvm::OptimizationLevel &opt, + const std::string triple) { + if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT")) + return; + // Check to see if we are passing a list of flags to disable + // optimizations. + auto flagList = mlir::triton::tools::getStrEnv("DISABLE_LLVM_OPT"); + if (!flagList.empty()) { + auto options = llvm::cl::getRegisteredOptions(); + llvm::SmallVector split; + StringRef(flagList.c_str()).split(split, ','); + for (auto flag : split) { + auto optIt = options.find(flag); + if (optIt != options.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + } + } + using namespace llvm; + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + PassInstrumentationCallbacks *instrCbPtr = nullptr; + PassInstrumentationCallbacks passInstrCb; + StandardInstrumentations standardInstr(mod->getContext(), + /*DebugLogging*/ true); + +#if !defined(TRITON_CONCEAL_IR) || (TRITON_CONCEAL_IR == 0) + if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { + auto optMap = llvm::cl::getRegisteredOptions(); + auto optIt = optMap.find("print-after-all"); + if (optIt != optMap.end()) { + auto optPtr = static_cast *>(optIt->second); + *optPtr = true; + } + standardInstr.registerCallbacks(passInstrCb, &mam); + instrCbPtr = &passInstrCb; + } +#endif + + { + uint32_t LMSizeLimit = -1; + std::string LMSizeLimitStr = + mlir::triton::tools::getStrEnv("LLVM_ERROR_LM_SIZE"); + if (!LMSizeLimitStr.empty()) { + llvm::StringRef LMSizeLimitSRf = LMSizeLimitStr; + if (LMSizeLimitSRf.getAsInteger(10, LMSizeLimit)) { + llvm::report_fatal_error( + "Invalid value for LLVM_ERROR_LM_SIZE: " + LMSizeLimitSRf); + } + } + llvm::StringMap optMap = + llvm::cl::getRegisteredOptions(); + auto optIt = optMap.find("xpu-error-lm-size"); + if (optIt != optMap.end()) { + llvm::cl::opt *optPtr = + static_cast *>(optIt->second); + *optPtr = LMSizeLimit; + } + } + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = false; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. + // This cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also + // applies some scheduling that helps performance in some cases. We + // should work on using NVPTX target instead and address the performance + // regressions with some scheduling solution. + //===-------------------- For Triton XPU -----------------------===// + tuningOptions.SLPVectorization = + false; // TODO[dyq]: wait for xtdk adaptation + tuningOptions.SimpleLoopUnswitching = + false; // To Avoid Copying When If Else is in For + tuningOptions.MemCpyOpt = false; // To Void Selecting Memset Instruction + //===-----------------------------------------------------------===// + + if (!triple.empty()) { + mod->setTargetTriple(triple.c_str()); + } else { + mod->setTargetTriple("xpu3"); + } + + PassBuilder pb(nullptr /*targetMachine*/, tuningOptions, std::nullopt, + instrCbPtr); + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + pb.registerVectorizerStartEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make + // sure all the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + fpm.addPass(InstCombinePass()); + }); + mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); + mpm.run(*mod, mam); + }, + py::arg("mod"), py::arg("opt"), py::arg("triple") = ""); + + m.def("is_elf_stack_size_oob", [](std::string ElfObj) -> py::bool_ { + bool StackSizeOutofBound = isElfStackSizeOOB(ElfObj); + return StackSizeOutofBound; + }); + + m.def( + "translate_to_asm", + [](std::string llvmIR, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, bool isObject) -> py::object { + std::string obj; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + obj = translateLLVMIRToASM(*module, triple, proc, features, flags, + enable_fp_fusion, isObject); + } + if (isObject) + return py::bytes(obj); + else + return py::str(obj); + }, + ret::take_ownership); + + m.def("init_targets", []() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + }); + }); + + m.def("link_extern_libs", [](llvm::Module *dstMod, + const std::vector &paths) { + if (paths.empty()) + return; + + LLVMContext &ctx = dstMod->getContext(); + llvm::Linker linker(*dstMod); + for (const std::string &path : paths) { + llvm::SMDiagnostic err; + std::unique_ptr libMod = llvm::parseIRFile(path, err, ctx); + if (!libMod) { + std::string message = "Failed to parse library at " + path; + throw std::invalid_argument(message); + } + libMod->setTargetTriple(dstMod->getTargetTriple()); + libMod->setDataLayout(dstMod->getDataLayout()); + + std::unordered_set externalFns; + for (llvm::Function &fn : libMod->functions()) { + if (!fn.isDeclaration()) + externalFns.insert(fn.getName().str()); + } + + if (linker.linkInModule(std::move(libMod), + llvm::Linker::Flags::LinkOnlyNeeded)) { + std::string message = "Failed to link library at " + path; + throw std::invalid_argument(message); + } + + // Mark linked-in functions as internal because backends use external + // linkage as a signifier of kernel functions. + for (llvm::Function &fn : dstMod->functions()) { + if (externalFns.count(fn.getName().str())) { + fn.setLinkage(llvm::GlobalValue::InternalLinkage); + } + } + } + }); +} diff --git a/third_party/xpu/python/src/main.cc b/third_party/xpu/python/src/main.cc new file mode 100644 index 000000000..5ad4be7d5 --- /dev/null +++ b/third_party/xpu/python/src/main.cc @@ -0,0 +1,50 @@ +#include +namespace py = pybind11; + +#define FOR_EACH_1(MACRO, X) MACRO(X) +#define FOR_EACH_2(MACRO, X, ...) MACRO(X) FOR_EACH_1(MACRO, __VA_ARGS__) +#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__) +#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__) + +#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N()) +#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__) +#define FOR_EACH_ARG_N(_1, _2, _3, _4, N, ...) N +#define FOR_EACH_RSEQ_N() 4, 3, 2, 1, 0 + +#define CONCATENATE(x, y) CONCATENATE1(x, y) +#define CONCATENATE1(x, y) x##y + +#define FOR_EACH(MACRO, ...) \ + CONCATENATE(FOR_EACH_, FOR_EACH_NARG_HELPER(__VA_ARGS__))(MACRO, __VA_ARGS__) +#define FOR_EACH_NARG_HELPER(...) FOR_EACH_NARG(__VA_ARGS__) + +// New macro to remove parentheses +#define REMOVE_PARENS(...) __VA_ARGS__ + +// Intermediate macro to ensure correct expansion +#define FOR_EACH_P_INTERMEDIATE(MACRO, ...) FOR_EACH(MACRO, __VA_ARGS__) + +// Modified FOR_EACH to handle parentheses +#define FOR_EACH_P(MACRO, ARGS_WITH_PARENS) \ + FOR_EACH_P_INTERMEDIATE(MACRO, REMOVE_PARENS ARGS_WITH_PARENS) + +#define DECLARE_BACKEND(name) void init_triton_##name(pybind11::module &&m); + +#define INIT_BACKEND(name) init_triton_##name(m.def_submodule(#name)); + +void init_triton_env_vars(pybind11::module &m); +void init_triton_ir(pybind11::module &&m); +void init_triton_llvm(pybind11::module &&m); +void init_triton_interpreter(pybind11::module &&m); +void init_triton_passes(pybind11::module &&m); +FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE) + +PYBIND11_MODULE(libtriton, m) { + m.doc() = "Python bindings to the C++ Triton API"; + init_triton_env_vars(m); + init_triton_ir(m.def_submodule("ir")); + init_triton_passes(m.def_submodule("passes")); + init_triton_interpreter(m.def_submodule("interpreter")); + init_triton_llvm(m.def_submodule("llvm")); + FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE) +} diff --git a/third_party/xpu/python/src/passes.cc b/third_party/xpu/python/src/passes.cc new file mode 100644 index 000000000..513e811d2 --- /dev/null +++ b/third_party/xpu/python/src/passes.cc @@ -0,0 +1,90 @@ +#include "mlir/Transforms/Passes.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Target/LLVMIR/Passes.h" +#include +#include + +namespace py = pybind11; + +void init_triton_analysis(py::module &&m) { + py::class_(m, "allocation", py::module_local()) + .def(py::init()); + py::class_(m, "membar", py::module_local()) + .def(py::init()) + .def("run", &mlir::ModuleMembarAnalysis::run); +} + +void init_triton_passes_common(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass); + ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass); + ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass); + ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); + ADD_PASS_WRAPPER_0("add_cse", createCSEPass); + ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); +} + +void init_triton_passes_ttir(py::module &&m) { + using namespace mlir::triton; + ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass); + ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass); + ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", + createRewriteTensorPointerPass); + ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", + createConvertTritonToTritonGPUPass, const std::string &, + int, int, int); +} + +void init_triton_passes_ttgpuir(py::module &&m) { + using namespace mlir::triton::gpu; + ADD_PASS_WRAPPER_0("add_coalesce", createTritonGPUCoalesce); + ADD_PASS_WRAPPER_0("add_optimize_thread_locality", + createTritonGPUOptimizeThreadLocality); + ADD_PASS_OPTION_WRAPPER_1("add_pipeline", createTritonGPUPipeline, int); + ADD_PASS_WRAPPER_0("add_prefetch", createTritonGPUPrefetch); + ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul); + ADD_PASS_WRAPPER_0("add_reorder_instructions", + createTritonGPUReorderInstructions); + ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC); + ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands", + createTritonGPUOptimizeDotOperands, bool); + ADD_PASS_WRAPPER_0("add_remove_layout_conversions", + createTritonGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_0("add_reduce_data_duplication", + createTritonGPUReduceDataDuplication); + ADD_PASS_WRAPPER_0("add_allocate_shared_memory", + createAllocateSharedMemoryPass); + ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", + createTritonGPUCombineTensorSelectAndIf); +} + +void init_triton_passes_convert(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_scf_to_cf", createConvertSCFToCFPass); + ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); + ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); + ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); +} + +void init_triton_passes_llvmir(py::module &&m) { + using namespace mlir; + ADD_PASS_WRAPPER_0("add_di_scope", createLLVMDIScopePass); +} + +void init_triton_passes(py::module &&m) { + init_triton_analysis(m.def_submodule("analysis")); + init_triton_passes_common(m.def_submodule("common")); + init_triton_passes_convert(m.def_submodule("convert")); + init_triton_passes_ttir(m.def_submodule("ttir")); + init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); + init_triton_passes_llvmir(m.def_submodule("llvmir")); +} diff --git a/third_party/xpu/python/src/passes.h b/third_party/xpu/python/src/passes.h new file mode 100644 index 000000000..46801d802 --- /dev/null +++ b/third_party/xpu/python/src/passes.h @@ -0,0 +1,40 @@ +#define ADD_PASS_WRAPPER_0(name, builder) \ + m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); }) + +#define ADD_PASS_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); }) + +#define ADD_PASS_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder(val0, val1)); \ + }) + +#define ADD_PASS_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder(val0, val1, val2)); \ + }) + +#define ADD_PASS_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ + ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); }) + +#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) + +#define ADD_PASS_OPTION_WRAPPER_2(name, builder, ty0, ty1) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \ + pm.addPass(builder({val0, val1})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_3(name, builder, ty0, ty1, ty2) \ + m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2) { \ + pm.addPass(builder({val0, val1, val2})); \ + }) + +#define ADD_PASS_OPTION_WRAPPER_4(name, builder, ty0, ty1, ty2, ty3) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3) { \ + pm.addPass(builder({val0, val1, val2, val3})); \ + }) diff --git a/third_party/xpu/python/test/unit/conftest.py b/third_party/xpu/python/test/unit/conftest.py new file mode 100644 index 000000000..7a02d322b --- /dev/null +++ b/third_party/xpu/python/test/unit/conftest.py @@ -0,0 +1,12 @@ +# content of conftest.py + +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default='cuda') + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/third_party/xpu/python/test/unit/language/assert_helper.py b/third_party/xpu/python/test/unit/language/assert_helper.py new file mode 100644 index 000000000..ddd6ad886 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/assert_helper.py @@ -0,0 +1,154 @@ +import sys + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +def get_current_target_warp_size(): + return triton.runtime.driver.active.get_current_target().warp_size + + +@triton.jit +def kernel_device_assert(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_assert(x == 0, "x != 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_assert_passes(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Trivial assert, should not be an error. + tl.device_assert(0 == 0, "x != 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit(debug=False) +def kernel_device_assert_no_debug(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_assert(x == 0, "x != 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_assert(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + assert x == 0, "x != 0" + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_static_assert(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_assert(BLOCK == 128, "BLOCK != 128") + tl.store(Y + tl.arange(0, BLOCK), x) + + +def test_assert(func: str): + N = 128 # This value should match with test_print in test_subprocess.py. + num_warps = N // get_current_target_warp_size() + + x = torch.arange(0, N, dtype=torch.int32, device='cuda') + y = torch.zeros((N, ), dtype=x.dtype, device="cuda") + if func == "device_assert": + kernel_device_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) + if func == "device_assert_passes": + # Assert passes; no error. + kernel_assert_passes[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "no_debug": + # TRITON_DEBUG=1 can override the debug flag + kernel_device_assert_no_debug[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "assert": + kernel_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "static_assert": + kernel_static_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "double_assert": + # Launching a different kernel after the first one asserted used to + # segfault. What seems to have happened is: + # - The first kernel is enqueued but doesn't run yet. + # - We go to launch the second kernel. Because this is the first time + # we're running it, we have to load the kernel into the GPU. + # - Loading the kernel takes some time, during which the first launch + # completes. + # - Now the GPU is in an error state. We need to detect this inside + # the kernel-launch/loading code and bail out properly. If we don't, + # we segfault. + kernel_device_assert[(1, )](x, y, num_warps=num_warps, BLOCK=N) + kernel_assert_passes[(1, )](x, y, num_warps=num_warps, BLOCK=N) + assert_close(y, x) + + +@triton.jit +def jit_device_assert_none(x): + tl.device_assert(x == 0, "x != 0") + + +@triton.jit(debug=True) +def jit_device_assert_true(x): + tl.device_assert(x == 0, "x != 0") + + +@triton.jit(debug=False) +def jit_device_assert_false(x): + tl.device_assert(x == 0, "x != 0") + + +@triton.jit +def kernel_device_assert_nested(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + if jit_debug == "true": + jit_device_assert_true(x) + elif jit_debug == "false": + jit_device_assert_false(x) + else: + jit_device_assert_none(x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit(debug=True) +def kernel_device_assert_nested_true(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + if jit_debug == "true": + jit_device_assert_true(x) + elif jit_debug == "false": + jit_device_assert_false(x) + else: + jit_device_assert_none(x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit(debug=False) +def kernel_device_assert_nested_false(X, Y, BLOCK: tl.constexpr, jit_debug: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + if jit_debug == "true": + jit_device_assert_true(x) + elif jit_debug == "false": + jit_device_assert_false(x) + else: + jit_device_assert_none(x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +def test_assert_nested(caller: str, callee: str): + N = 128 # This value should match with test_print in test_subprocess.py. + num_warps = N // get_current_target_warp_size() + + x = torch.arange(0, N, dtype=torch.int32, device='cuda') + y = torch.zeros((N, ), dtype=x.dtype, device="cuda") + if caller == "none": + kernel_device_assert_nested[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) + elif caller == "true": + kernel_device_assert_nested_true[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) + elif caller == "false": + kernel_device_assert_nested_false[(1, )](x, y, num_warps=num_warps, BLOCK=N, jit_debug=callee) + assert_close(y, x) + + +if __name__ == "__main__": + if len(sys.argv) == 3: + test_assert_nested(sys.argv[1], sys.argv[2]) + else: + test_assert(sys.argv[1]) diff --git a/third_party/xpu/python/test/unit/language/conftest.py b/third_party/xpu/python/test/unit/language/conftest.py new file mode 100644 index 000000000..091f9ea41 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/conftest.py @@ -0,0 +1,5 @@ +# content of conftest.py + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") diff --git a/third_party/xpu/python/test/unit/language/print_helper.py b/third_party/xpu/python/test/unit/language/print_helper.py new file mode 100644 index 000000000..e032792f3 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/print_helper.py @@ -0,0 +1,125 @@ +import sys +import uuid + +import torch +from torch.testing import assert_close + +import triton +import triton.language as tl + + +def get_current_target_warp_size(): + return triton.runtime.driver.active.get_current_target().warp_size + + +@triton.jit +def kernel_device_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_hex(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.device_print("x: ", x, hex=True) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_print(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Triton should add a space after this prefix. + print("x:", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_device_print_large( + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + x = tl.full([BLOCK_M, BLOCK_N], 1, tl.int32) + # Triton should change this prefix to "x: ". + tl.device_print("x ", x) + + +@triton.jit +def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + print("", x, y) + + +@triton.jit +def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.full((BLOCK, ), 1, tl.int32) + tl.device_print("", x, y) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_static_print(X, Y, BLOCK: tl.constexpr, PLACEHOLDER: tl.constexpr): + # This function takes an extra value as a tl.constexpr so this kernel is not + # cached. This way the static print is run every time. + x = tl.load(X + tl.arange(0, BLOCK)) + tl.static_print("", x) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def kernel_no_arg_print(): + print("", tl.program_id(0)) + + +@triton.jit +def kernel_print_no_arg(): + print("no arg") + + +@triton.jit +def kernel_print_pointer(X, Y, BLOCK: tl.constexpr): + tl.device_print("ptr ", X + tl.arange(0, BLOCK)) + + +def test_print(func: str, data_type: str): + N = 128 # This value should match with test_print in test_subprocess.py. + # TODO(antiagainst): Currently the warp count is chosen to make sure wedon't have multiple + # threads printing duplicated messages due to broadcasting. Improve print op lowering logic + # to filter out duplicated data range. + num_warps = N // get_current_target_warp_size() + + x = torch.arange(0, N, dtype=torch.int32, device='cuda').to(getattr(torch, data_type)) + y = torch.zeros((N, ), dtype=x.dtype, device="cuda") + if func == "device_print": + kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "print": + kernel_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_large": + kernel_device_print_large[(1, 2)](BLOCK_M=64, num_warps=num_warps, BLOCK_N=N) + elif func == "print_multiple_args": + kernel_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_multiple_args": + kernel_device_print_multiple_args[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "static_print": + kernel_static_print[(1, )](x, y, num_warps=num_warps, BLOCK=N, PLACEHOLDER=uuid.uuid4()) + elif func == "no_arg_print": + kernel_no_arg_print[(1, )](num_warps=num_warps) + elif func == "print_no_arg": + kernel_print_no_arg[(1, )](num_warps=num_warps) + elif func == "device_print_hex": + kernel_device_print_hex[(1, )](x, y, num_warps=num_warps, BLOCK=N) + elif func == "device_print_pointer": + kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N) + else: + assert f"Unknown kernel: {func}" + + if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ + func != "print_multiple_args" and func != "device_print_multiple_args" and \ + func != "device_print_pointer": + assert_close(y, x) + + +if __name__ == "__main__": + test_print(sys.argv[1], sys.argv[2]) diff --git a/third_party/xpu/python/test/unit/language/test_annotations.py b/third_party/xpu/python/test/unit/language/test_annotations.py new file mode 100644 index 000000000..0c1f065a1 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_annotations.py @@ -0,0 +1,49 @@ +from __future__ import annotations +import torch +import triton +import triton.language as tl +import pytest + + +def annotated_function(return_type=None, **arg_types): + """A decorator to add annotations to a function.""" + + def decorator(func): + func.__annotations__ = {**arg_types, 'return': return_type} + return func + + return decorator + + +# Test integer annotations +@pytest.mark.parametrize(("signed", "width"), [ + (signed, width) for signed in [False, True]\ + for width in [8, 16, 32, 64] +] + [(False, 1)] + ) +def test_int_annotation(signed, width, device): + + @triton.jit + @annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}") + def _kernel(X, v): + tl.store(X, v) + + h = _kernel[(1, )](torch.empty(1, device=device), 3) + pfx = 'si' if signed else 'ui' + assert f'%arg1: i{width}' in h.asm["ttir"] + assert f'arith.{pfx}tofp' in h.asm["ttir"] + + +# Test that unknown annotations do not emit an error +def test_unknown_annotation(device): + + @triton.jit + def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): + pass + + x = torch.empty(1, device=device) + _kernel[(1, )](x, x.shape[0], 32) + try: + _kernel[(1, )](x.shape[0], x.shape[0], 32) + except AttributeError: + pass diff --git a/third_party/xpu/python/test/unit/language/test_block_pointer.py b/third_party/xpu/python/test/unit/language/test_block_pointer.py new file mode 100644 index 000000000..7f7877652 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_block_pointer.py @@ -0,0 +1,102 @@ +import pytest +import torch + +import triton +import triton.language as tl + +pytest.skip("Skip for kunlunxin", allow_module_level=True) + + +@triton.jit +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr): + pid = tl.program_id(0) + # We only copy half of the data to see if the padding works + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) + tl.store(b_block_ptr, a, boundary_check=(0, )) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtypes_str, n, padding_option", [ # + (dtypes_str, n, padding) + for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("float16", "float16"), ("int16", "float16")) + for n in (64, 128, 256, 512, 1024) + for padding in ("zero", "nan") # +]) +def test_block_copy(dtypes_str, n, padding_option, device): + src_dtype_str = dtypes_str[0] + dst_dtype_str = dtypes_str[0] + src_dtype = getattr(torch, src_dtype_str) + dst_dtype = getattr(torch, dst_dtype_str) + if src_dtype_str in ("bool", "int16"): + if padding_option == "nan": + pytest.skip("Padding with NaN is not supported for integer types") + a = torch.randint(0, 2, (n, ), device=device, dtype=src_dtype) + else: + a = torch.randn((n, ), device=device, dtype=src_dtype) + b = torch.zeros((n, ), device=device, dtype=dst_dtype) + + grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) + a.to(dst_dtype) + assert torch.all(a[0:n // 2] == b[0:n // 2]) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + else: + assert torch.all(torch.isnan(b[n // 2:n])) + + +@triton.jit +def matmul_no_scf_with_advance_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr # +): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) + # Below two lines are just for testing negative offsets for the `advance` API, which could be removed + a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K)) + a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K)) + a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero") + b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero") + + c = tl.dot(a, b) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, c) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, num_warps", [ # + (shape, num_warps) for shape in [ + [64, 64, 16], + [64, 64, 32], + [64, 64, 64], + ] for num_warps in [4, 8] +]) +def test_block_ptr_matmul_no_scf(shape, num_warps, device): + m, n, k = shape + a = torch.randn((m, k), device=device, dtype=torch.float16) + b = torch.randn((k, n), device=device, dtype=torch.float16) + c = torch.empty((m, n), device=device, dtype=torch.float32) + + grid = lambda META: (1, ) + matmul_no_scf_with_advance_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=m, N=n, K=k, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # + num_warps=num_warps) + golden = torch.matmul(a, b) + torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/third_party/xpu/python/test/unit/language/test_compile_errors.py b/third_party/xpu/python/test/unit/language/test_compile_errors.py new file mode 100644 index 000000000..ab3b28a41 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_compile_errors.py @@ -0,0 +1,306 @@ +import pytest + +import triton +import triton.language as tl +from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure +import traceback + + +def test_err_undefined_variable(): + + @triton.jit + def kernel(): + a += 1 # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "is not defined" in str(e.value), "error should mention the undefined variable" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_operator(): + + @triton.jit + def kernel(): + 0 + "a" + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the 0" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_static_assert(): + + @triton.jit + def kernel(): + tl.static_assert(isinstance(0, tl.tensor)) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert isinstance(e.value, CompileTimeAssertionFailure) + assert e.value.__cause__ is None + assert "at 2:4:" in str(e.value), "error should point to the static_assert call" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_unary_op(): + # Currently Triton can't evaluate `not` of a tuple at compile time. That's + # ok, but the error message needs to point to the correct spot. + @triton.jit + def kernel(): + not (0, 0) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert e.value.__cause__ is None + assert "at 2:4:" in str(e.value), "error should point to the `not`" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_binary_op(): + + @triton.jit + def kernel(): + 1.0 << 1 + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the 1.0" + assert "" not in str(e.value) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +# This has to be defined as a top-level function; jit'ed functions can't call +# nested functions. +@triton.jit +def nested_call(): + xyz # noqa + + +def test_err_in_nested_call(): + + @triton.jit + def kernel(): + # this is a comment to push nested_call() onto the next line + nested_call() + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + inner = e.value.__cause__ + outer = e.value + assert "at 2:4:" in str(inner), "error should point to xyz" + assert "" not in str(inner) + + assert "at 3:4" in str(outer), "error should point to the nested_call" + assert "" not in str(outer) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +def test_err_in_builtin(): + + # The root error here comes from core.py. Make sure the stacktrace reflects + # this. + @triton.jit + def kernel(): + tl.expand_dims(None, -1) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + try: + inner = e.value.__cause__ + outer = e.value + assert "/core.py" in '\n'.join(traceback.format_tb(inner.__traceback__)), "error should point inside core.py" + + assert "at 2:4:" in str(outer), "error should point to expand_dims call" + assert "" not in str(outer) + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@triton.jit +def two_returns(): + return tl.arange(0, 4) + return tl.arange(0, 8) + + +def test_two_returns_no_err(): + # This program is valid; `a` has shape (10,). + @triton.jit + def kernel(): + a = two_returns() + a + tl.arange(0, 4) # only works if we took the first return + + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +@triton.jit +def returns_branched_on_constexpr(N: tl.constexpr): + if N == 0: + return tl.arange(0, 4) + # Ideally this would work even without the `else`, but we're not that smart + # yet. + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_constexpr(): + + @triton.jit + def kernel1(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 4) + + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0})) + + @triton.jit + def kernel2(N: tl.constexpr): + a = returns_branched_on_constexpr(N) + a + tl.arange(0, 8) + + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1})) + + +@triton.jit +def returns_branched_on_non_constexpr(N: int): + if N == 0: + return tl.arange(0, 4) + else: + return tl.arange(0, 8) + + +def test_returns_branched_on_non_constexpr(): + + @triton.jit + def kernel(N: int): + returns_branched_on_non_constexpr(N) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + + try: + assert "at 2:4:" in str(e.value), "error should point to the function call" + assert "at 5:8:" in str(e.value.__cause__), "error should point to the second `return`" + except AssertionError as assertion_err: + raise assertion_err from e.value + + +@pytest.mark.skip("Skip for kunlunxin") +def test_power_of_two_shapes(): + + @triton.jit + def kernel(): + tl.arange(2, 7) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert str(e.value.__cause__) == "arange's range must be a power of 2" + + +@pytest.mark.skip("Skip for kunlunxin") +def test_power_of_two_shapes_2(): + + @triton.jit + def kernel(): + tl.full((33, ), 0, dtype=tl.int64) + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" + + +def test_captured_var_access(): + + CAPTURED = 42 + + @triton.jit + def kernel(): + a = CAPTURED # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert "CAPTURED is not defined" in str(e.value) + + +GLOBAL = 42 + + +def test_global_var_access(): + + @triton.jit + def kernel(): + a = GLOBAL # noqa + + with pytest.raises(CompilationError) as e: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + assert "global variable" in str(e.value) + + +CONSTEXPR_ANNOTATED_GLOBAL: tl.constexpr = 42 + + +def test_constexpr_annotated_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_ANNOTATED_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +CONSTEXPR_GLOBAL = tl.constexpr(42) + + +def test_constexpr_global_var_access(): + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +TYPE_ALIAS = tl.pointer_type(tl.int32) + + +def test_global_type_alias_access(): + + @triton.jit + def kernel(): + a = TYPE_ALIAS # noqa + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + + +def test_global_access_in_fn_default_arg(): + + @triton.jit + def kernel(a=GLOBAL): + pass + + # No error. + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={0: "i32"}, constants={})) diff --git a/third_party/xpu/python/test/unit/language/test_conversions.py b/third_party/xpu/python/test/unit/language/test_conversions.py new file mode 100644 index 000000000..27de3e00a --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_conversions.py @@ -0,0 +1,358 @@ +# fmt: off + + +import os +import numpy as np +import torch +import pytest +import triton +import triton.language as tl + + +pytest.skip("Skip for kunlunxin", allow_module_level=True) + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + +def is_cuda(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda" + +def is_hip(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip" + +def is_on_mi300(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942') + +def matching_int(dtype): + if dtype.primitive_bitwidth == 8: + return torch.int8 + elif dtype.primitive_bitwidth == 16: + return torch.int16 + elif dtype.primitive_bitwidth == 32: + return torch.int32 + elif dtype.primitive_bitwidth == 64: + return torch.int64 + else: + raise ValueError('unsupported number of bits') + +@triton.jit +def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) + tl.store(dst + idxs, y) + + +def launch_type_convert_triton(src, src_dtype, dst_dtype, device, rounding=None, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE) + return dst + + +@triton.jit +def exhaustive_populate(dst, offset, BLOCK_SIZE : tl.constexpr, force_odd : tl.constexpr, output_bits : tl.constexpr, max_repr : tl.constexpr): + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + vals = (idxs + offset).to(tl.uint32) + + # pseudorandom permutation: + multiplier = vals << 1 + multiplier += 3511 + vals *= multiplier + + if force_odd: + vals *= 2 + vals += 1 + + if (output_bits == 8): + vals &= 0xff + avals = vals & 0x7f + elif (output_bits == 16): + vals &= 0xffff + avals = vals & 0x7fff + elif (output_bits == 32): + avals = vals & 0x7fffffff + + vals = tl.where(avals <= max_repr, vals, 0) + + if (output_bits == 8): + vals = vals.to(tl.uint8) + elif (output_bits == 16): + vals = vals.to(tl.uint16) + + vals = vals.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, vals) + + +def launch_exhaustive_populate(dst_dtype, offset, numel, force_odd, output_bits, max_repr, device, BLOCK_SIZE=4096): + + assert(numel % BLOCK_SIZE == 0) + dst = torch.empty((numel,), dtype=matching_int(dst_dtype), device=device) + exhaustive_populate[(numel // BLOCK_SIZE,)](triton.reinterpret(dst, dst_dtype), offset, BLOCK_SIZE, force_odd, output_bits, max_repr) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. We don't need to have that + # as input to the conversion kernels. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def arbitrary_fp32_downcast(x, rounding : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(x.dtype == tl.float32, "input must be float32") + numbits_dst : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_dst == 8) or (numbits_dst == 16), "numbits_dst must be 8 or 16") + + x = x.to(tl.uint32, bitcast=True) + + mantissa = (x & 0x7fffff) + exponent = ((x >> 23) & 0xff).to(tl.int32) + mantissa = tl.where(exponent == 0, mantissa, mantissa + 0x800000).to(tl.int32) + exponent = tl.where(exponent == 0, exponent, exponent - 1) + + sign = (x >> 31) + + exponent = exponent + exponent_bias - 127 + adjustment : tl.constexpr = 0.5 ** (23 - mantissa_bits) + mantissa = mantissa.to(tl.float32) * adjustment + + # make exponent nonnegative: + mantissa = tl.where(exponent > -16, mantissa, 0.0) # destination has fewer than 16 mantissa bits, so safe + exponent = tl.where(exponent > -16, exponent, 0) + mantissa = tl.where(exponent > -8, mantissa, mantissa * 0.00390625) + exponent = tl.where(exponent > -8, exponent, exponent + 8) + mantissa = tl.where(exponent > -4, mantissa, mantissa * 0.0625) + exponent = tl.where(exponent > -4, exponent, exponent + 4) + mantissa = tl.where(exponent > -2, mantissa, mantissa * 0.25) + exponent = tl.where(exponent > -2, exponent, exponent + 2) + mantissa = tl.where(exponent > -1, mantissa, mantissa * 0.5) + exponent = tl.where(exponent > -1, exponent, exponent + 1) + + if rounding == 'rtne': + # Bring the value to the range [2 ** 23, 2 ** 24] + # where the representable floats map exactly to integers. + # Addition has RTNE semantics. + mantissa += 0x800000 + # Bring the value back to the original range. + mantissa -= 0x800000 + mantissa = mantissa.to(tl.int32) + elif rounding == 'rtz': + mantissa = mantissa.to(tl.int32) + else: + raise ValueError('unrecognized rounding mode') + + # Reassemble output floating-point representation: + exponent = exponent.to(tl.uint32) + y = (sign << (exponent_bits + mantissa_bits)) + (exponent << mantissa_bits) + mantissa + if numbits_dst == 8: + y = y.to(tl.uint8) + elif numbits_dst == 16: + y = y.to(tl.uint16) + return y + + +@triton.jit +def downcast_emulated(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + tl.static_assert(src.dtype.element_ty == tl.float32, "src dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = arbitrary_fp32_downcast(x, rounding, exponent_bits, mantissa_bits, exponent_bias) + y = y.to(dst.dtype.element_ty, bitcast=True) + tl.store(dst + idxs, y) + + +def launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=matching_int(dst_dtype), device=device) + downcast_emulated[(src.shape[0] // BLOCK_SIZE,)]( + triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), rounding, BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + # 0x80 in float8e4b8 or float8e5b16 represents inf/nan. downcast_emulated kernel will + # convert -0. in higher precision to 0x80 and thus need to fix the result to 0. + if dst_dtype == tl.float8e4b8 or dst_dtype == tl.float8e5b16: + dst = torch.where(dst == 0x80, 0, dst) + return dst + + +@triton.jit +def upcast_emulated(src, dst, BLOCK_SIZE : tl.constexpr, exponent_bits : tl.constexpr, mantissa_bits : tl.constexpr, exponent_bias : tl.constexpr): + + exponent_compensator : tl.constexpr = 2.0 ** (127 - exponent_bias) + + numbits_src : tl.constexpr = 1 + exponent_bits + mantissa_bits + tl.static_assert((numbits_src == 8) or (numbits_src == 16), "numbits_src must be 8 or 16") + tl.static_assert(dst.dtype.element_ty == tl.float32, "dst dtype must be float32") + + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + x = tl.load(src + idxs) + + if numbits_src == 8: + x = x.to(tl.uint8, bitcast=True) + elif numbits_src == 16: + x = x.to(tl.uint16, bitcast=True) + + x = x.to(tl.uint32) + + mantissa_mask : tl.constexpr = (1 << mantissa_bits) - 1 + exponent_mask : tl.constexpr = (1 << exponent_bits) - 1 + + mantissa = x & mantissa_mask + exponent = (x >> mantissa_bits) & exponent_mask + sign = (x >> (numbits_src - 1)) + + y = (sign << 31) | (exponent << 23) | (mantissa << (23 - mantissa_bits)) + y = y.to(tl.float32, bitcast=True) + y = y * exponent_compensator + + tl.store(dst + idxs, y) + + +def launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device, BLOCK_SIZE=4096): + + dst = torch.empty(src.shape, dtype=torch.int32, device=device) + upcast_emulated[(src.shape[0] // BLOCK_SIZE,)](src, triton.reinterpret(dst, tl.float32), BLOCK_SIZE, exponent_bits, mantissa_bits, exponent_bias) + return dst + + +def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, max_repr, offset, device): + + src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr, device) + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device, rounding=rounding) + src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device) + + dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device) + + dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias, device=device) + dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias, device=device) + + if not (torch.equal(dst, dst2)): + print('Error!!!') + + dst = dst.cpu().detach().numpy() + dst2 = dst2.cpu().detach().numpy() + src = src.cpu().detach().numpy() + + print(src[dst != dst2][0]) + print(dst[dst != dst2][0]) + print(dst2[dst != dst2][0]) + print(hex(src.view(np.uint32)[dst != dst2][0])) + print(hex(dst.view(np.uint32)[dst != dst2][0])) + print(hex(dst2.view(np.uint32)[dst != dst2][0])) + print('') + raise ValueError('%d elements mismatch' % (dst != dst2).sum()) + + +def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bias, max_repr, device): + + numbits_src = exponent_bits + mantissa_bits + 1 + + src = launch_exhaustive_populate(src_dtype, 0, 65536, False, numbits_src, max_repr, device=device) + + dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device) + dst = launch_type_convert_triton(dst, dst_dtype, tl.float32, device=device) + + dst2 = launch_upcast_emulated(src, exponent_bits, mantissa_bits, exponent_bias, device=device) + + assert(torch.equal(dst, dst2)) + + +@pytest.mark.parametrize("src_dtype, dst_dtype", [ + ('float16', 'float32'), + ('bfloat16', 'float32'), + + ('float8e5', 'float16'), + ('float8e5', 'bfloat16'), + ('float8e5', 'float32'), + + ('float8e4b15', 'float16'), + # ('float8e4b15', 'bfloat16'), # Unsupported conversion from f8E4M3B11FNUZ to bf16 + ('float8e4b15', 'float32'), + + ('float8e4nv', 'float16'), + ('float8e4nv', 'bfloat16'), + ('float8e4nv', 'float32'), + + ('float8e4b8', 'float32'), + ('float8e4b8', 'float16'), + + ('float8e5b16', 'float32'), + ('float8e5b16', 'float16'), +]) +def test_typeconvert_upcast(src_dtype, dst_dtype, device): + + if src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("float8e4nv upcast tests only supported on NVGPU with compute capability 9.0+") + + if src_dtype in ('float8e4nv', 'float8e4b15') and is_hip(): + pytest.skip(f"{src_dtype} upcast tests not supported on ROCm") + + if src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()): + pytest.skip("{src_dtype} upcast tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) + stuff = { + 'float8e4b15': (4, 3, 15, 0x7e), + 'float8e4nv': (4, 3, 7, 0x7e), + 'float8e5': (5, 2, 15, 0x7b), + 'float8e4b8': (4, 3, 8, 0x7f), + 'float8e5b16': (5, 2, 16, 0x7f), + 'float16': (5, 10, 15, 0x7bff), + 'bfloat16': (8, 7, 127, 0x7f7f), + }[src_dtype] + + upcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), *stuff, device=device) + +@pytest.mark.parametrize("src_dtype, dst_dtype, rounding, max_repr", [ + ('float32', 'float16', 'rtne', 0x477fe000), + ('float32', 'float16', 'rtz', 0x477fe000), + ('float32', 'bfloat16', 'rtne', 0x7f7f0000), + ('float32', 'bfloat16', 'rtz', 0x7f7f0000), + ('float32', 'float8e5', 'rtne', 0x47600000), + ('float32', 'float8e5', 'rtz', 0x47600000), + ('float32', 'float8e4nv', 'rtne', 0x43e00000), + ('float32', 'float8e4b8', 'rtne', 0x43700000), + ('float32', 'float8e5b16', 'rtne', 0x47600000), + # ('float32', 'float8e4b15', 'rtne', 0x3fe00000), # Skip, no HW rtne conversion from f32 to f8e4b15 + + ('bfloat16', 'float8e5', 'rtne', 0x4760), + ('bfloat16', 'float8e4nv', 'rtne', 0x43e0), + + ('float16', 'float8e5', 'rtne', 0x7b00), + ('float16', 'float8e4nv', 'rtne', 0x5f00), + + ('bfloat16', 'float8e5b16', 'rtne', 0x4760), + ('bfloat16', 'float8e4b8', 'rtne', 0x4370), + + ('float16', 'float8e5b16', 'rtne', 0x7b00), + ('float16', 'float8e4b8', 'rtne', 0x5b80), +]) +def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): + + if src_dtype != 'float32' and is_cuda() and torch.cuda.get_device_capability(0) < (9, 0): + pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.get_device_capability(0) < (9, 0)): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") + + # dtype : (exponent_bits, mantissa_bits, exponent_bias) + stuff = { + 'float16': (5, 10, 15), + 'bfloat16': (8, 7, 127), + 'float8e5': (5, 2, 15), + 'float8e4b15': (4, 3, 15), + 'float8e4nv': (4, 3, 7), + 'float8e4b8': (4, 3, 8), + 'float8e5b16': (5, 2, 16), + }[dst_dtype] + + for i in range(256): + downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i, device=device) diff --git a/third_party/xpu/python/test/unit/language/test_core.py b/third_party/xpu/python/test/unit/language/test_core.py new file mode 100644 index 000000000..8e0e8bba8 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_core.py @@ -0,0 +1,5490 @@ +# flake8: noqa: F821,F841 +import itertools +import re +from typing import Optional, Union +import math +import textwrap +import tempfile + +import numpy as np +import pytest +import torch +import os +import inspect +from numpy.random import RandomState + +import triton +import triton.language as tl +from triton.runtime.jit import TensorWrapper, reinterpret + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cuda(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "hip" + + +int_dtypes = ['int8', 'int16', 'int32', 'int64'] +uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] +float_dtypes = ['float16', 'float32', 'float64'] +dtypes = int_dtypes + uint_dtypes + float_dtypes +dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] +# ===-------------------- For Triton XPU -----------------------=== +torch_float8_dtypes = [] +# ===-----------------------------------------------------------=== +torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] + +# TODO: enable multiple cta cluster testing. +# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1] +num_ctas_list = [1] + +GPU_DIALECT = "triton_gpu" +if is_interpreter(): + THREADS_PER_WARP = 1 +elif is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +def _bitwidth(dtype: str) -> int: + # ex.: "int64" -> 64 + return int(re.search(r'(\d+)$', dtype).group(1)) + + +def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): + """ + Override `rs` if you're calling this function twice and don't want the same + result for both calls. + """ + if isinstance(shape, int): + shape = (shape, ) + if rs is None: + rs = RandomState(seed=17) + if dtype_str in int_dtypes + uint_dtypes: + iinfo = np.iinfo(getattr(np, dtype_str)) + low = iinfo.min if low is None else max(low, iinfo.min) + high = iinfo.max if high is None else min(high, iinfo.max) + dtype = getattr(np, dtype_str) + x = rs.randint(low, high, shape, dtype=dtype) + x[x == 0] = 1 # Workaround. Never return zero so tests of division don't error out. + return x + elif dtype_str and 'float8' in dtype_str: + x = rs.randint(20, 40, shape, dtype=np.int8) + return x + elif dtype_str in float_dtypes: + return rs.normal(0, 1, shape).astype(dtype_str) + elif dtype_str == 'bfloat16': + return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') + elif dtype_str in ['bool', 'int1', 'bool_']: + return rs.normal(0, 1, shape) > 0.0 + else: + raise RuntimeError(f'Unknown dtype {dtype_str}') + + +def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]: + ''' + Note: We need dst_type because the type of x can be different from dst_type. + For example: x is of type `float32`, dst_type is `bfloat16`. + If dst_type is None, we infer dst_type from x. + ''' + t = x.dtype.name + if t in uint_dtypes: + signed_type_name = t.lstrip('u') # e.g. "uint16" -> "int16" + x_signed = x.astype(getattr(np, signed_type_name)) + return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t)) + else: + if dst_type and 'float8' in dst_type: + return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type)) + if t == 'float32' and dst_type == 'bfloat16': + return torch.tensor(x, device=device).bfloat16() + return torch.tensor(x, device=device) + + +def torch_dtype_name(dtype) -> str: + if isinstance(dtype, triton.language.dtype): + return dtype.name + elif isinstance(dtype, torch.dtype): + # 'torch.int64' -> 'int64' + m = re.match(r'^torch\.(\w+)$', str(dtype)) + return m.group(1) + else: + raise TypeError(f'not a triton or torch dtype: {type(dtype)}') + + +def to_numpy(x): + if isinstance(x, TensorWrapper): + return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype))) + elif isinstance(x, torch.Tensor): + if x.dtype is torch.bfloat16: + return x.cpu().float().numpy() + return x.cpu().numpy() + else: + raise ValueError(f"Not a triton-compatible tensor: {x}") + + +def patch_kernel(template, to_replace): + if is_interpreter(): + local_namespace = {} + src = textwrap.dedent(inspect.getsource(template.fn)) + for k, v in to_replace.items(): + src = src.replace(k, v) + exec(src, globals(), local_namespace) + return local_namespace[template.fn.__name__] + else: + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel.src = kernel.src.replace(key, value) + return kernel + + +def check_cuda_or_hip(device): + # CUDA and HIP both use pytorch device 'cuda'. Other backends like Intel + # GPU do not. + if device not in ['cuda']: + pytest.skip("Only for cuda") + + +def check_type_supported(dtype, device): + ''' + skip test if dtype is not supported on the current device + ''' + if device in ['cuda']: + cc = torch.cuda.get_device_capability() + if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): + pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") + if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}: + pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") + if is_interpreter(): + if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]: + pytest.skip("bfloat16 is not supported in the interpreter") + + +class MfmaLayout: + + def __init__(self, version, warps_per_cta, instr_shape, is_transposed): + self.version = version + self.warps_per_cta = warps_per_cta + self.instr_shape = instr_shape + self.is_transposed = is_transposed + + def __str__(self): + return f"#{GPU_DIALECT}.amd_mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, instrShape={self.instr_shape}, isTransposed = {str(self.is_transposed).lower()}}}>" + + +class WmmaLayout: + + def __init__(self, warps_per_cta): + self.warps_per_cta = warps_per_cta + + def __str__(self): + return f"#{GPU_DIALECT}.amd_wmma<{{warpsPerCTA = {self.warps_per_cta}}}>" + + +class MmaLayout: + + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): + self.version = version + self.warps_per_cta = warps_per_cta + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + self.instr_shape = instr_shape + + def __str__(self): + return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +class SharedLayout: + + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): + self.vec = vec + self.per_phase = per_phase + self.max_phase = max_phase + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +def is_layout_applicable(layout) -> bool: + common_layouts = [BlockedLayout, SharedLayout] + if layout in common_layouts: + return True + elif is_cuda(): + return isinstance(layout, MmaLayout) + elif is_hip(): + target_arch = triton.runtime.driver.active.get_current_target().arch + if "gfx11" in target_arch: + # RDNA 3 + return isinstance(layout, WmmaLayout) + elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch): + # CDNA 1, 2, 3 + return isinstance(layout, MfmaLayout) + else: + return False + else: + return True + + +def filter_layouts(layouts): + return [l for l in layouts if is_layout_applicable(l)] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) +def test_empty_kernel(dtype_x, device): + SIZE = 128 + + @triton.jit + def kernel(X, SIZE: tl.constexpr): + pass + + check_type_supported(dtype_x, device) + x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) + kernel[(1, )](x, SIZE=SIZE, num_warps=4) + + +# generic test functions +def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr}) + # inputs + x = numpy_random(SIZE, dtype_str=dtype_x) + if 'log' in expr: + x = np.abs(x) + 0.01 + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_x) + kernel[(1, )](Z=z_tri, X=x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + # compare + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + + +def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: + """ + Given two dtype strings, returns the numpy dtype Triton thinks binary + operations on the two types should return. Returns None if the return value + matches numpy. This is generally needed because Triton and pytorch return + narrower floating point types than numpy in mixed operations, and because + Triton follows C/C++ semantics around mixed signed/unsigned operations, and + numpy/pytorch do not. + """ + overrides = { + ('float16', 'int16'): np.float16, + ('float16', 'int32'): np.float16, + ('float16', 'int64'): np.float16, + ('float16', 'uint16'): np.float16, + ('float16', 'uint32'): np.float16, + ('float16', 'uint64'): np.float16, + ('int8', 'uint8'): np.uint8, + ('int8', 'uint16'): np.uint16, + ('int8', 'uint32'): np.uint32, + ('int8', 'uint64'): np.uint64, + ('int16', 'uint16'): np.uint16, + ('int16', 'uint32'): np.uint32, + ('int16', 'uint64'): np.uint64, + ('int32', 'uint32'): np.uint32, + ('int32', 'uint64'): np.uint64, + ('int64', 'uint64'): np.uint64, + } + key = (a, b) if a < b else (b, a) + return overrides.get(key) + + +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, + y_low=None, y_high=None, test_broadcast=True): + check_type_supported(dtype_x, device) # early return if dtype_x is not supported + check_type_supported(dtype_y, device) + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_lhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X) + y = tl.load(Y + off) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + @triton.jit + def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z + off, z) + + replacements = {'GENERATE_TEST_HERE': expr} + kernel = patch_kernel(kernel, replacements) + kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements) + kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements) + + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high) + if mode_x == 'nan': + x[:] = float('nan') + if mode_y == 'nan': + y[:] = float('nan') + + def do_test(x, y, kernel_fn): + # reference result + z_ref = eval(expr if numpy_expr is None else numpy_expr) + dtype_z = _binary_op_dtype_override(dtype_x, dtype_y) + if dtype_z is not None: + z_ref = z_ref.astype(dtype_z) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) + kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) + err_msg = f"{expr}, {kernel_fn.__name__}" + np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=1e-3, rtol=0.01) + + do_test(x, y, kernel) + if test_broadcast: + do_test(x[:1].reshape(()), y, kernel_broadcast_lhs) + do_test(x, y[:1].reshape(()), kernel_broadcast_rhs) + + +def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: + # The result of x % y is ill-conditioned if x % y is much smaller than x. + # pytorch/CUDA has slightly different (probably better) rounding on + # remainders than stock LLVM. We currently don't expect to match it + # bit-for-bit. + return (dtype_x, dtype_y) in [ + ('int32', 'bfloat16'), + ('int32', 'float16'), + ('int32', 'float32'), + ('int64', 'bfloat16'), + ('int64', 'float16'), + ('int64', 'float32'), + ('int64', 'float64'), + # ===-------------------- For Triton XPU -----------------------=== + # Triton XPU Can Pass This Dtype Combination + # ('uint16', 'bfloat16'), + # ('uint16', 'float16'), + # ('uint16', 'float32'), + # Triton XPU Cannot Pass This Dtype Combination + ('int32', 'float64'), + ('uint16', 'float64'), + ('uint32', 'float64'), + # ===-----------------------------------------------------------=== + ('uint32', 'bfloat16'), + ('uint32', 'float16'), + ('uint32', 'float32'), + ('uint64', 'bfloat16'), + ('uint64', 'float16'), + ('uint64', 'float32'), + ('uint64', 'float64'), + ] + + +def test_dtype_codegen(): + for dtype in dtypes_with_bfloat16: + full_name = f"triton.language.{dtype}" + assert repr(eval(full_name)) == full_name + + +# --------------- +# test binary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['+', '-', '*', '/', '%'] + for dtype_x in dtypes_with_bfloat16 + for dtype_y in dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f' x {op} y' + if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + numpy_expr = 'np.fmod(x, y)' + elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', + 'bfloat16'): + # Triton promotes 16-bit floating-point / and % to 32-bit because there + # are no native div or FRem operations on float16. Since we have to + # convert anyway, we may as well take the accuracy bump. + numpy_expr = f'x.astype(np.float32) {op} y.astype(np.float32)' + elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): + with pytest.raises(AssertionError, match="Not equal to tolerance"): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + else: + _test_binary( + dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, + # fails with values where fmod(x, y) is roughly zero, but happens to + # pass with the random values chosen for non-broadcast tests + test_broadcast=(op != "%")) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) +def test_addptr(dtype, order, device): + check_type_supported(dtype, device) + + @triton.jit + def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): + offs = tl.arange(0, SIZE) + if ORDER == 0: + tl.store(y + offs, tl.load(x + offs)) + else: + tl.store(offs + y, tl.load(offs + x)) + + SIZE = 1024 + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + x_tri = to_triton(x, dst_type=dtype, device=device) + y_tri = to_triton(y, dst_type=dtype, device=device) + y = x + kernel[ + 1, + ](x_tri, y_tri, order, SIZE) + np.testing.assert_allclose(y, to_numpy(y_tri)) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y", [ # + (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes +] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_floordiv(dtype_x, dtype_y, num_ctas, device): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + expr = 'x // y' + numpy_expr = '((x - np.fmod(x, y)) / y)' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +def test_unsigned_name_mangling(device): + # Test that uint32 and int32 are mangled differently by the compiler + SIZE = 128 + # define the kernel / launch-grid + + @triton.jit + def kernel(O1, O2, X, Y, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + out1 = tl.abs(x) # uint32 -> nop + out2 = tl.abs(-y) # int32 -> should have an effect + tl.store(O1 + off, out1) + tl.store(O2 + off, out2) + + dtype_x = 'uint32' + dtype_y = 'int32' + # inputs + rs = RandomState(17) + x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs) + # reference result + expect = (np.abs(x), np.abs(-y)) + # triton result + x_tri = to_triton(x, device=device, dst_type=dtype_x) + y_tri = to_triton(y, device=device, dst_type=dtype_y) + actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect) + kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) + + # Bitwise op, so expect exact equality + assert (expect[0] == to_numpy(actual[0])).all() + assert (expect[1] == to_numpy(actual[1])).all() + + +# test bitwise ops +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['&', '|', '^'] + for dtype_x in dtypes + dtypes_with_bfloat16 + for dtype_y in dtypes + dtypes_with_bfloat16 +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + if 'float' in dtype_x + dtype_y: + # The CompilationError must have been caused by a C++ exception with this text. + with pytest.raises(triton.TritonError, match='invalid operands of type'): + _test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device, num_ctas=num_ctas) + else: + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # + (dtype_x, dtype_y, op) + for op in ['<<', '>>'] + for dtype_x in int_dtypes + uint_dtypes + for dtype_y in int_dtypes + uint_dtypes +]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): + expr = f'x {op} y' + bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y)) + if dtype_x.startswith('int'): + dtype_z = f'int{bw}' + else: + dtype_z = f'uint{bw}' + numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})' + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, y_low=0, y_high=bw) + + +# --------------- +# test compare ops +# --------------- +ops = ['==', '!=', '>', '<', '>=', '<='] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "dtype_x, dtype_y, op, mode_x, mode_y", + # real + [(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes] + # NaNs + + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): + if mode_x == "nan" or mode_y == "nan": + pytest.skip("Skip for kunlunxin") + expr = f'x {op} y' + if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)): + numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})' + elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)): + numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})' + else: + numpy_expr = None + _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas) + + +# --------------- +# test broadcast +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) +def test_broadcast(dtype, device): + if dtype == "float64": + pytest.skip("Skip for kunlunxin") + check_type_supported(dtype, device) + + @triton.jit + def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) + y = tl.load(y_ptr + offset2) + _, y_broadcasted = tl.broadcast(x, y) + tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) + + M = 32 + N = 64 + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype, rs=rs) + y = numpy_random(N, dtype_str=dtype, rs=rs) + _, y_broadcasted_np = np.broadcast_arrays(x, y) + + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) + + broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + + +# ---------- +# test slice +# ---------- + + +@pytest.mark.interpreter +def test_slice(device): + + @triton.jit + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) + + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1, )](XBLOCK=32) + + +# ------------------ +# test invalid slice +# ------------------ + + +@pytest.mark.interpreter +def test_invalid_slice(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + dst[10:] + + with pytest.raises(triton.TritonError, match='unsupported tensor index'): + _kernel[(1, )](dst=dst) + + +# ---------------- +# test expand_dims +# ---------------- +@pytest.mark.interpreter +def test_expand_dims(device): + + @triton.jit + def expand_dims_kernel(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 0) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, 1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -2) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, (0, -1)) + tl.static_assert(t.shape == [1, N, 1]) + + t = tl.expand_dims(offset1, (0, 1, 3)) + tl.static_assert(t.shape == [1, 1, N, 1]) + + t = tl.expand_dims(offset1, (-4, 2, -1)) + tl.static_assert(t.shape == [1, N, 1, 1]) + + t = tl.expand_dims(offset1, (3, 1, 2)) + tl.static_assert(t.shape == [N, 1, 1, 1]) + + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + + # N is a scalar that's not even a tl.tensor -- this should work too. + t = tl.expand_dims(N, -1) + tl.static_assert(t.shape == [1]) + + N = 32 + dummy_tensor = torch.empty((), device=device) + expand_dims_kernel[(1, )](dummy_tensor, N) + + +@pytest.mark.interpreter +def test_expand_dims_error_cases(device): + + @triton.jit + def dim_out_of_range1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, -2) + t = tl.expand_dims(offset1, -3) + + @triton.jit + def dim_out_of_range2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 1) + t = tl.expand_dims(offset1, 2) + + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + + @triton.jit + def duplicate_dim1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, 0)) + + @triton.jit + def duplicate_dim2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, -3)) + + N = 32 + dummy_tensor = torch.empty((), device=device) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range1[(1, )](dummy_tensor, N) + assert "invalid axis -3" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range2[(1, )](dummy_tensor, N) + assert "invalid axis 2" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + dim_out_of_range3[(1, )](dummy_tensor, N) + assert "invalid axis 1" in str(exc_info.value.__cause__) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim1[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + with pytest.raises(triton.TritonError) as exc_info: + duplicate_dim2[(1, )](dummy_tensor, N) + assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__)) + + +# ---------------------------- +# test invalid program id axis +# ---------------------------- +@pytest.mark.interpreter +def test_invalid_pid_axis(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pid = tl.program_id(20) + + with pytest.raises(triton.TritonError) as exc_info: + _kernel[(1, )](dst) + assert re.search(r"program_id axis must be 0, 1, or 2 but got 20", str(exc_info.value.__cause__)) + + +# --------------- +# test where +# --------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where(dtype, num_ctas, device): + if dtype == "float64": + pytest.skip("Skip for kunlunxin") + select_ptrs = False + if dtype == "*int32": + dtype = "int64" + select_ptrs = True + check_type_supported(dtype, device) + + @triton.jit + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + decide = tl.load(cond_ptr + offsets, mask=mask) + if TEST_SCALAR_POINTERS: + ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr) + output = tl.load(ptr + offsets, mask=mask) + else: + if TEST_POINTERS: + a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t) + b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t) + else: + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + output = tl.where(decide, a, b) + tl.store(output_ptr + offsets, output, mask=mask) + + SIZE = 1_000 + rs = RandomState(17) + cond = numpy_random(SIZE, 'bool', rs) + x = numpy_random(SIZE, dtype_str=dtype, rs=rs) + y = numpy_random(SIZE, dtype_str=dtype, rs=rs) + z = np.where(cond, x, y) + + cond_tri = to_triton(cond, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + y_tri = to_triton(y, device=device, dst_type=dtype) + z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) + + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) + assert (z == to_numpy(z_tri)).all() + if select_ptrs: + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=True) + z = np.where(cond[0], x, y) + assert (z == to_numpy(z_tri)).all() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_where_broadcast(num_ctas, device): + + @triton.jit + def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + + mask = tl.load(cond_ptr + yoffsets) + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + @triton.jit + def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] + yoffsets = tl.arange(0, BLOCK_SIZE)[None, :] + mask = 0 + vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets) + res = tl.where(mask, vals, 0.) + tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res) + + SIZE = 32 + dtype = 'float32' + rs = RandomState(17) + x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs) + mask = numpy_random(SIZE, 'bool', rs=rs) + z = np.where(mask, x, 0) + cond_tri = to_triton(mask, device=device) + x_tri = to_triton(x, device=device, dst_type=dtype) + z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) + where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE) + assert (z == to_numpy(z_tri)).all() + where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas) + z = np.where(0, x, 0) + assert (z == to_numpy(z_tri)).all() + + +# --------------- +# test unary ops +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr", + [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') + for dtype_x in int_dtypes]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_unary_op(dtype_x, expr, num_ctas, device): + _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) + + +# ---------------- +# test math ops +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, expr, x", + [(dtype_x, expr, x) + for dtype_x in ["float32", "float64"] + for expr in ['exp', 'log', 'cos', 'sin', 'exp2', 'log2', 'sqrt', 'floor', 'ceil'] + for x in ['x', '3.0']]) +def test_math_op(dtype_x, expr, x, device): + if expr in ['exp2', 'log2']: + pytest.skip("Skip for kunlunxin") + _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_erf_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.math.erf(x) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = torch.erf(x) + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) +def test_math_fma_op(dtype, device): + check_type_supported(dtype, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, Y, W, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + y = tl.load(Y + off) + w = tl.load(W + off) + z = tl.math.fma(x, y, w) + tl.store(Z + off, z) + + torch_dtype = torch.float32 if dtype == "float32" else torch.float64 + x = torch.randn(SIZE, dtype=torch_dtype, device=device) + y = torch.randn(SIZE, dtype=torch_dtype, device=device) + w = torch.randn(SIZE, dtype=torch_dtype, device=device) + z_ref = x * y + w + z_tri = torch.zeros_like(x) + kernel[(1, )](z_tri, x, y, w, SIZE=SIZE, num_warps=4) + torch.testing.assert_close(z_tri, z_ref) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_math_divide_op(expr, num_ctas, device): + numpy_expr = "x / y" + dtype = "float32" + _test_binary(dtype, dtype, expr, numpy_expr, device=device, num_ctas=num_ctas) + + +# ------------- +# test precise math +# ------------- +@pytest.mark.interpreter +@pytest.mark.parametrize("expr_prec, expr_ref", + [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), + ('tl.math.div_rn(x,y)', '(x.to(tl.float64) / y.to(tl.float64)).to(tl.float32)')]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_precise_math(expr_prec, expr_ref, num_ctas, device): + + @triton.jit + def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + prec = PREC_CALC + ref = REF_CALC + tl.store(OUT + tl.arange(0, BLOCK), prec) + tl.store(OUT_REF + tl.arange(0, BLOCK), ref) + + shape = (128, ) + out = torch.zeros(shape, dtype=torch.float32, device=device) + out_ref = torch.zeros(shape, dtype=torch.float32, device=device) + + x = torch.randn(shape, dtype=torch.float32, device=device) + y = torch.randn(shape, dtype=torch.float32, device=device) + + if (expr_prec.count('sqrt') > 0): + x = torch.abs(x) + + if (expr_prec.count('div') > 0): + y += 1e-6 + + kernel = patch_kernel(kernel, {'PREC_CALC': expr_prec, 'REF_CALC': expr_ref}) + + kernel[(1, )](x, y, out, out_ref, BLOCK=shape[0], num_ctas=num_ctas) + assert torch.all(out == out_ref) # bitwise exact + + +# ---------------- +# test abs +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_abs(dtype_x, device): + _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) +def test_abs_fp8(in_dtype, device): + if is_hip(): + pytest.skip('test_abs_fp8 not supported on HIP.') + + @triton.jit + def abs_kernel(X, Z, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + x = tl.load(X + off) + z = tl.abs(x) + tl.store(Z + off, z) + + f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device=device) + # f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan + all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width + f8_tensor[all_exp_ones] = 0 + f8 = triton.reinterpret(f8_tensor, in_dtype) + n_elements = f8_tensor.numel() + out_f8 = torch.empty_like(f8_tensor) + abs_kernel[(1, )](f8, triton.reinterpret(out_f8, in_dtype), n_elements) + + f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) + expect = f32_tensor.abs() + actual_f8 = convert_float_to_float32(out_f8, in_dtype) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) + + +# ---------------- +# test passing shapes as individual params rather than tuples +# ---------------- + + +@pytest.mark.interpreter +def test_shapes_as_params(device): + + @triton.jit + def kernel(): + a = tl.arange(0, 32).expand_dims(-1).broadcast_to(32, 32) + tl.static_assert(a.shape == [tl.constexpr(32), tl.constexpr(32)]) + + a = tl.arange(0, 32).reshape(4, 8).permute(1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4)]) + + a = tl.arange(0, 32).reshape(4, 8).reshape(32) + tl.static_assert(a.shape == [tl.constexpr(32)]) + + a = tl.arange(0, 64).reshape(2, 4, 8).trans(2, 1, 0) + tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)]) + + a = tl.arange(0, 64).view(2, 4, 8) + tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)]) + + kernel[(1, )]() + + +# ---------------- +# test transpose +# ---------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) +def test_transpose(dtype_x, device): + check_type_supported(dtype_x, device) + SIZE = 128 + + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + off = tl.arange(0, SIZE) + off2d = off[None, :] + (tl.arange(0, 2) * SIZE)[:, None] + x = tl.load(X + off2d) + z = x.T + tl.store(Z + off2d.T, z) + + x = numpy_random([SIZE, 2], dtype_str=dtype_x) + z_ref = x.T + x_tri = to_triton(x, device=device, dst_type=dtype_x) + z_tri = to_triton(np.empty_like(z_ref), device=device, dst_type=dtype_x) + kernel[(1, )](z_tri, x_tri, SIZE=SIZE) + np.testing.assert_allclose(z_ref, to_numpy(z_tri)) + + +# ---------------- +# test indexing +# ---------------- + + +def make_ptr_str(name, shape): + rank = len(shape) + offsets = [] + stride = 1 + for i in reversed(range(rank)): + idx = ', '.join([':' if ii == i else 'None' for ii in range(rank)]) + offsets += [f'tl.arange(0, {shape[i]})[{idx}]*{stride}'] + stride *= shape[i] + return f"{name} + {' + '.join(offsets)}" + + +# TODO: handle `%4 = triton_gpu.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16']]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_index1d(expr, dtype_str, num_ctas, device): + rank_x = expr.count(':') + rank_y = expr.count(',') + 1 + shape_x = [32 for _ in range(rank_x)] + shape_z = [32 for _ in range(rank_y)] + shape_z_rank_mismatch = [32 for _ in range(rank_y + 1)] + shape_z_dim_mismatch = [64 for _ in range(rank_y)] + + # Triton kernel + @triton.jit + def kernel(Z, X, SIZE: tl.constexpr): + m = tl.arange(0, SIZE) + n = tl.arange(0, SIZE) + x = tl.load(X_PTR_EXPR) + z = GENERATE_TEST_HERE + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + 'GENERATE_TEST_HERE': expr, + } + return patch_kernel(kernel, to_replace) + + kernel_match = generate_kernel(shape_x, shape_z) + kernel_dim_mismatch = generate_kernel(shape_x, shape_z_dim_mismatch) + kernel_rank_mismatch = generate_kernel(shape_x, shape_z_rank_mismatch) + + # torch result + x = numpy_random(shape_x, dtype_str=dtype_str) + y = np.zeros(shape_z, dtype=getattr(np, dtype_str)) + z_ref = eval(expr) + y + # triton result + z_tri = to_triton(np.empty_like(z_ref), device=device) + x_tri = to_triton(x, device=device) + kernel_match[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0]) + # compare + assert (z_ref == to_numpy(z_tri)).all() + + def catch_compilation_error(kernel): + try: + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0], num_ctas=num_ctas) + except triton.CompilationError as e: + np.testing.assert_(True) + except BaseException: + np.testing.assert_(False) + + catch_compilation_error(kernel_dim_mismatch) + catch_compilation_error(kernel_rank_mismatch) + + +# --------------- +# test tuples +# --------------- + + +@triton.jit +def tuples_fn(a, b): + return a + b, \ + a - b, \ + a * b + + +@pytest.mark.interpreter +def test_tuples(device): + + @triton.jit + def with_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = tuples_fn(x, y) + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + @triton.jit + def without_fn(X, Y, A, B, C): + x = tl.load(X) + y = tl.load(Y) + a, b, c = x + y, x - y, x * y + tl.store(A, a) + tl.store(B, b) + tl.store(C, c) + + x = torch.tensor([1.3], device=device, dtype=torch.float32) + y = torch.tensor([1.9], device=device, dtype=torch.float32) + a_tri = torch.tensor([0], device=device, dtype=torch.float32) + b_tri = torch.tensor([0], device=device, dtype=torch.float32) + c_tri = torch.tensor([0], device=device, dtype=torch.float32) + for kernel in [with_fn, without_fn]: + kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1) + a_ref, b_ref, c_ref = x + y, x - y, x * y + assert a_tri == a_ref + assert b_tri == b_ref + assert c_tri == c_ref + + +@triton.jit(noinline=True) +def noinline_simple_fn(x, y, Z): + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_graph_fn1(x): + return x + 1 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn2(y): + return y + 2 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn(x, y, Z): + t0 = noinline_call_graph_fn1(x) + t1 = noinline_call_graph_fn2(y) + z = t0 + t1 + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_shared_fn(x, y, Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + x + y + tl.store(Z + offs, z) + + +@triton.jit(noinline=True) +def noinline_dynamic_fn(x, y, Z): + if x >= 1: + x = noinline_call_graph_fn1(x) + else: + x = noinline_call_graph_fn2(x) + if y >= 2: + y = noinline_call_graph_fn2(y) + else: + y = noinline_call_graph_fn1(y) + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_multi_values_fn(x, y): + return x + 1, y + 2 + + +@triton.jit(noinline=True) +def noinline_multi_values_fn(x, y, Z): + x, y = noinline_call_multi_values_fn(x, y) + z = x + y + tl.store(Z, z) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) +def test_noinline(mode, device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + GENERATE_TEST_HERE(x, y, Z) + + func_name = f'noinline_{mode}_fn' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name}) + x = torch.tensor([1.0], device=device, dtype=torch.float32) + y = torch.tensor([2.0], device=device, dtype=torch.float32) + if mode == "shared": + z = torch.ones((16, 16), device=device, dtype=torch.float32) + else: + z = torch.tensor([0.0], device=device, dtype=torch.float32) + kernel[(1, )](x, y, z, num_warps=1) + if mode == "simple": + assert torch.equal(z, x + y) + elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": + assert torch.equal(z, x + 1 + y + 2) + elif mode == "shared": + ref = torch.full((16, 16), 16, device=device, dtype=torch.float32) + assert torch.equal(z, ref + x + y) + + +# --------------- +# test atomics +# --------------- +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_x_str, mode, sem", + itertools.chain.from_iterable([[ + ('add', 'float16', mode, sem), + ('add', 'uint32', mode, sem), + ('add', 'int32', mode, sem), + ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), + ('add', 'int64', mode, sem), + ('add', 'float64', mode, sem), + ('max', 'uint32', mode, sem), + ('max', 'int32', mode, sem), + ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), + ('max', 'int64', mode, sem), + ('max', 'float64', mode, sem), + ('min', 'uint32', mode, sem), + ('min', 'int32', mode, sem), + ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), + ('min', 'int64', mode, sem), + ('min', 'float64', mode, sem), + ] + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] + for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) +def test_atomic_rmw(op, dtype_x_str, mode, sem, device): + if is_interpreter(): + if dtype_x_str == 'float16': + pytest.skip("Only test atomic float16 ops on GPU") + + n_programs = 5 + + # triton kernel + @triton.jit + def kernel(X, Z): + pid = tl.program_id(0) + x = tl.load(X + pid) + old = GENERATE_TEST_HERE + tl.static_assert(old.dtype == x.dtype) + + sem_arg = sem if sem is None else f'"{sem}"' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'}) + numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op] + max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min + min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max + neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] + + # triton result + rs = RandomState(17) + x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str)) + if mode == 'all_neg': + x = -np.abs(x) + if mode == 'all_pos': + x = np.abs(x) + if mode == 'min_neg': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = -np.max(np.abs(x)) - 1 + if mode == 'max_pos': + idx = rs.randint(n_programs, size=(1, )).item() + x[idx] = np.max(np.abs(x)) + 1 + x_tri = to_triton(x, device=device) + + z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device) + h = kernel[(n_programs, )](x_tri, z_tri) + # torch result + z_ref = numpy_op(x).astype(getattr(np, dtype_x_str)) + # compare + exact = op not in ['add'] + if exact: + assert z_ref.item() == to_numpy(z_tri).item() + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + sem_str = "acq_rel" if sem is None else sem + if not is_cuda(): + return + + assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_rmw_predicate(num_ctas, device): + + @triton.jit + def kernel(X): + val = tl.program_id(0) + if val < 64: + tl.atomic_max(X, val) + + x = torch.zeros((1, ), device=device, dtype=torch.int32) + kernel[(4096, )](x, num_ctas=num_ctas) + assert x.item() == 63 + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list]) +def test_tensor_atomic_rmw(shape, axis, num_ctas, device): + shape0, shape1 = shape + # triton kernel + + @triton.jit + def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + z = tl.sum(x, axis=AXIS) + if AXIS == 1: + tl.atomic_add(Z + off0, z) + else: + tl.atomic_add(Z + off1, z) + + rs = RandomState(17) + x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) + # reference result + z_ref = np.sum(x, axis=axis, keepdims=False) + # triton result + x_tri = to_triton(x, device=device) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) + kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_rmw_block(num_ctas, device): + shape = (8, 8) + + @triton.jit + def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + off0 = tl.arange(0, SHAPE0) + off1 = tl.arange(0, SHAPE1) + offs = off0[:, None] * SHAPE1 + off1[None, :] + val = offs.to(tl.float32) + x = X + offs + tl.atomic_min(x, val) + + x = torch.ones((8, 8), device=device, dtype=torch.float32) + kernel[(2, )](x, shape[0], shape[1], num_ctas=num_ctas) + assert torch.min(x).item() == 0.0 + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_atomic_cas(sem, num_ctas, device): + # 1. make sure that atomic_cas changes the original value (Lock) + @triton.jit + def change_value(Lock): + tl.atomic_cas(Lock, 0, 1) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + change_value[(1, )](Lock) + + assert (Lock[0] == 1) + + # 2. only one block enters the critical section + @triton.jit + def serialized_add(data, Lock, SEM: tl.constexpr): + ptrs = data + tl.arange(0, 128) + while tl.atomic_cas(Lock, 0, 1, SEM) == 1: + pass + + tl.store(ptrs, tl.load(ptrs) + 1.0) + + # release lock + tl.atomic_xchg(Lock, 0) + + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + data = torch.zeros((128, ), device=device, dtype=torch.float32) + ref = torch.full((128, ), 2000.0) + h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas) + sem_str = "acq_rel" if sem is None else sem + np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) + if not is_cuda(): + return + assert f"atom.global.{sem_str}" in h.asm["ptx"] + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_tensor_atomic_cas(sem, num_ctas, device): + + @triton.jit + def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64) + t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64) + tl.atomic_cas(X + offsets, t1, t2, sem=sem) + + X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64) + Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64) + + change_value[(2, )](X, 4, sem) + assert (torch.equal(X, Y)) + + +# --------------- +# test cast +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", + [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'int1', False, 1024), + ('int8', 'bfloat16', False, 1024), + ] + [(f'uint{x}', f'int{x}', True, 1024) + for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) + for x in [8, 16, 32, 64]] + + (([(dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32]] # + + [(dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32]]) if torch.__version__ >= "2.1" else [])) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): + if (dtype_x == 'float64' or dtype_x == 'int8' and dtype_z == 'bfloat16'): + pytest.skip("Skip for kunlunxin") + # CUDA: bfloat16 on cc < 80 will not be tested + # Interpreter: Only bfloat16 <-> float32 is supported + if not is_interpreter() or \ + (is_interpreter() and not ((dtype_z == 'bfloat16' and dtype_x == 'float32') + or (dtype_z == 'float32' and dtype_x == 'bfloat16'))): + check_type_supported(dtype_x, device) + check_type_supported(dtype_z, device) + + if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + + torch.manual_seed(0) + # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. + if dtype_x.startswith('bfloat'): + x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) + elif dtype_x.startswith('float8'): + x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x)) + else: + x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10 + # Triton clamps negative values to zero, while numpy wraps around + # intmax, so avoid negatives for now. + # TODO: figure out which one should actually be happening, and test it + if dtype_z in uint_dtypes: + x = np.absolute(x) + x_tri = to_triton(x, device=device) + if 'float' in dtype_z and 'float' in dtype_x: + # make sure we use values that can be represented in both types + x_tri = x_tri.to(getattr(torch, dtype_z)).to(getattr(torch, dtype_x)) + # triton kernel + + @triton.jit + def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): + x_ptr = X + tl.arange(0, SIZE) + z_ptr = Z + tl.arange(0, SIZE) + x = tl.load(x_ptr) + + # Depending on the value of ARG_HASH (a "random" number determined by + # the test parameters), spell the cast one of three different ways. + if ARG_HASH % 3 == 0: + z = x.to(Z.dtype.element_ty, bitcast=BITCAST) + elif ARG_HASH % 3 == 1: + z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) + else: + z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) + + tl.store(z_ptr, z) + + # "Random" number used inside the kernel to determine how we spell the cast. + # This way we don't have to increase the number of tests. + arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) + + dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' + # triton result + if dtype_z.startswith('bfloat'): + z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) + elif dtype_z.startswith('float8'): + z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) + else: + z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) + kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, num_ctas=num_ctas) + # torch result + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( + 'float8') or dtype_x.startswith('float8'): + assert bitcast is False + z_ref = x_tri.to(z_tri.dtype) + if dtype_z.startswith('float8') and device not in ['cuda']: + t = z_ref.byte() ^ z_tri.byte() + torch.testing.assert_close(torch.zeros_like(t, dtype=torch.uint8), t) + else: + torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) + else: + if bitcast: + z_ref = x.view(getattr(np, dtype_z_np)) + else: + z_ref = x.astype(getattr(np, dtype_z_np)) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, num_warps", + [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) +def test_cat(dtype_str, num_warps, device): + check_type_supported(dtype_str, device) + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.cat(x, y, can_reorder=True) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) + y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) + z_ref = torch.cat([x, y], dim=0).sum() + z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](x, y, z, N=128, num_warps=num_warps) + assert z.sum() == z_ref + # check if there's no duplicate value in z + assert z.unique().size(0) == z.size(0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", [dtype for dtype in torch_dtypes if dtype != 'int16']) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_store_constant(dtype_str, num_ctas, device): + check_type_supported(dtype_str, device) + """Tests that boolean True is stored as 1""" + + @triton.jit + def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + output = GENERATE_TEST_HERE + tl.store(output_ptr + offsets, output, mask=mask) + + triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'}) + block_size = 128 + ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) + output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + + assert torch.all(output == ref) + + +@pytest.mark.skip("Skip for kunlunxin") +def test_load_store_same_ptr(device): + + @triton.jit() + def kernel(in_out_ptr): + pid = tl.program_id(axis=0) + x = tl.load(in_out_ptr + pid) + out = x * 2 + tl.store(in_out_ptr + pid, out) + + for _ in range(1000): + x = torch.ones((65536, ), device=device, dtype=torch.float32) + if is_hip(): + kernel[(65536, )](x, num_warps=16) # threads per Warp for ROCM is 64 + else: + kernel[(65536, )](x, num_warps=32) + assert torch.all(x == 2) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['int32']) +def test_umulhi(dtype_str, device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.umulhi(x, y) + tl.store(Z + tl.arange(0, N), z) + + def umulhi32(a, b): + # Convert to 64-bit unsigned integers to prevent overflow + a_64 = a.astype(np.int64) + b_64 = b.astype(np.int64) + + # Perform the multiplication in 64-bit + product_64 = a_64 * b_64 + + # Shift right by 32 bits to get the high part of the product + result_high_32 = product_64 >> 32 + return result_high_32 + + rs = RandomState(17) + N = 128 + x = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + x_tri = to_triton(x, device=device) + y = numpy_random((N, ), dtype_str=dtype_str, rs=rs, low=0) + y_tri = to_triton(y, device=device) + z_tri = torch.zeros_like(x_tri) + kernel[(1, )](x_tri, y_tri, z_tri, N=N) + + z_ref = umulhi32(x, y) + np.testing.assert_equal(z_ref, to_numpy(z_tri)) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +def test_join(device): + + @triton.jit + def kernel(X, Y, Z, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + y = tl.load(Y + offs) + z = tl.join(x, y) + tl.store(Z + tl.arange(0, N)[:, None] * 2 + tl.arange(0, 2)[None, :], z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(-128, 0, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, y, z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +def test_join_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + z = tl.join(x, y) + tl.static_assert(z.shape == [2]) + tl.store(Z + tl.arange(0, 2), z) + + x = torch.full([1], 42, device=device).to(torch.int32) + y = torch.full([1], 100, device=device).to(torch.int32) + z = torch.zeros([2], device=device) + kernel[(1, )](x, y, z) + + np.testing.assert_equal([42, 100], to_numpy(z)) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +def test_join_with_mma(device): + + @triton.jit + def kernel(X, Z): + x = tl.load(X + 16 * tl.arange(0, 32)[:, None] + tl.arange(0, 16)[None, :]) # (32,16) + x2 = tl.join(x, 2 * x) # (32,16,2) + x3 = tl.reshape(x2, (32, 32)) + z = tl.dot(x3, x3) # (32,32) + tl.store(Z + 32 * tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :], z) + + x = torch.arange(0, 32 * 16, device=device, dtype=torch.float32).reshape((32, 16)) + r = torch.stack([x, 2 * x], dim=-1).reshape((32, 32)) + z_ref = torch.matmul(r, r) + z = torch.zeros_like(z_ref) + kernel[(1, )](x, z) + + torch.testing.assert_close(z, z_ref) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("debug", [False, True]) +def test_interleave(device, debug): + + @triton.jit(debug=debug) + def kernel(Z, N: tl.constexpr): + z = tl.interleave(tl.arange(0, N), tl.arange(N, 2 * N)) + tl.store(Z + tl.arange(0, 2 * N), z) + + x = torch.arange(0, 128, device=device).to(torch.int32) + y = torch.arange(128, 256, device=device).to(torch.int32) + z_ref = torch.stack([x, y], dim=-1).reshape(256) + z = torch.zeros_like(z_ref) + kernel[(1, )](z, N=128) + + np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +def test_interleave_scalars(device): + + @triton.jit + def kernel(X, Y, Z): + z = tl.interleave(X, Y) + tl.static_assert(z.shape == [tl.constexpr(2)]) + tl.store(Z + tl.arange(0, 2), z) + + z = torch.zeros(2, device=device) + kernel[(1, )](10, 20, z) + + np.testing.assert_equal([10, 20], to_numpy(z)) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +def test_split(device): + + @triton.jit + def kernel(X, Z1, Z2, N: tl.constexpr): + offs = tl.arange(0, N) + x = tl.load(X + offs) + x1 = tl.reshape(x, (N // 2, 2)) + z1, z2 = tl.split(x1) + tl.store(Z1 + tl.arange(0, N // 2), z1) + tl.store(Z2 + tl.arange(0, N // 2), z2) + + x = torch.arange(0, 256, device=device).to(torch.int32).reshape((128, 2)) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2, N=256) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +def test_split_to_scalar(device): + + @triton.jit + def kernel(X, Z1, Z2): + offs = tl.arange(0, 2) + x = tl.load(X + offs) + z1, z2 = tl.split(x) + tl.static_assert(isinstance(z1, tl.tensor)) + tl.static_assert(isinstance(z2, tl.tensor)) + tl.static_assert(z1.shape == []) + tl.static_assert(z2.shape == []) + tl.store(Z1, z1) + tl.store(Z2, z2) + + N = 2 + x = torch.arange(0, N, device=device).reshape(N // 2, 2) + z1_ref, z2_ref = (x[:, 0], x[:, 1]) + z1 = torch.zeros_like(z1_ref) + z2 = torch.zeros_like(z2_ref) + kernel[(1, )](x, z1, z2) + + np.testing.assert_equal(to_numpy(z1_ref), to_numpy(z1)) + np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) + + +def convert_float_to_float32(fp: torch.tensor, dtype=None): + if not dtype: + dtype = getattr(tl, torch_dtype_name(fp.dtype)) + + fp = fp.view(getattr(torch, f"int{dtype.primitive_bitwidth}")) + exp_width = dtype.primitive_bitwidth - dtype.fp_mantissa_width - 1 + exp_bias = dtype.exponent_bias + sign = ((fp >> (dtype.primitive_bitwidth - 1)) & 0x01).int() + exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int() + frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int() + + output = torch.where( + exp == 0, + # subnormal + ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (frac / (2.0**dtype.fp_mantissa_width)), + # normal + ((-1.0)**sign) * (2.0**(exp - exp_bias)) * (1.0 + frac / (2.0**dtype.fp_mantissa_width))).float() + + extended_exp = ( + (1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width + # special cases, exp is 0b11..1 + if dtype in [tl.float8e4nv, tl.float8e4b15]: + # float8e4m3nv does not have infinities + output[fp == 0b01111111] = torch.nan + output[fp == 0b11111111] = torch.nan + else: + output = torch.where(exp == (1 << exp_width) - 1, + ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp + | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))) # + .view(torch.float32), output) + return output + + +@pytest.mark.interpreter +@pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) +def test_convert_float16_to_float32(in_dtype, device): + """Tests that check convert_float_to_float32 function""" + check_type_supported(in_dtype, device) + + f16_input = torch.tensor(range(-int(2**(16 - 1)), int(2**(16 - 1))), dtype=torch.int16).view(in_dtype) + f32_output = convert_float_to_float32(f16_input) + + nan = f16_input.isnan() + assert torch.all(f32_output[nan].isnan()) + inf = f16_input.isinf() + assert torch.all(f32_output[inf].isinf()) + other = torch.logical_not(torch.logical_or(nan, inf)) + assert torch.all(f16_input[other] == f32_output[other]) + + +def serialize_fp8(np_data, in_dtype): + return np_data + + +# inverse of `serialize_fp8` + + +def deserialize_fp8(np_data, in_dtype): + return np_data + + +# --------------- +# test reduce +# --------------- + + +@pytest.mark.interpreter +def test_max_returns_zero(device): + # Simple test with a tl.max call that returns 0. The interpreter had a bug + # where it didn't handle this correctly. + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + z = tl.max(x) + tl.store(Z, z) + + BLOCK = 128 + x = torch.zeros((BLOCK, ), device=device) + z = torch.ones((1, ), device=device) + + kernel[(1, )](x, z, BLOCK=BLOCK) + assert z[0] == 0 + + +def get_reduced_dtype(dtype_str, op): + if op in ('argmin', 'argmax'): + return 'int32' + if dtype_str == 'bfloat16': + return 'float32' + return dtype_str + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ + 'min', + 'max', + 'min-with-indices', + 'max-with-indices', + 'argmin-tie-break-left', + 'argmax-tie-break-left', + 'sum', +] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce1d(op, dtype_str, shape, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + # triton kernel + @triton.jit + def kernel(X, Z, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + GENERATE_TEST_HERE + tl.store(Z, z) + + if 'with-indices' in op: + patch = f'z, _ = tl.{op.split("-")[0]}(x, axis=0, return_indices=True)' + elif 'arg' in op: + tie_break_left = 'tie-break-left' in op + patch = f'z = tl.{op.split("-")[0]}(x, axis=0, tie_break_left={tie_break_left})' + else: + patch = f'z = tl.{op}(x, axis=0)' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs) + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'max-with-indices': np.max, + 'min-with-indices': np.min, + 'argmin-tie-break-fast': np.argmin, + 'argmin-tie-break-left': np.argmin, + 'argmax-tie-break-fast': np.argmax, + 'argmax-tie-break-left': np.argmax, + }[op] + if 'tie-break-left' in op: + x[3:10] = numpy_op(x) + x_tri = to_triton(x, device=device) + # numpy result + z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str + z_tri_dtype_str = z_dtype_str + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + z_tri_dtype_str = 'bfloat16' + else: + z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) + # triton result + z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) + z_tri = to_numpy(z_tri) + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + np.testing.assert_equal(x[z_ref], x[z_tri]) + else: + np.testing.assert_equal(z_ref, z_tri) + + +# TODO: [Qingyi] Fix argmin / argmax +reduce_configs1 = [(op, dtype, (1, 1024), axis, False) + for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [1]] + +# shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory +# exceeds the limit of 99KB +reduce2d_shapes = [(2, 32), (4, 32), (4, 128)] +# TODO: fix and uncomment +# , (32, 64), (64, 128)] +if is_cuda() and 'V100' in torch.cuda.get_device_name(0): + reduce2d_shapes += [(128, 256) and (32, 1024)] + +reduce_configs2 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce2d_shapes + for axis in [0, 1]] + [(op, 'float32', [16, 32], None, False) for op in ['min', 'max', 'sum']] + +reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] +reduce_configs3 = [(op, 'float32', shape, axis, False) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2]] +invalid_config = [('sum', 'float32', (32, 32), axis, False) for axis in [2, 3]] +negative_config = [('sum', 'float32', (32, 32), -1, False)] +keep_dims_2d_configs = [(op, 'float32', (32, 32), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1]] + [(op, 'float32', (32, 32), None, True) for op in ['min', 'max', 'sum']] +keep_dims_3d_configs = [(op, 'float32', (32, 2, 16), axis, True) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [0, 1, 2]] + [(op, 'float32', (32, 2, 16), None, True) + for op in ['min', 'max', 'sum']] + + +@pytest.mark.interpreter +@pytest.mark.parametrize( + "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + + negative_config + keep_dims_2d_configs + keep_dims_3d_configs) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): + if (dtype_str == 'float64' and op in ['min', 'max']): + pytest.skip("Skip for kunlunxin") + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, + AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + range_k = tl.arange(0, BLOCK_K) + if IS_3D: + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + + range_k[None, None, :]) + else: + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + z = GENERATE_TEST_HERE + + z_ptr = Z + if KEEP_DIMS and AXIS is None: + if IS_3D: + z_ptr = z_ptr[None, None, None, :] + else: + z_ptr = z_ptr[None, None, :] + if IS_3D: + if AXIS == 0: + z_ptr = Z + range_n[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 1 or AXIS == -2: + z_ptr = Z + range_m[:, None] * BLOCK_K + range_k[None, :] + elif AXIS == 2 or AXIS == -1: + z_ptr = Z + range_m[:, None] * BLOCK_N + range_n[None, :] + else: + if AXIS == 0: + z_ptr = Z + range_n + elif AXIS == 1 or AXIS == -1: + z_ptr = Z + range_m + if KEEP_DIMS and AXIS is not None: + z_ptr = tl.expand_dims(z_ptr, axis=AXIS) + tl.store(z_ptr, z) + + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'}) + # input + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_tri = to_triton(x, device=device) + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax}[op] + z_dtype_str = get_reduced_dtype(dtype_str, op) + z_tri_dtype_str = z_dtype_str + + # numpy result + # Silence numpy error on axis out of bounds, to give triton a chance to fail + np_axis = axis if axis is not None and axis < len(shape) else None + if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16': + z_dtype_str = 'float32' + z_tri_dtype_str = 'bfloat16' + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + # trunc mantissa for a fair comparison of accuracy + z_ref = (z_ref.view('uint32') & np.uint32(0xffff0000)).view('float32') + else: + z_ref = numpy_op(x, axis=np_axis, keepdims=keep_dims).astype(getattr(np, z_dtype_str)) + + # triton result + z_shape = z_ref.shape + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + BLOCK_K = 1 if len(shape) == 2 else shape[2] + IS_3D = bool(len(shape) == 3) + if (axis == 0) or IS_3D: + pytest.skip("Skip for kunlunxin") + if axis is not None and axis >= len(shape): + with pytest.raises(triton.TritonError): + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, num_ctas=num_ctas) + return + else: + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + KEEP_DIMS=keep_dims, num_ctas=num_ctas) + + z_tri = to_numpy(z_tri) + + # compare + if op == 'sum': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + if op in ('argmin', 'argmax'): + # argmin and argmax can have multiple valid indices. + # so instead we compare the values pointed by indices + z_ref_index = z_ref + z_tri_index = z_tri + if not keep_dims: + z_ref_index = np.expand_dims(z_ref, axis=axis) + z_tri_index = np.expand_dims(z_tri, axis=axis) + z_ref_value = np.take_along_axis(x, z_ref_index, axis=axis) + z_tri_value = np.take_along_axis(x, z_tri_index, axis=axis) + np.testing.assert_equal(z_ref_value, z_tri_value) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)] + +scan_configs = [(op, type, shape, axis, reverse, num_warps) + for num_warps in [4, 16] + for type in ['int32', 'float32', 'bfloat16'] + for axis in [1, 0] + for reverse in [True, False] + for shape in scan2d_shapes + for op in ['cumsum', 'cumprod', 'get_first_element', 'linear_recurrence', 'cummax', 'roll']] +negative_config = [('cumsum', 'float32', (32, 32), -1, False, 4)] + + +@triton.jit +# trivial associative but not commutative function +def get_first_element(a, b): + return a + + +# Compute x_i = a_i * x_{i-1} + b_i +@triton.jit +def linear_recurrence(a1, b1, a2, b2): + return a1 * a2, b1 * a2 + b2 + + +@triton.jit +def cummax(v0, i0, v1, i1): + gt = v0 > v1 + return tl.where(gt, v0, v1), tl.where(gt, i0, i1) + + +@triton.jit +def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): + return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) +def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): + check_type_supported(dtype_str, device) + if dtype_str == 'bfloat16': + if op == 'cummax': + pytest.skip("bfloat16 compare not suppoted before sm90") + if op == 'linear_recurrence': + pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") + numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + + # triton kernel + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): + range_m = tl.arange(0, BLOCK_M) + range_n = tl.arange(0, BLOCK_N) + x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) + y = tl.load(Y + range_m[:, None] * BLOCK_N + range_n[None, :]) + GENERATE_TEST_HERE + tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) + + if op == 'cumsum' or op == 'cumprod': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'z = tl.{op}(x, axis={axis}, reverse={reverse})'}) + elif op == 'get_first_element': + kernel = patch_kernel( + kernel, + {'GENERATE_TEST_HERE': f'z = tl.associative_scan(x, axis={axis}, combine_fn={op}, reverse={reverse})'}) + elif op == 'cummax': + rg = "range_m[:, None]" if axis == 0 else "range_n[None, :]" + rg = f"tl.broadcast_to({rg}.to(tl.int64), [BLOCK_M, BLOCK_N])" + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, {rg}), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + elif op == 'roll': + assert op == 'roll' + kernel = patch_kernel( + kernel, { + 'GENERATE_TEST_HERE': + f'_, z, _ = tl.associative_scan((1 + 0* x, 0 * x, x), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + else: + assert op == 'linear_recurrence' + kernel = patch_kernel(kernel, { + 'GENERATE_TEST_HERE': + f'_, z = tl.associative_scan((x, y), axis={axis}, combine_fn={op}, reverse={reverse})' + }) + # input + rs = RandomState(17) + if op == 'linear_recurrence' and dtype_str in int_dtypes: + # If the numbers are too large the op will overflow + # We sample numbers in -1, 0, 1 + x = rs.randint(-1, 2, shape, dtype=dtype_str) + y = rs.randint(-1, 2, shape, dtype=dtype_str) + else: + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + # y is just used in linear_recurrence + y = numpy_random(shape, dtype_str=dtype_str, rs=rs) + x_in = x + if reverse: + x_in = np.flip(x, axis) + z = np.empty_like(x) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + y_tri = to_triton(y, device=device, dst_type=dtype_str) + if op == 'cumsum' or op == 'cumprod': + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] + z_ref = numpy_op(x_in, axis=axis).astype(getattr(np, numpy_dtype_str)) + if reverse: + z_ref = np.flip(z_ref, axis) + + elif op == 'cummax': + # NumPy does not have cummax + z = z.astype(np.int64) + z_ref = torch.cummax(torch.from_numpy(x_in.copy()), axis=axis).indices.numpy() + if reverse: + z_ref = x_in.shape[axis] - np.flip(z_ref, axis) - 1 + elif op == 'roll': + ROLL = 1 + z_ref = np.roll(x_in.copy(), ROLL, axis=axis) + if axis == 0: + z_ref[:ROLL] = 0 + else: + z_ref[:, :ROLL] = 0 + + if reverse: + z_ref = np.flip(z_ref, axis) + elif op == 'linear_recurrence': + # Simplify to the axis=1 case + x_ref = x.T if axis == 0 else x + y_ref = y.T if axis == 0 else y + if reverse: + x_ref = np.flip(x_ref, 1) + y_ref = np.flip(y_ref, 1) + + result = [] + for x_refi, y_refi in zip(x_ref, y_ref): + li = [] + acc = 0 + for xi, yi in zip(x_refi, y_refi): + acc = xi * acc + yi + li.append(acc) + result.append(li) + z_ref = np.array(result) + if reverse: + z_ref = np.flip(z_ref, 1) + + if axis == 0: + z_ref = z_ref.T + else: + assert op == 'get_first_element' + z_ref = x + if axis == 0: + if reverse: + z_ref[:-1] = x[-1] + else: + z_ref[1:] = x[0] + else: + if reverse: + z_ref[:, :-1] = x[:, -1:] + else: + z_ref[:, 1:] = x[:, 0:1] + + # triton result + # we don't cast the `fp32 = bf16 op bf16` result to bfloat16 to alleviate accuracy issues + z_tri = to_triton(z, device=device) + kernel[(1, )](x_tri, y_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) + + z_tri = to_numpy(z_tri) + # compare + if dtype_str not in int_dtypes: + if op == 'cumprod': + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3) + else: + np.testing.assert_allclose(z_ref, z_tri, rtol=0.01) + else: + np.testing.assert_equal(z_ref, z_tri) + + +scan_layouts = [ + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), +] + +# --------------- +# test histogram +# --------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) +def test_histogram(M, N, device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): + offset1 = tl.arange(0, M) + offset2 = tl.arange(0, N) + x = tl.load(x_ptr + offset1) + z = tl.histogram(x, N) + tl.store(z_ptr + offset2, z) + + torch.manual_seed(17) + x = torch.randint(0, N, (M, ), device=device, dtype=torch.int32) + z = torch.empty(N, dtype=torch.int32, device=device) + # torch.histc does not work when the input type is not float and the device is CPU + # https://github.com/pytorch/pytorch/issues/74236 + # This is a workload by converting the input to float + z_torch = torch.histc(x.float(), bins=N, min=0, max=N - 1) + histogram_kernel[(1, )](x, z, M=M, N=N) + assert (z_torch == z).all() + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['sum', 'max', 'min']) +@pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) +@pytest.mark.parametrize("N", [512, 1024, 2048]) +@pytest.mark.parametrize("num_pid_n", [2, 4]) +def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device): + + @triton.jit + def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.constexpr): + start_m = tl.program_id(0) + pid_n = tl.program_id(1) + local = INITIALIZE_PATCH + off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), NUM_PID_N): + off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * N + off_n[None, :] + x = tl.load(Xs) + local = ACCUMULATE_PATCH + tl.store(Y + off_m * NUM_PID_N + pid_n, local) + # the following segfaults AMD backend following #3492 + # really unclear why; the llvm-ir and kernel arguments are + # identical ! + # tl.store(Y + off_m * tl.num_programs(1) + pid_n, local) + + initialize_patch = { + 'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)', + 'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)', + 'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)', + }[op] + reduce_patch = { + 'sum': 'local + tl.sum(x, axis=1)', + 'max': 'tl.maximum(local, tl.max(x, axis=1))', + 'min': 'tl.minimum(local, tl.min(x, axis=1))', + }[op] + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + }[op] + kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch}) + torch.manual_seed(0) + BLOCK_M = 32 + x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device) + y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device) + h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n) + if not is_interpreter(): + assert h.asm['ttgir'].count( + '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" + y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) + y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True) + np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) +@pytest.mark.parametrize("src_layout", scan_layouts) +@pytest.mark.parametrize("axis", [0, 1]) +def test_scan_layouts(M, N, src_layout, axis, device): + + ir = f""" + #blocked = {src_layout} + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> + %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %7 = tt.broadcast %4 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %8 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #blocked> + %11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{ + ^bb0(%arg2: i32, %arg3: i32): + %16 = arith.addi %arg2, %arg3 : i32 + tt.scan.return %16 : i32 + }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #blocked> + %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %14 = tt.broadcast %13 : tensor<{M}x1x!tt.ptr, #blocked> -> tensor<{M}x{N}x!tt.ptr, #blocked> + %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + tt.store %15, %11 : tensor<{M}x{N}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + rs = RandomState(17) + x = rs.randint(-100, 100, (M, N)).astype('int32') + + z = np.zeros((M, N)).astype('int32') + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + kernel[(1, 1, 1)](x_tri, z_tri) + + z_ref = np.cumsum(x, axis=axis) + + np.testing.assert_equal(z_ref, z_tri.cpu().numpy()) + + +layouts = [ + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [THREADS_PER_WARP // 16, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 2], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 16, 16]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 32, 16]), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=False), + MfmaLayout(version=(2, 0), warps_per_cta=[2, 2], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[4, 1], instr_shape=[32, 32], is_transposed=True), + MfmaLayout(version=(2, 0), warps_per_cta=[1, 4], instr_shape=[32, 32], is_transposed=True), + WmmaLayout(warps_per_cta=[2, 2]), + WmmaLayout(warps_per_cta=[4, 1]), + WmmaLayout(warps_per_cta=[1, 4]), +] + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [64, 64], [32, 128], [32, 32], [16, 16]]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) +@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device): + if isinstance(src_layout, + (MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]): + pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") + if is_hip() and isinstance(src_layout, MfmaLayout) and ((M, N) == (128, 128)): + pytest.skip("Skipping test because it runs out of shared memory") + if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: + pytest.skip("Skipping sum reduction on float16 due to accuracy issues") + if epilogue_kind == 'expand_reduce2d' and isinstance(src_layout, MmaLayout): + pytest.skip( + "Currently MmaLayout combined with slice encoding and reduce op trigger device illegal memory access") + + if isinstance(src_layout, MmaLayout) and src_layout.version == 3: + src_layout[2] = 16 if dtype_str == "float16" else 8 + + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] + arith_op = { + "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} + }[reduce_op][dtype_str] + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] + rdims_1d = f"{N}" if axis == 0 else f"{M}" + rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" + store_range = "%7" if axis == 0 else "%1" + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) + num_warps = src_layout.warps_per_cta[0] * src_layout.warps_per_cta[1] + if num_warps == 8: + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, 2], [0, 1], [1, 1], [1, 1], [0, 1]) + one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [4], [0], [1], [1], [0]) + + expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1" + other_axis = 1 - axis + epilogue = { + "reduce1d": + f""" + %14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> + %16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> + %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked> + tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> + tt.return + }} + }} + """, "reduce2d": + f""" + %14 = "tt.reduce"(%13) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} + tt.store %arg2, %14 : !tt.ptr<{ty}> + tt.return + }} + }} + """, "expand_reduce2d": + f""" + %14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src> + %15 = "tt.reduce"(%14) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>) + %16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout> + %17 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.store %17, %16 : tensor<1x!tt.ptr<{ty}>, #one_d_layout> + tt.return + }} + }} + """ + }[epilogue_kind] + + ir = f""" + #blocked = {blocked} + #src = {src_layout} + #one_d_layout = {one_d_layout} + module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> + %4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> + %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> + %13 = "tt.reduce"(%12) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> + """ + epilogue + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) + reduce2d = 'reduce2d' in epilogue_kind + z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1) + z = np.zeros(z_shape).astype(dtype_str) + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) + z_ref = numpy_op(x) if reduce2d else numpy_op(x, axis=axis, keepdims=True) + + if dtype_str == 'float16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("M", [32, 64, 128, 256]) +@pytest.mark.parametrize("src_layout", layouts) +def test_store_op(M, src_layout, device): + + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 1 : i32}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xf32, #src> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %6 = tt.expand_dims %5 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> + tt.store %8, %4 : tensor<{M}x1x!tt.ptr, #src> + tt.return + }} + }} + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + store_kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, 1)).astype('float32') + y = np.zeros((M, 1), dtype='float32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + + pgm = store_kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +layouts = [ + # TODO (lixun): Add MfmaLayout + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) +] + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("M", [64, 128, 256]) +@pytest.mark.parametrize("src_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("dst_layout", filter_layouts(layouts)) +@pytest.mark.parametrize("src_dim", [0, 1]) +@pytest.mark.parametrize("dst_dim", [0, 1]) +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): + + ir = f""" + #dst = {dst_layout} + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %3 = tt.load %2 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %7 = {GPU_DIALECT}.convert_layout %3 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.store %6, %7 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.return + }} + }} + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, )).astype('int32') + y = np.zeros((M, ), dtype='int32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + pgm = kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@triton.jit +def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = weight_2 / new_weight + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +layouts = [ + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + # [HIP] TO DO: some tests are flaky with the layout, so turn off them for now. + # BlockedLayout([1, 4], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]) +] + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]]) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("op", ["sum", "max"]) +@pytest.mark.parametrize("first_axis", [0, 1]) +def test_chain_reduce(M, N, src_layout, op, device, first_axis): + + op_str = "" + if op == "sum": + op_str = """ + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32""" + elif op == "max": + op_str = """ + %13 = arith.cmpi "sgt", %arg2, %arg3 : i32 + %14 = arith.select %13, %arg2, %arg3 : i32 + tt.reduce.return %14 : i32""" + ir = f""" + #src = {src_layout} + module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> + %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %5 = tt.broadcast %2 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %6 = tt.broadcast %4 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr, #src> + %11 = "tt.reduce"(%10) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>> + %12 = "tt.reduce"(%11) ({{ + ^bb0(%arg2: i32, %arg3: i32): + {op_str} + }}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32 + tt.store %arg1, %12 : !tt.ptr + tt.return + }} + }} + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, N)).astype('int32') + + z = np.zeros((1, )).astype('int32') + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, z_tri) + if op == "sum": + z_ref = np.sum(x) + elif op == "max": + z_ref = np.max(x) + + np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +@pytest.mark.interpreter +def test_generic_reduction(device): + + @triton.jit + def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): + xindex = tl.arange(0, BLOCK) + x = tl.load(X + xindex) + mean = x + m2 = tl.zeros_like(x) + weight = tl.full(x.shape, 1, x.dtype) + (mean, m2, weight) = tl.reduce((mean, m2, weight), 0, _welford_combine) + tl.store(out_mean, mean) + tl.store(out_var, m2 / weight) + + SIZE = 512 + x = torch.rand(SIZE, device=device) + out_mean = torch.empty((), device=device) + out_var = torch.empty((), device=device) + + var_mean_kernel[(1, )](x, out_mean, out_var, BLOCK=SIZE) + + expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0) + torch.testing.assert_close(out_mean, expect_mean) + torch.testing.assert_close(out_var, expect_var) + + +# --------------- +# test permute +# --------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float8e4b15', 'float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_permute(dtype_str, shape, perm, num_ctas, device): + check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if is_hip() and shape == (128, 128) and dtype_str == 'float32': + pytest.skip("TODO Out of LDS for float32 with shape 128x128") + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + tl.store(Zs, tl.load(Xs)) + + # input + x = numpy_random(shape, dtype_str=dtype_str) + # triton result + z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) + x_tri = to_triton(x, device=device, dst_type=dtype_str) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), + x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), + z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + # numpy result + if dtype_str == 'float8e4b15': + ty = tl.float8e4b15 + z_ref = serialize_fp8(deserialize_fp8(x, ty).T.copy(), ty) + z_tri = z_tri.base + z_tri_contiguous = z_tri_contiguous.base + else: + z_ref = x.transpose(*perm) + # compare + np.testing.assert_allclose(to_numpy(z_tri), z_ref) + np.testing.assert_allclose(to_numpy(z_tri_contiguous), z_ref) + + if not is_cuda(): + return + + # parse ptx to make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1]))) +def test_trans_2d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: tl.constexpr, + ou_shape2: tl.constexpr, trans1: tl.constexpr, trans2: tl.constexpr): + in_offs = tl.arange(0, in_shape1)[:, None] * in_shape2 + tl.arange(0, in_shape2)[None, :] + ou_offs = tl.arange(0, ou_shape1)[:, None] * ou_shape2 + tl.arange(0, ou_shape2)[None, :] + tl.store(Out + ou_offs, tl.permute(tl.load(In + in_offs), (trans1, trans2))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["int32", "int8"]) +@pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)]) +@pytest.mark.parametrize("perm", list(itertools.permutations([0, 1, 2, 3]))) +def test_trans_4d(dtype_str, shape, perm, device): + + @triton.jit + def kernel(In, Out, # + in_shape1: tl.constexpr, in_shape2: tl.constexpr, in_shape3: tl.constexpr, in_shape4: tl.constexpr, + ou_shape1: tl.constexpr, ou_shape2: tl.constexpr, ou_shape3: tl.constexpr, ou_shape4: tl.constexpr, + trans1: tl.constexpr, trans2: tl.constexpr, trans3: tl.constexpr, trans4: tl.constexpr): + in_ptr = tl.make_block_ptr( + base=In, + shape=(in_shape1, in_shape2, in_shape3, in_shape4), + strides=(in_shape4 * in_shape3 * in_shape2, in_shape4 * in_shape3, in_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(in_shape1, in_shape2, in_shape3, in_shape4), + order=(3, 2, 1, 0), + ) + out_ptr = tl.make_block_ptr( + base=Out, + shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + strides=(ou_shape4 * ou_shape3 * ou_shape2, ou_shape4 * ou_shape3, ou_shape4, 1), + offsets=(0, 0, 0, 0), + block_shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4), + order=(3, 2, 1, 0), + ) + tl.store(out_ptr, tl.load(in_ptr).permute((trans1, trans2, trans3, trans4))) + + input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape) + expected = torch.permute(input, perm) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device) + + kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm, num_warps=8) + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# --------------- +# test dot +# --------------- + + +def convert_fp8_to_fp32(x, device, dtype_str): + if dtype_str == 'float8e4nv': + return torch.tensor(x, device=device).view(torch.float8_e4m3fn).to(torch.float32) + elif dtype_str == 'float8e5': + return torch.tensor(x, device=device).view(torch.float8_e5m2).to(torch.float32) + assert "Unsupported float8 dtype" + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize( + "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack", + [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1) + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for input_precision in ['tf32', 'tf32x3', 'ieee'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] + if not (input_precision != 'ieee' and (in_dtype in ['float16']))] + + [(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack) + for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], + [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] + for input_precision in ["ieee" if is_hip() else "tf32"] + for col_a in [True, False] + for col_b in [True, False] + for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', + 'float32')] + for kpack in [1, 2 if is_hip() else 1]] + [(64, 64, 64, 4, col_a, col_b, 'none', 'ieee', 'float32', 'float32', 1) + for col_a in [True, False] + for col_b in [True, False]] + + [(64, 64, 64, 4, False, False, 'chain-dot', 'ieee', 'bfloat16', 'float32', 1)] + + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1) + for float8_type in ["float8e5", "float8e4nv"]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, num_ctas, device): + if is_interpreter(): + if in_dtype == 'bfloat16': + pytest.skip("bfloat16 is not supported in the interpreter") + else: + if is_cuda(): + capability = torch.cuda.get_device_capability() + + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + if capability[0] < 8: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if input_precision != "ieee": + pytest.skip("Only test tf32 on devices with sm >= 80") + if capability[0] == 7: + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: + pytest.skip("shared memory out of resource") + if out_dtype == 'float16': + # TODO: support out_dtype=float16 for tl.dot on V100 + pytest.skip("Only test out_dtype=float16 on devices with sm >=80") + if capability[0] < 9 and in_dtype == 'float8e4nv': + pytest.skip("float8e4nv not supported on sm <= 80") + if is_hip() and (in_dtype == 'float8e4nv' or in_dtype == 'float8e5'): + pytest.skip("float8e4nv and float8e5 not supported on HIP") + if is_hip() and (input_precision != "ieee"): + pytest.skip(f"{input_precision} not supported on HIP") + if is_hip() and (kpack == 2 and in_dtype == 'int8' and K < 64): + pytest.skip("kpack too large for K") + if not is_hip() and kpack == 2: + pytest.skip("Skip duplicated tests on nv path") + + torch.backends.cuda.matmul.allow_tf32 = input_precision == "tf32" + + if num_ctas > 1 and in_dtype == 'int8': + # FIXME: mma v2 with num_ctas > 1 does not work + pytest.skip() + + # triton kernel + @triton.jit + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, INPUT_PRECISION: tl.constexpr, DO_SOFTMAX: tl.constexpr, + CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk + Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl + Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn + x = tl.load(Xs) + y = tl.load(Ys) + z = tl.dot(x, y, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + if ADD_MATRIX: + z += tl.load(Zs) + if ADD_ROWS: + ZRs = Z + off_m * stride_zm + z += tl.load(ZRs)[:, None] + if ADD_COLS: + ZCs = Z + off_n * stride_zn + z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z.to(tl.float32)).to(max.dtype) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + w = tl.load(Ws) + z = tl.dot(z.to(w.dtype), w, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + tl.store(Zs, z) + + # input + rs = RandomState(17) + if col_a: + x = numpy_random((K, M), dtype_str=in_dtype, rs=rs).T + else: + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + if col_b: + y = numpy_random((N, K), dtype_str=in_dtype, rs=rs).T + else: + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + w = numpy_random((N, N), dtype_str=in_dtype, rs=rs) + if 'int' not in in_dtype and 'float8' not in in_dtype: + x *= .1 + y *= .1 + if in_dtype == 'float32' and input_precision == "tf32": + x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') + y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') + x_tri = to_triton(x, device=device, dst_type=in_dtype) + y_tri = to_triton(y, device=device, dst_type=in_dtype) + w_tri = to_triton(w, device=device, dst_type=in_dtype) + # triton result + if out_dtype == 'int8': + z = 1 + numpy_random((M, N), dtype_str='int32', rs=rs) + else: + z = 1 + numpy_random((M, N), dtype_str=in_dtype, rs=rs) * .1 + + z_tri = to_triton(z, device=device) + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) + + if out_dtype == 'int8': + out_dtype = tl.int8 + elif out_dtype == 'float16' and epilogue != 'softmax': + # TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will + # fail with the following error: 'llvm.fmul' op requires the same type + # for all operands and results + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + kern_kwargs = { + 'COL_A': col_a, 'COL_B': col_b, 'BLOCK_M': M, 'BLOCK_K': K, 'BLOCK_N': N, 'ADD_MATRIX': + epilogue == 'add-matrix', 'ADD_ROWS': epilogue == 'add-rows', 'ADD_COLS': epilogue == 'add-cols', 'DO_SOFTMAX': + epilogue == 'softmax', 'CHAIN_DOT': epilogue == 'chain-dot', 'INPUT_PRECISION': input_precision, 'num_warps': + num_warps, 'num_ctas': num_ctas, 'out_dtype': out_dtype + } + + if is_hip(): + kern_kwargs['kpack'] = kpack + + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, + w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), **kern_kwargs) + + if epilogue == 'softmax' and (in_dtype != 'float32' or input_precision == "tf32"): + if not is_cuda(): + pass + else: + ptx = pgm.asm["ptx"] + start = ptx.find("shfl.sync.bfly") + end = ptx.find("cvt.rn.f16.f32") + red_code = ptx[start:end] + assert len(red_code) > 0 + + # skip this check on hopper because there are some functions whose name contain "shared" in ptx. + # TODO: we should eliminate these unused functions in ptx code. + if not (capability[0] >= 9): + assert "shared" not in red_code + assert "bar.sync" not in red_code + # torch result + if in_dtype == 'int8': + z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) + elif 'float8' in in_dtype: + x = convert_fp8_to_fp32(x, device, in_dtype) + y = convert_fp8_to_fp32(y, device, in_dtype) + z_ref = to_numpy(torch.matmul(x, y)) + else: + z_ref = np.matmul(x, y) + + if epilogue == 'add-matrix': + z_ref += z + if epilogue == 'add-rows': + z_ref += z[:, 0][:, None] + if epilogue == 'add-cols': + z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + if 'float8' in in_dtype: + w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) + z_ref = np.matmul(z_ref, w) + # compare + if in_dtype == 'float32': + # XXX: Somehow there's a larger difference when we use float32 + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + elif out_dtype == tl.float16 or in_dtype == 'bfloat16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + # added atol, to loose precision for float16xfloat16->float32 case + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + if not is_cuda(): + return + # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] + if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): + # XXX: skip small sizes because they are not vectorized + assert 'ld.global.v4' in ptx + if 'float8' in in_dtype: + assert 'st.global.v2' in ptx + else: + assert 'st.global.v4' in ptx + if in_dtype == 'float32' and input_precision != "ieee": + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float32: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx) + elif in_dtype == 'float16' and out_dtype == tl.float16: + if capability[0] == 7 and capability[1] == 5: # Turing + assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx) + else: + assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) + elif in_dtype == 'int8': + if capability[0] == 7 and capability[1] == 5: # Turing + assert 'mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32' in ptx + else: + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + elif in_dtype == "float8e5" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2' in ptx + elif in_dtype == "float8e4nv" and out_dtype == tl.float32: + if capability[0] == 9: + assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("B", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("M, N, K", [(64, 64, 64), (32, 32, 32)]) +@pytest.mark.parametrize("in_dtype_str, out_dtype_str", [('int8', 'int8'), ('float16', 'float16'), + ('float16', 'float32'), ('float32', 'float32')]) +def test_dot3d(B, num_warps, M, N, K, in_dtype_str, out_dtype_str, device): + if is_hip(): + # hip does not support tf32 precision, so use ieee for all tests + input_precision = "ieee" + if "gfx11" in triton.runtime.driver.active.get_current_target().arch: + if in_dtype_str == "float32": + pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") + if out_dtype_str == "float16": + pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + else: + input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" + + @triton.jit + def kernel( + q_ptr, + k_ptr, + o_ptr, + stride_qb, + stride_qm, + stride_qk, + stride_kb, + stride_kk, + stride_kn, + stride_ob, + stride_om, + stride_on, + BLOCK_B: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + INPUT_PRECISION: tl.constexpr, + out_dtype: tl.constexpr = tl.float32, + ): + startm = tl.program_id(0) * BLOCK_M + startn = tl.program_id(1) * BLOCK_N + offs_b = tl.arange(0, BLOCK_B) + offs_m = startm + tl.arange(0, BLOCK_M) + offs_n = startn + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + q_ptrs = q_ptr + offs_b[:, None, None] * stride_qb + offs_m[None, :, None] * stride_qm + offs_k[ + None, None, :] * stride_qk + k_ptrs = k_ptr + offs_b[:, None, None] * stride_kb + offs_k[None, :, None] * stride_kk + offs_n[ + None, None, :] * stride_kn + q = tl.load(q_ptrs) + k = tl.load(k_ptrs) + qk = tl.dot(q, k, input_precision=INPUT_PRECISION, out_dtype=out_dtype) + o_ptrs = o_ptr + offs_b[:, None, None] * stride_ob + offs_m[None, :, None] * stride_om + offs_n[ + None, None, :] * stride_on + tl.store(o_ptrs, qk) + + if out_dtype_str == 'int8': + out_dtype = tl.int8 + elif out_dtype_str == 'float16': + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + + rs = RandomState(17) + x = numpy_random((B, M, K), dtype_str=in_dtype_str, rs=rs) + y = numpy_random((B, K, N), dtype_str=in_dtype_str, rs=rs) + if in_dtype_str == 'int8': + out = numpy_random((B, M, N), dtype_str='int32', rs=rs) + else: + out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) + + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + out_tri = to_triton(out, device=device) + + BLOCK_B = B + BLOCK_M, BLOCK_N = 32, 32 + BLOCK_K = K + + grid = ( + triton.cdiv(M, BLOCK_M), + triton.cdiv(N, BLOCK_N), + ) + kernel[grid]( + x_tri, + y_tri, + out_tri, + x_tri.stride(0), + x_tri.stride(1), + x_tri.stride(2), + y_tri.stride(0), + y_tri.stride(1), + y_tri.stride(2), + out_tri.stride(0), + out_tri.stride(1), + out_tri.stride(2), + BLOCK_B=BLOCK_B, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + INPUT_PRECISION=input_precision, + out_dtype=out_dtype, + num_warps=num_warps, + ) + + if in_dtype_str == 'int8': + out_ref = np.matmul(x.astype(np.float32), y.astype(np.float32)).astype(np.int32) + else: + out_ref = np.matmul(x, y) + np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +def test_max_num_imprecise_acc(device): + + if not hasattr(torch, 'float8_e5m2'): + pytest.skip(f"torch {torch.__version__} does not support float8_e5m2") + + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability != (9, 0): + return + + @triton.jit + def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + MAX_NUM_IMPRECISE_ACC: tl.constexpr): + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + off_k = tl.arange(0, BLOCK_K) + x = tl.load(X + off_m[:, None] * BLOCK_K + off_k[None, :]) + y = tl.load(Y + off_k[:, None] * BLOCK_N + off_n[None, :]) + z = tl.load(Z + off_m[:, None] * BLOCK_N + off_n[None, :]) + z = tl.dot(x, y, acc=z, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC) + tl.store(Z + off_m[:, None] * BLOCK_N + off_n[None, :], z) + + M, N, K, num_warps, MAX_NUM_IMPRECISE_ACC = 128, 128, 128, 4, 64 + x = torch.zeros((M, K), dtype=torch.float8_e5m2, device=device) + y = torch.zeros((K, N), dtype=torch.float8_e5m2, device=device) + z = torch.zeros((M, N), dtype=torch.float32, device=device) + h = kernel[(1, 1)](x, y, z, M, N, K, MAX_NUM_IMPRECISE_ACC, num_warps=num_warps) + if not is_cuda(): + return + assert h.asm["ptx"].count("add.f32") == (M * N) // (32 * num_warps) * (K / MAX_NUM_IMPRECISE_ACC) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize('in_dtype', ['float32']) +def test_dot_mulbroadcasted(in_dtype, device): + if is_cuda(): + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + pytest.skip("Requires sm >= 80 to run") + + @triton.jit + def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.constexpr, BN: tl.constexpr, + BK: tl.constexpr): + pidn = tl.program_id(1) + pidm = tl.program_id(0) + offm = tl.arange(0, BM)[:, None] + offn = tl.arange(0, BN)[None, :] + offak = tl.arange(0, BK)[None, :] + offbk = tl.arange(0, BK)[:, None] + acc = tl.full((BM, BN), 0.0, tl.float32) + for ridx5 in range(0, K // BK): + x = tl.load(X + ((pidm * K * BM) + (offm * K) + (ridx5 * BK) + offak)) + y = tl.load(Y + ((pidn * BN) + (offbk * N) + (ridx5 * N * BK) + offn)) + x = tl.expand_dims(x, axis=2) + y = tl.expand_dims(y, axis=0) + t = tl.sum(x * y, axis=1) + acc = t + acc + tl.store(Z + ((pidm * BM * N) + (pidn * BN) + (offm * N) + offn), acc) + + M, N, K = 256, 192, 160 + BM, BN, BK = 128, 32, 32 + rs = RandomState(17) + x = numpy_random((M, K), dtype_str=in_dtype, rs=rs) + y = numpy_random((K, N), dtype_str=in_dtype, rs=rs) + x = x * 0.1 + y = y * 0.1 + z = numpy_random((M, N), dtype_str=in_dtype, rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(z, device=device) + grid = M // BM, N // BN + h = kernel[grid](z_tri, x_tri, y_tri, M, N, K, BM, BN, BK) + z_ref = np.matmul(x, y) + np.testing.assert_allclose(z_ref, to_numpy(z_tri), atol=0.01) + + if not is_cuda(): + return + assert "tt.dot" in h.asm['ttir'] + # When using MMAv3, we will not pipeline the load op for Y, as the loaded + # value is in rowmajor. But MMAv3 requires its second operand is in colmajor + # because transpose is not supported for MMAv3 with float32 input. + if capability[0] >= 9: + assert re.search(r"triton_gpu.async_wait %.* {num = 1 : i32}", h.asm["ttgir"]) is not None + else: + assert re.search(r"triton_gpu.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) +@pytest.mark.parametrize("shape", [(), (1, ), (128, )]) +def test_full(dtype_str, shape, device): + if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): + # PyTorch only has unsigned 8, but not 16, 32, or 64 + dtype = getattr(torch, dtype_str[1:]) # uintx -> intx + else: + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + @triton.jit + def kernel_static(out): + a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + @triton.jit + def kernel_dynamic(out, val, dtype: tl.constexpr): + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) + out_ptr = out + tl.arange(0, 128)[:] + tl.store(out_ptr, a) + + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) + out_static = torch.zeros((128), dtype=dtype, device=device) + kernel_static_patched[(1, )](out_static) + assert torch.all(out_static == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device=device) + kernel_dynamic_patched[(1, )](out_dynamic, 2, getattr(triton.language, dtype_str)) + assert torch.all(out_dynamic == 2) + + +@pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), + ('float("-inf")', "f32"), ('float("nan")', "f32"), + ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) +def test_constexpr(literal, dtype_str, device): + + @triton.jit + def kernel(out_ptr): + val = GENERATE_TEST_HERE + tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val) + + kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"}) + out = torch.zeros((1, ), dtype=torch.float32, device=device) + h = kernel_patched[(1, )](out) + assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None + + +@triton.jit +def pass_const(a, b, choose_b): + if choose_b: + return b + else: + return a + + +@pytest.mark.parametrize("choose_const", [True, False]) +@pytest.mark.parametrize("constexpr", [True, False]) +@pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) +def test_const(device, choose_const, constexpr, mode): + + @triton.jit(do_not_specialize=["choose_const"]) + def kernel(in_ptr: tl.const, out, c_out: tl.const, choose_const, n_elems: tl.int32, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + @triton.jit + def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.constexpr, n_elems: tl.int32, + BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elems + val = tl.load(in_ptr + offsets, mask=mask) + LOSE_TAIL + tl.store(final_out + offsets, val, mask=mask) + + if mode == "direct": + if choose_const: + LOSE_TAIL = "final_out = c_out" + else: + LOSE_TAIL = "final_out = out" + elif mode == "call": + LOSE_TAIL = "final_out = pass_const(out, c_out, choose_const)" + elif mode == "ternary": + LOSE_TAIL = "final_out = c_out if choose_const else out" + elif mode == "if": + LOSE_TAIL = """ + if choose_const: + final_out = c_out + else: + final_out = out +""" + + SIZE = 128 + input = torch.randn((SIZE, ), dtype=torch.float32, device=device) + output = torch.zeros((SIZE, ), dtype=torch.float32, device=device) + patched_kernel = patch_kernel(kernel_constexpr if constexpr else kernel, {'LOSE_TAIL': LOSE_TAIL, 'CONSTEXPR': ''}) + + expect_fail = (not constexpr and mode != "direct") or choose_const + if expect_fail: + with pytest.raises(triton.CompilationError) as exc_info: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + else: + patched_kernel[(1, )](input, output, output, choose_const, SIZE, SIZE) + assert torch.all(input == output) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ['float32', 'float16']) +def test_dot_without_load(dtype_str, device): + + @triton.jit + def _kernel(out): + a = GENERATE_TEST_HERE + b = GENERATE_TEST_HERE + c = tl.dot(a, b) + out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) + a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) + out_ref = torch.matmul(a, b) + out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](out) + assert torch.all(out == out_ref) + + +# --------------- +# test arange +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("start", [0, 1, 7, 16]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_arange(start, num_ctas, device): + BLOCK = 128 + z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): + off = tl.arange(0, BLOCK) + val = tl.arange(START, END) + tl.store(z + off, val) + + _kernel[(1, )](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) + z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) + np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) + + +# --------------- +# test load +# --------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) + for dtype_str in torch_dtypes + for size in [128, 512] + for size_diff in [0, 1, 2, 3, 4] + for other in [0, 1]]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_masked_load(dtype_str, size, size_diff, other, num_ctas, device): + dtype = getattr(torch, dtype_str) + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + input_size = size - size_diff + output_size = size + if dtype_str == 'bool': + input = torch.randint(0, 2, (input_size, ), dtype=dtype, device=device) + elif dtype_str in int_dtypes or dtype_str in uint_dtypes: + input = torch.randint(0, 127, (input_size, ), dtype=dtype, device=device) + else: + input = torch.rand(input_size, dtype=dtype, device=device) + output = torch.zeros((output_size, ), dtype=dtype, device=device) + + @triton.jit + def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): + in_offsets = tl.arange(0, out_size) + # Load inputs. + x = GENERATE_TEST_HERE + # Store output + output_offsets = tl.arange(0, out_size) + tl.store(out_ptr + output_offsets, x) + + mask_str = f"mask=in_offsets < in_size, other={other}" if size_diff > 0 else "None" + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) + kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas) + + reference_out = torch.cat((input, torch.full((size_diff, ), other, dtype=dtype, device=device))) + torch.testing.assert_close(output, reference_out) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("num_ctas", num_ctas_list) +@pytest.mark.parametrize("mask_val", [True, False]) +@pytest.mark.parametrize("other_val", [0, 1]) +def test_masked_load_scalar(num_ctas, mask_val, other_val, device): + input_val = 4.0 + size = 128 + dtype = torch.float32 + input = torch.full((size, ), input_val, dtype=dtype, device=device) + output = torch.zeros((size, ), dtype=dtype, device=device) + + @triton.jit + def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.constexpr): + offsets = tl.arange(0, size) + x = tl.load(in_ptr + offsets, mask=mask, other=other) + tl.store(out_ptr + offsets, x) + + kernel[(1, )](input, output, size, mask_val, other_val, num_ctas=num_ctas) + + if mask_val: + reference_out = torch.full((size, ), input_val, dtype=dtype, device=device) + else: + reference_out = torch.full((size, ), other_val, dtype=dtype, device=device) + + torch.testing.assert_close(output, reference_out) + + +# Testing masked loads with an intermate copy to shared memory run. +# FIXME: Shape too small for ldmatrix when num_ctas=4 +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_masked_load_shared_memory(dtype, device): + + check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested + + M = 32 + N = 32 + K = 16 + + in1 = torch.rand((M, K), dtype=dtype, device=device) + in2 = torch.rand((K, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=dtype, device=device) + + @triton.jit + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + + M_offsets = tl.arange(0, M) + N_offsets = tl.arange(0, N) + K_offsets = tl.arange(0, K) + + in_offsets = M_offsets[:, None] * in_stride + K_offsets[None, :] + in2_offsets = K_offsets[:, None] * in2_stride + N_offsets[None, :] + + # Load inputs. + x = tl.load(in1_ptr + in_offsets, mask=in_offsets < M * K) + w = tl.load(in2_ptr + in2_offsets, mask=in2_offsets < K * N) + + # Without a dot product the memory doesn't get promoted to shared. + o = tl.dot(x, w, out_dtype=tl.float32) + + # Store output + output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] + tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) + + pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), + out.numel(), M=M, N=N, K=K) + + reference_out = torch.matmul(in1, in2) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) +def test_load_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets, cache_modifier=CACHE) + tl.store(dst + offsets, x) + + pgm = _kernel[(1, )](dst, src, CACHE=cache) + if not is_cuda(): + return + + ptx = pgm.asm['ptx'] + if cache == '': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cg' not in ptx + if cache == '.cg': + assert 'ld.global.cg' in ptx + assert 'ld.global.ca' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cg' not in ptx + + +@pytest.mark.interpreter +@pytest.mark.parametrize("N", [16, 10, 11, 1024]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_vectorization(N, num_ctas, device): + block_size = 1024 * num_ctas + src = torch.empty(block_size, device=device) + dst = torch.empty(block_size, device=device) + + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, N=N, BLOCK_SIZE=block_size) + + if not is_cuda(): + # ===-------------------- For Triton XPU -----------------------=== + llir = str(pgm.asm["llir"]) + if llir and N % 16 == 0: + assert "16 x float" in llir + # ===-----------------------------------------------------------=== + return + + ptx = pgm.asm["ptx"] + if N % 16 == 0: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.b32" in ptx + # np.testing.assert_allclose(dst, src[:N]) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("has_hints", [False, True]) +def test_vectorization_hints(has_hints, device): + src = torch.empty(1024, device=device) + dst = torch.empty(1024, device=device) + off = torch.zeros(1, device=device, dtype=torch.int32) + + @triton.jit + def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offsets = offsets + tl.load(off) + if HINT: + tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + pgm = _kernel[(1, )](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + if not is_cuda(): + return + + ptx = pgm.asm["ptx"] + if has_hints: + assert "ld.global.v4.b32" in ptx + else: + assert "ld.global.v4.b32" not in ptx + + +# --------------- +# test store +# --------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"]) +def test_store_cache_modifier(cache, device): + src = torch.empty(128, device=device) + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst, src, CACHE: tl.constexpr): + offsets = tl.arange(0, 128) + x = tl.load(src + offsets) + tl.store(dst + offsets, x, cache_modifier=CACHE) + + if not is_cuda(): + return + pgm = _kernel[(1, )](dst, src, CACHE=cache) + ptx = pgm.asm['ptx'] + if cache == '': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.wb': + assert 'st.global.wb' in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cg': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' not in ptx + if cache == '.cs': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' in ptx + assert 'st.global.wt' not in ptx + if cache == '.wt': + assert 'st.global.wb' not in ptx + assert 'st.global.cg' not in ptx + assert 'st.global.cs' not in ptx + assert 'st.global.wt' in ptx + + +# --------------- +# test default +# --------------- +# TODO: can't be local to test_default + + +@triton.jit +def _impl(value=10): + return value + + +@pytest.mark.interpreter +def test_default(device): + value = 5 + ret0 = torch.zeros(1, dtype=torch.int32, device=device) + ret1 = torch.zeros(1, dtype=torch.int32, device=device) + + @triton.jit + def _kernel(ret0, ret1, value=3): + tl.store(ret0, _impl()) + tl.store(ret1, _impl(value)) + + _kernel[(1, )](ret0, ret1, value) + assert ret0.item() == 10 + assert ret1.item() == value + + _kernel[(1, )](ret0, ret1) + assert ret0.item() == 10 + assert ret1.item() == 3 + + +# --------------- +# test noop +# ---------------- + + +@pytest.mark.interpreter +def test_noop(device): + + @triton.jit + def kernel(x): + pass + + x = to_triton(numpy_random((1, ), dtype_str='int32'), device=device) + kernel[(1, )](x) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) +def test_pointer_arguments(device): + + @triton.jit + def kernel(x): + pass + + pin_memory = 'pinned' in device + x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) + if device == "cpu": + with pytest.raises(ValueError): + kernel[(1, )](x) + else: + kernel[(1, )](x) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) +def test_value_specialization(value: int, value_type: str, device) -> None: + + def repr(specialization): + spec_type = specialization.signature["VALUE"] + return f"kernel_{spec_type}" + + @triton.jit(repr=repr) + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + h = kernel[(1, )](value, x) + assert value_type in h.name + + +# -------------------- +# value specialization +# -------------------- + + +@pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) +def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: + + @triton.jit + def kernel(VALUE, X): + pass + + x = torch.tensor([3.14159], device=device) + + if overflow: + with pytest.raises(OverflowError): + kernel[(1, )](value, x) + else: + kernel[(1, )](value, x) + + +# ---------------- +# test constexpr +# ---------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) +@pytest.mark.parametrize("is_lhs_constexpr", [False, True]) +@pytest.mark.parametrize("is_rhs_constexpr", [True, False]) +def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): + + @triton.jit + def kernel(Z, X, Y): + x = tl.load(X) + y = tl.load(Y) + z = GENERATE_TEST_HERE + tl.store(Z, z) + + if op in ['<<', '>>', '&', '^', '|']: # int op + x_str = "3" if is_lhs_constexpr else "x" + y_str = "4" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="int32") + + # NOTE: bitshifting beyond bitwidth can lead to undefined behavior + if op in ['<<', '>>']: + y = numpy_random((1, ), dtype_str="int32", low=0, high=_bitwidth("int32")) + else: + y = numpy_random((1, ), dtype_str="int32") + else: + x_str = "3.14" if is_lhs_constexpr else "x" + y_str = "4.13" if is_rhs_constexpr else "y" + x = numpy_random((1, ), dtype_str="float32") + y = numpy_random((1, ), dtype_str="float32") + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) + z = np.array(eval(f"{x_str} {op} {y_str}")) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + z_tri = to_triton(np.empty((1, ), dtype=z.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri) + np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) + + +@pytest.mark.interpreter +def test_constexpr_shape(device): + + @triton.jit + def kernel(X): + off = tl.arange(0, 128 + 128) + tl.store(X + off, off) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) + + +@pytest.mark.interpreter +def test_constexpr_scalar_shape(device): + + @triton.jit + def kernel(X, s): + off = tl.arange(0, 256) + val = off % (256 // s) + tl.store(X + off, val) + + x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) + kernel[(1, )](x_tri, 32) + np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) + + +reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("formats", reshape_list) +def test_reshape(formats, device): + in_format, out_format = formats + + @triton.jit + def kernel(Z, X, out_tuple: tl.constexpr): + x = tl.load(X_PTR_EXPR) + z = tl.reshape(x, out_tuple) + tl.store(Z_PTR_EXPR, z) + + def generate_kernel(shape_x, shape_z): + to_replace = { + 'X_PTR_EXPR': make_ptr_str('X', shape_x), + 'Z_PTR_EXPR': make_ptr_str('Z', shape_z), + } + return patch_kernel(kernel, to_replace) + + x = numpy_random(in_format, dtype_str="int32") + z = x.reshape(out_format) + x_tri = to_triton(x, device=device) + patched_kernel = generate_kernel(in_format, out_format) + z_tri = to_triton(np.empty(out_format, dtype=np.int32), device=device) + patched_kernel[(1, )](z_tri, x_tri, out_format) + np.testing.assert_equal(z, to_numpy(z_tri)) + + +def test_reshape_err(device): + + @triton.jit + def kernel(): + x = tl.arange(0, 8 * 8) + y = tl.reshape(x, (8 * 4, )) + + with pytest.raises(triton.CompilationError) as exc_info: + kernel[(1, )]() + + assert "reshape" in str(exc_info.value) + + +@pytest.mark.skip("Skip for kunlunxin") +def test_trans_reshape(device): + + @triton.jit + def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.constexpr): + + in_block_ptr = tl.make_block_ptr( + base=in_base_ptr, + shape=(IN_SHAPE0, IN_SHAPE1), + strides=(IN_SHAPE1, 1), + offsets=(0, 0), + block_shape=(IN_SHAPE0, IN_SHAPE1), + order=(1, 0), + ) + x = tl.load(in_block_ptr) + x = tl.reshape(x, (32, 4, 4, 2)) + x = tl.permute(x, (1, 2, 3, 0)) + x = tl.reshape(x, (IN_SHAPE0 * IN_SHAPE1, )) + tl.store(out_base_ptr + tl.arange(0, IN_SHAPE0 * IN_SHAPE1), x) + + shape = (32, 32) + input = torch.arange(math.prod(shape), dtype=torch.int32, device=device).reshape(shape) + expected = torch.permute(input, (1, 0)) + # Don't do zeros_like -- that copies the layout, which we don't want. + actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) + + k = kernel[(1, )](input, actual, shape[0], shape[1]) + assert k.asm['ttgir'].count( + 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + + np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) + + +# ------------- +# test call +# ------------- + + +@triton.jit +def val_multiplier(val, i): + return val * i + + +@triton.jit(noinline=True) +def val_multiplier_noinline(val, i): + return val * i + + +@triton.jit +def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): + pid = tl.program_id(axis=0) + offsets = pid * 128 + tl.arange(0, 128) + mask = offsets < n_elements + vec = tl.load(ptr + offsets, mask=mask) + for i in range(1, rep): + if type == "inline": + vec = val_multiplier(vec, i) + else: + vec = val_multiplier_noinline(vec, i) + tl.store(ptr + offsets, vec, mask=mask) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("type", ["inline", "noinline"]) +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_call(type, num_ctas, device): + + @triton.jit + def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): + vecmul_kernel(ptr, n_elements, num1, type) + vecmul_kernel(ptr, n_elements, num2, type) + + size = 1024 + rand_val = numpy_random((size, ), dtype_str="float32") + rand_val_tri = to_triton(rand_val, device=device) + err_msg = "" + try: + kernel[(size // 128, )](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) + except Exception as e: + err_msg = str(e) + + if type == "noinline" and not is_interpreter(): + assert err_msg != "" + else: + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) + + +# ------------- +# test if +# ------------- + + +@pytest.mark.interpreter +@pytest.mark.parametrize("if_type", [ + "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", + "if_and_static" +]) +def test_if(if_type, device): + + @triton.jit + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr): + pid = tl.program_id(0) + cond = tl.load(Cond) + if IfType == "if": + if pid % 2 == 0: # eq + tl.store(Ret, tl.load(XTrue)) + elif 1 == pid % 2: # req + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_dynamic": + val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_constexpr": + val = 3.14 if pid % 2 == 0 else tl.load(XFalse) + tl.store(Ret, val) + elif IfType == "if_exp_dynamic_void": + tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_exp_static": + tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_dynamic": + if BoolVar and (1 != pid % 2 and pid % 2 != 1): # rne and ne + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and_static": + if StaticVaue != 0 and StaticVaue != 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) + + cond = torch.ones(1, dtype=torch.int32, device=device) + x_true = torch.tensor([3.14], dtype=torch.float32, device=device) + x_false = torch.tensor([1.51], dtype=torch.float32, device=device) + ret = torch.zeros(1, dtype=torch.float32, device=device) + + kernel[(1, )](cond, x_true, x_false, ret, if_type, True, 1) + assert torch.equal(ret, x_true) + + +@pytest.mark.skip("Skip for kunlunxin") +def test_num_warps_pow2(device): + dst = torch.empty(128, device=device) + + @triton.jit + def _kernel(dst): + pass + + with pytest.raises(AssertionError, match='must be a power of 2'): + _kernel[(1, )](dst=dst, num_warps=3) + _kernel[(1, )](dst=dst, num_warps=1) + _kernel[(1, )](dst=dst, num_warps=2) + _kernel[(1, )](dst=dst, num_warps=4) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_str", ['sqrt', 'rsqrt', 'exp', 'exp2', 'log', 'log2', 'sin', 'cos']) +def test_unary_math(func_str, device): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.FUNC_STR(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + kernel = patch_kernel(kernel, {'FUNC_STR': func_str}) + + shape = (128, ) + x = torch.randn(shape, dtype=torch.float32, device=device) + if func_str in ['sqrt', 'rsqrt']: + x = torch.abs(x) + if func_str in ['log', 'log2']: + x = torch.max(x, torch.tensor(1e-6, dtype=torch.float32, device=device)) + y = torch.zeros(shape, dtype=torch.float32, device=device) + + kernel[(1, )](x, y, BLOCK=shape[0]) + torch.allclose(getattr(torch, func_str)(x), y, rtol=1e-3) + + +# ----------------------- +# test inline asm +# ----------------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = tl.load(Y + tl.arange(0, BLOCK)) + s = tl.full([BLOCK], n, tl.int32) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, + is_pure=True, pack=1) + tl.store(Z + tl.arange(0, BLOCK), z) + + shape = (128, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint32', rs=rs) + y = numpy_random(shape, dtype_str='uint32', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(y, device=device) + n = 17 + z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = (y << n) | (x >> (32 - n)) + # compare + np.testing.assert_equal(y_ref, to_numpy(z_tri)) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("num_ctas", num_ctas_list) +def test_inline_asm_packed(num_ctas, device): + if not is_cuda(): + pytest.skip("test_inline_asm is only supported in CUDA") + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # shift 4x8bits values together. + y = tl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", "=r,r", [ + x, + ], dtype=tl.int8, is_pure=True, pack=4) + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize('num_ctas', num_ctas_list) +def test_inline_asm_with_pointers(num_ctas, device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x_ptrs = X + tl.arange(0, BLOCK) + y_ptrs = Y + tl.arange(0, BLOCK) + tl.inline_asm_elementwise( + "ld.global.b8 $0, [$1]; \ + shl.b32 $0, $0, 3; \ + st.global.b8 [$2], $0;", "=r,l,l", [x_ptrs, y_ptrs], dtype=tl.int8, is_pure=False, + pack=1) + + shape = (512, ) + rs = RandomState(17) + x = numpy_random(shape, dtype_str='uint8', rs=rs) + x_tri = to_triton(x, device=device) + y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + y_ref = x << 3 + # compare + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + + +@pytest.mark.skip("Skip for kunlunxin") +def test_inline_asm_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # C = A - B + # D = B - A + (c, d) = tl.inline_asm_elementwise( + asm=""" + sub.u32 $0, $2, $3; // C = A - B + sub.u32 $1, $3, $2; // D = B - A + """, + constraints=( + # 2 output registers: $0=C and $1=D. + "=r,=r," + # 2 input registers: $2=A and $3=B. + "r,r"), + args=[a, b], + dtype=(tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint32', rs=rs) + B = numpy_random(shape, dtype_str='uint32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A - B + D_ref = B - A + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +@pytest.mark.skip("Skip for kunlunxin") +def test_inline_asm_packed_multiple_outputs(device): + if not is_cuda(): + pytest.skip('test_inline_asm is only supported in CUDA') + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) + b = tl.load(B + tl.arange(0, BLOCK)) + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + shape = (512, ) + rs = RandomState(17) + A = numpy_random(shape, dtype_str='uint8', rs=rs) + B = numpy_random(shape, dtype_str='float32', rs=rs) + A_tri = to_triton(A, device=device) + B_tri = to_triton(B, device=device) + C_tri = to_triton(numpy_random(shape, dtype_str='int32', rs=rs), device=device) + D_tri = to_triton(numpy_random(shape, dtype_str='float32', rs=rs), device=device) + kernel[(1, )](A_tri, B_tri, C_tri, D_tri, BLOCK=shape[0]) + + C_ref = A.astype(np.int32) + D_ref = np.maximum(A.astype(np.float32), B) + + np.testing.assert_equal(C_ref, to_numpy(C_tri)) + np.testing.assert_equal(D_ref, to_numpy(D_tri)) + + +# ----------------------- +# test control flow +# ----------------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), + (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) +def test_for_iv(lo, hi, iv, device): + + @triton.jit + def kernel(Out, lo, hi, iv: tl.constexpr): + acc = 0 + acc = acc.to(tl.int64) + for i in range(lo, hi, iv): + acc += i + tl.store(Out, acc) + + lo = 2**35 + hi = 2**35 + 20 + out = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + kernel[(1, )](out, lo, hi, iv) + assert out[0] == sum(range(lo, hi, iv)) + + +@pytest.mark.interpreter +def test_if_else(device): + + @triton.jit + def kernel(Cond, TrueVal, FalseVal, Out): + if tl.load(Cond): + val = tl.load(TrueVal) + else: + val = tl.load(FalseVal) + tl.store(Out, val) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + true_val = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + false_val = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + cond = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # True + cond[0] = True + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == true_val[0] + # False + cond[0] = False + kernel[(1, )](cond, true_val, false_val, out) + assert to_numpy(out)[0] == false_val[0] + + +@pytest.mark.interpreter +@pytest.mark.parametrize("mode", ["dynamic", "static"]) +def test_if_return(mode, device): + + @triton.jit + def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): + if mode == "dynamic": + if tl.load(ExitEarly): + tl.store(Out, 0) + return + else: + if cond: + tl.store(Out, 0) + return + tl.store(Out, 1) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + exit_early = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + # exit early path taken + exit_early[0] = 1 + kernel[(1, )](exit_early, out, True, mode) + assert to_numpy(out)[0] == 0 + # exit early path not taken + exit_early[0] = 0 + kernel[(1, )](exit_early, out, False, mode) + assert to_numpy(out)[0] == 1 + + +@triton.jit +def add_fn(x): + return x + 1 + + +@triton.jit(noinline=True) +def add_fn_noinline(x): + return x + 1 + + +@triton.jit +def add_fn_return(x, pid): + if pid == 0: + return x + 1 + else: + return x + 2 + + +@triton.jit +def add_fn_expr(Out, x): + tl.store(Out, x) + + +@triton.jit +def add_fn_static_cond(x, cond: tl.constexpr): + if cond == "": + return x + else: + return x + 1 + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize( + "call_type", + ["attribute", "attribute_jit", "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) +def test_if_call(call_type, device): + + @triton.jit + def kernel(Out, call_type: tl.constexpr): + pid = tl.program_id(0) + o = tl.load(Out) + if call_type == "attribute": + # call attribute + if pid == 0: + a = o + a = a.to(tl.int32).to(tl.int32) + 1 + o = a + elif call_type == "attribute_jit": + # call attribute and jit function + if pid == 0: + a = o + a = tl.load(Out + add_fn(a) - 1).to(tl.int32) + 1 + o = a + elif call_type == "jit": + if pid == 0: + # regular function call + a = o + a = add_fn(a) + o = a + elif call_type == "jit_if": + # function without end_if block + if pid == 0: + a = o + a = add_fn_return(a, pid) + o = a + elif call_type == "jit_if_exp": + # ifexp expression + if pid == 0: + a = o + a = add_fn(a) if pid == 0 else add_fn_return(a, pid) + o = a + elif call_type == "jit_expr": + # call without return + if pid == 0: + a = o + 1 + add_fn_expr(Out, a) + o = a + elif call_type == "jit_static_cond": + if pid == 0: + a = o + 1 + add_fn_static_cond(o, call_type) + o = a + elif call_type == "jit_noinline": + if pid == 0: + a = o + 1 + add_fn_noinline(a) + o = a + elif call_type == "jit_extern": + if pid == 0: + a = o + 1 + tl.cdiv(a, a) + o = a + + tl.store(Out, o) + + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + kernel[(1, )](out, call_type) + assert to_numpy(out)[0] == 1 + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("_cond1", [True, False]) +@pytest.mark.parametrize("_cond2", [True, False]) +@pytest.mark.parametrize("_cond3", [True, False]) +def test_nested_if_else_return(_cond1, _cond2, _cond3, device): + + @triton.jit + def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): + val = 0 + if tl.load(Cond1): + if tl.load(Cond2): + val = tl.load(Val1) + else: + return + else: + if tl.load(Cond3): + val = tl.load(Val2) + else: + val = tl.load(Val3) + tl.store(Out, val) + + out = to_triton(np.full((1, ), -1, dtype=np.int32), device=device) + cond1 = to_triton(np.full((1, ), _cond1, dtype=np.int32), device=device) + cond2 = to_triton(np.full((1, ), _cond2, dtype=np.int32), device=device) + cond3 = to_triton(np.full((1, ), _cond3, dtype=np.int32), device=device) + val1 = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + val2 = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + val3 = to_triton(np.full((1, ), 3, dtype=np.int32), device=device) + kernel[(1, )](cond1, cond2, cond3, val1, val2, val3, out) + targets = { + (True, True, True): val1[0], + (True, True, False): val1[0], + (True, False, True): out[0], + (True, False, False): out[0], + (False, True, True): val2[0], + (False, True, False): val3[0], + (False, False, True): val2[0], + (False, False, False): val3[0], + } + assert out[0] == targets[(_cond1, _cond2, _cond3)] + + +@pytest.mark.interpreter +def test_while(device): + + @triton.jit + def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): + init_i = tl.load(InitI) + curr_i = init_i + j = 0 + # Check that init_i is not updated by the loop + while j < tl.load(Bound): + curr_i = curr_i + (j == tl.load(CutOff)) + j += 1 + tl.store(OutInitI, init_i) + tl.store(OutI, curr_i) + tl.store(OutJ, j) + + out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) + cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) + assert out_init_i[0] == init_i[0] + assert out_i[0] == init_i[0] + 1 + assert out_j[0] == bound[0] + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +def test_nested_while(device): + + @triton.jit + def nested_while(data, countPtr): + for i in range(10): + count = tl.load(countPtr) + while count > 0: + tl.store(data, tl.load(data) + 1.0) + count = count - 2 + + counter = torch.tensor([8], dtype=torch.int32, device=device) + data = torch.zeros((1, ), device=device, dtype=torch.float32) + nested_while[(1, )](data, counter) + assert data[0] == 40 + + +# ----------------------- +# test extra +# ----------------------- + + +def test_num_threads(device): + if is_hip(): + pytest.skip("test_num_threads is not supported in HIP") + + @triton.jit + def kernel(Out): + num_threads: tl.constexpr = tl.extra.cuda.num_threads() + offs = tl.arange(0, num_threads) + tl.store(Out + offs, 1) + + num_threads = 256 + out = to_triton(np.zeros((num_threads, ), dtype=np.int32), device=device) + kernel[(1, )](out, num_warps=num_threads // 32) + assert torch.sum(out) == 256 + + +@pytest.mark.skip("Skip for kunlunxin") +def test_globaltimer(device): + if is_hip(): + pytest.skip("test_globaltimer is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out1, Out2): + start = tl.extra.cuda.globaltimer() + off = tl.arange(0, 128) + for i in range(10000): + tl.store(Out1 + off, tl.load(Out1 + off) + 1) + end = tl.extra.cuda.globaltimer() + tl.store(Out2, end - start) + + out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device) + out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + h = kernel[(1, )](out1, out2) + assert out2[0] > 0 + assert h.asm["ptx"].count("%globaltimer") == 2 + + +@pytest.mark.skip("Skip for kunlunxin") +def test_smid(device): + if is_hip(): + pytest.skip("test_smid is not supported in HIP") + check_cuda_or_hip(device) + + @triton.jit + def kernel(Out): + tl.store(Out + tl.program_id(0), tl.extra.cuda.smid()) + + out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) + h = kernel[(out.shape[0], )](out) + assert out.sort()[0].unique().shape[0] > 0 + assert h.asm["ptx"].count("%smid") == 1 + + +# ----------------------- +# test layout conversions +# ----------------------- +# TODO: backend should be tested separately + +layouts = [ + BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [2, THREADS_PER_WARP // 2], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([8, 1], [16, THREADS_PER_WARP // 16], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), +] + +intermediate_layouts = [ + None, + SharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +def compute_rep_shape(layout): + if type(layout) is BlockedLayout: + warp_shape = np.multiply(layout.sz_per_thread, layout.threads_per_warp) + rep_shape = np.multiply(warp_shape, layout.warps_per_cta) + return rep_shape + else: + assert False, "TODO: support compute_rep_shape for layout " + str(type(layout)) + + +# This function gives a lower bound approximation of scratch buffer shape for convert_layout operation +def compute_scratch_buffer_shape(src_layout, dst_layout, shape): + src_rep_shape = compute_rep_shape(src_layout) + dst_rep_shape = compute_rep_shape(dst_layout) + full_scratch_shape = np.maximum(src_rep_shape, dst_rep_shape) + return np.minimum(full_scratch_shape, shape) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("interm_layout", intermediate_layouts) +@pytest.mark.parametrize("dst_layout", layouts) +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): + if (M == 1 or N == 1) and interm_layout: + # TODO(jlebar): These OOB accesses don't even hit an assert in the + # compiler, and some of them return the wrong result instead of + # crashing! + pytest.skip("Out of bound access when maxPhase > 1") + if str(src_layout) == str(dst_layout): + pytest.skip() + if is_hip(): + try: + scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) + except AssertionError: + pytest.skip("Can't compute scratch buffer size") + lds_size = 65536 + # consider int32 dtype in scratch buffer size, + # because it is the largest dtype used in convert_layout in this test + int32_size = 4 + # skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding + if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size: + pytest.skip("Scratch buffer is too large") + + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ if interm_layout is None else f""" + #src = {src_layout} + #interm = {interm_layout} + #dst = {dst_layout} + """ + + conversion = f""" + %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ if interm_layout is None else f""" + %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm> + %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm> -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm> + %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm> -> tensor<{M}x{N}xf16, #src> + + %12 = triton_gpu.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + +mma_pairs = [ + [ + MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + [ + MmaLayout((2, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8]), + MmaLayout((2, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8]), + ], + # Mma -> mma support is TODO on Hopper (and Volta) + # [ + # MmaLayout((3, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 0), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 0), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 1), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 1), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], + # [ + # MmaLayout((3, 1), [2, 8], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # MmaLayout((3, 1), [8, 2], [1, 1], [1, 1], [0, 1], [16, 8, 16]), + # ], +] + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) +@pytest.mark.parametrize("dtype", ['float16']) +@pytest.mark.parametrize("mma_pair", mma_pairs) +def test_convertmma2mma(M, N, mma_pair, dtype, device): + if is_hip(): + pytest.skip("test_mma2mma is not supported in HIP") + + src_layout, _ = mma_pair + num_warps = np.cumprod(src_layout.warps_per_cta)[-1] + + def do_test(src_layout, dst_layout): + layouts = f""" + #src = {src_layout} + #dst = {dst_layout} + """ + + conversion = f""" + %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> + """ + + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = {num_warps} : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr, #src> + %3 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr, #dst> + tt.return + }} + }} + """ + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + kernel[(1, 1, 1)](x.data_ptr(), z.data_ptr()) + + assert torch.equal(z, x) + + do_test(mma_pair[0], mma_pair[1]) + do_test(mma_pair[1], mma_pair[0]) + + +@pytest.mark.interpreter +def test_load_scalar_with_mask(device): + + @triton.jit + def kernel(Input, Index, Out, N: int): + index = tl.load(Index) + scalar = tl.load(Input + index, mask=index < N, other=0) + tl.store(Out, scalar, mask=index < N) + + Index = torch.tensor([0], dtype=torch.int32, device=device) + Input = torch.tensor([0], dtype=torch.int32, device=device) + Out = torch.empty_like(Index, device=device) + kernel[(1, )](Input, Index, Out, Index.numel()) + assert Out.data[0] == 0 + + +# This test is used to test our own PTX codegen for float16 and int16 conversions +# maybe delete it later after ptxas has been fixed +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("dtype_str", ['float16', 'int16']) +def test_ptx_cast(dtype_str, device): + + @triton.jit + def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x0 = xindex + _tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype) + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r1 = rindex + tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype) + tmp1 = 2 + tmp2 = tmp0 * tmp1 + tmp3 = tmp2.to(dtype) + tmp5 = _tmp4 < tmp3 + _tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4) + tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask) + + torch.manual_seed(123) + if dtype_str == 'int16': + torch_dtype = torch.int16 + triton_dtype = tl.int32 + else: + torch_dtype = torch.float16 + triton_dtype = tl.float32 + + s0 = 4 + buf11 = -torch.ones((6 * s0, 197, 197), device=device, dtype=torch_dtype) + buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) + kernel[(4728, )](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) + assert buf14.to(torch.float32).mean() == -2.0 + + +# ----------------------- +# test fp8 -> fp32 dot +# ----------------------- + + +def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + low_precision_acc: tl.constexpr, # + num_pipeline_stages: tl.constexpr = 3 # +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15']) +@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) +def test_fp8_dot_acc(in_type_str, low_precision_acc, device): + if is_hip(): + pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') + if is_cuda(): + cc = torch.cuda.get_device_capability() + if cc[0] >= 9 and in_type_str == "float8e4b15": + pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") + check_type_supported(in_type_str, device) + M, N, K = 128, 256, 256 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128 + A = numpy_random((M, K), dtype_str=in_type_str) + B = numpy_random((K, N), dtype_str=in_type_str) + C = torch.empty((M, N), dtype=torch.float32, device=device) + num_warps = 8 + a = to_triton(A, device=device, dst_type=in_type_str) + b = to_triton(B, device=device, dst_type=in_type_str) + grid = (triton.cdiv(M, BLOCK_M), 1) + matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) + torch_a = torch.from_numpy(A).to(device=device) + th_a = f8_to_f16(torch_a, in_type_str) + torch_b = torch.from_numpy(B).to(device=device) + th_b = f8_to_f16(torch_b, in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': + torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) + elif low_precision_acc > 32: + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + else: + torch.testing.assert_close(ref_out, C) + + +# ----------------------- +# test enable_fp_fusion +# ----------------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("enable_fp_fusion", [False, True]) +def test_enable_fp_fusion(enable_fp_fusion, device): + if is_hip(): + pytest.skip( + 'test_enable_fp_fusion for HIP currently broken in https://github.com/triton-lang/triton. Use https://github.com/ROCmSoftwarePlatform/triton' + ) + + # Sequential multiply add can be fused by backend + @triton.jit + def mul_add(data): + ptrs = data + tl.arange(0, 128) + tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) + + data = torch.randn((128, ), device=device, dtype=torch.float32) + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) + + if not is_cuda(): + return + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion + + +# ----------------------- +# test propagate_nan +# ----------------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) +def test_propagate_nan(dtype, propagate_nan, func, device): + + @triton.jit + def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): + if func == 'clamp': + tl.store( + C, + getattr(tl, func)(tl.load(A), -tl.load(B), tl.load(B), + propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + else: + tl.store(C, + getattr(tl, func)(tl.load(A), tl.load(B), propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + + for mode in ['A', 'B', 'both']: + if func == 'clamp' and mode == 'B': + # clamp does not guarantee propagation from 'min' and 'max' args + continue + A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'A' or mode == 'both': A[0] = torch.nan + B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'B' or mode == 'both': B[0] = torch.nan + C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) + kernel[(1, )](A, B, C, propagate_nan, func) + + if mode == 'both' or propagate_nan == 'ALL': + assert torch.isnan(C[0]) + else: + assert not torch.isnan(C[0]) + + +# ----------------------- +# test clamp +# ----------------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp(dtype, device): + + @triton.jit + def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + min = tl.load(min_ptr + off, mask=mask) + max = tl.load(max_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, min, max), mask=mask) + ref_val = tl.minimum(tl.maximum(x, min), max) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + a = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + b = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + min = torch.min(a, b) + max = torch.max(a, b) + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, min, max, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# Test for symmetric clamp(x, -limit, limit), as it may go through optimized +# codegen in the backends +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +def test_clamp_symmetric(dtype, device): + + @triton.jit + def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): + + off = tl.arange(0, BLOCK_SIZE) + mask = off < N + x = tl.load(x_ptr + off, mask=mask) + limit = tl.load(limit_ptr + off, mask=mask) + out = out_ptr + off + ref = ref_ptr + off + + tl.store(out, tl.clamp(x, -limit, limit), mask=mask) + ref_val = tl.minimum(tl.maximum(x, -limit), limit) + tl.store(ref, ref_val, mask=mask) + + size = 128 + + x = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)) + limit = torch.randn((size, ), device=device, dtype=getattr(torch, dtype)).abs() + out = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + ref = torch.zeros_like(x, device=device, dtype=getattr(torch, dtype)) + + kernel[(size, )](x, limit, out, ref, x.numel(), BLOCK_SIZE=size) + + torch.testing.assert_close(out, ref) + + +# ----------------------- +# test iterators +# ----------------------- + + +@pytest.mark.interpreter +def test_static_range(device): + + @triton.jit + def loop_kernel(Z, N: tl.constexpr, step: tl.constexpr): + acc = 0 + for i in tl.static_range(0, N, step=step): + acc += i + tl.store(Z, acc) + + N = 100 + step = 7 + Out = torch.empty(1, dtype=torch.int32, device=device) + loop_kernel[(1, )](Out, N, step) + Acc = torch.tensor([0], dtype=torch.int32, device=device) + for i in range(0, N, step): + Acc += i + assert (Out == Acc).all(), (Out, Acc) + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +def test_tl_range(device): + if is_hip(): + pytest.skip("test_tl_range is not supported in HIP") + M, N, K = 64, 64, 512 + BLOCK_M, BLOCK_N, BLOCK_K = M, N, 64 + a = torch.randn((M, K), device=device, dtype=torch.float16) + b = torch.randn((K, N), device=device, dtype=torch.float16) + c = torch.empty((M, N), dtype=torch.float32, device=device) + pgm = matmul_kernel[ + 1, + ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, + BLOCK_K, 0, num_pipeline_stages=5) + ref_out = torch.matmul(a, b).to(torch.float32) + if is_interpreter(): + # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. + # Thus we use a higher tolerance + torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1) + else: + torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3) + if device in ['cuda']: + capability = torch.cuda.get_device_capability() + if capability[0] >= 8: + ptx = pgm.asm['ptx'] + # check that the loop got pipelined with the right number of stages. + assert 'cp.async.wait_group 0x6' in ptx + + +@triton.jit(noinline=True) +def maxnreg_noinline1(X): + tl.store(X, 0) + + +@triton.jit(noinline=True) +def maxnreg_noinline2(X): + tl.store(X, 0) + + +@pytest.mark.skip("Skip for kunlunxin") +def test_maxnreg(device): + assert not is_interpreter(), "this test won't work with the interpreter" + if is_hip(): + pytest.skip('maxnreg only works on CUDA') + + # triton kernel + @triton.jit + def kernel(X): + maxnreg_noinline1(X) + tl.store(X, 0) + maxnreg_noinline2(X) + + X = torch.empty(1, dtype=torch.int32, device=device) + k = kernel[(1, )](X, maxnreg=42) + + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise + + +@pytest.mark.interpreter +def test_temp_var_in_loop(device): + + @triton.jit + def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): + acc = tl.full((BLOCK, ), 0, dtype=tl.int32) + for i in range(N): + if i == 0: + temp = tl.full((BLOCK, ), 2, dtype=tl.int32) + acc = temp + else: + acc += tl.full((BLOCK, ), 1, dtype=tl.int32) + # re-use the temp variable and make sure to check that it isn't creating incorrect IR. + temp = tl.full((BLOCK, ), 1, dtype=tl.int32) + acc += temp + z = Z + tl.arange(0, BLOCK) + tl.store(z, acc) + + N = 10 + BLOCK = 32 + out = torch.empty((BLOCK, ), dtype=torch.int32, device=device) + temp_in_loop[(1, )](out, N, BLOCK) + acc = torch.full((BLOCK, ), 0, dtype=torch.int32, device=device) + for i in range(N): + if i == 0: + temp = torch.full((BLOCK, ), 2, dtype=torch.int32, device=device) + acc = temp + else: + acc += torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + temp = torch.full((BLOCK, ), 1, dtype=torch.int32, device=device) + acc += temp + assert (acc == out).all() diff --git a/third_party/xpu/python/test/unit/language/test_decorator.py b/third_party/xpu/python/test/unit/language/test_decorator.py new file mode 100644 index 000000000..66371ba60 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_decorator.py @@ -0,0 +1,48 @@ +import torch + +import triton +import triton.language as tl +import pytest + + +def test_decorator_with_def(device): + + def triton_heuristics_pointwise(**kwargs): + + def decorator(func): + return func + + return decorator + + # "def" might appear in a decorator call, e.g. a hash string argument. + # This test makes sure the compiler can find the right position of function + # definition. + @triton_heuristics_pointwise(inductor_meta={'backend_hash': 'def0aeffabe53b3f8'}, ) + @triton.jit + def kernel(): + pass + + try: + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + except Exception as e: + pytest.fail(f"triton compile failed with error: {e}") + + +def test_triton_heuristic(device): + N = 1023 + src = torch.empty(N, device=device) + dst = torch.zeros(N, device=device) + + @triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE': 32})], key=['N'], warmup=1, rep=1) + @triton.heuristics({'EVEN_N': lambda nargs: nargs['N'] % 2 == 0}) # test kwargs + @triton.heuristics({'EVEN_src': lambda nargs: nargs['src'].data_ptr() % 2 == 0}) # test args + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr, EVEN_N: tl.constexpr, EVEN_src: tl.constexpr): + tl.store(dst, EVEN_N) + tl.store(dst + 1, EVEN_src) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + assert dst[0].item() == 0.0 + assert dst[1].item() == 1.0 + assert _kernel.base_fn.__name__ == "_kernel" diff --git a/third_party/xpu/python/test/unit/language/test_line_info.py b/third_party/xpu/python/test/unit/language/test_line_info.py new file mode 100644 index 000000000..6421c7309 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_line_info.py @@ -0,0 +1,171 @@ +import subprocess +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def kernel_single(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.store(Y + tl.arange(0, BLOCK), x) + + +@triton.jit +def device_inline(x): + return x + x + + +@triton.jit +def kernel_call(X, + Y, + BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = device_inline(x) + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit(noinline=True) +def device_noinline(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = x + x + tl.store(Y + tl.arange(0, BLOCK), y) + + +@triton.jit +def kernel_call_noinline(X, Y, BLOCK: tl.constexpr): + device_noinline(X, Y, BLOCK) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK": 128}, num_warps=4), + ], + key=[], +) +@triton.jit +def kernel_autotune(X, Y, SIZE: tl.constexpr, BLOCK: tl.constexpr): + for i in range(0, SIZE, BLOCK): + x = tl.load(X + i + tl.arange(0, BLOCK)) + tl.store(Y + i + tl.arange(0, BLOCK), x) + + +# AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +# Since the + symbol will take effect in the dot op after combination, +# it seems making sense to annotate with the same line as dot. +@triton.jit +def kernel_dot_combine(x): + c = tl.full((32, 32), 4, dtype=tl.int8) + a = (tl.arange(0, 32)[:, None] + tl.arange(0, 32)[None, :]).to(tl.int8) + d = tl.dot(a, a) + d = d + c + tl.device_print("", d) + + +def get_disassembler_command_and_debug_line_format(): + """Gets backend specific disassembler information. + + Returns a tuple: (object file kind, disassembler tool command, + debug line anchor, debug line file and line number separator). + """ + backend = triton.runtime.driver.active.get_current_target().backend + + if backend == "cuda": + from triton.backends.nvidia.compiler import _path_to_binary + nvdisasm, _ = _path_to_binary("nvdisasm") + return ("cubin", [nvdisasm, "-g"], "## File", ",") + + if backend == "hip": + import shutil + # Try to find llvm-objdump from the current PATH to disassmble hsaco. + tool = shutil.which("llvm-objdump") + if tool is not None: + return ("hsaco", [tool, "-D", "-l", "--arch=amdgcn"], ";", ":") + raise RuntimeError("llvm-objdump not found in PATH") + + raise RuntimeError(f"unknown backend {backend}") + + +def extract_file_lines(command, anchor, separator, asm): + fd, path = tempfile.mkstemp() + with open(fd, 'wb') as cubin: + cubin.write(asm) + asm = subprocess.check_output(command + [path]).decode("utf-8") + file_lines = [] + lines = asm.splitlines() + for line in lines: + # We are looking for an anchor string and a separator between the file name and line number. + if anchor in line and separator in line: + entries = line[line.index(anchor):].split(separator) + if len(entries) == 2 and all(len(e) != 0 for e in entries): + file_lines.append((entries[0].strip(), entries[1].strip())) + return file_lines + + +def check_file_lines(file_lines, file_name, lineno, should_contain=True): + """ + Check if the file name and line number is in the file_lines + + Args: + file_lines: list of (file_name, line_number) + file_name: file name + lineno: line number, -1 means do not check line number + should_contain: whether the file name and line number should be in the file_lines + """ + for file, line in file_lines: + if lineno == -1: + if file_name in file: + return True + if file_name in file and str(lineno) in line: + return should_contain + return not should_contain + + +func_types = ["single", "call", "call_noinline", "autotune", "dot_combine"] + + +@pytest.mark.parametrize("func", func_types) +def test_line_info(func: str): + try: + obj_kind, command, anchor, separator = get_disassembler_command_and_debug_line_format() + except BaseException: + pytest.skip("disassembler is not available") + + shape = (128, ) + kernel_info = {} + if func == "single": + kernel_info = kernel_single.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "call": + kernel_info = kernel_call.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "call_noinline": + kernel_info = kernel_call_noinline.warmup(torch.float32, torch.float32, BLOCK=shape[0], grid=(1,)) + elif func == "autotune": + kernel_info = kernel_autotune.warmup(torch.float32, torch.float32, SIZE=shape[0], grid=(1,))[0] + elif func == "dot_combine": + kernel_info = kernel_dot_combine.warmup(20, grid=(1,)) + + file_lines = extract_file_lines(command, anchor, separator, kernel_info.asm[obj_kind]) + if func == "single": + assert (check_file_lines(file_lines, "test_line_info.py", 15)) + assert (check_file_lines(file_lines, "test_line_info.py", 16)) + elif func == "call": + assert (check_file_lines(file_lines, "test_line_info.py", 28)) + assert (check_file_lines(file_lines, "test_line_info.py", 21)) + assert (check_file_lines(file_lines, "test_line_info.py", 30)) + elif func == "call_noinline": + assert (check_file_lines(file_lines, "test_line_info.py", 42)) + assert (check_file_lines(file_lines, "test_line_info.py", 35)) + assert (check_file_lines(file_lines, "test_line_info.py", 36)) + assert (check_file_lines(file_lines, "test_line_info.py", 37)) + elif func == "autotune": + assert (check_file_lines(file_lines, "test_line_info.py", 53)) + assert (check_file_lines(file_lines, "test_line_info.py", 54)) + assert (check_file_lines(file_lines, "test_line_info.py", 55)) + elif func == "dot_combine": + assert (check_file_lines(file_lines, "test_line_info.py", 65)) + assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False)) diff --git a/third_party/xpu/python/test/unit/language/test_random.py b/third_party/xpu/python/test/unit/language/test_random.py new file mode 100644 index 000000000..de48a5f00 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_random.py @@ -0,0 +1,261 @@ +import numpy as np +import pytest +import scipy.stats +import torch + +import triton +import triton.language as tl + +##################################### +# Reference Philox Implementation +##################################### + + +class PhiloxConfig: + + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): + self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) + self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) + self.PHILOX_KEY_A = np.array(PHILOX_KEY_A, dtype=DTYPE) + self.PHILOX_KEY_B = np.array(PHILOX_KEY_B, dtype=DTYPE) + self.DTYPE = DTYPE + + +# This is better for GPU +PHILOX_32 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B9, + PHILOX_KEY_B=0xBB67AE85, + PHILOX_ROUND_A=0xD2511F53, + PHILOX_ROUND_B=0xCD9E8D57, + DTYPE=np.uint32, +) + +# This is what numpy implements +PHILOX_64 = PhiloxConfig( + PHILOX_KEY_A=0x9E3779B97F4A7C15, + PHILOX_KEY_B=0xBB67AE8584CAA73B, + PHILOX_ROUND_A=0xD2E7470EE14C6C93, + PHILOX_ROUND_B=0xCA5A826395121157, + DTYPE=np.uint64, +) + + +class CustomPhilox4x: + + def __init__(self, seed, config): + self._config = config + seed = self._into_pieces(seed) + self._key = np.array(seed[:2], dtype=self._dtype) + self._counter = np.array((0, 0) + seed[2:], dtype=self._dtype) + + @property + def _dtype(self): + return self._config.DTYPE + + def _into_pieces(self, n, pad=4): + res = [] + while len(res) < pad: + res.append(np.array(n, dtype=self._dtype)) + n >>= (np.dtype(self._dtype).itemsize * 8) + assert n == 0 + return tuple(res) + + def _multiply_low_high(self, a, b): + low = a * b + high = int(a) * int(b) + high = np.array(high >> (np.dtype(self._dtype).itemsize * 8), dtype=self._dtype) + return low, high + + def _single_round(self, counter, key): + lo0, hi0 = self._multiply_low_high(self._config.PHILOX_ROUND_A, counter[0]) + lo1, hi1 = self._multiply_low_high(self._config.PHILOX_ROUND_B, counter[2]) + ret0 = hi1 ^ counter[1] ^ key[0] + ret1 = lo1 + ret2 = hi0 ^ counter[3] ^ key[1] + ret3 = lo0 + return np.array([ret0, ret1, ret2, ret3], dtype=self._dtype) + + def _raise_key(self, key): + pk = [self._config.PHILOX_KEY_A, self._config.PHILOX_KEY_B] + return key + np.array(pk, dtype=self._dtype) + + def random_raw(self): + counter = self._counter + key = self._key + for _ in range(10): + counter = self._single_round(counter, key) + key = self._raise_key(key) + self.advance(1) + return counter + + def advance(self, n_steps): + self._counter[0] += n_steps + assert self._counter[0] < 2**32, "FIXME: doesn't work for large offsets" + + +class CustomPhilox(CustomPhilox4x): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.buffer = [] + + def random_raw(self): + if len(self.buffer) == 0: + self.buffer = list(super().random_raw())[::-1] + return int(self.buffer.pop()) + + +##################################### +# Unit Tests +##################################### + +BLOCK: tl.constexpr = 1024 + +# test generation of random uint32 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in ['10', '4,53', '400'] + for seed in [0, 42, 124, 54, 0xffffffff, 0x0000000fcafeb0ba] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randint(size, seed, device, dtype, const_seed): + if dtype == 'int64': + pytest.skip("Skip for kunlunxin") + size = list(map(int, size.split(','))) + torch_dtype = getattr(torch, dtype) + numpy_dtype = getattr(np, f"u{dtype}") + config = {'int32': PHILOX_32, 'int64': PHILOX_64}[dtype] + + @triton.jit + def kernel(X, N, seed): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr): + pid = tl.program_id(0).to(X.dtype.element_ty) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randint(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch_dtype, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed) + else: + kernel[grid](x, N, seed) + out_tri = x.cpu().numpy().astype(numpy_dtype).flatten().tolist() + # reference result + gen = CustomPhilox4x(seed, config=config) + out_ref = [gen.random_raw()[0] for _ in out_tri] + assert out_tri == out_ref + + +# test uniform PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_rand(size, seed, dtype, device, const_seed): + if dtype == 'int64': + pytest.skip("Skip for kunlunxin") + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.rand(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert all((x >= 0) & (x <= 1)) + assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + + +# test normal PRNG + + +@pytest.mark.interpreter +@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed) + for size in [100000] + for seed in [0, 42, 124, 54] + for dtype in ['int32', 'int64'] + for const_seed in [True, False]]) +def test_randn(size, seed, dtype, device, const_seed): + if dtype == 'int64': + pytest.skip("Skip for kunlunxin") + + @triton.jit + def kernel(X, N, seed, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + @triton.jit + def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr): + pid = tl.program_id(0).to(dtype) + offset = pid * BLOCK + tl.arange(0, BLOCK) + rand = tl.randn(seed, offset) + tl.store(X + offset, rand, mask=offset < N) + + # triton result + x = torch.empty(size, dtype=torch.float32, device=device) + N = x.numel() + grid = (triton.cdiv(N, BLOCK), ) + if const_seed: + const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype)) + else: + kernel[grid](x, N, seed, dtype=getattr(tl, dtype)) + assert abs(x.mean()) < 1e-2 + assert abs(x.std() - 1) < 1e-2 + + +# tl.rand() should never produce >=1.0 + + +@pytest.mark.interpreter +@pytest.mark.parametrize('dtype', ['int32', 'int64']) +def test_rand_limits(dtype, device): + + @triton.jit + def kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = tl.random.uint_to_uniform_float(x) + tl.store(output + idx, y) + + torch_dtype = getattr(torch, dtype) + min_max_int = torch.tensor([ + torch.iinfo(torch_dtype).min, + torch.iinfo(torch_dtype).max, + ], dtype=torch_dtype, device=device) + output = torch.empty(2, dtype=torch.float32, device=device) + kernel[(1, )](min_max_int, output, 2) + + assert output[0] == output[1] + assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/third_party/xpu/python/test/unit/language/test_reproducer.py b/third_party/xpu/python/test/unit/language/test_reproducer.py new file mode 100644 index 000000000..a045e8f30 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_reproducer.py @@ -0,0 +1,42 @@ +import os +import shutil + +import pytest + +import torch +import triton +import re + + +@triton.jit +def triton_(): + return + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") +def test_reproducer(): + tmpdir = ".tmp" + reproducer = 'triton-reproducer.mlir' + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) + os.environ["TRITON_CACHE_DIR"] = tmpdir + os.environ["TRITON_REPRODUCER_PATH"] = reproducer + triton_[(1, )]() + foundPipeline = "" + with open(reproducer, 'r') as f: + line = f.read() + if 'pipeline:' in line: + foundPipeline = line + if 0 == len(foundPipeline): + raise Exception("Failed to find pipeline info in reproducer file.") + + ttgir_to_llvm_pass = re.compile("convert-triton-{{.*}}gpu-to-llvm") + if ttgir_to_llvm_pass.search(foundPipeline): + raise Exception("Failed to find triton passes in pipeline") + # cleanup + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + if os.path.exists(reproducer): + os.remove(reproducer) diff --git a/third_party/xpu/python/test/unit/language/test_standard.py b/third_party/xpu/python/test/unit/language/test_standard.py new file mode 100644 index 000000000..9bae8e883 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_standard.py @@ -0,0 +1,77 @@ +import triton +import pytest +import torch +import triton.language as tl + +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random + +# --------------- +# test maximum/minimum ops +# --------------- + + +# TODO: Tests with unsigned integers failed at compilation stage. +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", int_dtypes + uint_dtypes + float_dtypes + ["bfloat16"]) +@pytest.mark.parametrize("op", ["maximum", "minimum"]) +def test_maximum_minium(dtype, op, device): + expr = f'tl.{op}(x, y)' + numpy_expr = f'np.{op}(x, y)' + _test_binary(dtype, dtype, expr, numpy_expr, device=device) + + +# --------------- +# test sort op +# --------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_sort(M, N, descending, dtype_str, device): + + @triton.jit + def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.sort(x, descending=descending) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.sort(x, descending=descending)[0] + z = torch.empty_like(x) + sort_kernel[(1, )](x, z, N, M, descending, num_warps=8) + assert (y == z).all(), (y, z) + + +# --------------- +# test flip op +# --------------- + + +@pytest.mark.skip("Skip for kunlunxin") +@pytest.mark.interpreter +@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32']) +def test_flip(M, N, dtype_str, device): + + @triton.jit + def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr): + offx = tl.arange(0, M) + offy = tl.arange(0, N) * M + off2d = offx[None, :] + offy[:, None] + x = tl.load(X + off2d) + x = tl.flip(x) + tl.store(Z + off2d, x) + + x = numpy_random((N, M), dtype_str=dtype_str) + x = torch.from_numpy(x).to(device) + y = torch.flip(x, (1, )) + z = torch.empty_like(x, device=device) + flip_kernel[(1, )](x, z, N, M, num_warps=8) + assert (y == z).all(), (y, z) diff --git a/third_party/xpu/python/test/unit/language/test_subprocess.py b/third_party/xpu/python/test/unit/language/test_subprocess.py new file mode 100644 index 000000000..71a0a5175 --- /dev/null +++ b/third_party/xpu/python/test/unit/language/test_subprocess.py @@ -0,0 +1,161 @@ +import itertools +import os +import subprocess +import sys +from collections import Counter + +import pytest + +pytest.skip("Skip for kunlunxin", allow_module_level=True) + +dir_path = os.path.dirname(os.path.realpath(__file__)) +print_path = os.path.join(dir_path, "print_helper.py") +assert_path = os.path.join(dir_path, "assert_helper.py") + +# TODO: bfloat16 after LLVM-15 +assert_types = ["device_assert", "device_assert_passes", "assert", "static_assert", "no_debug", "double_assert"] +nested_types = [(caller, callee) for caller in ["true", "false", "none"] for callee in ["true", "false", "none"]] +torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +# TODO: Print with multiple operands + + +@pytest.mark.interpreter +@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + ("no_arg_print", "int32"), + ("print_no_arg", "int32"), + ("device_print_large", "int32"), + ("print_multiple_args", "int32"), + ("device_print_multiple_args", "int32"), + ("device_print_hex", "int16"), + ("device_print_hex", "int32"), + ("device_print_hex", "int64"), + ("device_print_pointer", "int32"), +]) +def test_print(func_type: str, data_type: str): + proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) + outs, err = proc.communicate() + assert proc.returncode == 0 + + if is_interpreter() and func_type != "static_assert": + # Interpreter uses a different format for device_print + # Only check if there's no error + assert err == b'' + return + + outs = [line for line in outs.decode("UTF-8").split("\n") if line] + # The total number of elements in the 1-D tensor to print. + N = 128 + + # Format is + # pid (, , ) idx (, , ...) (operand ) + expected_lines = Counter() + if func_type == "print" or func_type == "device_print": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: {i}" + if data_type.startswith("float"): + line += ".000000" + expected_lines[line] = 1 + elif func_type == "device_print_hex": + for i in range(N): + line = f"pid (0, 0, 0) idx ({i:3}) x: 0x" + if data_type == "int16": + line += f"{i:04x}" + if data_type == "int32": + line += f"{i:08x}" + if data_type == "int64": + line += f"{i:016x}" + expected_lines[line] = 1 + elif func_type == "static_print": + expected_lines[f" int32[constexpr[{N}]]"] = 1 + elif func_type == "no_arg_print": + expected_lines["pid (0, 0, 0) idx (): 0"] = N + elif func_type == "print_no_arg": + expected_lines["pid (0, 0, 0) no arg"] = N + elif func_type == "device_print_large": + for i, j, k in itertools.product(range(2), range(64), range(N)): + expected_lines[f"pid (0, {i}, 0) idx ({j:2}, {k:3}) x: 1"] = 1 + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 0) {i}"] = 1 + expected_lines[f"pid (0, 0, 0) idx ({i:3}): (operand 1) 1"] = 1 + elif func_type == "device_print_pointer": + for i in range(N): + expected_lines[f"pid (0, 0, 0) idx ({i:3}) ptr: 0x"] = 1 + + actual_lines = Counter() + for line in outs: + # Trim the exact pointer address in the output--they can change per run. + line = (line.split(':')[0] + ": 0x") if func_type == "device_print_pointer" else line + actual_lines[line] += 1 + + diff = Counter(actual_lines) + diff.subtract(expected_lines) + for line, delta in diff.items(): + if delta == 0: + continue + print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') + assert all(delta == 0 for delta in diff.values()) + + +@pytest.mark.parametrize("func_type", assert_types) +def test_assert(func_type: str): + # The total number of elements in the 1-D tensor to assert on. + N = 128 + + os.environ["TRITON_DEBUG"] = "1" + proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + shell=False) + _, errs = proc.communicate() + errs = errs.splitlines() + num_errs = 0 + for err in errs: + if "x != 0" in err.decode("utf-8", errors="ignore"): + num_errs += 1 + + # Check for segfaults. + assert all("segmentation fault" not in line.decode("utf-8", errors="ignore").lower() for line in errs) + + os.environ["TRITON_DEBUG"] = "0" + if func_type == "static_assert" or func_type == "device_assert_passes": + assert num_errs == 0 + else: + assert num_errs == N - 1 + + +@pytest.mark.parametrize("caller_type, callee_type", nested_types) +def test_assert_nested(caller_type, callee_type): + # The total number of elements in the 1-D tensor to assert on. + N = 128 + + proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) + _, errs = proc.communicate() + errs = errs.splitlines() + num_errs = 0 + for err in errs: + if "x != 0" in err.decode("utf-8", errors="ignore"): + num_errs += 1 + if caller_type == "none": + if callee_type == "true": + assert num_errs == N - 1 + else: + assert num_errs == 0 + elif caller_type == "true": + if callee_type == "false": + assert num_errs == 0 + else: + assert num_errs == N - 1 + elif caller_type == "false": + if callee_type == "true": + assert num_errs == N - 1 + else: + assert num_errs == 0 diff --git a/third_party/xpu/python/test/unit/operators/conftest.py b/third_party/xpu/python/test/unit/operators/conftest.py new file mode 100644 index 000000000..ab9ff1130 --- /dev/null +++ b/third_party/xpu/python/test/unit/operators/conftest.py @@ -0,0 +1,11 @@ +# content of conftest.py + + +def pytest_configure(config): + config.addinivalue_line("markers", "interpreter: indicate whether interpreter supports the test") + + +def pytest_sessionfinish(session, exitstatus): + # If all tests are skipped (exit code 5), modify the exit code to 0 + if exitstatus == 5: + session.exitstatus = 0 diff --git a/third_party/xpu/python/test/unit/operators/test_blocksparse.py b/third_party/xpu/python/test/unit/operators/test_blocksparse.py new file mode 100644 index 000000000..35cc33cfa --- /dev/null +++ b/third_party/xpu/python/test/unit/operators/test_blocksparse.py @@ -0,0 +1,239 @@ +import pytest +import torch + +import triton +import triton.ops + +pytest.skip("Skip for kunlunxin", allow_module_level=True) + + +def is_hip_mi200(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == 'hip' and target.arch == 'gfx90a' + + +def sparsify_tensor(x, mask, block): + ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) + for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): + ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] + return ret + + +def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32): + if data is None: + data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device) + ref_ret = data + ref_ret = ref_ret * alpha + beta + ref_ret = ref_ret.half().to(dtype) + if trans: + ref_ret = ref_ret.t().requires_grad_() + ref_ret = ref_ret.detach().requires_grad_() + tri_ret = ref_ret.clone().detach().requires_grad_() + return ref_ret, tri_ret + + +def mask_tensor(x, mask, block, value=0): + ret = x.clone() + for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): + ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value + return ret + + +@pytest.mark.parametrize("MODE", ["sdd", "dds", "dsd"]) +@pytest.mark.parametrize("TRANS_A", [False, True]) +@pytest.mark.parametrize("TRANS_B", [False, True]) +@pytest.mark.parametrize("BLOCK", [16, 32, 64]) +@pytest.mark.parametrize("DTYPE", [torch.float16]) +def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, device, Z=3, H=2, M=512, N=384, K=256): + seed = 0 + torch.manual_seed(seed) + is_sdd = MODE == "sdd" + is_dsd = MODE == "dsd" + is_dds = MODE == "dds" + do_sparsify = lambda x: sparsify_tensor(x, layout, BLOCK) + do_mask = lambda x: mask_tensor(x, layout, BLOCK) + # create inputs + # create op + a_shape = (Z, H, K, M) if TRANS_A else (Z, H, M, K) + b_shape = (Z, H, N, K) if TRANS_B else (Z, H, K, N) + c_shape = (Z, H, M, N) + shape = { + "sdd": (M, N), + "dsd": (a_shape[2], a_shape[3]), + "dds": (b_shape[2], b_shape[3]), + }[MODE] + layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # create data + a_ref, a_tri = make_pair(a_shape, alpha=.1, dtype=DTYPE) + b_ref, b_tri = make_pair(b_shape, alpha=.1, dtype=DTYPE) + dc_ref, dc_tri = make_pair(c_shape, dtype=DTYPE) + # compute [torch] + dc_ref = do_mask(dc_ref) if is_sdd else dc_ref + a_ref = do_mask(a_ref) if is_dsd else a_ref + b_ref = do_mask(b_ref) if is_dds else b_ref + a_ref.retain_grad() + b_ref.retain_grad() + c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, b_ref.transpose(2, 3) if TRANS_B else b_ref) + c_ref.backward(dc_ref) + c_ref = do_sparsify(c_ref) if is_sdd else c_ref + da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad + db_ref = do_sparsify(b_ref.grad) if is_dds else b_ref.grad + # triton result + dc_tri = do_sparsify(dc_tri) if is_sdd else dc_tri + a_tri = do_sparsify(a_tri) if is_dsd else a_tri + b_tri = do_sparsify(b_tri) if is_dds else b_tri + a_tri.retain_grad() + b_tri.retain_grad() + op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device=device) + c_tri = op(a_tri, b_tri) + c_tri.backward(dc_tri) + da_tri = a_tri.grad + db_tri = b_tri.grad + + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and + # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + tol = {'atol': 1e-3, 'rtol': 0} if is_hip_mi200() else {} + + # compare + torch.testing.assert_close(c_ref, c_tri, **tol) + torch.testing.assert_close(da_ref, da_tri, **tol) + torch.testing.assert_close(db_ref, db_tri, **tol) + + +configs = [ + (16, 256), + (32, 576), + (64, 1871), + (128, 2511), +] + + +@pytest.mark.parametrize("is_dense", [False, True]) +@pytest.mark.parametrize("BLOCK, WIDTH", configs) +def test_softmax(BLOCK, WIDTH, is_dense, device, Z=2, H=2, is_causal=True, scale=0.4): + # set seed + torch.random.manual_seed(0) + Z, H, M, N = 2, 3, WIDTH, WIDTH + # initialize layout + # make sure each row has at least one non-zero element + layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) + if is_dense: + layout[:] = 1 + else: + layout[1, 2, :] = 0 + layout[1, :, 1] = 0 + # initialize data + a_shape = (Z, H, M, N) + a_ref, a_tri = make_pair(a_shape) + dout_ref, dout_tri = make_pair(a_shape) + # compute [torch] + a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf")) + a_ref.retain_grad() + at_mask = torch.ones((M, N), device=device) + if is_causal: + at_mask = torch.tril(at_mask) + M = at_mask[None, None, :, :] + torch.zeros_like(a_ref) + a_ref[M == 0] = float("-inf") + out_ref = torch.softmax(a_ref * scale, -1) + out_ref.backward(dout_ref) + out_ref = sparsify_tensor(out_ref, layout, BLOCK) + da_ref = sparsify_tensor(a_ref.grad, layout, BLOCK) + # compute [triton] + a_tri = sparsify_tensor(a_tri, layout, BLOCK) + a_tri.retain_grad() + dout_tri = sparsify_tensor(dout_tri, layout, BLOCK) + op = triton.ops.blocksparse.softmax(layout, BLOCK, device=device, is_dense=is_dense) + out_tri = op(a_tri, scale=scale, is_causal=is_causal) + out_tri.backward(dout_tri) + da_tri = a_tri.grad + # compare + torch.testing.assert_close(out_tri, out_ref, equal_nan=True) + torch.testing.assert_close(da_tri, da_ref, equal_nan=True) + + +@pytest.mark.parametrize("block", [16, 32, 64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_attention_fwd_bwd( + block, + dtype, + device, + input_scale=1.0, + scale=1 / 8.0, + n_ctx=256, + batch_size=2, + n_heads=2, +): + capability = torch.cuda.get_device_capability() + if capability[0] < 7: + pytest.skip("Only test tl.dot() on devices with sm >= 70") + + # inputs + qkv_shape = (batch_size, n_heads, n_ctx, 64) + qkvs = [ + torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3) + ] + + # Triton: + n_blocks = n_ctx // block + layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) + query, key, value = [x.clone() for x in qkvs] + query.retain_grad() + key.retain_grad() + value.retain_grad() + attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) + # ad hoc loss + loss = (attn_out**2).mean() + loss.backward() + grads = [query.grad, key.grad, value.grad] + + # Torch version: + torch_q, torch_k, torch_v = [x.clone() for x in qkvs] + attn_mask = torch.ones([n_ctx, n_ctx], device=device, dtype=dtype) + attn_mask = torch.tril(attn_mask, diagonal=0) + attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) + torch_q.retain_grad() + torch_k.retain_grad() + torch_v.retain_grad() + scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) + scores = scores + attn_mask + probs = torch.softmax(scores, dim=-1) + torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) + # ad hoc loss + torch_loss = (torch_attn_out**2).mean() + torch_loss.backward() + torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] + + # comparison + # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") + torch.testing.assert_close(loss, torch_loss, atol=1e-3, rtol=0) + + # Bigger tolerance for AMD MI200 devices. + # MI200 devices use reduced precision fp16 and bf16 and flush input and + # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + tol = {'atol': 1e-3, 'rtol': 0} if is_hip_mi200() else {} + for g1, g2 in zip(grads, torch_grads): + torch.testing.assert_close(g1, g2, **tol) + + +@pytest.mark.parametrize("block", [16, 32, 64]) +def triton_attention( + layout, + block: int, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +): + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, + device=value.device) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, + device=value.device) + sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) + + w = sparse_dot_sdd_nt(query, key) + w = sparse_softmax(w, scale=scale, is_causal=True) + a = sparse_dot_dsd_nn(w, value) + return a diff --git a/third_party/xpu/python/test/unit/operators/test_cross_entropy.py b/third_party/xpu/python/test/unit/operators/test_cross_entropy.py new file mode 100644 index 000000000..d40977d4d --- /dev/null +++ b/third_party/xpu/python/test/unit/operators/test_cross_entropy.py @@ -0,0 +1,43 @@ +import pytest +import torch + +import triton +import triton.ops + +pytest.skip("Skip for kunlunxin", allow_module_level=True) + + +@pytest.mark.parametrize("M, N, dtype, mode", [ # + (M, N, dtype, mode) + for M in [1024, 821] + for N in [512, 857, 1871, 2089, 8573, 31000] + for dtype in ['float16', 'float32'] + for mode in ['forward', 'backward'] +]) +def test_op(M, N, dtype, mode, device): + capability = torch.cuda.get_device_capability() + if capability[0] < 8 and dtype == "bfloat16": + pytest.skip("Only test bfloat16 on devices with sm >= 80") + dtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32}[dtype] + # create inputs + x = torch.randn(M, N, dtype=dtype, device=device, requires_grad=True) + idx = 4 + torch.ones(M, dtype=torch.int64, device=device) + # forward pass + tt_y = triton.ops.cross_entropy(x, idx) + th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) + if mode == 'forward': + torch.testing.assert_close(th_y, tt_y) + # backward pass + elif mode == 'backward': + dy = torch.randn_like(tt_y) + # triton backward + tt_y.backward(dy) + tt_dx = x.grad.clone() + # torch backward + x.grad = None + th_y.backward(dy) + th_dx = x.grad.clone() + if dtype == torch.float16: + torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001) + else: + torch.testing.assert_close(th_dx, tt_dx) diff --git a/python/test/unit/operators/test_flash_attention.py b/third_party/xpu/python/test/unit/operators/test_flash_attention.py similarity index 98% rename from python/test/unit/operators/test_flash_attention.py rename to third_party/xpu/python/test/unit/operators/test_flash_attention.py index f5cff538e..d440d6fd1 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/third_party/xpu/python/test/unit/operators/test_flash_attention.py @@ -5,6 +5,8 @@ import triton import triton.ops +pytest.skip("Skip for kunlunxin", allow_module_level=True) + @pytest.mark.interpreter @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ # diff --git a/third_party/xpu/python/test/unit/operators/test_inductor.py b/third_party/xpu/python/test/unit/operators/test_inductor.py new file mode 100644 index 000000000..8d7df12a5 --- /dev/null +++ b/third_party/xpu/python/test/unit/operators/test_inductor.py @@ -0,0 +1,200 @@ +import pytest +import torch + +import triton +import triton.language as tl + +pytest.skip("Skip for kunlunxin", allow_module_level=True) + + +def test_normalization_with_remat(device): + + @triton.jit + def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr): + xnumel = 512 + rnumel = 4096 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:, None] + xmask = xindex < xnumel + rbase = tl.arange(0, RBLOCK)[None, :] + x3 = xindex + x0 = xindex % 64 + tmp1 = tl.load(in_ptr0 + (x0), xmask) + tmp3 = tl.load(in_ptr1 + (x0), xmask) + tmp11 = tl.load(in_ptr2 + (x0), xmask) + tmp13 = tl.load(in_ptr3 + (x0), xmask) + _tmp17 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + r2 = rindex + tmp0 = tl.load(in_out_ptr0 + (r2 + (4096 * x3)), rmask & xmask, eviction_policy='evict_last', other=0) + tmp2 = tmp0 - tmp1 + tmp4 = 1e-05 + tmp5 = tmp3 + tmp4 + tmp6 = tl.sqrt(tmp5) + tmp7 = 1 / tmp6 + tmp8 = 1.0 + tmp9 = tmp7 * tmp8 + tmp10 = tmp2 * tmp9 + tmp12 = tmp10 * tmp11 + tmp14 = tmp12 + tmp13 + _tmp17 = tl.where(rmask & xmask, _tmp17 + tmp14, _tmp17) + tl.store(in_out_ptr0 + (r2 + (4096 * x3) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp14, rmask & xmask) + tmp17 = tl.sum(_tmp17, 1)[:, None] + tmp18 = 4096.0 + tmp19 = tmp17 / tmp18 + tl.store(in_out_ptr1 + (x3 + tl.zeros([XBLOCK, 1], tl.int32)), tmp19, xmask) + + torch.manual_seed(123) + + buf14 = torch.rand(8, 64, 64, 64, device=device) + buf16 = torch.rand(8, 1, 64, device=device) + arg114_1 = torch.rand(64, device=device) + arg115_1 = torch.rand(64, device=device) + arg8_1 = torch.rand(64, device=device) + arg9_1 = torch.rand(64, device=device) + triton_[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) + torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) + + +def test_avg_pool_bw(device): + + @triton.jit + def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + x1 = (xindex // 8) % 8 + x0 = xindex % 8 + x2 = (xindex // 64) + x5 = xindex + tmp0 = (-1) + x1 + tmp1 = (-1) + x0 + tmp2 = 2 + x1 + tmp3 = 2 + x0 + tmp4 = 0 + tmp5 = tl.where(tmp0 != tmp0, tmp0, tl.where(tmp0 > tmp4, tmp0, tmp4)) + tmp6 = tl.where(tmp1 != tmp1, tmp1, tl.where(tmp1 > tmp4, tmp1, tmp4)) + tmp7 = 8 + tmp8 = tl.where(tmp2 != tmp2, tmp2, tl.where(tmp2 < tmp7, tmp2, tmp7)) + tmp9 = tl.where(tmp3 != tmp3, tmp3, tl.where(tmp3 < tmp7, tmp3, tmp7)) + tmp10 = tmp5 + tmp4 + tmp11 = tmp6 + tmp4 + tmp12 = 1 + tmp13 = tmp8 - tmp12 + tmp14 = tl.where(tmp10 != tmp10, tmp10, tl.where(tmp10 < tmp13, tmp10, tmp13)) + tmp15 = tmp9 - tmp12 + tmp16 = tl.where(tmp11 != tmp11, tmp11, tl.where(tmp11 < tmp15, tmp11, tmp15)) + tmp17 = tl.load(in_ptr0 + (tmp16 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp18 = tmp17 / 9 + tmp19 = tmp10 < tmp8 + tmp20 = tmp11 < tmp9 + tmp21 = tmp19 & tmp20 + tmp22 = 0.0 + tmp23 = tl.where(tmp21, tmp18, tmp22) + tmp24 = tmp6 + tmp12 + tmp25 = tl.where(tmp24 != tmp24, tmp24, tl.where(tmp24 < tmp15, tmp24, tmp15)) + tmp26 = tl.load(in_ptr0 + (tmp25 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp27 = tmp26 / 9 + tmp28 = tmp24 < tmp9 + tmp29 = tmp19 & tmp28 + tmp30 = tmp23 + tmp27 + tmp31 = tl.where(tmp29, tmp30, tmp23) + tmp32 = 2 + tmp33 = tmp6 + tmp32 + tmp34 = tl.where(tmp33 != tmp33, tmp33, tl.where(tmp33 < tmp15, tmp33, tmp15)) + tmp35 = tl.load(in_ptr0 + (tmp34 + (8 * tmp14) + (64 * x2)), None).to(tl.float32) + tmp36 = tmp35 / 9 + tmp37 = tmp33 < tmp9 + tmp38 = tmp19 & tmp37 + tmp39 = tmp31 + tmp36 + tmp40 = tl.where(tmp38, tmp39, tmp31) + tmp41 = tmp5 + tmp12 + tmp42 = tl.where(tmp41 != tmp41, tmp41, tl.where(tmp41 < tmp13, tmp41, tmp13)) + tmp43 = tl.load(in_ptr0 + (tmp16 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp44 = tmp43 / 9 + tmp45 = tmp41 < tmp8 + tmp46 = tmp45 & tmp20 + tmp47 = tmp40 + tmp44 + tmp48 = tl.where(tmp46, tmp47, tmp40) + tmp49 = tl.load(in_ptr0 + (tmp25 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp50 = tmp49 / 9 + tmp51 = tmp45 & tmp28 + tmp52 = tmp48 + tmp50 + tmp53 = tl.where(tmp51, tmp52, tmp48) + tmp54 = tl.load(in_ptr0 + (tmp34 + (8 * tmp42) + (64 * x2)), None).to(tl.float32) + tmp55 = tmp54 / 9 + tmp56 = tmp45 & tmp37 + tmp57 = tmp53 + tmp55 + tmp58 = tl.where(tmp56, tmp57, tmp53) + tmp59 = tmp5 + tmp32 + tmp60 = tl.where(tmp59 != tmp59, tmp59, tl.where(tmp59 < tmp13, tmp59, tmp13)) + tmp61 = tl.load(in_ptr0 + (tmp16 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp62 = tmp61 / 9 + tmp63 = tmp59 < tmp8 + tmp64 = tmp63 & tmp20 + tmp65 = tmp58 + tmp62 + tmp66 = tl.where(tmp64, tmp65, tmp58) + tmp67 = tl.load(in_ptr0 + (tmp25 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp68 = tmp67 / 9 + tmp69 = tmp63 & tmp28 + tmp70 = tmp66 + tmp68 + tmp71 = tl.where(tmp69, tmp70, tmp66) + tmp72 = tl.load(in_ptr0 + (tmp34 + (8 * tmp60) + (64 * x2)), None).to(tl.float32) + tmp73 = tmp72 / 9 + tmp74 = tmp63 & tmp37 + tmp75 = tmp71 + tmp73 + tmp76 = tl.where(tmp74, tmp75, tmp71) + tl.store(out_ptr0 + (x5 + tl.zeros([XBLOCK], tl.int32)), tmp76, None) + + inp = torch.ones(8, 2048, 8, 8, device=device, dtype=torch.half) + out = torch.ones_like(inp) * 3 + numel = inp.numel() + triton_[(numel // 1024, )](inp, out, 1024) + out_ref = torch.ones_like(inp) + out_ref[:, :, 1:7, 0::7] = 2 / 3 + out_ref[:, :, 0::7, 1:7] = 2 / 3 + out_ref[:, :, 0::7, 0::7] = 4 / 9 + torch.testing.assert_close(out, out_ref) + + +@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128]) +@pytest.mark.parametrize("num_warps", [1, 4]) +def test_scan2d_broadcast(RBLOCK, num_warps, device): + + @triton.jit(debug=True) + def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + rindex = tl.arange(0, RBLOCK)[None, :] + xindex = tl.arange(0, XBLOCK)[:, None] + data = tl.load(in_ptr + rindex) + scan = tl.cumsum(data, 1) + expected_max = tl.sum(data, 1) + tl.device_assert(scan <= expected_max) + tl.store(out_ptr + xindex * RBLOCK + rindex, scan) + + XBLOCK = 4 + input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device=device) + output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device=device) + fn[(1, )](input, output, XBLOCK, RBLOCK, num_warps=num_warps) + ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)) + torch.testing.assert_close(output, ref) + + +def test_scan2d_for(device): + + @triton.jit + def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): + rbase = tl.arange(0, RBLOCK)[None, :] + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + tmp3 = tl.where(rmask, 1, 0) + tmp6 = tl.cumsum(tmp3, 1) + tl.store(out_ptr0 + rindex, tmp6, rmask) + + RBLOCK = 8 + out0 = torch.empty(RBLOCK, device=device, dtype=torch.int64) + fn[(1, )](out0, RBLOCK, RBLOCK) + ref = torch.arange(RBLOCK, device=device, dtype=torch.int64) + 1 + torch.testing.assert_close(out0, ref) diff --git a/python/test/unit/operators/test_matmul.py b/third_party/xpu/python/test/unit/operators/test_matmul.py similarity index 99% rename from python/test/unit/operators/test_matmul.py rename to third_party/xpu/python/test/unit/operators/test_matmul.py index 2feb0727b..9918a7448 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/third_party/xpu/python/test/unit/operators/test_matmul.py @@ -7,6 +7,8 @@ import triton.language as tl import triton.ops +pytest.skip("Skip for kunlunxin", allow_module_level=True) + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" diff --git a/third_party/xpu/python/test/unit/runtime/test_autotuner.py b/third_party/xpu/python/test/unit/runtime/test_autotuner.py new file mode 100644 index 000000000..65e7535c8 --- /dev/null +++ b/third_party/xpu/python/test/unit/runtime/test_autotuner.py @@ -0,0 +1,134 @@ +import torch + +import triton +import triton.language as tl +import pytest + +pytest.skip("Skip for kunlunxin", allow_module_level=True) + + +@pytest.mark.parametrize('use_cuda_graph', [False, True]) +def test_kwargs(use_cuda_graph: bool): + N = 1024 + src = torch.empty(N, device='cuda') + dst = torch.empty(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N) + _kernel[grid](dst=dst, src=src, N=N) + + +def test_restore(): + N = 1024 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + @triton.autotune(configs=configs, key=['N'], restore_value=['src'], warmup=1, rep=1) + @triton.jit + def _kernel(src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + 1 + tl.store(src + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](src, N) + triton.testing.assert_close(src, torch.ones_like(src)) + + +def test_hooks(): + # Autotuner's pre- and post- hooks should be called the same number of times + N = 4096 + src = torch.zeros(N, device='cuda') + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 4096}), triton.Config(kwargs={'BLOCK_SIZE': 32})] + + values = {"counter": 0, "has_exception": False} + + def _pre_hook(*args, **kwargs): + values["counter"] += 1 + + def _post_hook(*args, exception): + values["counter"] -= 1 + if exception is not None: + values["has_exception"] = True + assert values["counter"] == 0 + + @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, pre_hook=_pre_hook, post_hook=_post_hook) + @triton.heuristics({"N_STAGES": lambda nargs: 100 if nargs['N'] == 4096 else 4}) + @triton.jit + def _kernel(src, N, N_STAGES: tl.constexpr, BLOCK_SIZE: tl.constexpr): + offsets = tl.arange(0, BLOCK_SIZE) + max_iters = tl.cdiv(N, BLOCK_SIZE) + for _ in tl.range(max_iters, num_stages=N_STAGES): + x = tl.load(src + offsets, mask=offsets < N) + tl.store(src + offsets, x, mask=offsets < N) + offsets += BLOCK_SIZE + + _kernel[(1, )](src, N) + + # On NVIDIA GPUs: + # The tunning knob `num_stages` can be set by users. + # This will cause out of resources when N_STAGES = 100 + # shared memory bytes = N_STAGES * BLOCK_SIZE * sizeof(float) + # On AMD GPUs: + # `num_stages` is a fixed value of 2, so it won't cause out of resources + if triton.runtime.driver.active.get_current_target().backend == "cuda": + assert values["has_exception"] is True + else: + assert values["has_exception"] is False + + +@pytest.mark.parametrize('with_perf_model', [False, True]) +def test_prune_configs(with_perf_model: bool): + N = 1024 + src = torch.empty(N, device='cuda') + dst = torch.empty(N, device='cuda') + records = {} + + def early_config_prune(configs, named_args, **kwargs): + records['run_early_config_prune'] = True + if "N" in kwargs and kwargs["N"] == 1024: + records['capture_kwargs'] = True + if "dst" in named_args and "src" in named_args and len(named_args) == 2: + records['capture_named_args'] = True + return [configs[0]] + + def perf_model(*args, **kwargs): + records['run_perf_model'] = True + return kwargs['BLOCK_SIZE'] + + configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + + if with_perf_model: + prune_configs_by = {'perf_model': perf_model, 'top_k': 1} + else: + prune_configs_by = {'early_config_prune': early_config_prune} + + @triton.autotune(configs=configs, key=['N'], prune_configs_by=prune_configs_by, warmup=1, rep=1) + @triton.jit + def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offsets, mask=offsets < N) + tl.store(dst + offsets, x, mask=offsets < N) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) + _kernel[grid](dst, src, N=N) + torch.testing.assert_close(src, dst) + if with_perf_model: + assert len(records) == 1 + assert records['run_perf_model'] + else: + assert len(records) == 3 + assert records['run_early_config_prune'] + assert records['capture_kwargs'] + assert records['capture_named_args'] diff --git a/third_party/xpu/python/test/unit/runtime/test_bindings.py b/third_party/xpu/python/test/unit/runtime/test_bindings.py new file mode 100644 index 000000000..c48ba9b4a --- /dev/null +++ b/third_party/xpu/python/test/unit/runtime/test_bindings.py @@ -0,0 +1,81 @@ +import triton +import triton.language as tl + +import torch + + +@triton.jit +def add_helper(x, y): + return x + y + + +@triton.jit +def add_kernel( + in_ptr0, + in_ptr1, + n_elements, + out_ptr, + BLOCK_SIZE: "tl.constexpr", +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = add_helper(x, y) + tl.store(out_ptr + offsets, output, mask=mask) + + +def test_module_walk(): + """ + Test the MLIR bindings exposed for the out-ot-tree walk. + """ + + def walk_fn(op): + name = op.get_name() + for i in range(op.get_num_results()): + op.get_result(i).id() + for i in range(op.get_num_operands()): + op.get_operand(i).id() + for i in range(op.get_num_regions()): + op.get_region(i).id() + block = op.get_block() + if block is not None: + block.id() + for i in range(block.get_num_arguments()): + block.get_argument(i) + if name == "tt.func": + op.get_str_attr("sym_name") + if name == "tt.call": + op.get_flat_symbol_ref_attr("callee") + + kernel = add_kernel + args = [ + torch.empty((32, 32), device="cuda"), # in_ptr0 + torch.empty((32, 32), device="cuda"), # in_ptr1 + 1024, # n_elements + torch.empty((32, 32), device="cuda"), # out_ptr + 16, # BLOCK_SIZE + ] + src = triton.compiler.compiler.ASTSource( + fn=kernel, + signature={i: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args) + if i not in kernel.constexprs}, + constants={i: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + attrs=kernel._get_config(*args, ), + ) + + context = triton._C.libtriton.ir.context() + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options(dict()) + codegen_fns = dict() + triton._C.libtriton.ir.load_dialects(context) + backend.load_dialects(context) + + ttir_module = src.make_ir(options, codegen_fns, context) + ttir_module.walk(walk_fn) diff --git a/third_party/xpu/python/test/unit/runtime/test_cache.py b/third_party/xpu/python/test/unit/runtime/test_cache.py new file mode 100644 index 000000000..4705924f9 --- /dev/null +++ b/third_party/xpu/python/test/unit/runtime/test_cache.py @@ -0,0 +1,536 @@ +import importlib.util +import itertools +import os +import shutil +import tempfile + +import pytest +import torch + +import triton +import triton.language as tl +from triton.runtime.jit import JITFunction + +tmpdir = ".tmp" + + +@triton.jit +def function_1(i): + i = i + 1 + i = function_2(i) + return i + + +@triton.jit +def function_2(i): + i = i + 1 + return i + + +@triton.jit +def combine_fn(a, b): + return COMBINE_OP # noqa: F821 + + +@triton.jit +def kernel(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit(do_not_specialize=["i"]) +def kernel_nospec(X, i, BLOCK: tl.constexpr): + i = i + 1 + i = function_1(i) + tl.store(X, i) + + +@triton.jit +def kernel_with_combine_fn(X, BLOCK: tl.constexpr): + i = tl.arange(0, BLOCK) + i = REDUCE_OR_SCAN(i, 0, combine_fn) # noqa: F821 + tl.store(X, i) + + +def apply_src_change(target, old, new): + kernel.hash = None + function_1.hash = None + function_2.hash = None + function_1.src = function_1.src.replace(old, new) + target.src = target.src.replace(old, new) + ret = target.cache_key + target.src = target.src.replace(new, old) + return ret + + +def test_nochange(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 1') + assert baseline == updated + + +def test_toplevel_change(): + baseline = kernel.cache_key + updated = apply_src_change(kernel, 'i + 1', 'i + 2') + assert baseline != updated + + +def test_nested1_change(): + baseline = kernel.cache_key + updated = apply_src_change(function_1, 'i + 1', 'i + 2') + assert baseline != updated + + +def test_combine_fn_change(): + # Test that tl.reduce and associative_scan calls include + # the combine_fn in the hash + + orig_combine_fn_src = combine_fn.src + orig_kernel_src = kernel_with_combine_fn.src + seen_keys = set() + + for reduce_or_scan, combine_op in itertools.product( + ["tl.reduce", "tl.associative_scan"], + ["a + b", "a * b"], + ): + combine_fn.src = orig_combine_fn_src.replace("COMBINE_OP", combine_op) + kernel_with_combine_fn.src = orig_kernel_src.replace("REDUCE_OR_SCAN", reduce_or_scan) + try: + key = kernel_with_combine_fn.cache_key + finally: + combine_fn.src = orig_combine_fn_src + kernel_with_combine_fn.src = orig_kernel_src + + kernel_with_combine_fn.hash = None + combine_fn.hash = None + + assert key not in seen_keys + seen_keys.add(key) + + +def write_and_load_module(code, num_extra_lines): + with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: + f.write(('# extra line\n' * num_extra_lines) + code) + f.flush() + spec = importlib.util.spec_from_file_location("module.name", f.name) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_changed_line_numbers_invalidate_cache(): + from textwrap import dedent + code = dedent(""" + import triton + @triton.jit + def test_kernel(i): + i = i + 1 + """) + orig_mod = write_and_load_module(code, 0) + orig_cache_key = orig_mod.test_kernel.cache_key + + updated_mod = write_and_load_module(code, 1) + updated_cache_key = updated_mod.test_kernel.cache_key + assert orig_cache_key != updated_cache_key + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + # https://stackoverflow.com/questions/303200/how-do-i-remove-delete-a-folder-that-is-not-empty + shutil.rmtree(tmpdir, ignore_errors=True) + + +def test_reuse(): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device='cuda') + for i in range(10): + kernel[(1, )](x, 1, BLOCK=1024) + assert counter == 1 + + +@pytest.mark.parametrize('mode', ['enable', 'disable']) +def test_specialize(mode): + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + reset_tmp_dir() + x = torch.empty(1, dtype=torch.int32, device='cuda') + function = {'enable': kernel, 'disable': kernel_nospec}[mode] + target = {'enable': 3, 'disable': 1}[mode] + for i in [1, 2, 4, 8, 16, 32]: + function[(1, )](x, i, BLOCK=512) + assert counter == target + + +def test_annotation(): + + @triton.jit + def kernel(X, i: tl.int32): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + + device = torch.cuda.current_device() + kernel[(1, )](x, 1) + kernel[(1, )](x, 8) + kernel[(1, )](x, 16) + kernel[(1, )](x, 17) + assert len(kernel.cache[device]) == 3 + + +GLOBAL_DEFAULT_ARG = 1 + + +def test_kernel_default_arg(): + global GLOBAL_DEFAULT_ARG + + @triton.jit + def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + kernel[(1, )](x) + assert x == torch.ones_like(x) + + # Changing the global variable should not change the default argument in + # `kernel`. That value gets set at the time the function is declared. + GLOBAL_DEFAULT_ARG = 2 + kernel[(1, )](x) + assert x == torch.ones_like(x) + + device = torch.cuda.current_device() + assert len(kernel.cache[device]) == 1 + + +GLOBAL_VAR: tl.constexpr = 1 + + +def test_kernel_global_var_change(): + global GLOBAL_VAR + + @triton.jit + def kernel(X): + tl.store(X, GLOBAL_VAR) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + kernel[(1, )](x) + assert x == torch.ones_like(x) + + GLOBAL_VAR = 2 + with pytest.raises(RuntimeError) as e: + kernel[(1, )](x) + + assert "global variable" in str(e.value).lower() + + +GLOBAL = 42 # noqa + + +def test_local_shadows_global(): + global GLOBAL + + @triton.jit + def kernel(): + _, GLOBAL = 0, 0 # noqa + a = GLOBAL # noqa + + # No error because the `GLOBAL` we're modifying is not the same `GLOBAL` as + # inside the kernel. + GLOBAL = 42 + kernel[(1, )]() + GLOBAL = 43 + kernel[(1, )]() + + +CONSTEXPR_GLOBAL: tl.constexpr = 42 + + +def test_local_does_not_shadow_global(): + global CONSTEXPR_GLOBAL + + @triton.jit + def kernel(): + a = CONSTEXPR_GLOBAL # noqa + _, CONSTEXPR_GLOBAL = 0, 0 # noqa + + CONSTEXPR_GLOBAL = 42 + kernel[(1, )]() + CONSTEXPR_GLOBAL = 43 + + # Error because the `CONSTEXPR_GLOBAL` we're modifying is the same + # `CONSTEXPR_GLOBAL` that's read inside `kernel`. (Alternatively, we could + # make this kernel an error altogether, as it is if it's a pure Python + # function -- the fact that we store to `CONSTEXPR_GLOBAL` inside the kernel + # makes the first read a read of the local variable, which doesn't exist + # yet.) + with pytest.raises(RuntimeError): + kernel[(1, )]() + + +CONFLICTING_GLOBAL: tl.constexpr = 0 + + +@triton.jit +def conflicting_global_inner(): + a = CONFLICTING_GLOBAL # noqa + + +def test_conflicting_global_in_inner_function(): + global CONFLICTING_GLOBAL + + @triton.jit + def kernel1(): + a = CONFLICTING_GLOBAL # noqa + conflicting_global_inner() + + @triton.jit + def kernel2(): + a = CONFLICTING_GLOBAL #noqa + conflicting_global_inner() + + kernel1[(1, )]() + + # This should be an error because kernel2 calls conflicting_global_inner, + # which saw a value for 42 for the global when it was first compiled. + CONFLICTING_GLOBAL = 1 + + with pytest.raises(RuntimeError) as e: + kernel2[(1, )]() + + assert "Global variable CONFLICTING_GLOBAL has value" in str(e.value) + + +def test_use_builtin(): + + @triton.jit + def kernel(): + a = float(0) # noqa + + # No error about the value of `float` changing. + kernel[(1, )]() + kernel[(1, )]() + + +def test_no_cache_module_as_global(): + + @triton.jit + def kernel(): + tl.arange(0, 16) + + kernel[(1, )]() + # `tl` should not be entered into used_global_vals + assert not kernel.used_global_vals + + +BUILTIN_AS_GLOBAL = tl.int32 + + +def test_cache_builtin_as_global(): + global BUILTIN_AS_GLOBAL + + @triton.jit + def kernel(): + x = BUILTIN_AS_GLOBAL # noqa + + kernel[(1, )]() + + BUILTIN_AS_GLOBAL = tl.int64 + with pytest.raises(RuntimeError) as e: + kernel[(1, )]() + + assert "global variable" in str(e.value).lower() + + +@triton.jit +def no_cache_callable_inner(): + pass + + +def test_no_cache_callable(): + + @triton.jit + def kernel(): + no_cache_callable_inner() + + kernel[(1, )]() + # `no_cache_callable_inner` should not be entered into used_global_vals. + assert not kernel.used_global_vals + + +def test_constexpr_not_callable() -> None: + + @triton.jit + def kernel(X, c: tl.constexpr): + tl.store(X, 2) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + error = False + try: + kernel[(1, )](x, c="str") + except BaseException: + error = True + assert error is False + # try and catch + try: + kernel[(1, )](x, c=tl.abs) + except BaseException: + error = True + assert error is True + + +def test_jit_warmup_cache() -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + args = [ + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + torch.randn(32, dtype=torch.float32, device="cuda"), + 32, + ] + device = torch.cuda.current_device() + assert len(kernel_add.cache[device]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.warmup(*args, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + + +@pytest.mark.skip("Skip for kunlunxin") +def test_jit_debug() -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + device = torch.cuda.current_device() + assert len(kernel_add.cache[device]) == 0 + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 1 + kernel_add.debug = False + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 2 + kernel_add.debug = True + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add.cache[device]) == 3 + bins = list(kernel_add.cache[device].values()) + assert bins[2].asm['ttir'] != bins[1].asm['ttir'] + + +@triton.jit +def add_fn(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + +@pytest.mark.skip("Skip for kunlunxin") +def test_jit_noinline() -> None: + + @triton.jit + def kernel_add_device(a, b, o, N: tl.constexpr): + add_fn(a, b, o, N) + + device = torch.cuda.current_device() + assert len(kernel_add_device.cache[device]) == 0 + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + inline_ttir = bins[0].asm['ttir'] + add_fn.noinline = True + add_fn.hash = None + kernel_add_device.hash = None + kernel_add_device.cache[device].clear() + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + noinline_ttir = bins[0].asm['ttir'] + assert inline_ttir != noinline_ttir + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + +def test_preload() -> None: + + @triton.jit + def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): + idx = tl.arange(0, N) + tl.device_assert(idx < 32, "idx < 32") + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) + + device = torch.cuda.current_device() + + # get the serialized specialization data + specialization_data = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + specialization_data = kwargs["compile"]["specialization_data"] + + JITFunction.cache_hook = cache_hook + pre_compile = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + hash = pre_compile.hash + assert specialization_data is not None + + # clear the cache + reset_tmp_dir() + kernel_add.cache[device].clear() + + # preload the kernel + kernel_preload = kernel_add.preload(specialization_data) + assert kernel_preload.hash == hash + assert len(kernel_add.cache[device]) == 1 + + # we should hit the cache and not compile anything + counter = 0 + + def inc_counter(*args, **kwargs): + nonlocal counter + counter += 1 + + JITFunction.cache_hook = inc_counter + final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, )) + JITFunction.cache_hook = None + assert counter == 0 + assert len(kernel_add.cache[device]) == 1 + assert final_kernel.hash == hash + + # test that we can't preload a mismatched kernel + with pytest.raises(RuntimeError, match="Specialization data is for"): + kernel_sub.preload(specialization_data) diff --git a/third_party/xpu/python/test/unit/runtime/test_driver.py b/third_party/xpu/python/test/unit/runtime/test_driver.py new file mode 100644 index 000000000..de00082f5 --- /dev/null +++ b/third_party/xpu/python/test/unit/runtime/test_driver.py @@ -0,0 +1,14 @@ +import sys + +import triton + + +def test_is_lazy(): + from importlib import reload + reload(sys.modules["triton.runtime.driver"]) + reload(sys.modules["triton.runtime"]) + mod = sys.modules[triton.runtime.driver.__module__] + assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy")) + assert triton.runtime.driver.active._obj is None + utils = triton.runtime.driver.active.utils # noqa: F841 + assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase")) diff --git a/third_party/xpu/python/test/unit/runtime/test_jit.py b/third_party/xpu/python/test/unit/runtime/test_jit.py new file mode 100644 index 000000000..5892494c4 --- /dev/null +++ b/third_party/xpu/python/test/unit/runtime/test_jit.py @@ -0,0 +1,42 @@ +import itertools +import pytest +import torch + +import triton +import triton.language as tl + + +def test_pre_call_hooks(device): + + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + class MyTensor(torch.Tensor): + pass + + def my_hook(*args, **kwargs): + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, MyTensor): + raise Exception("MyTensor is not allowed") + + add_kernel.add_pre_run_hook(my_hook) + + x = torch.randn(4, device=device) + y = MyTensor(x) + out = torch.zeros_like(x) + with pytest.raises(Exception): + add_kernel[(4, )](x, y, out, 4, 4) diff --git a/third_party/xpu/python/test/unit/runtime/test_launch.py b/third_party/xpu/python/test/unit/runtime/test_launch.py new file mode 100644 index 000000000..ef4c60ac4 --- /dev/null +++ b/third_party/xpu/python/test/unit/runtime/test_launch.py @@ -0,0 +1,137 @@ +import gc +# import importlib +# import os +# import sys +# import tempfile +# import textwrap +# import time +import tracemalloc + +import torch + +import triton +import triton.language as tl +import pytest + +# from typing import Tuple + +pytest.skip("Skip for kunlunxin", allow_module_level=True) + + +def test_metadata() -> None: + + used_hook = False + + def _launch_metadata(grid, kernel, args): + ret = dict() + ret["grid"] = grid + ret["value"] = args["x"] + return ret + + def hook(launch_metadata): + nonlocal used_hook + metadata = launch_metadata.get() + assert metadata["grid"] == (1, 3, 2) + assert metadata["value"] == 6 + used_hook = True + + @triton.jit(launch_metadata=_launch_metadata) + def kernel(x): + pass + + # launch kernel + triton.compiler.CompiledKernel.launch_enter_hook = hook + kernel[(1, 3, 2)](6) + triton.compiler.CompiledKernel.launch_enter_hook = None + assert used_hook + + +def test_memory_leak() -> None: + + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 10 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask) + + tracemalloc.start() + try: + inp = torch.randn(10, device='cuda') + out = torch.randn(10, device='cuda') + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + begin, _ = tracemalloc.get_traced_memory() + for _ in range(100): + kernel[(10, )](inp, out, 10, XBLOCK=16) + gc.collect() + end, _ = tracemalloc.get_traced_memory() + assert end - begin < 30000 + finally: + tracemalloc.stop() + + +# LATENCY_THRESHOLD_US = 46 + +# def test_kernel_launch_latency() -> None: +# def define_kernel(kernel_name: str, num_tensor_args: int) -> str: +# arg_str = ",".join([f"arg{i}: torch.Tensor" for i in range(num_tensor_args)]) +# arg_str += ", n_elements: int, BLOCK_SIZE: tl.constexpr" +# func_str = f""" +# import torch + +# import triton +# import triton.language as tl + +# @triton.jit +# def {kernel_name}({arg_str}): +# pass +# """ +# with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py", delete=False) as temp_file: +# temp_file.write(textwrap.dedent(func_str)) +# temp_file_path = temp_file.name + +# return temp_file_path + +# def import_kernel(file_path, kernel_name): +# directory, filename = os.path.split(file_path) +# module_name, _ = os.path.splitext(filename) +# sys.path.insert(0, directory) + +# module = importlib.import_module(module_name) +# kernel = getattr(module, kernel_name) +# return kernel + +# def empty(*kernel_args: Tuple[torch.Tensor]): +# first_arg = kernel_args[0] +# n_elements = first_arg.numel() +# grid = (triton.cdiv(n_elements, 1024),) +# device = torch.cuda.current_device() +# # Warmup +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# torch.cuda.synchronize() +# # Measure launch overhead at steady state +# num_runs = 1000 +# start_time = time.time() +# for i in range(num_runs): +# empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device) +# end_time = time.time() +# latency_us = (end_time - start_time) / num_runs * 1e6 + +# assert latency_us < LATENCY_THRESHOLD_US, "Kernel launch time has increased!" + +# num_tensor_args = 40 +# kernel_name = 'empty_kernel' +# file_path = define_kernel(kernel_name, num_tensor_args) +# empty_kernel = import_kernel(file_path, kernel_name) + +# # Initialize random tensors for the empty_kernel +# torch.manual_seed(0) +# size = 1024 +# kernel_args = (torch.rand(size, device='cuda') for i in range(num_tensor_args)) + +# # Run empty, which would run empty_kernel internally +# empty(*kernel_args) diff --git a/third_party/xpu/python/test/unit/runtime/test_subproc.py b/third_party/xpu/python/test/unit/runtime/test_subproc.py new file mode 100644 index 000000000..0ee3fe5a6 --- /dev/null +++ b/third_party/xpu/python/test/unit/runtime/test_subproc.py @@ -0,0 +1,75 @@ +import multiprocessing +import os +import shutil + +import torch +import pytest + +import triton +import triton.language as tl +from triton.compiler import ASTSource + +tmpdir = ".tmp" + +target = triton.runtime.driver.active.get_current_target() + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir, ignore_errors=True) + + +def compile_fn(attrs, capability): + + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + + src = ASTSource( + fn=kernel_sub, + constants={3: 32}, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + attrs=attrs, + ) + triton.compile(src=src, target=target) + + +def test_compile_in_subproc() -> None: + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(4)), ()) + + multiprocessing.set_start_method('fork') + proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_fn_dot(attrs, capability): + + @triton.jit + def kernel_dot(Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + tl.store(Z + offs, z) + + src = ASTSource(fn=kernel_dot, signature={0: "*fp32"}, attrs=attrs, constants=dict()) + triton.compile(src=src, target=target) + + +@pytest.mark.skip("Skip for kunlunxin") +def test_compile_in_forked_subproc() -> None: + reset_tmp_dir() + major, minor = torch.cuda.get_device_capability(0) + capability = major * 10 + minor + config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) + + assert multiprocessing.get_start_method() == 'fork' + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) + proc.start() + proc.join() + assert proc.exitcode == 0 diff --git a/third_party/xpu/python/triton/_C/include b/third_party/xpu/python/triton/_C/include new file mode 120000 index 000000000..b85a40983 --- /dev/null +++ b/third_party/xpu/python/triton/_C/include @@ -0,0 +1 @@ +../../../include/ \ No newline at end of file diff --git a/third_party/xpu/python/triton/__init__.py b/third_party/xpu/python/triton/__init__.py new file mode 100644 index 000000000..031c58fb1 --- /dev/null +++ b/third_party/xpu/python/triton/__init__.py @@ -0,0 +1,73 @@ +"""isort:skip_file""" +__version__ = '3.0.0' + +# --------------------------------------- +# Note: import order is significant here. + +# submodules +from .runtime import ( + autotune, + Config, + heuristics, + JITFunction, + KernelInterface, + reinterpret, + TensorWrapper, + OutOfResources, + InterpreterError, + MockTensor, +) +from .runtime.jit import jit +from .compiler import compile, CompilationError +from .errors import TritonError + +from . import language +from . import testing +from . import tools + +__all__ = [ + "autotune", + "cdiv", + "CompilationError", + "compile", + "Config", + "heuristics", + "impl", + "InterpreterError", + "jit", + "JITFunction", + "KernelInterface", + "language", + "MockTensor", + "next_power_of_2", + "ops", + "OutOfResources", + "reinterpret", + "runtime", + "TensorWrapper", + "TritonError", + "testing", + "tools", +] + +# ------------------------------------- +# misc. utilities that don't fit well +# into any specific module +# ------------------------------------- + + +def cdiv(x: int, y: int): + return (x + y - 1) // y + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/third_party/xpu/python/triton/backends b/third_party/xpu/python/triton/backends new file mode 120000 index 000000000..13a83a85c --- /dev/null +++ b/third_party/xpu/python/triton/backends @@ -0,0 +1 @@ +../../../../python/triton/backends \ No newline at end of file diff --git a/third_party/xpu/python/triton/compiler/__init__.py b/third_party/xpu/python/triton/compiler/__init__.py new file mode 100644 index 000000000..ce0cfedfc --- /dev/null +++ b/third_party/xpu/python/triton/compiler/__init__.py @@ -0,0 +1,4 @@ +from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict +from .errors import CompilationError + +__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/third_party/xpu/python/triton/compiler/code_generator.py b/third_party/xpu/python/triton/compiler/code_generator.py new file mode 100644 index 000000000..6903052ca --- /dev/null +++ b/third_party/xpu/python/triton/compiler/code_generator.py @@ -0,0 +1,1302 @@ +import ast +import inspect +import re +import sys +import warnings +import os +import textwrap +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from .. import language +from .._C.libtriton import ir +from ..language import constexpr, tensor, str_to_ty +from ..runtime.jit import _normalize_ty +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) +from types import ModuleType + + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + assert False, "Unsupported type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return isinstance(o, constexpr) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _unwrap_if_constexpr(o: Any): + return o.value if isinstance(o, constexpr) else o + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +def _get_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + for s in body: + if self.visit(s): + return True + return False + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) == ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + self.context = context + self.builder = ir.builder(context) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + self.gscope = gscope + self.lscope = dict() + self.attributes = attributes + self.constants = constants + self.jit_fn = jit_fn + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.debug = options.debug if debug is None else debug + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): + return True + + if a := self.gscope.get("__annotations__", {}).get(name): + return _normalize_ty(a) == "constexpr" + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if (val is absent # + or name in self.builtin_namespace # + or type(val) == ModuleType # + or isinstance(val, JITFunction) # + or getattr(val, "__triton_builtin__", False) # + or getattr(val, "__module__", "").startswith("triton.language") # + or isinstance(val, language.dtype) # + or self._is_constexpr_global(name) # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + or self.visiting_arg_default_value # + or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are annotated as constexpr (`x: triton.language.constexpr = 42` + or `x = triton.language.constexpr(42)`). Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + # ret_block = self.builder.create_block() + # post_ret_block = self.builder.create_block() + # self.builder.create_branch(ret_block) + # self.builder.set_insertion_point_to_end(ret_block) + if ret_value is None: + self.builder.ret([]) + ret_ty = language.void + elif isinstance(ret_value, tuple): + ret_values = [language.core._to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + ret_ty = tuple(ret_types) + else: + ret = language.core._to_tensor(ret_value, self.builder) + self.builder.ret([ret.handle]) + ret_ty = ret.type + # self.builder.create_branch(post_ret_block) + # self.builder.set_insertion_point_to_end(post_ret_block) + + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + for name, value in self.attributes[i]: + self.fn.set_arg_attr(idx, name, value) + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + # finalize function + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + # update return type + if isinstance(self.ret_type, tuple): + self.prototype.ret_types = list(self.ret_type) + self.fn.reset_type(self.prototype.to_ir(self.builder)) + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + # Remove dead code + self.fn.finalize() + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not _is_constexpr(value): + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise self._unsupported(node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not _is_list_like(names): + names = [names] + if not _is_list_like(values): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_tensor(value) and \ + not isinstance(value, native_nontensor_types): + value = language.core._to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + ret_types = [] + ir_ret_types = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + assert defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name].type}, '\ + f'but the {block_name} block redefines it as {defs[name].type}' + if name in then_defs or name in else_defs: + names.append(name) + ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in then_defs.keys() & else_defs.keys(): + if name in names: + continue + then_ty = then_defs[name].type + else_ty = else_defs[name].type + assert then_ty == else_ty, \ + f'mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + ret_types.append(then_ty) + ir_ret_types.append(then_defs[name].handle.get_type()) + + return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + + def visit_if_top_level(self, cond, node): + has_endif_block = True + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create basic-block after conditional + endif_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # then terminator + self.builder.set_insertion_point_to_end(then_block) + if then_block.has_return() and else_block.has_return(): + has_endif_block = False + endif_block.erase() + if not then_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + if not else_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + if has_endif_block: + for ty in ir_ret_types: + endif_block.add_argument(ty) + if has_endif_block: + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + # update values + for i, name in enumerate(names): + new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) + self.set_value(name, new_tensor) + + def visit_If(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if self.scf_stack and contains_return: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + elif self.scf_stack or not contains_return: + self.visit_if_scf(cond, node) + else: + self.visit_if_top_level(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.orelse) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = language.core._to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = language.core._to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) == ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) == ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_builder=self.builder) + try: + return getattr(operand, fn)() + except AttributeError: + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + assert _is_triton_tensor(loop_defs[name]), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop' + assert loop_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {loop_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + # these are loop-carried values + names.append(name) + ret_types.append(loop_defs[name].type) + init_args.append(liveins[name]) + + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + self.builder.set_insertion_point_to_start(before_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + yields.append(loop_defs[name]) + self.builder.create_yield_op([y.handle for y in yields]) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + if IteratorClass is language.range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = language.core._to_tensor(lb, self.builder) + ub = language.core._to_tensor(ub, self.builder) + step = language.core._to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_undef(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' + assert _is_triton_tensor(liveins[name]) + assert self.local_defs[name].type == liveins[name].type, \ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {self.local_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + names.append(name) + init_args.append(language.core._to_tensor(liveins[name], self.builder)) + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + if num_stages is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + + self.scf_stack.append(node) + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create YieldOp + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + if not self.debug: + return + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + # Convert assert to triton's device_assert which happens on the device + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = fn.__globals__ + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = _get_fn_file_line(fn) + debug = self.debug if fn.debug is None else fn.debug + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + raise CompilationError(self.jit_fn.src, self.cur_node, None) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + if fn is language.core.device_assert: # TODO: this should not be so hardcoded + if not self.debug: + return + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = dict(_builder=self.builder) + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + return fn(*args, **extra_kwargs, **kws) + except Exception as e: + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, None) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + return fn(*args, **kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise self._unsupported( + node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + if sys.version_info < (3, 8): + + def visit_NameConstant(self, node): + return constexpr(node.value) + + def visit_Num(self, node): + return constexpr(node.n) + + def visit_Str(self, node): + return constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs): + if node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + } + + +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisible_by_16: + suffix += 'd' + return suffix + + +def ast_to_ttir(fn, specialization, context, options, codegen_fns): + attrs = specialization.attrs + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in specialization.constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = fn.repr(specialization) + tys = list(specialization.signature.values()) + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1} + new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16} + + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + file_name, begin_line = _get_fn_file_line(fn) + + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns) + generator.visit(fn.parse()) + + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/third_party/xpu/python/triton/compiler/compiler.py b/third_party/xpu/python/triton/compiler/compiler.py new file mode 100644 index 000000000..120ac6e6d --- /dev/null +++ b/third_party/xpu/python/triton/compiler/compiler.py @@ -0,0 +1,463 @@ +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir +from ..backends import backends +from ..backends.compiler import GPUTarget +from .. import __version__ +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.driver import driver +# TODO: this shouldn't be here +from dataclasses import dataclass +from .code_generator import ast_to_ttir +from pathlib import Path +import re +import functools +import os + + +@dataclass +class AttrsDescriptor: + divisible_by_16: set = None + equal_to_1: set = None + + def __post_init__(self): + if self.divisible_by_16 is None: + self.divisible_by_16 = set() + if self.equal_to_1 is None: + self.equal_to_1 = set() + + def to_dict(self): + return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)} + + @staticmethod + def from_dict(data): + return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])), + equal_to_1=set(data.get('equal_to_1', []))) + + def hash(self): + key = str([sorted(x) for x in self.__dict__.values()]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+),?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if + # e.g. someone has an instruction (not module) attribute named "num-warps". + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + num_warps = int(num_warps_matches[0]) + return num_warps + + +class ASTSource: + + def __init__(self, fn, signature, constants=None, attrs=None) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.attrs = attrs + # ===-------------------- For XPytorch Inductor -----------------------=== + # Pytorch(v2.0.1) Inductor Can't Generate AttrsDescriptor() + self.attrs = None + # ===------------------------------------------------------------------=== + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + if self.constants is None: + self.constants = dict() + if self.attrs is None: + self.attrs = AttrsDescriptor() + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + # Note - we stringify the keys here to allow sorting to work for cases + # where constants have mixed int/str keys. + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, context): + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def parse_options(self): + if self.ext == "ttgir": + return {'num_warps': _get_num_warps_from_ir_str(self.src)} + return dict() + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx": + return Path(full_name).read_text() + if ext == "cubin": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None + else: + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +def compile(src, target=None, options=None): + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + metadata_filename = f"{src.name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" + if not always_compile and metadata_path is not None: + # cache hit! + metadata = json.loads(Path(metadata_path).read_text()) + return CompiledKernel(src, metadata_group, hash) + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() + try: + module = src.make_ir(options, codegen_fns, context) + except Exception as e: + filter_traceback(e) + raise + use_ttgir_loc = os.environ.get("USE_TTGIR_LOC", "0") == "1" + + def post_stage(module): + ir_filename = f"{src.name}.{ext}" + metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(module, ir_filename) + if (fn_override_manager is not None and fn_override_manager.has_file(ir_filename)): + print(f"\nOverriding kernel with file {ir_filename}") + full_name = fn_override_manager.get_file(ir_filename) + module = parse(full_name, ext, context) + # use an env variable to parse ttgir from file + if use_ttgir_loc and ext == "ttgir": + ttgir_full_name = fn_cache_manager.get_file(ir_filename) + module.create_location_snapshot(ttgir_full_name) + print(f"Create new locations for {ttgir_full_name}") + + if target.backend == "xpu": + make_ttxir_stage_index = list(stages.keys()).index("ttxir") + make_elf_stage_index = list(stages.keys()).index("elf") + if first_stage > make_ttxir_stage_index: + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + post_stage(next_module) + module = next_module + else: + index = 0 + while True: + for ext, compile_ir in list(stages.items())[first_stage:make_elf_stage_index + 1]: + next_module = compile_ir(module, metadata) + post_stage(next_module) + module = next_module + if backend.is_elf_stack_size_oob(module): + index = index + 1 + if metadata["buffer_size_limit"] == 16: + raise RuntimeError("Failed to tune buffer size.") + module = src.make_ir(options, codegen_fns, context) + buffer_size_limit_tuned = metadata["buffer_size_limit"] // 2 + tune_log = ('[Buffer Size Limit Tunning] ' + f'Tune buffer_size_limit from {metadata["buffer_size_limit"]}' + f' to {buffer_size_limit_tuned}') + # print(tune_log) + metadata["buffer_size_limit"] = buffer_size_limit_tuned + else: + break + + for ext, compile_ir in list(stages.items())[make_elf_stage_index + 1:]: + next_module = compile_ir(module, metadata) + post_stage(next_module) + module = next_module + else: + for ext, compile_ir in list(stages.items())[first_stage:]: + next_module = compile_ir(module, metadata) + post_stage(next_module) + module = next_module + + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target): + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self) -> None: + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = { + file.suffix[1:]: + file.read_bytes() if file.suffix[1:] == binary_ext or file.suffix[1:] == 'elf' else + file.read_text() # ELF Check For Triton XPU + for file in asm_files + } + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + def _init_handles(self): + if self.module is not None: + return + device = driver.active.get_current_device() + # create launcher + self.run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + # ===-------------------- For Triton XPU -----------------------=== + if (self.metadata.backend_name == 'xpu'): + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.printf_buf_offset) + return + # ===-----------------------------------------------------------=== + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + + def __getattribute__(self, name): + if name == 'run': + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if CompiledKernel.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + if i in self.src.fn.constexprs: + arg_dict[arg_name] = self.src.constants[arg_name] + else: + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) + + return runner diff --git a/third_party/xpu/python/triton/compiler/errors.py b/third_party/xpu/python/triton/compiler/errors.py new file mode 100644 index 000000000..39e6c4dfb --- /dev/null +++ b/third_party/xpu/python/triton/compiler/errors.py @@ -0,0 +1,51 @@ +import ast +from typing import Optional +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/third_party/xpu/python/triton/compiler/make_launcher.py b/third_party/xpu/python/triton/compiler/make_launcher.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/xpu/python/triton/errors.py b/third_party/xpu/python/triton/errors.py new file mode 100644 index 000000000..3a0a86355 --- /dev/null +++ b/third_party/xpu/python/triton/errors.py @@ -0,0 +1,5 @@ +"""Base class for all errors raised by Triton""" + + +class TritonError(Exception): + ... diff --git a/third_party/xpu/python/triton/language/__init__.py b/third_party/xpu/python/triton/language/__init__.py new file mode 100644 index 000000000..168dccfea --- /dev/null +++ b/third_party/xpu/python/triton/language/__init__.py @@ -0,0 +1,284 @@ +"""isort:skip_file""" +# Import order is significant here. + +from . import math +from . import extra +from .standard import ( + argmax, + argmin, + cdiv, + cumprod, + cumsum, + flip, + interleave, + max, + min, + ravel, + sigmoid, + softmax, + sort, + sum, + swizzle2d, + xor_sum, + zeros, + zeros_like, +) +from .core import ( + PropagateNan, + TRITON_MAX_TENSOR_NUMEL, + _experimental_descriptor_load, + _experimental_descriptor_store, + advance, + arange, + associative_scan, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + cast, + clamp, + const, + const_pointer_type, + constexpr, + debug_barrier, + device_assert, + device_print, + dot, + dtype, + expand_dims, + float16, + float32, + float64, + float8e4b15, + float8e4nv, + float8e4b8, + float8e5, + float8e5b16, + full, + function_type, + histogram, + inline_asm_elementwise, + int1, + int16, + int32, + int64, + int8, + join, + load, + make_block_ptr, + max_constancy, + max_contiguous, + maximum, + minimum, + multiple_of, + num_programs, + permute, + pi32_t, + pointer_type, + program_id, + range, + reduce, + reshape, + split, + static_assert, + static_print, + static_range, + store, + tensor, + trans, + uint16, + uint32, + uint64, + uint8, + view, + void, + where, +) +from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor, + ceil) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint_to_uniform_float, +) + +__all__ = [ + "PropagateNan", + "TRITON_MAX_TENSOR_NUMEL", + "_experimental_descriptor_load", + "_experimental_descriptor_store", + "abs", + "advance", + "arange", + "argmax", + "argmin", + "associative_scan", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "block_type", + "broadcast", + "broadcast_to", + "builtin", + "cat", + "cast", + "cdiv", + "ceil", + "clamp", + "const", + "const_pointer_type", + "constexpr", + "cos", + "cumprod", + "cumsum", + "debug_barrier", + "device_assert", + "device_print", + "div_rn", + "dot", + "dtype", + "erf", + "exp", + "exp2", + "expand_dims", + "extra", + "fdiv", + "flip", + "float16", + "float32", + "float64", + "float8e4b15", + "float8e4nv", + "float8e4b8", + "float8e5", + "float8e5b16", + "floor", + "fma", + "full", + "function_type", + "histogram", + "inline_asm_elementwise", + "interleave", + "int1", + "int16", + "int32", + "int64", + "int8", + "ir", + "join", + "load", + "log", + "log2", + "make_block_ptr", + "math", + "max", + "max_constancy", + "max_contiguous", + "maximum", + "min", + "minimum", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "permute", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "range", + "ravel", + "reduce", + "reshape", + "rsqrt", + "sigmoid", + "sin", + "softmax", + "sort", + "split", + "sqrt", + "sqrt_rn", + "static_assert", + "static_print", + "static_range", + "store", + "sum", + "swizzle2d", + "tensor", + "trans", + "triton", + "uint16", + "uint32", + "uint64", + "uint8", + "uint_to_uniform_float", + "umulhi", + "view", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] + + +def str_to_ty(name): + if name[0] == "*": + name = name[1:] + if name[0] == "k": + name = name[1:] + ty = str_to_ty(name) + return const_pointer_type(ty) + ty = str_to_ty(name) + return pointer_type(ty) + tys = { + "fp8e4nv": float8e4nv, + "fp8e4b8": float8e4b8, + "fp8e5": float8e5, + "fp8e5b16": float8e5b16, + "fp8e4b15": float8e4b15, + "fp16": float16, + "bf16": bfloat16, + "fp32": float32, + "fp64": float64, + "i1": int1, + "i8": int8, + "i16": int16, + "i32": int32, + "i64": int64, + "u1": int1, + "u8": uint8, + "u16": uint16, + "u32": uint32, + "u64": uint64, + "B": int1, + } + return tys[name] diff --git a/third_party/xpu/python/triton/language/core.py b/third_party/xpu/python/triton/language/core.py new file mode 100644 index 000000000..89e0ac4c9 --- /dev/null +++ b/third_party/xpu/python/triton/language/core.py @@ -0,0 +1,2628 @@ +from __future__ import annotations + +from warnings import warn +from contextlib import contextmanager +from enum import Enum +from functools import partial, wraps +import typing +from typing import Union, Callable, List, Sequence, TypeVar, Optional +import builtins +from ..runtime.jit import jit +import inspect +import os + +from .._C.libtriton import ir +from . import semantic + +T = TypeVar('T') +# ===-------------------- For Triton XPU -----------------------=== +# Triton XPU don't need the maxTensorNumElements (legalize pass) +import sys + +TRITON_MAX_TENSOR_NUMEL = sys.maxsize +# ===-----------------------------------------------------------=== + +TRITON_BUILTIN = "__triton_builtin__" + +PropagateNan = ir.PROPAGATE_NAN + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def _tensor_member_fn(fn: T) -> T: + """Decorator that adds this free function as a member fn on class tensor. + + When called as a member function on class tensor, the first argument to `fn` + is `self`, i.e. the tensor object. + + If there are multiple decorators on a function, you probably want this one + to be the highest one (i.e. furthest from the function's `def`), so it's + applied last. + + Unfortunately you still need to add a type stub to the body of class tensor + in order for pytype to know about it. + """ + assert callable(fn) + orig_sig = inspect.signature(fn) + # Does fn take args other than _builder, _generator, and the tensor itself? + has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1 + + if not fn.__doc__: + fn.__doc__ = "" + fn.__doc__ += f""" + This function can also be called as a member function on :py:class:`tensor`, + as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of + :code:`{fn.__name__}(x{", ..." if has_args else ""})`. + """ + + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Match the signature of `fn`, but change the first arg to `self` so the + # docs are a little less weird. + new_params = list(orig_sig.parameters.values()) + new_params[0] = new_params[0].replace(name='self') + new_sig = orig_sig.replace(parameters=new_params) + wrapper.__signature__ = new_sig + wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function" + # If fn is a builtin, mark the wrapper as a builtin too. + if is_builtin(fn): + setattr(wrapper, TRITON_BUILTIN, True) + + setattr(tensor, fn.__name__, wrapper) + return fn + + +def _unwrap_iterable(x): + """Returns x[0] if x has one element and x[0] is iterable.""" + if len(x) == 1: + # Determine whether x[0] is iterable. + # + # You might want to use collections.abc.Iterable instead of this + # try/except block. Unfortunately, this doesn't work with constexpr. + # + # The problem is that abc.Iterable checks for __iter__ on the *class*. + # But we want constexpr to expose an __iter__ method if and only if the + # wrapped *object* (i.e. self.value) is iterable. Therefore there's no + # right answer for whether the class constexpr defines __iter__, and + # abc.Iterable doesn't work (at least not without some metaclass magic). + try: + iter(x[0]) + return x[0] + except TypeError: + pass + + return x + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +@builtin +def to_tensor(x, _builder=None): + return _to_tensor(x, _builder) + + +def _to_tensor(x, builder): + if isinstance(x, bool): + return tensor(builder.get_int1(x), int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + return tensor(builder.get_int32(x), int32) + elif 2**31 <= x < 2**32: + return tensor(builder.get_uint32(x), uint32) + elif -2**63 <= x < 2**63: + return tensor(builder.get_int64(x), int64) + elif 2**63 <= x < 2**64: + return tensor(builder.get_uint64(x), uint64) + else: + raise RuntimeError(f'Nonrepresentable integer {x}.') + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + return tensor(builder.get_fp32(x), float32) + else: + return tensor(builder.get_fp64(x), float64) + + elif isinstance(x, constexpr): + return _to_tensor(x.value, builder) + elif isinstance(x, tensor): + return x + assert False, f"cannot convert {x} of type {type(x)} to tensor" + + +class dtype: + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + def __init__(self, name): + if hasattr(name, 'value'): + name = name.value + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 53 + self.primitive_bitwidth = 64 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + +# Some functions have a param named `dtype`, which shadows the `dtype` class. +# We can't change the param name because it is part of function's public API. +# Declare an alias so those functions can still reference the dtype class. +_DtypeClass = dtype + + +class pointer_type(dtype): + + def __init__(self, element_ty: dtype, address_space: int = 1): + if not isinstance(element_ty, dtype): + raise TypeError(f'element_ty is a {type(element_ty).__name__}.') + self.element_ty = element_ty + self.address_space = address_space + + self.name = f'pointer<{element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class const_pointer_type(pointer_type): + + def __init__(self, element_ty: dtype, address_space: int = 1): + super().__init__(element_ty, address_space) + + def __str__(self): + return f'const_pointer<{self.element_ty}>' + + def is_const(self): + return True + + def __eq__(self, other) -> bool: + if not isinstance(other, const_pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + + # shape can be empty ([]) when an input is a 0D tensor. + if not shape: + raise TypeError('0d block_type is forbidden') + if isinstance(shape[0], constexpr): + shape = [s.value for s in shape] + + self.shape = shape + self.numel = 1 + for s in self.shape: + self.numel *= s + if self.numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: + self.ret_types = ret_types + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_types}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(ir_param_types, ret_types) + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') +float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') +float8e4b15 = dtype('fp8e4b15') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _constexpr_to_value + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _constexpr_to_value(other)) + + def __radd__(self, other): + return constexpr(_constexpr_to_value(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _constexpr_to_value(other)) + + def __rsub__(self, other): + return constexpr(_constexpr_to_value(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _constexpr_to_value(other)) + + def __mod__(self, other): + return constexpr(self.value % _constexpr_to_value(other)) + + def __rmul__(self, other): + return constexpr(_constexpr_to_value(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _constexpr_to_value(other)) + + def __rtruediv__(self, other): + return constexpr(_constexpr_to_value(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _constexpr_to_value(other)) + + def __rfloordiv__(self, other): + return constexpr(_constexpr_to_value(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _constexpr_to_value(other)) + + def __rgt__(self, other): + return constexpr(_constexpr_to_value(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _constexpr_to_value(other)) + + def __rge__(self, other): + return constexpr(_constexpr_to_value(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _constexpr_to_value(other)) + + def __rlt__(self, other): + return constexpr(_constexpr_to_value(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _constexpr_to_value(other)) + + def __rle__(self, other): + return constexpr(_constexpr_to_value(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _constexpr_to_value(other)) + + def __ne__(self, other): + return constexpr(self.value != _constexpr_to_value(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _constexpr_to_value(other)) + + def logical_and(self, other): + return constexpr(self.value and _constexpr_to_value(other)) + + def __or__(self, other): + return constexpr(self.value | _constexpr_to_value(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _constexpr_to_value(other)) + + def logical_or(self, other): + return constexpr(self.value or _constexpr_to_value(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_constexpr_to_value(other)) + + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _constexpr_to_value(other)) + + def __lshift__(self, other): + return constexpr(self.value << _constexpr_to_value(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +CONSTEXPR_0 = constexpr(0) + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +class tensor: + """Represents an N-dimensional array of values or pointers. + + :code:`tensor` is the fundamental data structure in Triton programs. Most + functions in :py:mod:`triton.language` operate on and return tensors. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + :code:`tensor` also defines most of the magic/dunder methods, so you can + write :code:`x+y`, :code:`x << 2`, etc. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, type: dtype): + """Not called by user code.""" + # IR handle + self.handle = handle + # Block shape + self.shape = type.shape if type.is_block() else () + self.numel = 1 + for s in self.shape: + self.numel *= s + self.numel = constexpr(self.numel) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = [constexpr(s) for s in self.shape] + + def __str__(self) -> str: + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.add(self, other, _builder) + + @builtin + def __radd__(self, other, _builder=None): + return self.__add__(other, _builder=_builder) + + @builtin + def __sub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(self, other, _builder) + + @builtin + def __rsub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(other, self, _builder) + + @builtin + def __mul__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mul(self, other, _builder) + + @builtin + def __rmul__(self, other, _builder=None): + return self.__mul__(other, _builder=_builder) + + @builtin + def __truediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(self, other, _builder) + + @builtin + def __rtruediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(other, self, _builder) + + @builtin + def __floordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(self, other, _builder) + + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(other, self, _builder) + + @builtin + def __mod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(self, other, _builder) + + @builtin + def __rmod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(other, self, _builder) + + # unary operators + @builtin + def __neg__(self, _builder=None): + return semantic.minus(self, _builder) + + @builtin + def __invert__(self, _builder=None): + return semantic.invert(self, _builder) + + # bitwise operators + + @builtin + def __and__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(self, other, _builder) + + @builtin + def __rand__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(other, self, _builder) + + @builtin + def __or__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(self, other, _builder) + + @builtin + def __ror__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(other, self, _builder) + + @builtin + def __xor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(self, other, _builder) + + @builtin + def __rxor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(other, self, _builder) + + @builtin + def __lshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + return semantic.shl(self, other, _builder) + + @builtin + def __rlshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + return semantic.shl(other, self, _builder) + + @builtin + def __rshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + + @builtin + def __rrshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(other, self, _builder) + else: + return semantic.lshr(other, self, _builder) + + # > + @builtin + def __gt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) + + @builtin + def __rgt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) + + # >= + @builtin + def __ge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + + @builtin + def __rge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) + + # < + @builtin + def __lt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) + + @builtin + def __rlt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(other, self, _builder) + + # <= + @builtin + def __le__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) + + @builtin + def __rle__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) + + # == + @builtin + def __eq__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(self, other, _builder) + + @builtin + def __req__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(other, self, _builder) + + @builtin + def __ne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) + + @builtin + def __rne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(other, self, _builder) + + @builtin + def logical_and(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_and(self, other, _builder) + + @builtin + def logical_or(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_or(self, other, _builder) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _builder=None): + return semantic.not_(self, _builder) + + @builtin + def __getitem__(self, slices, _builder=None): + if isinstance(slices, (slice, constexpr)) or slices is None: + slices = [slices] + ret = self + for dim, sl in enumerate(slices): + if sl is None or isinstance(sl, constexpr) and sl.value is None: + ret = semantic.expand_dims(ret, dim, _builder) + elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + pass + else: + raise ValueError(f"unsupported tensor index: {sl}") + return ret + + @property + def T(self): + """Transposes a 2D tensor.""" + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Alias for :py:func:`tensor.cast`. + """ + # Triton doesn't like core functions calling other core functions, so we + # just copy-paste the implementation of cast here. It's not too bad. + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder, fp_downcast_rounding) + + # Type stubs for functions added by the _tensor_member_fn decorator. + # (Unfortunately these can't be created automatically.) + # + # We couldn't write these definitions out even if we wanted to, because some + # of these functions are defined in standard.py. + def broadcast_to(self, *shape) -> tensor: + ... + + def trans(self, *dims) -> tensor: + ... + + def permute(self, *dims) -> tensor: + ... + + def split(self) -> tuple[tensor, tensor]: + ... + + def view(self, *shape) -> tensor: + ... + + def reshape(self, *shape) -> tensor: + ... + + def expand_dims(self, axis) -> tensor: + ... + + def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor: + ... + + def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor: + ... + + def advance(self, offsets) -> tensor: + ... + + def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor: + ... + + def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def exp(self) -> tensor: + ... + + def log(self) -> tensor: + ... + + def cos(self) -> tensor: + ... + + def sin(self) -> tensor: + ... + + def sqrt(self) -> tensor: + ... + + def rsqrt(self) -> tensor: + ... + + def abs(self) -> tensor: + ... + + def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: + ... + + def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: + ... + + def histogram(self, num_bins) -> tensor: + ... + + def cdiv(self, div) -> tensor: + ... + + def sigmoid(self) -> tensor: + ... + + def softmax(self, ieee_rounding=False) -> tensor: + ... + + def ravel(self) -> tensor: + ... + + def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def xor_sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def cumsum(self, axis=0, reverse=False) -> tensor: + ... + + def cumprod(self, axis=0, reverse=False) -> tensor: + ... + + def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor: + ... + + def flip(self, dim=None) -> tensor: + ... + + +def get_bool_env_var(var_name): + v = os.getenv(var_name, "0") + return v == "1" or v == "true" or v == "on" + + +# ----------------------- +# SPMD Programming Model +# ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v + + +@builtin +def program_id(axis, _builder=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(0, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) + + +@builtin +def num_programs(axis, _builder=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _builder=None): + """ + Returns contiguous values within the half-open interval :code:`[start, + end)`. :code:`end - start` must be less than or equal to + :code:`TRITON_MAX_TENSOR_NUMEL = 131072` + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two greater than + :code:`start`. + :type end: int32 + """ + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) + + +def _shape_check_impl(shape): + shape = _constexpr_to_value(shape) + for i, d in enumerate(shape): + if isinstance(d, int): + d = constexpr(d) + if not isinstance(d, constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + # ===-------------------- For Triton XPU -----------------------=== + # Triton XPU Don't need the power-of-two limitation + # if d.value & (d.value - 1) != 0: + # raise ValueError(f"Shape element {i} must be a power of 2") + # ===-----------------------------------------------------------=== + return [_constexpr_to_value(x) for x in shape] + + +@builtin +def full(shape, value, dtype, _builder=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :value value: A scalar value to fill the array with + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + shape = _shape_check_impl(shape) + value = _constexpr_to_value(value) + dtype = _constexpr_to_value(dtype) + return semantic.full(shape, value, dtype, _builder) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _builder=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return semantic.broadcast_impl_value(input, other, _builder) + + +@_tensor_member_fn +@builtin +def broadcast_to(input, *shape, _builder=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + broadcast_to(x, (32, 32)) + broadcast_to(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.broadcast_impl_shape(input, shape, _builder) + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + If no permutation is specified, tries to do a (1,0) permutation, i.e. tries + to transpose a 2D tensor. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + if not dims: + dims = (1, 0) + return semantic.permute(input, dims, _builder) + + +@_tensor_member_fn +@builtin +def permute(input, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + :param input: The input tensor. + :type input: Block + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + permute(x, (2, 1, 0)) + permute(x, 2, 1, 0) + + :py:func:`trans` is equivalent to this function, except when + :code:`dims` is empty, it tries to do a (1,0) permutation. + """ + dims = _unwrap_iterable(dims) + return semantic.permute(input, dims, _builder) + + +@builtin +def cat(input, other, can_reorder=False, _builder=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: + :param other: The second input tensor. + :type other: + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops) + """ + return semantic.cat(input, other, can_reorder, _builder) + + +@builtin +def join(a, b, _builder=None): + """ + Join the given tensors in a new, minor dimension. + + For example, given two tensors of shape (4,8), produces a new tensor of + shape (4,8,2). Given two scalars, returns a tensor of shape (2). + + The two inputs are broadcasted to be the same shape. + + If you want to join more than two elements, you can use multiple calls to + this function. This reflects the constraint in Triton that tensors must + have power-of-two sizes. + + join is the inverse of split. + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + return semantic.join(a, b, _builder) + + +@jit +def _take_first(a, b): + return a + + +@_tensor_member_fn +@builtin +def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: + """ + Split a tensor in two along its last dim, which must have size 2. + + For example, given a tensor of shape (4,8,2), produces two tensors of shape + (4,8). Given a tensor of shape (2), returns two scalars. + + If you want to split into more than two pieces, you can use multiple calls + to this function (probably plus calling reshape). This reflects the + constraint in Triton that tensors must have power-of-two sizes. + + split is the inverse of join. + + :param a: The tensor to split. + :type a: Tensor + """ + # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars. + # But semantic.split can only handle returning tensors. Work around this by + # expanding the input to shape [1,2] and then reducing the result. + was_rank_1 = len(a.shape) == 1 + if was_rank_1: + a = semantic.expand_dims(a, 0, _builder) + + out_lhs, out_rhs = semantic.split(a, _builder) + + if was_rank_1: + # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar. + out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator)) + out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator)) + + return out_lhs, out_rhs + + +@_tensor_member_fn +@builtin +def view(input, *shape, _builder=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + view(x, (32, 32)) + view(x, 32, 32) + """ + warn("view is deprecated, please use reshape with can_reorder being true.") + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder=True, builder=_builder) + + +@_tensor_member_fn +@builtin +def reshape(input, *shape, can_reorder=False, _builder=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: Block + :param shape: The new shape. + + :code:`shape ` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + reshape(x, (32, 32)) + reshape(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder, _builder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@_tensor_member_fn +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + input = _to_tensor(input, _builder) + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, Sequence) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + """ + input = _to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + return semantic.cast(input, dtype, _builder, fp_downcast_rounding) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _builder=None): + """ + Returns the matrix product of two blocks. + + The two blocks must be two-dimensional and have compatible inner dimensions. + + :param input: The first tensor to be multiplied. + :type input: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param input_precision: How to exercise the Tensor Cores for f32 x f32. If + the device does not have Tensor Cores or the inputs are not of dtype f32, + this option is ignored. For devices that do have tensor cores, the + default precision is tf32. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + Only one of :code:`input_precision` and :code:`allow_tf32` can be + specified (i.e. at least one must be :code:`None`). + """ + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + if input_precision is None: + supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions + default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) + + input_precision = _constexpr_to_value(input_precision) + out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, _builder=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be None, and + - `boundary_check` and `padding_option` can be specified to control + the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + # `mask` and `other` can be constexpr + mask = _constexpr_to_value(mask) + other = _constexpr_to_value(other) + if mask is not None: + mask = _to_tensor(mask, _builder) + if other is not None: + other = _to_tensor(other, _builder) + padding_option = _constexpr_to_value(padding_option) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile, _builder) + + +@builtin +def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None): + """ + Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This loads a tensor of data based on the descriptor and offsets. + """ + type = block_type(dtype, shape) + return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) + + +@builtin +def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): + """ + Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This stores a tensor of data based on the descriptor and offsets. + """ + return semantic.descriptor_store(desc_pointer, value, offsets, _builder) + + +@_tensor_member_fn +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): + """ + Store a tensor of data into memory locations defined by `pointer`. + + (1) If `pointer` is a single element pointer, a scalar is stored. In + this case: + + - `mask` must also be scalar, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional block is stored. In this case: + + - `mask` is implicitly broadcast to `pointer.shape`, and + - `boundary_check` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block + of data is stored. In this case: + + - `mask` must be None, and + - `boundary_check` can be specified to control the behavior of out-of-bound access. + + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + """ + # `value` can be constexpr + value = _to_tensor(value, _builder) + mask = _constexpr_to_value(mask) + if mask is not None: + mask = _to_tensor(mask, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder) + + +@_tensor_member_fn +@builtin +def advance(base, offsets, _builder=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return semantic.advance(base, offsets, _builder) + + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = f""" + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default), + "ACQUIRE", "RELEASE", or "RELAXED") + :type sem: str + :param scope: Scope of threads that observe synchronizing effect of the + atomic operation ("GPU" (default), "CTA", or "SYSTEM") + :type scope: str + """ + func.__doc__ = docstr + return func + + return _decorator + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("compare-and-swap", has_cmp=True) +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_add(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_max(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_min(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_and(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_or(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = _to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, _builder=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = _to_tensor(condition, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.where(condition, x, y, _builder) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.minimum(x, y, propagate_nan, _builder) + + +@builtin +def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.maximum(x, y, propagate_nan, _builder) + + +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = _to_tensor(x, _builder) + min = _to_tensor(min, _builder) + max = _to_tensor(max, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + min = _promote_bfloat16_to_float32(min, _builder=_builder) + max = _promote_bfloat16_to_float32(max, _builder=_builder) + + propagate_nan = _constexpr_to_value(propagate_nan) + + return semantic.clamp(x, min, max, propagate_nan, _builder) + + +# ----------------------- +# Reductions +# ----------------------- + + +def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the reduction should be done + :param keep_dims: if true, keep the reduced dimensions with length 1""" + if return_indices_arg is not None: + docstr += f""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value""" + if tie_break_arg is not None: + docstr += f""" + :param {tie_break_arg}: if true, return the left-most indices in case of ties for values that aren't NaN""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@_tensor_member_fn +@builtin +def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param keep_dims: if true, keep the reduced dimensions with length 1 + + """ + if isinstance(input, tensor): + return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) + + def expand_ndims(t, ndims): + for _ in builtins.range(ndims): + t = expand_dims(t, 0, _builder=_builder) + return t + + axis = _constexpr_to_value(axis) + keep_dims = _constexpr_to_value(keep_dims) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + ret = semantic.reduction(input, axis, make_combine_region, _builder) + if keep_dims: + if axis is not None: + ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret) + else: + ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + return ret + + +@builtin +def _promote_bfloat16_to_float32(t, _builder=None): + scalar_ty = t.type.scalar + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) + return t + + +@builtin +def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, + _generator=_generator) + return rvalue, rindices + + +# ----------------------- +# Scans +# ----------------------- + + +def _add_scan_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the scan should be done""" + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@_tensor_member_fn +@builtin +def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None): + """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :param reverse: apply the associative scan in the reverse direction along axis. + + """ + if isinstance(input, tensor): + return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(scan_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = scan_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_scan_ret(*handles) + + axis = _constexpr_to_value(axis) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder) + + +@_tensor_member_fn +@builtin +def histogram(input, num_bins, _builder=None, _generator=None): + """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. + + :param input: the input tensor + :param num_bins: number of histogram bins + + """ + num_bins = _constexpr_to_value(num_bins) + return semantic.histogram(input, num_bins, _builder) + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_builder=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return semantic.debug_barrier(_builder) + + +@builtin +def multiple_of(input, values, _builder=None): + """ + Let the compiler know that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) + + +@builtin +def max_constancy(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are constant. + + e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal, + for example [0, 0, 0, 0, 1, 1, 1, 1]. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_constancy(input, values) + + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + + NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`, + which has special requirements for the arguments. + + .. highlight:: python + .. code-block:: python + + tl.static_print(f"{BLOCK_SIZE=}") + ''' + pass + + +@builtin +def static_assert(cond, msg="", _builder=None): + ''' + Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable + is set. + + .. highlight:: python + .. code-block:: python + + tl.static_assert(BLOCK_SIZE == 1024) + ''' + pass + + +@builtin +def device_print(prefix, *args, hex=False, _builder=None): + ''' + Print the values at runtime from the device. String formatting does not work for runtime values, so you should + provide the values you want to print as arguments. The first value must be a string, all following values must + be scalars or tensors. + + Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match + this function (not the normal requirements for :code:`print`). + + .. highlight:: python + .. code-block:: python + + tl.device_print("pid", pid) + print("pid", pid) + + On CUDA, printfs are streamed through a buffer of limited size (on one host, + we measured the default as 6912 KiB, but this may not be consistent across + GPUs and CUDA versions). If you notice some printfs are being dropped, you + can increase the buffer size by calling + + .. highlight:: python + .. code-block:: python + + triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes) + + CUDA may raise an error if you try to change this value after running a + kernel that uses printfs. The value set here may only affect the current + device (so if you have multiple GPUs, you'd need to call it multiple times). + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + :param hex: print all values as hex instead of decimal + ''' + import string + prefix = _constexpr_to_value(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_to_tensor(arg, _builder)) + return semantic.device_print(prefix, new_args, hex, _builder) + + +@builtin +def device_assert(cond, msg="", _builder=None): + ''' + Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG` + is set to a value besides :code:`0` in order for this to have any effect. + + Using the Python :code:`assert` statement is the same as calling this function, except that the second argument + must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must + be set for this :code:`assert` statement to have any effect. + + .. highlight:: python + .. code-block:: python + + tl.device_assert(pid == 0) + assert pid == 0, f"pid != 0" + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _constexpr_to_value(msg) + import inspect + frame = inspect.currentframe() + module = inspect.getmodule(frame) + # The triton function module doesn't have the name attribute. + # We use this trick to find the caller. + while hasattr(module, "__name__"): + frame = frame.f_back + module = inspect.getmodule(frame) + lineno = 0 + func_name = 'unknown' + file_name = 'unknown' + if frame is not None and frame.f_back is not None: + func_name = frame.f_code.co_name + file_name = frame.f_back.f_code.co_filename + # TODO: The line number currently indicates the line + # where the triton function is called but not where the + # device_assert is called. Need to enhance this. + lineno = frame.f_back.f_lineno + return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder) + + +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]], + is_pure: bool, pack: int, _builder=None): + ''' + Execute inline assembly over a tensor. Essentially, this is :code:`map` + where the function is inline assembly. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + :code:`dtype` can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes :code:`pack` elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty :code:`dtype` -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Example using + [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + assembly: + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + :param asm: assembly to run. Must match target's assembly format. + :param constraints: asm constraints in + [LLVM format](https://llvm.org/docs/LangRef.html#inline-asm-constraint-string) + :param args: the input tensors, whose values are passed to the asm block + :param dtype: the element type(s) of the returned tensor(s) + :param is_pure: if true, the compiler assumes the asm block has no side-effects + :param pack: the number of elements to be processed by one instance of inline assembly + :param _builder: the builder + :return: one tensor or a tuple of tensors of the given dtypes + ''' + asm = _constexpr_to_value(asm) + constraints = _constexpr_to_value(constraints) + pack = _constexpr_to_value(pack) + is_pure = _constexpr_to_value(is_pure) + + # Wrap `dtype` in a tuple if it's not already. + try: + iter(dtype) # type: ignore + has_multiple_outputs = True + except TypeError: + has_multiple_outputs = False + dtype = (dtype, ) # type: ignore + + dtype = typing.cast(Sequence[_DtypeClass], dtype) + + res_tys = dtype + if dispatch_args := [_to_tensor(arg, _builder) for arg in args]: + bin_op_type_checking = partial( + semantic.binary_op_type_checking_impl, + builder=_builder, + arithmetic_check=False, + allow_lhs_ptr=True, + allow_rhs_ptr=True, + ) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = bin_op_type_checking(item, broadcast_arg) + if broadcast_arg.shape: + # Change the shape of each argument based on the broadcast shape + for i, item in enumerate(dispatch_args): + dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg) + res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype] + handles = [t.handle for t in dispatch_args] + call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack) + + if not has_multiple_outputs: + return tensor(call.get_result(0), res_tys[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr) + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr) + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + assert isinstance(arg2, constexpr) + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +class range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + +# ----------------------- +# Extern functions +# ----------------------- + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _builder=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = block_type(ret_type, ret_shape) + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) + + +@builtin +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _builder=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + ret_shape = None + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _to_tensor(dispatch_args[i], _builder) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if len(arg_types) > 0: + arg_types = tuple(arg_types) + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in builtins.range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + if not all_scalar: + ret_shape = broadcast_arg.shape + func = _builder.create_extern_elementwise + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder) + + +def binary_op_type_legalization(lhs, rhs, builder): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs, builder) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) diff --git a/third_party/xpu/python/triton/language/extra/__init__.py b/third_party/xpu/python/triton/language/extra/__init__.py new file mode 100644 index 000000000..f3def5345 --- /dev/null +++ b/third_party/xpu/python/triton/language/extra/__init__.py @@ -0,0 +1,5 @@ +from . import cuda +from . import hip +from . import xpu + +__all__ = ['cuda', 'hip', 'xpu'] diff --git a/third_party/xpu/python/triton/language/extra/cuda/__init__.py b/third_party/xpu/python/triton/language/extra/cuda/__init__.py new file mode 100644 index 000000000..3ca510e02 --- /dev/null +++ b/third_party/xpu/python/triton/language/extra/cuda/__init__.py @@ -0,0 +1,8 @@ +from . import libdevice + +from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80) + +__all__ = [ + "libdevice", "globaltimer", "num_threads", "num_warps", "smid", "convert_custom_float8_sm70", + "convert_custom_float8_sm80" +] diff --git a/third_party/xpu/python/triton/language/extra/cuda/libdevice.py b/third_party/xpu/python/triton/language/extra/cuda/libdevice.py new file mode 100644 index 000000000..3490e6b0e --- /dev/null +++ b/third_party/xpu/python/triton/language/extra/cuda/libdevice.py @@ -0,0 +1,1629 @@ +from triton.language import core + + +@core.extern +def clz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def popc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("__nv_umul64hi", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul24(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def brev(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sad(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp64h(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def saturatef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__nv_dmul_ru", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__nv_fmul_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hiloint2double(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2loint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2hiint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_uint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def longlong_as_double(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double_as_longlong(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_sinf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_cosf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log2f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_logf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_tanf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_exp10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_powf(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_nearbyint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isnand", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_signbitd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinpi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_norm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_rnorm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def yn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def jn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def scalbn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def remainder(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_pow", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def logb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isfinited(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/xpu/python/triton/language/extra/cuda/utils.py b/third_party/xpu/python/triton/language/extra/cuda/utils.py new file mode 100644 index 000000000..01bc040b2 --- /dev/null +++ b/third_party/xpu/python/triton/language/extra/cuda/utils.py @@ -0,0 +1,109 @@ +from triton.language import core + + +@core.extern +def globaltimer(_builder=None): + return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1, + _builder=_builder) + + +@core.extern +def smid(_builder=None): + return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, + _builder=_builder) + + +@core.builtin +def num_threads(_builder=None): + return core.constexpr(_builder.options.num_warps * 32) + + +@core.builtin +def num_warps(_builder=None): + return core.constexpr(_builder.options.num_warps) + + +# ----- FP8E4M3B15 ------ +# This data-type is a variant of the standard FP8E4M3 format. +# It was designed for fast software conversion to FP16 on +# nvidia GPUs that do not support it natively. +# This is the same format as FP8E4M3Nv, but: +# - the exponent bias is 15 instead of 7 +# - 0xff and 0x7f are mapped to +-1.750 instead of +-nan +@core.builtin +def convert_fp8e4b15_to_float16(arg, _builder=None): + return core.inline_asm_elementwise( + "{ \n" + ".reg .b32 a<2>, b<2>; \n" + "prmt.b32 a0, 0, $2, 0x5746; \n" + "and.b32 b0, a0, 0x7f007f00; \n" + "and.b32 b1, a0, 0x00ff00ff; \n" + "and.b32 a1, a0, 0x00800080; \n" + "shr.b32 b0, b0, 1; \n" + "add.u32 b1, b1, a1; \n" + "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" + "shl.b32 $1, b1, 7; \n" + "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None): + asm = """{ + .reg .pred p<4>; + .reg .b32 a<2>, b<2>; + .reg .b16 c<4>; + .reg .b16 max_val_f16; + .reg .b32 max_val_f16x2; + mov.b16 max_val_f16, 0x3F00; + mov.b32 max_val_f16x2, 0x3F003F00; + and.b32 a0, $1, 0x7fff7fff; + and.b32 a1, $2, 0x7fff7fff;""" + if has_minx2: + asm += """min.f16x2 a0, a0, max_val_f16x2; + min.f16x2 a1, a1, max_val_f16x2;""" + else: + asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2; + setp.lt.f16x2 p2|p3, a1, max_val_f16x2; + mov.b32 {c0, c1}, a0; + mov.b32 {c2, c3}, a1; + selp.b16 c0, c0, max_val_f16, p0; + selp.b16 c1, c1, max_val_f16, p1; + selp.b16 c2, c2, max_val_f16, p2; + selp.b16 c3, c3, max_val_f16, p3; + mov.b32 a0, {c0, c1}; + mov.b32 a1, {c2, c3};""" + asm += """mad.lo.u32 a0, a0, 2, 0x00800080; + mad.lo.u32 a1, a1, 2, 0x00800080; + lop3.b32 b0, $1, 0x80008000, a0, 0xea; + lop3.b32 b1, $2, 0x80008000, a1, 0xea; + prmt.b32 $0, b0, b1, 0x7531; + }""" + return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4, + _builder=_builder) + + +@core.builtin +def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _builder=None): + if arg.type.scalar.is_fp8e4b15(): + upcast_val = convert_fp8e4b15_to_float16(arg, _builder=_builder) + if dst_ty.scalar.is_fp32(): + upcast_val = upcast_val.to(core.float32, _builder=_builder) + return upcast_val + + assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32() + downcast_val = arg + if arg.type.scalar.is_fp32(): + downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _builder=_builder) + downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _builder=_builder) + return downcast_val + + +@core.builtin +def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _builder=_builder) + + +@core.builtin +def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _builder=None): + return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _builder=_builder) diff --git a/third_party/xpu/python/triton/language/extra/hip/__init__.py b/third_party/xpu/python/triton/language/extra/hip/__init__.py new file mode 100644 index 000000000..229b57d87 --- /dev/null +++ b/third_party/xpu/python/triton/language/extra/hip/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/third_party/xpu/python/triton/language/extra/hip/libdevice.py b/third_party/xpu/python/triton/language/extra/hip/libdevice.py new file mode 100644 index 000000000..02e5d2d0b --- /dev/null +++ b/third_party/xpu/python/triton/language/extra/hip/libdevice.py @@ -0,0 +1,468 @@ +from triton.language import core + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")), + (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")), + (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")), + (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/xpu/python/triton/language/extra/libdevice.py b/third_party/xpu/python/triton/language/extra/libdevice.py new file mode 100644 index 000000000..55267a4aa --- /dev/null +++ b/third_party/xpu/python/triton/language/extra/libdevice.py @@ -0,0 +1,1220 @@ +# ===-------------------- For Triton XPU -----------------------=== +from .xpu import libdevice as xpu_libdevice +# ===-----------------------------------------------------------=== +from .cuda import libdevice as cuda_libdevice +from .hip import libdevice as hip_libdevice +from triton.language import core +from functools import wraps +from typing import TypeVar + +T = TypeVar('T') + + +def dispatch(fn: T) -> T: + """Dispatch a function to a correct implementation.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + _backend = kwargs["_builder"].options.backend_name + # ===-------------------- For Triton XPU -----------------------=== + if _backend == 'xpu': + _curr_libdevice_module = xpu_libdevice + # ===-----------------------------------------------------------=== + elif _backend == 'cuda': + _curr_libdevice_module = cuda_libdevice + elif _backend == 'hip': + _curr_libdevice_module = hip_libdevice + else: + raise RuntimeError('unknown backend') + + try: + _impl = getattr(_curr_libdevice_module, fn.__name__) + except AttributeError: + raise RuntimeError(f'`{_backend}` does not provide support for `{fn.__name__}` extra function') + + return _impl(*args, **kwargs) + + return wrapper + + +@core.extern +@dispatch +def clz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def popc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def byte_perm(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def mulhi(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul24(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def brev(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sad(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def abs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def floor(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp64h(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ceil(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def trunc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def saturatef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rn(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rz(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_rd(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fma_ru(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def fast_dividef(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def div_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcp_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sqrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def add_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def add_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def mul_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2int_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2uint_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hiloint2double(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def double2loint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2hiint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ll_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double2ull_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2float_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ll2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rz(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_rd(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ull2double_ru(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def int_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_int(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def uint_as_float(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def float_as_uint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def longlong_as_double(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def double_as_longlong(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_sinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_cosf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log2f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_logf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_expf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_tanf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_exp10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_log10f(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fast_powf(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def hadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhadd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rz(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_rd(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sub_ru(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rsqrt_rn(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ffs(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llrint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nearbyint(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isnan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def signbit(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def copysign(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def finitef(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isinf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def nextafter(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def sin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinpi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cospi(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log2(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def exp10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def sinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def tanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atan2(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def atan(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asin(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acos(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log10(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def log1p(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def acosh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def asinh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def atanh(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def expm1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def hypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def rhypot(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def norm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm3d(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + ... + + +@core.extern +@dispatch +def cbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def rcbrt(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def j1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def y1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def yn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def jn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i0(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def cyl_bessel_i1(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfc(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcx(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def erfcinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdfinv(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def normcdf(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def lgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def ldexp(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def scalbn(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fmod(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def remainder(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def fma(arg0, arg1, arg2, _builder=None): + ... + + +@core.extern +@dispatch +def pow(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def tgamma(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def round(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def llround(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def fdim(arg0, arg1, _builder=None): + ... + + +@core.extern +@dispatch +def ilogb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def logb(arg0, _builder=None): + ... + + +@core.extern +@dispatch +def isfinited(arg0, _builder=None): + ... diff --git a/third_party/xpu/python/triton/language/extra/xpu/__init__.py b/third_party/xpu/python/triton/language/extra/xpu/__init__.py new file mode 100644 index 000000000..229b57d87 --- /dev/null +++ b/third_party/xpu/python/triton/language/extra/xpu/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/third_party/xpu/python/triton/language/extra/xpu/libdevice.py b/third_party/xpu/python/triton/language/extra/xpu/libdevice.py new file mode 100644 index 000000000..66006339c --- /dev/null +++ b/third_party/xpu/python/triton/language/extra/xpu/libdevice.py @@ -0,0 +1,1650 @@ +from triton.language import core + + +@core.extern +def clz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def popc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("_ZN3xpu6umulhiEjj", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64")): ("Unsupported", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("Unsupported", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul24(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("Unsupported", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def brev(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sad(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("Unsupported", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int64")), + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu9xpu_floorEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("_ZN3xpu9xpu_floorEd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp64h(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu6hrsqrtEDF16_", core.dtype("fp16")), + (core.dtype("fp32"), ): ("_ZN3xpu6rsqrtfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("_ZN3xpu6rsqrtfEf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("_ZN3xpu8xpu_ceilEd", core.dtype("fp64")), + (core.dtype("fp32"), ): ("_ZN3xpu8xpu_ceilEf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), ): ("_ZN3xpu6truncfEf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5exp2fEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def saturatef(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu9__fdiv_rnEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu9__fdiv_rzEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu10__fsqrt_rnEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("_ZN3xpu10__dsqrt_rnEd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu8xpu_sqrtEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("_ZN3xpu8xpu_sqrtEd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("Unsupported", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hiloint2double(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2loint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2hiint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rz(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rd(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_ru(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint_as_float(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("uint32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_uint(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def longlong_as_double(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("int64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double_as_longlong(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_sinf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_cosf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log2f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_logf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_tanf(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_exp10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log10f(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_powf(arg0, arg1, _builder=None): + return core.extern_elementwise("", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("Unsupported", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhadd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("Unsupported", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("Unsupported", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("int32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("int64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("_ZN3xpu4rintEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("_ZN3xpu9nearbyintEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp16"), ): ("_ZN3xpu6hisnanEDF16_", core.dtype("int32")), + (core.dtype("fp32"), ): ("_ZN3xpu5isnanEf", core.dtype("int32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise( + "", "", [ + arg0, + ], { + (core.dtype("fp32"), ): ("_ZN3xpu10__signbitfEf", core.dtype("int32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu7hfiniteEDF16_", core.dtype("int16")), + (core.dtype("fp32"), ): ("_ZN3xpu7finitefEf", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu4hsinEDF16_", core.dtype("int32")), + (core.dtype("fp32"), ): ("_ZN3xpu5isinfEf", core.dtype("int32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu4sinfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu4cosfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinpi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu4tanfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu5hlog2EDF16_", core.dtype("fp16")), + (core.dtype("fp32"), ): ("_ZN3xpu5log2fEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5coshfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5sinhfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp16"), ): ("_ZN3xpu5htanhEDF16_", core.dtype("fp16")), + (core.dtype("fp32"), ): ("_ZN3xpu5tanhfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu6atan2fEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5atanfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5asinfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu5acosfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6log10fEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6log1pfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6acoshfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6asinhfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6atanhfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6expm1fEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhypot(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcbrt(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def yn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def jn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu3erfEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu6erfinvEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("_ZN3xpu4erfcEf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdfinv(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdf(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def scalbn(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu5fmodfEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def remainder(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu3fmaEfff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("Unsupported", core.dtype("fp64")), + (core.dtype("fp16"), core.dtype("fp16")): ("_ZN3xpu4hpowEDF16_DF16_", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu3powEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tgamma(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int64")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("int32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def logb(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Unsupported", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isfinited(arg0, _builder=None): + return core.extern_elementwise("", "", [arg0], { + (core.dtype("fp64"), ): ("Unsupported", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def xpu_trunc_div(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("_ZN3xpu9xpu_truncEff", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Unsupported", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) diff --git a/third_party/xpu/python/triton/language/math.py b/third_party/xpu/python/triton/language/math.py new file mode 100644 index 000000000..de5b5be6b --- /dev/null +++ b/third_party/xpu/python/triton/language/math.py @@ -0,0 +1,250 @@ +from . import core +from . import semantic +from functools import wraps +from typing import List + +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x` and :code:`y`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + :param z: the input values + :type z: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + + +@core.builtin +@_add_math_1arg_docstr("absolute value") +@core._tensor_member_fn +def abs(x, _builder=None): + x = core._to_tensor(x, _builder) + dtype = x.dtype + if dtype.is_fp8e4b15(): + mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder) + return core.tensor(_builder.create_and(x.handle, mask.handle), x.type) + elif dtype.is_floating(): + return core.tensor(_builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return core.tensor(_builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +@core.builtin +@_add_math_2arg_docstr("fast division") +def fdiv(x, y, ieee_rounding=False, _builder=None): + ieee_rounding = core._constexpr_to_value(ieee_rounding) + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + return semantic.fdiv(x, y, ieee_rounding, _builder) + + +@core.builtin +@_check_dtype(dtypes=["fp32"]) +@_add_math_2arg_docstr("precise division (rounding to nearest)") +def div_rn(x, y, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["fp32", "fp64"]) +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = core._to_tensor(x, _builder) + return core.tensor(_builder.create_ceil(x.handle), x.type) + + +@core.builtin +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = core._to_tensor(x, _builder) + y = core._to_tensor(y, _builder) + z = core._to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) diff --git a/third_party/xpu/python/triton/language/random.py b/third_party/xpu/python/triton/language/random.py new file mode 100644 index 000000000..430aeb09e --- /dev/null +++ b/third_party/xpu/python/triton/language/random.py @@ -0,0 +1,207 @@ +from ..runtime.jit import jit +from . import core as tl +from . import math + +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +# ------------------- +# randint +# ------------------- + + +@jit +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + if c0.dtype == tl.uint32: + PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 + PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 + PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 + else: + tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl") + PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15 + PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B + PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93 + PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157 + + for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = math.umulhi(B, _c2) ^ c1 ^ k0 + c2 = math.umulhi(A, _c0) ^ c3 ^ k1 + c1 = B * _c2 + c3 = A * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A + k1 = k1 + PHILOX_KEY_B + return c0, c1, c2, c3 + + +@jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = tl.to_tensor(seed) + c0 = tl.to_tensor(c0) + c1 = tl.to_tensor(c1) + c2 = tl.to_tensor(c2) + c3 = tl.to_tensor(c3) + seed = seed.to(tl.uint64) + if tl.constexpr(c0.dtype.primitive_bitwidth) == 32: + int_dtype = tl.uint32 + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + else: + tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox") + int_dtype = tl.uint64 + seed_hi = tl.full((1, ), 0, dtype=int_dtype) + seed_lo = seed + c0 = c0.to(int_dtype, bitcast=True) + c1 = c1.to(int_dtype, bitcast=True) + c2 = c2.to(int_dtype, bitcast=True) + c3 = c3.to(int_dtype, bitcast=True) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +@jit +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offset: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +@jit +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + # _0 = tl.zeros(offset.shape, offset.dtype) + _0 = offset * 0 + return philox(seed, offset, _0, _0, _0, n_rounds) + + +# ------------------- +# rand +# ------------------- + +# @jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + + +@jit +def uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + # TODO: fix frontend issues and cleanup + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32): + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + x = x.to(tl.int32, bitcast=True) + scale = 4.6566127342e-10 + else: + tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)) + x = x.to(tl.int64, bitcast=True) + scale = 1.0842020432385337e-19 + x = tl.where(x < 0, -x - 1, x) + return x * scale + + +@jit +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + source = randint(seed, offset, n_rounds) + return uint_to_uniform_float(source) + + +@jit +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + u3 = uint_to_uniform_float(i3) + u4 = uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +# ------------------- +# randn +# ------------------- + + +@jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = math.sqrt(-2.0 * math.log(u1)) + return r * math.cos(th), r * math.sin(th) + + +@jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint_to_uniform_float(i1) + u2 = uint_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/third_party/xpu/python/triton/language/semantic.py b/third_party/xpu/python/triton/language/semantic.py new file mode 100644 index 000000000..1c8217c6f --- /dev/null +++ b/third_party/xpu/python/triton/language/semantic.py @@ -0,0 +1,1674 @@ +from __future__ import annotations # remove after python 3.11 + +from typing import List, Optional, Sequence, Tuple, TypeVar + +from .._C.libtriton import ir +from . import core as tl +from . import math + +T = TypeVar('T') + + +class IncompatibleTypeErrorImpl(Exception): + + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") + + +def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 + if not a_ty.is_int() or not b_ty.is_int(): + raise TypeError(f"unexpected type {a_ty} and {b_ty}") + # 5 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False, + allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) + lhs = cast(lhs, ret_sca_ty, builder) + rhs = cast(rhs, ret_sca_ty, builder) + return lhs, rhs + + +def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise TypeError("cannot add pointers together") + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # * int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to the highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + raise TypeError(f"unexpected type {input_scalar_ty}") + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise TypeError("both operands of fdiv must have floating scalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + # ===-------------------- For Triton XPU -----------------------=== + from .extra.xpu.libdevice import fmod + r = fmod(input, other, _builder=builder) + zero = full([], 0.0, scalar_ty, builder) + c1 = not_equal(r, zero, builder) + c2_l = less_than(input, zero, builder) + c2_r = less_than(other, zero, builder) + c2 = xor_(c2_l, c2_r, builder) + ret = where(and_(c1, c2, builder), add(r, other, builder), r, builder) + return ret + # ===-----------------------------------------------------------=== + + # input - input.div(other, rounding_mode="floor") * other + ret = sub(input, mul(math.floor(fdiv(input, other, False, builder), _builder=builder), other, builder), builder) + return ret + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +############## +# other arithmetic ops +############## + + +def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") + + +def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + min, max = binary_op_type_checking_impl(min, max, builder) + x, min = binary_op_type_checking_impl(x, min, builder) + x, max = binary_op_type_checking_impl(x, max, builder) + + dtype = x.dtype + if dtype.is_floating(): + return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") + + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return and_(input, other, builder) + + +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return or_(input, other, builder) + + +def not_(input: tl.tensor, builder: ir.builder): + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + return invert(input, builder) + + +def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, builder) + + +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + # ===-------------------- For Triton XPU -----------------------=== + # Triton XPU Don't need the power-of-two limitation + # if (range & (range - 1)) != 0: + # raise ValueError("arange's range must be a power of 2") + # ===-----------------------------------------------------------=== + shape = [range] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) + + +def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + value = cast(value, dtype, builder) + else: + # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + value = tl.tensor(value, dtype) + + return splat(value, shape, builder) + + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + + +def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("reshape() cannot change total number of elements in tensor") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) + + +def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] + dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type) + + +def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor: + a, b = broadcast_impl_value(a, b, builder) + + # The IR can't handle joining two scalars, so upcast them to 1D tensors, + # then downcast the result. + was_rank_1 = a.shape == [] + if was_rank_1: + a = expand_dims(a, 0, builder) + b = expand_dims(b, 0, builder) + + if isinstance(a.shape[-1], tl.constexpr): + two = tl.constexpr(2) + else: + two = 2 + new_shape = a.shape + [two] + + ret_type = tl.block_type(a.type.scalar, new_shape) + ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type) + + if was_rank_1: + ret = reshape(ret, [2], can_reorder=False, builder=builder) + + return ret + + +def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + assert (len(a.shape) > 0) + assert (tl._constexpr_to_value(a.shape[-1]) == 2) + + new_shape = a.shape[:-1] + ret_type = tl.block_type(a.type.scalar, new_shape) + outLHS, outRHS = builder.create_split(a.handle) + return ( + tl.tensor(outLHS, ret_type), + tl.tensor(outRHS, ret_type), + ) + + +def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor: + if len(input.shape) != len(dims): + raise ValueError("permute dims must have the same length as input shape") + if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))): + raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") + + ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) + return tl.tensor(builder.create_trans(input.handle, dims), ret_type) + + +def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for _ in range(len(lhs_shape), len(rhs_shape)): + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for _ in range(len(rhs_shape), len(lhs_shape)): + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + + +####### +# cast +####### + + +def _str_to_rounding_mode(rounding_mode: Optional[str]): + if rounding_mode is None: + return None + if rounding_mode == 'rtne': + return ir.ROUNDING_MODE.RTNE + if rounding_mode == 'rtz': + return ir.ROUNDING_MODE.RTZ + raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.") + + +def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, + fp_downcast_rounding: Optional[str] = None) -> tl.tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + + if (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): + assert builder.options.allow_fp8e4nv, "fp8e4nv data type is not supported on CUDA arch < 89" + + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def _str_to_load_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_store_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + elif cache_modifier == ".wt": + cache = ir.CACHE_MODIFIER.WT + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + +def _str_to_padding_option(padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + +def _str_to_sem(sem_option): + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + if sem_option: + if sem_option == "acquire": + sem = ir.MEM_SEMANTIC.ACQUIRE + elif sem_option == "release": + sem = ir.MEM_SEMANTIC.RELEASE + elif sem_option == "acq_rel": + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + elif sem_option == "relaxed": + sem = ir.MEM_SEMANTIC.RELAXED + else: + raise ValueError(f"Memory semantic {sem_option} not supported") + return sem + + +def _str_to_scope(scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + +def _canonicalize_boundary_check(boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return () + + +def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask is not None or other is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + + # Build IR + return tl.tensor( + builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) + + +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other is not None: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast `other` into `ele_ty` type + if other is not None: + other = cast(other, elt_ty, builder) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + # ===-------------------- For Triton XPU -----------------------=== + # TODO[dyq]: Set TRITONXPU_OTHER_SIM To Default Mode + # use where op to simulate load's other attribute + import os + if bool(os.environ.get('TRITONXPU_OTHER_SIM', False)): + if other is None: + return tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, None, cache, eviction, is_volatile), dst_ty) + else: + load_value = tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle, cache, eviction, is_volatile), + dst_ty) + ret = where(mask, load_value, other, builder) + return bitcast(ret, dst_ty, builder) + # ===-----------------------------------------------------------=== + else: + return tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) + + +def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, + padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, + builder: ir.builder) -> tl.tensor: + # Cache, eviction and padding options + cache = _str_to_load_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + padding = _str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + + +def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type, + builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder), + _str_to_load_cache_modifier(cache_modifier), + _str_to_eviction_policy(eviction_policy)) + return tl.tensor(x, type) + + +def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void) + + +def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = broadcast_impl_shape(val, block_shape, builder) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, block_shape) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), + tl.void) + + +def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + # ===-------------------- For Triton XPU -----------------------=== + # TODO[dyq]: Set TRITONXPU_STORE_MASK_SIM To Default Mode + # use where op to simulate store's mask attribute + import os + if bool(os.environ.get('TRITONXPU_STORE_MASK_SIM', False)): + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + load_value = tl.tensor(builder.create_load(ptr.handle, cache, eviction, False), dst_ty) + masked_value = where(mask, val, load_value, builder) + return tl.tensor(builder.create_masked_store(ptr.handle, masked_value.handle, mask.handle, cache, eviction), + tl.void) + # ===-----------------------------------------------------------=== + else: + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) + + +def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: + # Cache and eviction options + cache = _str_to_store_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder) + + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]: + raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val is not None: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_val.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + zero = full([], 0.0, sca_ty, builder) + + i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 + i_val = bitcast(val, i_type, builder) + i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) + ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 + ui_val = bitcast(val, ui_type, builder) + ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, + and_(mask, neg, builder).handle, sem, scope), ui_ptr.type) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) + + +def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + +def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + + +def _str_to_dot_input_precision(input_precision, builder): + assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + return getattr(ir.INPUT_PRECISION, input_precision) + + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + + def assert_dtypes_valid(lhs_dtype, rhs_dtype, options): + if not options.allow_fp8e4nv: + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( + ), "Dot op does not support fp8e4nv on CUDA arch < 90" + if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): + return + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + else: + if lhs_dtype.is_int() or rhs_dtype.is_int(): + assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8( + ), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): + if options.allow_fp8e4b15: + allowed_types = ['fp8e4nv', 'fp8e5', 'fp8e4b15'] + else: + allowed_types = ['fp8e4nv', 'fp8e5'] + + def _validate_dtype(dtype, allowed_types, operand_name): + if not any(getattr(dtype, f'is_{dtype_name}')() for dtype_name in allowed_types): + supported_types = ', '.join(allowed_types) + raise AssertionError(f"Only supports {supported_types}. {operand_name} ({dtype})") + + _validate_dtype(lhs_dtype, allowed_types, "First operand") + _validate_dtype(rhs_dtype, allowed_types, "Second operand") + else: + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1( + ), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1( + ), f"Unsupported dtype {rhs_dtype}" + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + + assert lhs.type.is_block() and rhs.type.is_block() + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options) + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + lhs = cast(lhs, tl.float16, builder) + rhs = cast(rhs, tl.float16, builder) + + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + input_precision = _str_to_dot_input_precision(input_precision, builder) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + # ===-------------------- For Triton XPU -----------------------=== + # assert lhs.shape[-2].value >= 16 and lhs.shape[-1].value >= 16 \ + # and rhs.shape[-1].value >= 16, \ + # f"All non-batch values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" + # ===-----------------------------------------------------------=== + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + # TODO: This is CUDA specific, check if ROCm has the same limitation + # ===-------------------- For Triton XPU -----------------------=== + # assert lhs.shape[1].value >= 32, "small blocks not supported!" + # ===-----------------------------------------------------------=== + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if max_num_imprecise_acc is None: + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default + else: + max_num_imprecise_acc = 0 + + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), + ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + + +def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + condition = cast(condition, tl.int1, builder) + if condition.type.is_block(): + condition, x = broadcast_impl_value(condition, x, builder) + x, y = broadcast_impl_value(x, y, builder) + condition, x = broadcast_impl_value(condition, x, builder) + + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + if not condition.type.is_block(): + condition, _ = broadcast_impl_value(condition, x, builder) + ret_ty = x.type + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + + +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + +def wrap_tensor(x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + +def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: + if axis is None: + inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + assert axis < rank, f"reduction axis must be < inputs rank ({rank})" + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Associative Scan +# ===----------------------------------------------------------------------=== + + +def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool, + builder: ir.builder) -> Tuple[tl.tensor, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + scan_op.verify() + + return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Histogram +# ===----------------------------------------------------------------------=== + + +def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor: + assert len(input.shape) == 1, "histogram only supports 1D input" + assert input.dtype.is_int(), "histogram only supports integer input" + return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, ))) + + +## + + +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if max(1, len(x.shape)) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_constancy does not match the length of values") + x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(), tl.void) + + +def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + + new_args = [arg.handle for arg in args] + return tl.tensor(builder.create_print(prefix, hex, new_args), tl.void) + + +def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: + cond_ty = cond.type + if not cond_ty.is_block(): + cond_ty = tl.block_type(cond_ty.scalar, (1, )) + cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) + return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) + + +def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) + elif elem.dtype != tl.int32 and not require_i64: + assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ + "add a `.to(tl.int32)` or use regular indexing for 64 bit support" + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + +def _convert_to_ir_values(builder, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like] + return [_convert_elem_to_ir_value(builder, list_like, require_i64)] + + +def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = _convert_to_ir_values(builder, shape) + strides = _convert_to_ir_values(builder, strides) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) + + +def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + # Convert dynamic offsets to IR values + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return tl.tensor(builder.create_advance(base.handle, offsets), base.type) diff --git a/third_party/xpu/python/triton/language/standard.py b/third_party/xpu/python/triton/language/standard.py new file mode 100644 index 000000000..de30cf260 --- /dev/null +++ b/third_party/xpu/python/triton/language/standard.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +from ..runtime.jit import jit +from . import core +from . import math + +# constexpr utilities (triton metaprogramming sucks) + + +def _unwrap_if_constexpr(o): + return o.value if isinstance(o, core.constexpr) else o + + +def _log2(i: core.constexpr): + log2 = 0 + n = i.value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + +def _is_power_of_two(i: core.constexpr): + n = i.value + return core.constexpr((n & (n - 1)) == 0 and n != 0) + + +# ----------------------- +# Standard library +# ----------------------- + + +@core._tensor_member_fn +@jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type x: Block + :param div: the divisor + :param div: Block + """ + return (x + div - 1) // div + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + return 1 / (1 + math.exp(-x)) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("softmax") +def softmax(x, ieee_rounding=False): + z = x - max(x, 0) + num = math.exp(z) + den = sum(num, 0) + return math.fdiv(num, den, ieee_rounding) + + +@core._tensor_member_fn +@jit +def ravel(x): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.reshape(x, [x.numel], can_reorder=True) + + +@jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms indices of a row-major :code:`size_i * size_j` matrix into those + of one where the indices are col-major for each group of :code:`size_g` + rows. + + For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will + transform :: + + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + + into :: + + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # new row and column indices + new_i = off_i + (ij % size_g) + new_j = (ij % size_gj) // size_g + return new_i, new_j + + +@jit +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) + + +@jit +def zeros_like(input): + """ + Creates a tensor of zeros with the same shape and type as a given tensor. + """ + return zeros(input.shape, input.dtype) + + +# max and argmax + + +@jit +def _argmax_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + gt = value1 > value2 or tie + v_ret = core.where(gt, value1, value2) + i_ret = core.where(gt, index1, index2) + return v_ret, i_ret + + +@jit +def _argmax_combine_tie_break_left(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, True) + + +@jit +def _argmax_combine_tie_break_fast(value1, index1, value2, index2): + return _argmax_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_max(a, b): + return core.maximum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left") +def argmax(input, axis, tie_break_left=True, keep_dims=False): + (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +# min and argmin + + +@jit +def _argmin_combine(value1, index1, value2, index2, tie_break_left): + if tie_break_left: + tie = value1 == value2 and index1 < index2 + else: + tie = False + lt = value1 < value2 or tie + value_ret = core.where(lt, value1, value2) + index_ret = core.where(lt, index1, index2) + return value_ret, index_ret + + +@jit +def _argmin_combine_tie_break_left(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, True) + + +@jit +def _argmin_combine_tie_break_fast(value1, index1, value2, index2): + return _argmin_combine(value1, index1, value2, index2, False) + + +@jit +def _elementwise_min(a, b): + return core.minimum(a, b) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", + tie_break_arg="return_indices_tie_break_left") +def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + if return_indices: + if return_indices_tie_break_left: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims) + else: + return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims) + else: + if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.is_floating()): + input = input.to(core.float32) + else: + assert input.dtype.is_int(), "Expecting input to be integer type" + input = input.to(core.int32) + return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims) + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") +def argmin(input, axis, tie_break_left=True, keep_dims=False): + _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims) + return ret + + +@jit +def _sum_combine(a, b): + return a + b + + +# sum + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("sum") +def sum(input, axis=None, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims) + + +@jit +def _xor_combine(a, b): + return a ^ b + + +# xor sum + + +@core._tensor_member_fn +@core.builtin +@core._add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = core._promote_bfloat16_to_float32(input, _builder=_builder) + return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator) + + +# cumsum + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumsum") +def cumsum(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _sum_combine, reverse) + + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _prod_combine, reverse) + + +# sort + + +@jit +def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): + n_outer: core.constexpr = x.numel >> n_dims + shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = core.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = core.arange(0, 2)[None, :, None] + left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape) + right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) + left = core.reshape(left, x.shape) + right = core.reshape(right, x.shape) + # actual compare-and-swap + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + ret = ix ^ core.where((left > right) ^ flip, ileft ^ iright, zeros_like(ix)) + return ret.to(x.dtype, bitcast=True) + + +@jit +def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + n_outer: core.constexpr = x.numel >> n_dims + core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims) + return x + + +@core._tensor_member_fn +@jit +def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: core.constexpr = _log2(x.shape[_dim]) + for i in core.static_range(1, n_dims + 1): + x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims) + return x + + +# flip + + +def _get_flip_dim(dim, shape): + dim = _unwrap_if_constexpr(dim) + shape = _unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + assert dim == len(shape) - 1, "Currently only support flipping the last dimension" + return core.constexpr(dim) + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along (currently only final dimension supported) + :type dim: int + """ + core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)])) + core.static_assert(_is_power_of_two(x.numel)) + # # reshape the tensor to have all dimensions be 2. + # # TODO: We shouldn't have to change the dimensions not sorted. + steps: core.constexpr = _log2(x.numel) + start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)]) + y = core.reshape(x, [2] * steps) + y = core.expand_dims(y, start) + flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2)) + for i in core.static_range(start, steps): + flip2 = flip + for j in core.static_range(0, steps + 1): + if j != i and j != i + 1: + flip2 = core.expand_dims(flip2, j) + y = sum(y * flip2, i + 1, keep_dims=True) + x = core.reshape(y, x.shape) + return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. + + The two tensors must have the same shape. + + Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])` + """ + c = core.join(a, b) + + assert isinstance(c.shape, list) + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) diff --git a/third_party/xpu/python/triton/ops/__init__.py b/third_party/xpu/python/triton/ops/__init__.py new file mode 100644 index 000000000..18f1d782d --- /dev/null +++ b/third_party/xpu/python/triton/ops/__init__.py @@ -0,0 +1,7 @@ +# from .conv import _conv, conv +from . import blocksparse +from .cross_entropy import _cross_entropy, cross_entropy +from .flash_attention import attention +from .matmul import _matmul, get_higher_dtype, matmul + +__all__ = ["blocksparse", "_cross_entropy", "cross_entropy", "_matmul", "matmul", "attention", "get_higher_dtype"] diff --git a/third_party/xpu/python/triton/ops/blocksparse/__init__.py b/third_party/xpu/python/triton/ops/blocksparse/__init__.py new file mode 100644 index 000000000..6b24b5377 --- /dev/null +++ b/third_party/xpu/python/triton/ops/blocksparse/__init__.py @@ -0,0 +1,7 @@ +from .matmul import matmul +from .softmax import softmax + +__all__ = [ + "matmul", + "softmax", +] diff --git a/third_party/xpu/python/triton/ops/blocksparse/matmul.py b/third_party/xpu/python/triton/ops/blocksparse/matmul.py new file mode 100644 index 000000000..098e15438 --- /dev/null +++ b/third_party/xpu/python/triton/ops/blocksparse/matmul.py @@ -0,0 +1,432 @@ +import torch + +from ... import cdiv, heuristics, jit +from ... import language as tl + +# ******************************************************** +# -------------------------------------------------------- +# Sparse = Dense x Dense (SDD) +# This operation uses super-blocking to make sure that +# it's done efficiently when small blocks can be grouped +# together +# -------------------------------------------------------- +# ******************************************************** + + +@heuristics({ + 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, +}) +@jit +def _sdd_kernel(A, B, C, # + stride_za, stride_ha, stride_ma, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_nb, # + stride_zc, stride_hc, stride_mc, stride_nc, # + K, grid_offset, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + BLOCK: tl.constexpr, EVEN_K: tl.constexpr # + ): + # ------------ # + # - Prologue - # + # ------------ # + block_id = tl.program_id(0) + grid_offset + lut += block_id * 3 + # offsets + off_z = tl.program_id(2) # batch + off_h = tl.load(lut + 0) # head + + # initialize pointers to A + start_am = tl.load(lut + 1) + offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) + offs_ak = tl.arange(0, TILE_K) + a_ptrs = A \ + + off_z * stride_za \ + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B + start_bn = tl.load(lut + 2) + offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) + offs_bk = tl.arange(0, TILE_K) + b_ptrs = B \ + + off_z * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_nb \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for k in range(K, 0, -TILE_K): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.) + b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.) + acc += tl.dot(a, b, out_dtype=tl.float32) + a_ptrs += TILE_K * stride_ak + b_ptrs += TILE_K * stride_bk + c = acc.to(C.dtype.element_ty) + # ---------------- # + # Epilogue # + # ---------------- # + offs_cm = tl.arange(0, TILE_M) % BLOCK + offs_cn = tl.arange(0, TILE_N) % BLOCK + pc = C \ + + off_z * stride_zc \ + + block_id * stride_hc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc + tl.store(pc, c, mask=True) + + +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # (A * B)^T = B^T * A^T + if trans_c: + a, b = b, a + trans_a, trans_b = not trans_b, not trans_a + # shape constraints + a_dim = -2 if trans_a else -1 + b_dim = -1 if trans_b else -2 + Ka, Kb = a.shape[a_dim], b.shape[b_dim] + if Ka != Kb: + raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") + # allocate output + if out is None: + c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device) + else: + assert out.shape == (a.shape[0], lut.shape[0], block, block) + c = out + grid = [c.shape[1], 1, c.shape[0]] + _sdd_kernel[grid]( + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(2), c.stride(3), # + Ka, 0, lut, # + TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, # + num_warps=4 # + ) + return c + + +def sdd_lut(layout, block, device): + lut = layout.nonzero(as_tuple=False).to(device).int() + lut = lut.contiguous() + return lut, None + + +# ----------------------------- +# Dense = Sparse x Dense (DSD) +# This operation uses a look-up table that contains pre-computed pointer increments +# in order to minimize computations in the inner loop of the matmul kernel. +# ----------------------------- + + +@jit +def _dsd_kernel(A, B, C, # + stride_az, stride_ha, stride_am, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_bn, # + stride_zc, stride_hc, stride_cm, stride_cn, # + DS0, DS1, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr # + ): + # ------------ # + # - Prologue - # + # ------------ # + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) + pidz = tl.program_id(2) + header = lut + pid_n * 4 + offset = tl.load(header + 0) + K = tl.load(header + 1) + column = tl.load(header + 2) + off_h = tl.load(header + 3) + pinc = lut + offset + # initialize pointers to A (sparse) + block_id = tl.load(pinc + 1) + block_id = tl.multiple_of(block_id, 8) # compiler hint + offs_am = tl.arange(0, TILE_M) + offs_ak = tl.arange(0, TILE_K) + pa = A + pidz * stride_az \ + + block_id * stride_ha \ + + offs_am[:, None] * stride_am \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B (dense) + offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) + start_bk = tl.load(pinc) + start_bk = tl.multiple_of(start_bk, 8) # compiler hint + offs_bk = start_bk + tl.arange(0, TILE_K) + pb = B + pidz * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + for k in range(K, 0, -TILE_K): + a = tl.load(pa) + b = tl.load(pb) + acc += tl.dot(a, b, out_dtype=tl.float32) + pa += inc_a + pb += inc_b * stride_bk + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + c = acc.to(C.dtype.element_ty) + # initialize pointers to C + offs_cm = column * TILE_M + tl.arange(0, TILE_M) + offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) + pc = C \ + + off_h * stride_hc \ + + pidz * stride_zc \ + + offs_cm[:, None] * stride_cm \ + + offs_cn[None, :] * stride_cn + tl.store(pc, c, mask=offs_cn[None, :] < DS0) + + +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # shapes / dtypes + AS1 = block * spdims[2 if trans_a else 1] + BS0 = b.size(0) + BS1 = b.size(1) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + # allocate output + CS0 = BS0 + CS1 = BS1 + CS2 = BS3 if trans_c else AS1 + CS3 = AS1 if trans_c else BS3 + if out is None: + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + else: + assert out.shape == (CS0, CS1, CS2, CS3) + c = out + # meta-parameter heuristics + TILE_N = 128 + # compute output + grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0] + _dsd_kernel[grid]( + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), # + BS3, AS1, lut, # + TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, # + num_warps=4, GROUP_SIZE_M=4 # + ) + # exit() + return c + + +def dsd_lut(layout, block, step, trans, device): + """ + Generates the look-up table for incrementing pointers in the DSD/DDS matmul. + Example (BLOCK=32, STEP=16) + [[1, 0, 0, 1, 0], + [0, 1, 1, 0, 1], + [1, 0, 1, 0, 0]] + + Then the offsets for A are + [0 , 16, 32, 48] <- row 0 + \\----/ \\----/ + col=0 col=3 + [64, 80, 96, 112, 128, 144] <- row 1 + \\----/ \\----/ \\------/ + col=1 col=2 col=3 + [160, 176, 192, 208] + which leads to increments table + [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16] + + Because B is dense, the offsets are + [0, 16, 96, 112] <- row 0 + [32, 48, 64, 80] <- row 1 + [0, 16, 64, 80] <- row 2 + """ + sizes = torch.sum(layout, 2 if trans else 1) + head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True) + sizes = sizes.flatten() + segments = sizes * step + # pointer increments + if trans: + nnz = layout.nonzero(as_tuple=False) + else: + nnz = layout.transpose(1, 2).nonzero(as_tuple=False) + num_blocks = nnz.size(0) + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets)) + # ------------------------------- + # dense input pointer increments + # ------------------------------- + # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K) + # that is smaller than the block size, so we need to do a bit of extra work + # to handle this case + B_idx = nnz[:, 2] * block + B_incs = B_idx.clone() + B_incs[1:] -= B_idx[:-1] + div = block // step + B_incs = B_incs.view(-1, 1).repeat(1, div) + B_incs[:, 1:] = step + B_incs[:, 0] -= (div - 1) * step + # first increment for each reduction is actually the offset + B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]] + B_incs = B_incs.view(-1) + # ------------------------------- + # sparse input pointer increments + # ------------------------------- + # same as above, except that the increments are in the sparse memory layout + if trans: + A_idx = torch.arange(num_blocks, device=layout.device) + else: + A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) + current_offset = 0 + for z in range(layout.size(0)): + layoutw = layout[z, :, :].clone().long() + msum = layoutw.sum() + layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device) + A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) + current_offset += msum + A_incs = A_idx * block * block + A_incs[1:] -= A_idx[:-1] * block * block + A_incs = A_incs.view(-1, 1).repeat(1, div) + if trans: + A_incs[:, 1:] = step + A_incs[:, 0] -= (div - 1) * step + else: + A_incs[:, 1:] = step * block + A_incs[:, 0] -= (div - 1) * step * block + A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] + A_incs = A_incs.view(-1) + # create header + width = col_id.size(0) + offsets = offsets * 2 * div + 4 * width + segments = segments * div + header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() + # create increments + incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() + # pad by a factor 2*MAX_NUM_STAGES + # to accommodate pre-fetching inside the kernel + pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) + incs = torch.cat((incs, pad)) + # create lut + lut = torch.cat((header, incs)) + lut = lut.type(torch.int32).to(device) + # create locks + return lut, width + + +# ----------------------------- +# Dense = Dense x Sparse (DDS) +# ----------------------------- +# AB = (B^T A^T)^T + + +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out) + + +############## +# MAIN API # +############## + + +class _matmul(torch.autograd.Function): + + fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} + + @staticmethod + def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut, + db_width, out): + c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) + # save for backward + ctx.save_for_backward(a, b) + ctx.da_lut = da_lut + ctx.da_width = da_width + ctx.db_lut = db_lut + ctx.db_width = db_width + ctx.mode = mode + ctx.spdims = spdims + ctx.block = block + ctx.trans_a = trans_a + ctx.trans_b = trans_b + ctx.trans_c = trans_c + ctx.has_out = out is not None + return c + + @staticmethod + def backward(ctx, dc): + # saved for backward + a, b = ctx.saved_tensors + da, db = None, None + mode = ctx.mode + # gradients w.r.t. a + if ctx.needs_input_grad[0]: + mode_da = mode[1] + mode[0] + mode[2] + da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, + ctx.da_lut, ctx.da_width) + # gradients w.r.t. b + if ctx.needs_input_grad[1]: + mode_db = mode[2] + mode[1] + mode[0] + db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, + ctx.db_lut, ctx.db_width) + dout = dc if ctx.has_out else None + return da, db, None, None, None, \ + None, None, None, None, \ + None, None, None, None, None, dout + + +class matmul: + + def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False): + if mode not in ['sdd', 'dsd', 'dds']: + raise NotImplementedError('Supported modes are: sdd, dsd, dds') + self.block = block + self.mode = mode + self.trans_a = trans_a + self.trans_b = trans_b + self.trans_c = trans_c + self.layout = layout + self.spdims = layout.shape + step = min(block, 32) + if self.mode == 'sdd': + self.c_lut, self.c_width = sdd_lut(layout, block, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device) + if self.mode == 'dsd': + self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device) + self.da_lut, self.da_width = sdd_lut(layout, block, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) + if self.mode == 'dds': + self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) + self.db_lut, self.db_width = sdd_lut(layout, block, device) + + def __call__(self, a, b, out=None): + c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, # + self.c_lut, self.c_width, # + self.da_lut, self.da_width, # + self.db_lut, self.db_width, # + out) + return c diff --git a/third_party/xpu/python/triton/ops/blocksparse/softmax.py b/third_party/xpu/python/triton/ops/blocksparse/softmax.py new file mode 100644 index 000000000..bcffff26b --- /dev/null +++ b/third_party/xpu/python/triton/ops/blocksparse/softmax.py @@ -0,0 +1,228 @@ +import torch + +from ... import jit +from ... import language as tl +from ... import next_power_of_2 + + +def num_warps(n): + if n <= 128: + return 1 + if n <= 256: + return 2 + if n <= 512: + return 4 + if n <= 4096: + return 8 + return 16 + + +@jit +def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, # + R, extent, stride_zr, stride_hr, # relative attention + scale, is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr # + ): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # pointer offset + off_a = z * stride_xz + off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx + off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load X + mask = block_n < size + a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf")) + a = a.to(tl.float32) + # compute + out = a + out *= scale + # apply relative attention + if R is not None: + R += z * stride_zr + R += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) + rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0) + out += rel_logits + out = out.to(tl.float32) + # apply causal mask + out = tl.where((ns > m) & is_causal, -float("inf"), out) + # computation + out = tl.softmax(out) + # write-back + tl.store(Out + off_a + lane_n, out, mask=mask) + + +@jit +def _blocksparse_softmax_bwd(DA, stride_zdx, # + DOut, stride_zdout, # + Out, stride_zout, # + scale, # + LUT, # + DR, extent, stride_zr, stride_hr, stride_er, # + is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # row-col offset + off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE + off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE + mask = block_n < size + # pointers + As = Out + z * stride_zout + off_mn + DOuts = DOut + z * stride_zdout + off_mn + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load data + a = tl.load(As + lane_n, mask=mask, other=0.0) + a = a.to(tl.float32) + dout = tl.load(DOuts + lane_n, mask=mask, other=0.0) + dout = dout.to(tl.float32) + # compute + a = tl.where((ns > m) & is_causal & (a == a), 0., a) + da = a * (dout - tl.sum(a * dout, 0)) + # apply relative attention + if DR is not None: + DR += z * stride_zr + DR += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) & mask + tl.store(DR + m * extent + off_lo, da, mask=mask_lo) + da = da * scale + # convert da + # write-back + DAs = DA + z * stride_zdx + off_mn + tl.store(DAs + lane_n, da, mask=mask) + + +class _softmax(torch.autograd.Function): + + @staticmethod + def make_lut(layout, block, device): + _empty = torch.tensor([], dtype=torch.int64, device=layout.device) + sizes = _empty.clone() + # sizes along rows + for h in range(layout.shape[0]): + sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) + total_sizes = sizes * block + # offsets in block format + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + # block indices + columns = layout.nonzero(as_tuple=False)[:, 2] + header = torch.stack((sizes, offsets), dim=1).view(-1) + lut = torch.cat((header, columns)).type(torch.int32).to(device) + return lut, int(total_sizes.max()) + + @staticmethod + def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense): + if scale is not None and isinstance(scale, torch.Tensor): + assert scale.device.type == "cpu" + scale = scale.item() + M = a.shape[0] + grid = [spdims[0], spdims[1] * block, M] + rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape + rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride() + # enqueue kernel + out = torch.empty_like(a) + _blocksparse_softmax_fwd[grid]( + out, a, a.stride(0), lut, # + rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn# + scale, # + is_causal, # + BLOCK_SIZE=block, # + ROW_SIZE=next_power_of_2(maxlut), # + IS_DENSE=is_dense, # + num_warps=num_warps(maxlut) # + ) + # save to context + # ctx.mark_dirty(x) + ctx.save_for_backward(out, lut) + ctx.spdims = spdims + ctx.block = block + ctx.maxlut = maxlut + ctx.scale = scale + ctx.rel_shape = rel_shape + ctx.rel_strides = rel_strides + ctx.rel_dtype = a.dtype + ctx.is_dense = is_dense + ctx.is_causal = is_causal + return out + + @staticmethod + def backward(ctx, dout): + # retrieve from context + out, lut = ctx.saved_tensors + # relative logits gradients + dr = None + if ctx.needs_input_grad[3]: + dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device) + # run kernel + M = out.shape[0] + grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) + da = torch.empty_like(dout) + _blocksparse_softmax_bwd[grid]( + da, da.stride(0), # + dout, dout.stride(0), # + out, out.stride(0), # + ctx.scale, # + lut, # + dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], # + ctx.is_causal, # + BLOCK_SIZE=ctx.block, # + ROW_SIZE=next_power_of_2(ctx.maxlut), # + IS_DENSE=ctx.is_dense, # + num_warps=num_warps(ctx.maxlut) # + ) + return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None) + + +class softmax: + + def __init__(self, layout, block, device, is_dense=False): + self.spdims = layout.shape + self.layout = layout + self.block = block + self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device) + self.is_dense = is_dense + + def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): + if rel_logits is not None and rel_logits.dtype != a.dtype: + raise ValueError(f"relative position embedding must be {a.dtype}") + a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut, + self.is_dense) + return a diff --git a/third_party/xpu/python/triton/ops/cross_entropy.py b/third_party/xpu/python/triton/ops/cross_entropy.py new file mode 100644 index 000000000..88e8dae50 --- /dev/null +++ b/third_party/xpu/python/triton/ops/cross_entropy.py @@ -0,0 +1,96 @@ +import torch + +from .. import heuristics, jit +from .. import language as tl +from .. import next_power_of_2 + + +def num_warps(N): + if N < 2048: + return 4 + elif N < 8192: + return 8 + return 16 + + +@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@jit +def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to logit and probs + LOGITS = LOGITS + row * N + cols + WRIT_PROBS = PROBS + row * N + cols + READ_PROBS = PROBS + row * N + idx + # write-back negative log-probs + logits = tl.load(LOGITS, mask=cols < N, other=-float('inf')) + logits = logits.to(tl.float32) + logits = logits - tl.max(logits, 0) + probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits + tl.store(WRIT_PROBS, probs, mask=cols < N) + # There is a bug in the compiler, which fails to insert a barrier here. + # We add it explicitly for now. Will be fixed soon. + tl.debug_barrier() + # write-back loss + probs = tl.load(READ_PROBS) + tl.store(LOSS + row, probs) + + +@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@jit +def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to probs + PROBS = PROBS + row * N + cols + # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + # and we have -log(p[k]) stored in PROBS, so this is easy + probs = -tl.load(PROBS, mask=cols < N, other=float('inf')) + probs = tl.exp(probs.to(tl.float32)) + delta = cols == idx + # write result in-place in PROBS + dout = tl.load(DPROBS + row) + din = (probs - delta) * dout + tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N) + + +class _cross_entropy(torch.autograd.Function): + + @classmethod + def forward(cls, ctx, logits, indices): + # make sure we can use triton + assert (indices.dtype == torch.int64), "Indices are expected to be of type long." + # make kernel + device, dtype = logits.device, logits.dtype + n_cols = logits.shape[-1] + # run the kernel + result = torch.empty_like(indices, dtype=dtype, device=device) + neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device) + grid = lambda opt: (logits.numel() // n_cols, ) + _forward[grid](logits, neg_logprobs, indices, result, n_cols) + # save for backward + ctx.save_for_backward(neg_logprobs, indices) + return result + + @classmethod + def backward(cls, ctx, dneg_logprobs): + """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + so we initialize the gradient as neg_logprobs, so we can just exponentiate + to get p[k], which is most of what we need... neg_logprobs will be + modified in place to become the gradient we want + """ + # load saved tensors + neg_logprobs, indices = ctx.saved_tensors + # run the kernel + # neg_logprobs will be modified in place to become our gradient: + n_cols = neg_logprobs.shape[-1] + grid = lambda opt: (neg_logprobs.numel() // n_cols, ) + _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols) + return neg_logprobs, None + + +cross_entropy = _cross_entropy.apply diff --git a/third_party/xpu/python/triton/ops/flash_attention.py b/third_party/xpu/python/triton/ops/flash_attention.py new file mode 100644 index 000000000..0825ef26c --- /dev/null +++ b/third_party/xpu/python/triton/ops/flash_attention.py @@ -0,0 +1,466 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) + +Sequence Parallel implementation inspired by HazyResearch +(see https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py) +""" + +import torch +import triton + +from .. import cdiv, jit +from .. import language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@jit +def _fwd_kernel(Q, K, V, sm_scale, # + L, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + IS_CAUSAL: tl.constexpr # + ): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + vk_offset = qvk_offset // stride_qm + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(BLOCK_DMODEL, Z_H_N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, vk_offset), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(vk_offset, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # credits to: Adam P. Goucher (https://github.com/apgoucher): + # scale sm_scale by 1/log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + + offs_k = tl.arange(0, BLOCK_DMODEL) + Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + q = tl.load(Q_ptrs) + + q = (q * qk_scale).to(K.dtype.element_ty) + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back l and m + acc = acc / l_i[:, None] + l_ptrs = L + off_hz * N_CTX + offs_m + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(vk_offset + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) + + +@jit +def _bwd_preprocess( + Out, + DO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_m, delta) + + +@jit +def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): + if CAUSAL: + lo = start_n * BLOCK_M + else: + lo = 0 + + Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm + DQ_offset = off_z * stride_qz + off_h * stride_qh + K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn + V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn + if SEQUENCE_PARALLEL: + DQ_offset += stride_dqa * start_n + DQ_offset = DQ_offset // stride_qm + + Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0)) + K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0)) + DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0)) + DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + + # initialize row/col offsets + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + l_ptrs = L + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(Q_block_ptr) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + if CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf")) + else: + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk *= qk_scale + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i[:, None]) + # compute dv + do = tl.load(DO_block_ptr) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp = tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty) + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds), q) + # compute dq + if not SEQUENCE_PARALLEL: + dq = tl.load(DQ_block_ptr) + dq += tl.dot(ds, k) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) + elif SEQUENCE_PARALLEL: + if MMA_V3: + dq = tl.dot(ds, k) + else: + # not work with mma v3, because M % 64 != 0 + dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds))) + tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty)) + + # increment pointers + DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0)) + Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) + # write-back + tl.store(DV_block_ptr, dv.to(V.dtype.element_ty)) + tl.store(DK_block_ptr, dk.to(K.dtype.element_ty)) + + +@jit +def _bwd_kernel(Q, K, V, sm_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + Z_H_N_CTX, # + SQ_Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): + qk_scale = sm_scale * 1.44269504 + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + if SEQUENCE_PARALLEL: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + else: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + DK_block_ptr = tl.make_block_ptr( + base=DK, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + DV_block_ptr = tl.make_block_ptr( + base=DV, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + num_block_n = tl.cdiv(N_CTX, BLOCK_N) + if not SEQUENCE_PARALLEL: + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) + else: + start_n = tl.program_id(1) + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): + # only support for Ampere now + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported for compute capability >= 80") + BLOCK_M = 128 + BLOCK_N = 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( + q, k, v, sm_scale, # + L, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, # + IS_CAUSAL=causal, # + num_warps=num_warps, # + num_stages=4 # + ) + + ctx.save_for_backward(q, k, v, o, L) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + ctx.causal = causal + ctx.sequence_parallel = sequence_parallel + return o + + @staticmethod + def backward(ctx, do): + capability = torch.cuda.get_device_capability() + MMA_V3 = capability[0] >= 9 + BLOCK = 128 + + if is_hip(): + # Bwd pass runs out of shared memory on HIP with larger block size. + BLOCK = 64 + + q, k, v, o, L = ctx.saved_tensors + sequence_parallel = ctx.sequence_parallel + seq_len_kv = k.shape[2] + do = do.contiguous() + if sequence_parallel: + replicas = cdiv(seq_len_kv, BLOCK) + new_dq_shape = (replicas, ) + q.shape + dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype) + else: + dq = torch.zeros_like(q, dtype=q.dtype) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + delta = torch.empty_like(L) + _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )]( + o, + do, + delta, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)]( + q, k, v, ctx.sm_scale, # + o, do, # + dq, dk, dv, # + L, # + delta, # + o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + SEQUENCE_PARALLEL=sequence_parallel, # + CAUSAL=ctx.causal, # + MMA_V3=MMA_V3, # + num_warps=8, # + num_stages=1 # + ) + + if len(dq.shape) == 5: + dq = dq.sum(dim=0) + return dq, dk, dv, None, None, None + + +attention = _attention.apply diff --git a/third_party/xpu/python/triton/ops/matmul.py b/third_party/xpu/python/triton/ops/matmul.py new file mode 100644 index 000000000..f7f577a1b --- /dev/null +++ b/third_party/xpu/python/triton/ops/matmul.py @@ -0,0 +1,219 @@ +import torch + +from .. import Config, autotune, cdiv, heuristics, jit +from .. import language as tl +from .matmul_perf_model import early_config_prune, estimate_matmul_time + +_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] + + +def upcast_if_fp8(a): + if "fp8" in str(a): + return torch.float16 + return a + + +def get_higher_dtype(a, b): + a = upcast_if_fp8(a) + b = upcast_if_fp8(b) + if a is b: + return a + + assert a in _ordered_datatypes + assert b in _ordered_datatypes + + for d in _ordered_datatypes: + if a is d: + return b + if b is d: + return a + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@autotune( + configs=[ + # basic configs for compute-bound matmuls + Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10, + }, +) +@heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@jit +def _kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + acc_dtype: tl.constexpr, # + input_precision: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr # + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, input_precision=input_precision) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +class _matmul(torch.autograd.Function): + kernel = _kernel + + _locks = {} + + @staticmethod + def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if (output_dtype is None): + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + # Allowed types for acc_type given the types of a and b. + supported_acc_dtypes = { + torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32, ), torch.int8: (torch.int32, ) + } + + if acc_dtype is None: + acc_dtype = supported_acc_dtypes[ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert acc_dtype in supported_acc_dtypes[a.dtype], "acc_dtype not compatible with the type of a" + assert acc_dtype in supported_acc_dtypes[b.dtype], "acc_dtype not compatible with the type of b" + + def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: + ab_dtype = None + # launch kernel + grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid]( + a, b, c, M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + acc_dtype=acc_dtype, # + input_precision=input_precision, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, AB_DTYPE=ab_dtype) + return c + + @staticmethod + def forward(ctx, a, b, acc_dtype=None, input_precision=None, fp8_fast_accum=True, output_dtype=None): + return _matmul._call(a, b, acc_dtype=acc_dtype, input_precision=input_precision, fp8_fast_accum=fp8_fast_accum, + output_dtype=output_dtype) + + +matmul = _matmul.apply diff --git a/third_party/xpu/python/triton/ops/matmul_perf_model.py b/third_party/xpu/python/triton/ops/matmul_perf_model.py new file mode 100644 index 000000000..b60b74540 --- /dev/null +++ b/third_party/xpu/python/triton/ops/matmul_perf_model.py @@ -0,0 +1,171 @@ +import functools +import heapq + +import torch + +from .. import cdiv +from ..runtime import driver +from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi) + + +@functools.lru_cache() +def get_clock_rate_in_khz(): + try: + return nvsmi(['clocks.max.sm'])[0] * 1e3 + except FileNotFoundError: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 + + +def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops( + dtype, get_clock_rate_in_khz(), device) + return tflops + + +def get_simd_tflops(device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) + return tflops + + +def get_tflops(device, num_ctas, num_warps, dtype): + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: + return get_simd_tflops(device, num_ctas, num_warps, dtype) + return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) + + +def estimate_matmul_time( + # backend, device, + num_warps, num_stages, # + A, B, C, # + M, N, K, # + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, # + debug=False, **kwargs # +): + ''' return estimated running time in ms + = max(compute, loading) + store ''' + device = torch.cuda.current_device() + dtype = A.dtype + dtsize = A.element_size() + + num_cta_m = cdiv(M, BLOCK_M) + num_cta_n = cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k + + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) + + # time to compute + total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS + tput = get_tflops(device, num_ctas, num_warps, dtype) + compute_ms = total_ops / tput + + # time to load data + num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) + load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw + + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram / store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram / reduce_bw + # c.zero_() + zero_ms = M * N * 2 / (1024 * 1024) / store_bw + store_ms += zero_ms + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' + f'loading time: {load_ms}ms, store time: {store_ms}ms, ' + f'Activate CTAs: {active_cta_ratio*100}%') + return total_time_ms + + +def early_config_prune(configs, named_args, **kwargs): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['A'].element_size() + dtype = named_args['A'].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages + + max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + + # Some dtypes do not allow atomic_add + if dtype not in [torch.float16, torch.float32]: + configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] + + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages + + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs diff --git a/third_party/xpu/python/triton/runtime/__init__.py b/third_party/xpu/python/triton/runtime/__init__.py new file mode 100644 index 000000000..0b3979d28 --- /dev/null +++ b/third_party/xpu/python/triton/runtime/__init__.py @@ -0,0 +1,23 @@ +from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics) +from .cache import RedisRemoteCacheBackend, RemoteCacheBackend +from .driver import driver +from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret +from .errors import OutOfResources, InterpreterError + +__all__ = [ + "autotune", + "Autotuner", + "Config", + "driver", + "Heuristics", + "heuristics", + "InterpreterError", + "JITFunction", + "KernelInterface", + "MockTensor", + "OutOfResources", + "RedisRemoteCacheBackend", + "reinterpret", + "RemoteCacheBackend", + "TensorWrapper", +] diff --git a/third_party/xpu/python/triton/runtime/autotuner.py b/third_party/xpu/python/triton/runtime/autotuner.py new file mode 100644 index 000000000..5e1cb24e2 --- /dev/null +++ b/third_party/xpu/python/triton/runtime/autotuner.py @@ -0,0 +1,897 @@ +from __future__ import annotations + +import builtins +import os +import time +import inspect +from typing import Dict + +from ..testing import do_bench, do_bench_cudagraph +from .jit import KernelInterface +from .errors import OutOfResources + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + use_cuda_graph=False, + generate_configs=None, + op_affiliation="", + row_sign="", + col_sign="", + n_elem_sign="", + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + self.no_configs = False + self.generate_configs = generate_configs + self.op_affiliation = op_affiliation + self.row_sign = row_sign + self.col_sign = col_sign + self.n_elem_sign = n_elem_sign + if not configs: + self.no_configs = True + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] + + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args, exception: 0 + if pre_hook: + self.pre_hook = pre_hook + elif (len(self.reset_idx) > 0 or len(self.restore_idx) > 0): + + def _pre_hook(args, reset_only=False): + for i in self.reset_idx: + args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + elif len(self.restore_idx) > 0: + + def _post_hook(args, exception): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + self.num_warmups = warmup + self.num_reps = rep + # import torch + self.use_cuda_graph = False # use_cuda_graph and torch.cuda.is_available() + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(args, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(args, exception=None) + + try: + if self.use_cuda_graph: + import torch + with torch.cuda.stream(torch.cuda.Stream()): + bench_res = do_bench_cudagraph(kernel_call, rep=self.num_reps, return_mode="median") + return bench_res + return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure): + return float("inf") if self.use_cuda_graph else [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if self.no_configs and self.generate_configs is not None: + self.configs = block_size_candidates(self.nargs, self.generate_configs, self.op_affiliation, self.row_sign, + self.col_sign, self.n_elem_sign) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + for arg in _args: + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + # prune configs + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.pre_hook(args, reset_only=True) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + if config.pre_hook is not None: + config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()}) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.maxnreg = maxnreg + self.pre_hook = pre_hook + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("maxnreg", self.maxnreg), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=25, rep=100, use_cuda_graph=False, generate_configs=None, op_affiliation="sdnn", row_sign=None, + col_sign=None, n_elem_sign=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'args': a list of arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :type warmup: int + :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :type rep: int + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, generate_configs=generate_configs, + op_affiliation=op_affiliation, row_sign=row_sign, col_sign=col_sign, n_elem_sign=n_elem_sign) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator + + +def largest_factor(x: int): + ret = 1 + for i in range(x - 1, 1, -1): + if x % i == 0: + ret = i + break + return ret + + +def cdiv(x: int, y: int): + return (x + y - 1) // y + + +def floordiv(x: int, y: int): + return x // y + + +def aligned(x: int, y: int): + return cdiv(x, y) * y + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def find_next_multiple_of_12(n): + """Return the next multiple of 12 greater than n""" + if n <= 0: + return 12 + + remainder = n % 12 + + if remainder == 0: + return n + else: + return n + (12 - remainder) + + +def append_candidate(candicates: list, target_candicate: Config): + found = False + for item in candicates: + if item.all_kwargs() == target_candicate.all_kwargs(): + found = True + break + if not found: + candicates.append(target_candicate) + return + + +def check_out_of_mem(block_size_m, block_size_n, block_size_k, mem, ele_bytes, bias, buffer_num, a_trans, b_trans): + am_layout = block_size_m + ak_layout = block_size_k + bn_layout = block_size_n + bk_layout = block_size_k + + if a_trans and b_trans: + am_layout, ak_layout = block_size_k, block_size_m + bn_layout, bk_layout = block_size_k, block_size_n + elif a_trans: + am_layout, ak_layout = block_size_k, block_size_m + elif b_trans: + bn_layout, bk_layout = block_size_k, block_size_n + + return ((aligned(ak_layout * ele_bytes, mem[1]) * am_layout + aligned(bn_layout * ele_bytes, mem[1]) * bk_layout) + > (mem[0] - aligned(block_size_n * ele_bytes, mem[1]) * (block_size_m + bias * + (2 + block_size_m))) // buffer_num) + + +def add_candidate_for_workload_not_balanced(configs: list, block_size_m, block_size_n, block_size_k, buffer_num, + meta_info): + input_size = meta_info['input_size'] + mem = meta_info['mem'] + ele_bytes = meta_info['ele_bytes'] + bias = meta_info['bias'] + block_names = meta_info['block_names'] + grid_aligned = meta_info['grid_aligned'] + aligned_size = meta_info['aligned_size'] + a_trans = meta_info['a_trans'] + b_trans = meta_info['b_trans'] + + grid_m_aligned = cdiv(input_size[0], block_size_m) + grid_n_aligned = cdiv(input_size[1], block_size_n) + + top_p = 3 + + while check_out_of_mem(block_size_m, block_size_n, block_size_k, mem, ele_bytes, bias, buffer_num, a_trans, + b_trans): + if block_size_k % 2 == 0: + block_size_k = block_size_k // 2 + else: + block_size_k = largest_factor(block_size_k) + + if block_size_k == 1: + break + + if (grid_m_aligned * grid_n_aligned) < grid_aligned: + block_size_m = max(2, min(block_size_m, input_size[0])) + block_size_n = max(2, min(block_size_n, input_size[1])) + + tmp_grid_m = cdiv(input_size[0], block_size_m) + tmp_grid_n = cdiv(input_size[1], block_size_n) + + append_candidate( + configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n, block_names[2]: block_size_k})) + for i in range(2, 13): + tmp_block_size_m = block_size_m // i + if tmp_block_size_m < 2: + break + tmp_grid_m = cdiv(input_size[0], tmp_block_size_m) + + if (tmp_grid_m * tmp_grid_n) % grid_aligned == 0: + append_candidate( + configs, + Config( + {block_names[0]: tmp_block_size_m, block_names[1]: block_size_n, block_names[2]: block_size_k})) + + for i in range(2, 13): + tmp_block_size_n = block_size_n // i + if tmp_block_size_n < 2: + break + tmp_grid_n = cdiv(input_size[1], tmp_block_size_n) + + if (tmp_grid_m * tmp_grid_n) % grid_aligned == 0: + append_candidate( + configs, + Config( + {block_names[0]: block_size_m, block_names[1]: tmp_block_size_n, block_names[2]: block_size_k})) + else: + append_candidate( + configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n, block_names[2]: block_size_k})) + + if input_size[0] % block_size_m != 0: + for i in range(block_size_m, 1, -1): + _block_size_m = i + if (cdiv(input_size[0], _block_size_m) * grid_n_aligned) % grid_aligned == 0: + append_candidate( + configs, + Config( + {block_names[0]: _block_size_m, block_names[1]: block_size_n, block_names[2]: + block_size_k})) + break + + elif input_size[1] % block_size_n != 0: + for i in range(block_size_n, 1, -1): + _block_size_n = i + if (cdiv(input_size[1], _block_size_n) * grid_m_aligned) % grid_aligned == 0: + append_candidate( + configs, + Config( + {block_names[0]: block_size_m, block_names[1]: _block_size_n, block_names[2]: + block_size_k})) + break + else: + for i in range(block_size_m, (grid_m_aligned - 1) * aligned_size["m_aligned"] + 1, -1): + _block_size_m = i + for j in range(block_size_n, (grid_n_aligned - 1) * aligned_size["n_aligned"] + 1, -1): + _block_size_n = j + tmp_grid_m = cdiv(input_size[0], _block_size_m) + tmp_grid_n = cdiv(input_size[1], _block_size_n) + + if (tmp_grid_m * tmp_grid_n) % grid_aligned == 0: + top_p -= 1 + append_candidate( + configs, + Config({ + block_names[0]: _block_size_m, block_names[1]: _block_size_n, block_names[2]: + block_size_k + })) + break + + if top_p == 0: + break + + return + + +def add_candidate_for_workload_balanced(configs: list, block_size_m, block_size_n, block_size_k, buffer_num, meta_info): + input_size = meta_info['input_size'] + mem = meta_info['mem'] + ele_bytes = meta_info['ele_bytes'] + bias = meta_info['bias'] + block_names = meta_info['block_names'] + aligned_size = meta_info['aligned_size'] + a_trans = meta_info['a_trans'] + b_trans = meta_info['b_trans'] + + grid_m_aligned = cdiv(input_size[0], block_size_m) + grid_n_aligned = cdiv(input_size[1], block_size_n) + + while check_out_of_mem(block_size_m, block_size_n, block_size_k, mem, ele_bytes, bias, buffer_num, a_trans, + b_trans): + if block_size_k % 2 == 0: + block_size_k = block_size_k // 2 + else: + block_size_k = largest_factor(block_size_k) + + if block_size_k == 1: + break + + append_candidate(configs, + Config({block_names[0]: block_size_m, block_names[1]: block_size_n, block_names[2]: block_size_k})) + + if input_size[0] % grid_m_aligned == 0 and input_size[1] % grid_n_aligned == 0: + block_size_m = max(2, floordiv(input_size[0], grid_m_aligned)) + block_size_n = max(2, floordiv(input_size[1], grid_n_aligned)) + elif input_size[0] % grid_m_aligned == 0: + block_size_m = max(2, floordiv(input_size[0], grid_m_aligned)) + elif input_size[1] % grid_n_aligned == 0: + block_size_n = max(2, floordiv(input_size[1], grid_n_aligned)) + + append_candidate(configs, + Config({block_names[0]: block_size_m, block_names[1]: block_size_n, block_names[2]: block_size_k})) + return + + +def get_input_ele_bytes(args): + ele_bytes = 4 + + if "a_ptr" in args.keys(): + A = args["a_ptr"] + elif "inp" in args.keys(): + A = args["inp"] + else: + A = args["A"] + + if A.dtype.__str__() == "torch.float16": + ele_bytes = 2 + + return ele_bytes + + +def balance_grid(block_size_m, block_size_n, input_size): + grid_x = cdiv(input_size[0], block_size_m) + grid_y = cdiv(input_size[1], block_size_n) + + total_grid = grid_x * grid_y + + # simple balance method + next_multiple_of_12 = find_next_multiple_of_12(total_grid) + grid_y = cdiv(next_multiple_of_12, grid_x) + block_size_n = cdiv(input_size[1], grid_y) + + # todo: add more balance method + + return block_size_m, block_size_n + + +def block_size_candidates_cluster(args, generate_configs, op_affiliation, row_sign, col_sign, n_elem_sign): + # The result of block_size_candidates + configs = [] + + # 1D Tune + if "BLOCK_SIZE" in args.keys(): # TODO: add more 1d block_size str to match + if n_elem_sign == None: + raise RuntimeError("Failed to tune block size. Miss n_elem_sign") + n_elements = args[n_elem_sign] + + # max cluster + block_size = cdiv(n_elements, 12) + append_candidate(configs, Config({"BLOCK_SIZE": block_size})) + + # max cluster with power2 block_size + block_size = next_power_of_2(cdiv(n_elements, 12)) + append_candidate(configs, Config({"BLOCK_SIZE": block_size})) + + # Print the result of block_size_candidates + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1": + print(f"row: {m}, col: {n}") + for config in configs: + print(f"config: {config}") + + return configs + + # 2D Tune + # collect all useful info + ele_bytes = get_input_ele_bytes(args) + + grid_aligned = 12 + + BLOCK_M = "BLOCK_M" # TODO: add more 2d block_size m/n str to match + BLOCK_N = "BLOCK_N" + + block_names = (BLOCK_M, BLOCK_N) + + if row_sign == None or col_sign == None: + raise RuntimeError("Failed to tune block_m/block_n size. Miss row_sign/col_sign") + + m = args[row_sign] + n = args[col_sign] + + input_size = ( + m, + n, + ) + + mem = (8192, 64) # 8K * 64 cores LM + + aligned_size = { + "m_aligned": 64, + "n_aligned": 64, + } + + meta_info = { + "ele_bytes": ele_bytes, + "grid_aligned": grid_aligned, + "block_names": block_names, + "input_size": input_size, + "aligned_size": aligned_size, + "mem": mem, + } + + core_num = 64 + buffer_size_upper = 512 # TODO: set to 2048 bytes + if "buffer_size" in args.keys(): + buffer_size_upper = args["buffer_size"] + + buffer_size_elem_cnt = cdiv(buffer_size_upper, ele_bytes) + + experimental_fine_tune = bool(os.getenv("TRITON_FINE_AUTOTUNE", False)) + + # Start To Tune + block_size_m = input_size[0] + block_size_n = input_size[1] + + if buffer_size_elem_cnt != next_power_of_2(buffer_size_elem_cnt): + raise RuntimeError("buffer_size should be power of two") + + # buffer can cache all input_data + if buffer_size_elem_cnt * core_num >= block_size_n: + # naive config + block_size_m = next_power_of_2(cdiv(input_size[0], 12)) + block_size_n = input_size[1] + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + + # balance config + block_size_m = cdiv(input_size[0], 12) + block_size_n = input_size[1] + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + + if experimental_fine_tune: + # naive config + block_size_m = next_power_of_2(cdiv(input_size[0], 12)) + block_size_n = next_power_of_2(input_size[1]) + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + + # naive config + block_size_m = cdiv(input_size[0], 12) + block_size_n = next_power_of_2(input_size[1]) + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + + return configs + + # buffer cannot cache all input_data + block_size_m = input_size[0] + block_size_n = input_size[1] + + # naive config + block_size_m = next_power_of_2(cdiv(input_size[0], 12)) + block_size_n = buffer_size_elem_cnt * core_num + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + + # TODO: add block_size_m auto tune + # only support logic: gridX = cdiv(M, BLOCK_M)、gridY = cdiv(N, BLOCK_N) + for block_size_n in range(buffer_size_elem_cnt * core_num, 0, -aligned_size["n_aligned"]): + if len(configs) == 5: + break + grid_x = cdiv(input_size[0], block_size_m) + grid_y = cdiv(input_size[1], block_size_n) + total_grid = grid_x * grid_y + if total_grid % grid_aligned != 0: + (block_size_m, block_size_n) = balance_grid(block_size_m, block_size_n, input_size) + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + else: + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + + # balance config + block_size_m = cdiv(input_size[0], 12) + block_size_n = buffer_size_elem_cnt * core_num + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + + # TODO: add block_size_m auto tune + # only support logic: gridX = cdiv(M, BLOCK_M)、gridY = cdiv(N, BLOCK_N) + for block_size_n in range(buffer_size_elem_cnt * core_num, 0, -aligned_size["n_aligned"]): + if len(configs) == 5: + break + grid_x = cdiv(input_size[0], block_size_m) + grid_y = cdiv(input_size[1], block_size_n) + total_grid = grid_x * grid_y + if total_grid % grid_aligned != 0: + (block_size_m, block_size_n) = balance_grid(block_size_m, block_size_n, input_size) + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + else: + append_candidate(configs, Config({block_names[0]: block_size_m, block_names[1]: block_size_n})) + + # Print the result of block_size_candidates + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1": + print(f"row: {m}, col: {n}") + for config in configs: + print(f"config: {config}") + + return configs + + +def block_size_candidates(args, generate_configs, op_affiliation, row_sign, col_sign, n_elem_sign): + if op_affiliation == "cluster": + return block_size_candidates_cluster(args, generate_configs, op_affiliation, row_sign, col_sign, n_elem_sign) + + # Get compile time info + BLOCK_M = "BLOCK_M" + BLOCK_N = "BLOCK_N" + BLOCK_K = "BLOCK_K" + bias = 0 + + if generate_configs == "bmm": + BLOCK_M = "TILE_M" + BLOCK_N = "TILE_N" + BLOCK_K = "TILE_K" + elif generate_configs == "addmm": + BLOCK_M = "BLOCK_SIZE_M" + BLOCK_N = "BLOCK_SIZE_N" + BLOCK_K = "BLOCK_SIZE_K" + bias = 1 + + # Block names + block_names = (BLOCK_M, BLOCK_N, BLOCK_K) + + ele_bytes = 4 + if "a_ptr" in args.keys(): + A = args["a_ptr"] + else: + A = args["A"] + if A.dtype.__str__() == "torch.float16": + ele_bytes = 2 + + a_trans = False + b_trans = False + + if "stride_ak" in args.keys(): + a_trans = args["stride_ak"] != 1 + if "stride_bn" in args.keys(): + b_trans = args["stride_bn"] != 1 + + # Input size info + input_size = (args["M"], args["N"], args["K"]) + + mem = (1605632, 128) + aligned_size = { + "m_aligned": 80, + "n_aligned": 64, + "k_aligned": 128, + } + + grid_aligned = 12 + + meta_info = { + "ele_bytes": ele_bytes, + "bias": bias, + "grid_aligned": grid_aligned, + "block_names": block_names, + "input_size": input_size, + "aligned_size": aligned_size, + "mem": mem, + "a_trans": a_trans, + "b_trans": b_trans, + } + + max_m_aglined = 4 + max_n_aglined = 7 + + # The result of block_size_candidates + configs = [] + + buffer_nums = [2] + for buffer_num in buffer_nums: + block_size_m = input_size[0] + block_size_n = input_size[1] + block_size_k = input_size[2] + + if block_size_m < 2: + block_size_m = 2 + + if block_size_n < 2: + block_size_n = 2 + n_loop_num = 2 + for i in range(min(max_m_aglined, cdiv(input_size[0], aligned_size["m_aligned"])), 0, -1): + if n_loop_num == 0: + break + n_loop_num -= 1 + for j in range(min(max_n_aglined, cdiv(input_size[1], aligned_size["n_aligned"])), 0, -1): + tmp_block_size_m = i * aligned_size["m_aligned"] + tmp_block_size_n = j * aligned_size["n_aligned"] + grid_m_aligned = cdiv(input_size[0], tmp_block_size_m) + grid_n_aligned = cdiv(input_size[1], tmp_block_size_n) + total_grid = grid_m_aligned * grid_n_aligned + if total_grid % grid_aligned != 0: + add_candidate_for_workload_not_balanced(configs, tmp_block_size_m, tmp_block_size_n, block_size_k, + buffer_num, meta_info) + else: + add_candidate_for_workload_balanced(configs, tmp_block_size_m, tmp_block_size_n, block_size_k, + buffer_num, meta_info) + + # Print the result of block_size_candidates + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1": + print(f"M: {input_size[0]}, N: {input_size[1]}, K: {input_size[2]}") + for config in configs: + print(f"config: {config}") + + return configs diff --git a/third_party/xpu/python/triton/runtime/build.py b/third_party/xpu/python/triton/runtime/build.py new file mode 100644 index 000000000..3be8aab7a --- /dev/null +++ b/third_party/xpu/python/triton/runtime/build.py @@ -0,0 +1,82 @@ +import contextlib +import sys +import io +import sysconfig +import os +import shutil +import subprocess +import setuptools + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +def _build(name, src, srcdir, library_dirs, include_dirs, libraries): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("g++") # XRE g++ For Triton XPU + if os.environ.get('TRITON_CLOSE_XPU_BACKEND', False): + gcc = shutil.which("gcc") # cudart need gcc + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + include_dirs = include_dirs + [srcdir, py_include_dir] + cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + cc_cmd += [f"-Wl,-rpath,{library_dirs[1]}"] + # print(f"cc_cmd = {cc_cmd}") + ret = subprocess.check_call(cc_cmd) + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) + return so diff --git a/third_party/xpu/python/triton/runtime/cache.py b/third_party/xpu/python/triton/runtime/cache.py new file mode 100644 index 000000000..bd3c29b99 --- /dev/null +++ b/third_party/xpu/python/triton/runtime/cache.py @@ -0,0 +1,281 @@ +import importlib +import json +import os +import uuid +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Optional +import hashlib + + +def default_cache_dir(): + return os.path.join(Path.home(), ".triton", "cache") + + +def default_override_dir(): + return os.path.join(Path.home(), ".triton", "override") + + +def default_dump_dir(): + return os.path.join(Path.home(), ".triton", "dump") + + +class CacheManager(ABC): + + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename) -> bool: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c, p in child_paths.items(): + if os.path.exists(p): + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + grp_contents = json.dumps({"child_paths": group}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + raise RuntimeError("Could not create or locate cache dir") + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + # we use the PID in case a bunch of these around so we can see what PID made it + pid = os.getpid() + # use tempfile to be robust against program interruptions + temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + return filepath + + +class RemoteCacheBackend: + """ + A backend implementation for accessing a remote/distributed cache. + """ + + def __init__(self, key: str): + pass + + @abstractmethod + def get(self, filenames: List[str]) -> Dict[str, bytes]: + pass + + @abstractmethod + def put(self, filename: str, data: bytes): + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend): + + def __init__(self, key): + import redis + self._key = key + self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}") + self._redis = redis.Redis( + host=os.environ.get("TRITON_REDIS_HOST", "localhost"), + port=int(os.environ.get("TRITON_REDIS_PORT", 6379)), + ) + + def _get_key(self, filename: str) -> str: + return self._key_fmt.format(key=self._key, filename=filename) + + def get(self, filenames: List[str]) -> Dict[str, str]: + results = self._redis.mget([self._get_key(f) for f in filenames]) + return {filename: result for filename, result in zip(filenames, results) if result is not None} + + def put(self, filename: str, data: bytes) -> Dict[str, bytes]: + self._redis.set(self._get_key(filename), data) + + +class RemoteCacheManager(CacheManager): + + def __init__(self, key, override=False, dump=False): + # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`. + remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"] + module_path, clz_nme = remote_cache_manager.split(":") + module = importlib.import_module(module_path) + remote_cache_cls = getattr(module, clz_nme) + self._backend = remote_cache_cls(key) + + self._override = override + self._dump = dump + + # Use a `FileCacheManager` to materialize remote cache paths locally. + self._file_cache_manager = FileCacheManager(key, override=override, dump=dump) + + def _materialize(self, filename: str, data: bytes): + # We use a backing `FileCacheManager` to provide the materialized data. + return self._file_cache_manager.put(data, filename, binary=True) + + def get_file(self, filename: str) -> Optional[str]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_file(filename) + + # We always check the remote cache backend -- even if our internal file- + # based cache has the item -- to make sure LRU accounting works as + # expected. + results = self._backend.get([filename]) + if len(results) == 0: + return None + (_, data), = results.items() + return self._materialize(filename, data) + + def put(self, data, filename: str, binary=True) -> str: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put(data, filename, binary=binary) + + if not isinstance(data, bytes): + data = str(data).encode("utf-8") + self._backend.put(filename, data) + return self._materialize(filename, data) + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.get_group(filename) + + grp_filename = f"__grp__{filename}" + grp_filepath = self.get_file(grp_filename) + if grp_filepath is None: + return None + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + + result = None + + # Found group data. + if child_paths is not None: + result = {} + for child_path, data in self._backend.get(child_paths).items(): + result[child_path] = self._materialize(child_path, data) + + return result + + def put_group(self, filename: str, group: Dict[str, str]): + # We don't handle the dump/override cases. + if self._dump or self._override: + return self._file_cache_manager.put_group(filename, group) + + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename) + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(key) + + +def get_override_manager(key) -> CacheManager: + return __cache_cls(key, override=True) + + +def get_dump_manager(key) -> CacheManager: + return __cache_cls(key, dump=True) + + +def make_so_cache_key(version_hash, signature, constants, ids, **kwargs): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}" + for kw in kwargs: + key = f"{key}-{kwargs.get(kw)}" + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + return key diff --git a/third_party/xpu/python/triton/runtime/driver.py b/third_party/xpu/python/triton/runtime/driver.py new file mode 100644 index 000000000..c3b97a764 --- /dev/null +++ b/third_party/xpu/python/triton/runtime/driver.py @@ -0,0 +1,60 @@ +from ..backends import backends +from ..backends import DriverBase + + +def _create_driver(): + actives = [x.driver for x in backends.values() if x.driver.is_active()] + if len(actives) != 1: + raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.") + return actives[0]() + + +class LazyProxy: + + def __init__(self, init_fn): + self._init_fn = init_fn + self._obj = None + + def _initialize_obj(self): + if self._obj is None: + self._obj = self._init_fn() + + def __getattr__(self, name): + self._initialize_obj() + return getattr(self._obj, name) + + def __setattr__(self, name, value): + if name in ["_init_fn", "_obj"]: + super().__setattr__(name, value) + else: + self._initialize_obj() + setattr(self._obj, name, value) + + def __delattr__(self, name): + self._initialize_obj() + delattr(self._obj, name) + + def __repr__(self): + if self._obj is None: + return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>" + return repr(self._obj) + + def __str__(self): + self._initialize_obj() + return str(self._obj) + + +class DriverConfig: + + def __init__(self): + self.default = LazyProxy(_create_driver) + self.active = self.default + + def set_active(self, driver: DriverBase): + self.active = driver + + def reset_active(self): + self.active = self.default + + +driver = DriverConfig() diff --git a/third_party/xpu/python/triton/runtime/errors.py b/third_party/xpu/python/triton/runtime/errors.py new file mode 100644 index 000000000..4dce91767 --- /dev/null +++ b/third_party/xpu/python/triton/runtime/errors.py @@ -0,0 +1,26 @@ +from ..errors import TritonError +from typing import Optional + + +class InterpreterError(TritonError): + + def __init__(self, error_message: Optional[str] = None): + self.error_message = error_message + + def __str__(self) -> str: + return self.error_message or "" + + +class OutOfResources(TritonError): + + def __init__(self, required, limit, name): + self.required = required + self.limit = limit + self.name = name + + def __str__(self) -> str: + return f"out of resource: {self.name}, Required: {self.required}, Hardware limit: {self.limit}. Reducing block sizes or `num_stages` may help." + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) diff --git a/third_party/xpu/python/triton/runtime/interpreter.py b/third_party/xpu/python/triton/runtime/interpreter.py new file mode 100644 index 000000000..a82832ecf --- /dev/null +++ b/third_party/xpu/python/triton/runtime/interpreter.py @@ -0,0 +1,1127 @@ +import inspect +from typing import Tuple + +import math +import numpy as np + +import triton +import triton.language as tl +from dataclasses import dataclass +from .errors import InterpreterError +from functools import partial +from .._C.libtriton import interpreter as _interpreter +from .._C.libtriton import ir as _ir + + +class TensorHandle: + + def __init__(self, data, dtype): + ''' + data: numpy array + dtype: triton type, either pointer_type or scalar_type. + we don't store block_type here because the shape information is already availale in the data field + attr: a dictionary of attributes + ''' + self.data = data + self.dtype = dtype + self.attr = {} + + def __bool__(self): + return bool(self.data.all()) + + def get_element_ty(self): + dtype = self.dtype + while hasattr(dtype, "element_ty"): + dtype = dtype.element_ty + return dtype + + def clone(self): + return TensorHandle(self.data.copy(), self.dtype) + + def set_attr(self, key, value): + self.attr[key] = value + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, tensor_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.tensor_shape = tensor_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.get_element_ty() + n_bytes = dtype_tt.primitive_bitwidth // 8 + tensor_shape = self.tensor_shape + ptrs = np.broadcast_to(self.base.data, self.tensor_shape) + masks = np.ones(self.tensor_shape, dtype=bool) + for dim in range(len(tensor_shape)): + bcast_dims = [1] * len(tensor_shape) + bcast_dims[dim] = tensor_shape[dim] + off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = np.logical_and(masks, off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +@dataclass(frozen=True) +class InterpreterOptions: + extern_libs: dict = None + debug: bool = False + arch: str = None + allow_fp8e4nv: bool = True + allow_fp8e4b15: bool = True + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: int = 0 + + +def _get_signed_np_dtype(dtype): + if dtype == np.uint8: + return np.int8 + if dtype == np.uint16: + return np.int16 + if dtype == np.uint32: + return np.int32 + if dtype == np.uint64: + return np.int64 + return dtype + + +def _get_np_dtype(tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.int1: np.dtype(bool), + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + # bfloat16 types are stored as uint16 + tl.bfloat16: np.dtype(np.uint16), + # float8 types are stored as uint8 + tl.float8e5: np.dtype(np.uint8), + tl.float8e5b16: np.dtype(np.uint8), + tl.float8e4nv: np.dtype(np.uint8), + tl.float8e4b8: np.dtype(np.uint8), + tl.float8e4b15: np.dtype(np.uint8), + } + if isinstance(tt_dtype, tl.block_type): + if isinstance(tt_dtype.element_ty, tl.pointer_type): + return np.dtype(np.uint64) + return np_types[tt_dtype.element_ty] + return np_types[tt_dtype] + + +def _convert_float(input, input_dtype, output_dtype, rounding_mode): + input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}") + output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}") + input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype) + sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01 + input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1 + output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1 + significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1) + bias_input = input_dtype.exponent_bias + bias_output = output_dtype.exponent_bias + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + subnormal_index = exponent == 0 + if np.any(subnormal_index): + # Credit to Phil: phil@openai.com + # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0)) + bit_pos = np.zeros_like(input_bin, dtype=np.int32) + # Find the most significant bit of the mantissa in the significand + for i in range(input_dtype.fp_mantissa_width): + bit_index = ((significand >> i) & 0x01) + # pos should be >= 1 + bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i + zero_significand_index = significand == 0 + exponent[subnormal_index] = 1 - bit_pos[subnormal_index] + # 0 significand and subnormal should be treated as 0 + exponent[zero_significand_index & subnormal_index] = bias_input - bias_output + significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( + (1 << input_dtype.fp_mantissa_width) - 1) + # Prevent overflow and underflow + exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_output = exponent_output.astype(output_unint_dtype) + sign_output = sign.astype(output_unint_dtype) + if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast + significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even + # find the cut-off bit + cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1)) + significand_output = significand_output + (cut_off > 0) + significand_output = significand_output.astype(output_unint_dtype) + else: # Upcast + significand_output = (significand.astype(output_unint_dtype) << + (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + subnormal_index = exponent_output == 0 + if np.any(subnormal_index): # underflow + # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # shift = (1 - exp_bias_output) - (exp - exp_bias_input) + # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift)) + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + non_zero_exponent_index = exponent != 0 + # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa + subnormal_index = subnormal_index & non_zero_exponent_index + shift = np.zeros_like(input_bin, dtype=np.int32) + shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) + significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( + 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( + exponent_output << output_dtype.fp_mantissa_width) | significand_output + return output.reshape(input.shape) + + +def _erf(x): + # Numpy does not support erf + return math.erf(x) + + +def _umulhi_64(a, b): + # Numpy does not support 128-bit multiplication + # So we have to implement it manually + return (int(a) * int(b)) >> 64 + + +np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32]) +np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64]) +np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64]) + + +class ExtraFunctions: + + @staticmethod + def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder): + return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty) + + +class InterpreterBuilder: + ir_sem_to_interpreter_sem = { + _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE, + _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE, + _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED, + _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE, + } + + ir_rmw_op_to_interpreter_rmw_op = { + _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD, + _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD, + _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN, + _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN, + _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX, + _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX, + _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND, + _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR, + _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR, + _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG, + } + + def __init__(self) -> None: + self.arch = None + self.options = InterpreterOptions() + self.codegen_fns = {} + self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + + def set_grid_idx(self, x, y, z): + if not x < self.grid_dim[0]: + raise ValueError("x >= grid_dim[0]") + if not y < self.grid_dim[1]: + raise ValueError("y >= grid_dim[1]") + if not z < self.grid_dim[2]: + raise ValueError("z >= grid_dim[2]") + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + # constants + + def get_half_ty(self): + return tl.float16 + + def get_bf16_ty(self): + return tl.bfloat16 + + def get_float_ty(self): + return tl.float32 + + def get_double_ty(self): + return tl.float64 + + def get_int8_ty(self): + return tl.int8 + + def get_uint8_ty(self): + return tl.uint8 + + def get_int16_ty(self): + return tl.int16 + + def get_uint16_ty(self): + return tl.uint16 + + def get_int32_ty(self): + return tl.int32 + + def get_uint32_ty(self): + return tl.uint32 + + def get_int64_ty(self): + return tl.int64 + + def get_uint64_ty(self): + return tl.uint64 + + def get_fp8e4nv_ty(self): + return tl.float8e4nv + + def get_fp8e4b15_ty(self): + return tl.float8e4b15 + + def get_fp8e4b8_ty(self): + return tl.float8e4b8 + + def get_fp8e5_ty(self): + return tl.float8e5 + + def get_fp8e5b16_ty(self): + return tl.float8e5b16 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.block_type(dtype, shape) + + def get_int1(self, value): + return TensorHandle(np.array([value], dtype=np.bool_), tl.int1) + + def get_uint8(self, value): + return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8) + + def get_int8(self, value): + return TensorHandle(np.array([value], dtype=np.int8), tl.int8) + + def get_uint16(self, value): + return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16) + + def get_int16(self, value): + return TensorHandle(np.array([value], dtype=np.int16), tl.int16) + + def get_uint32(self, value): + return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_uint64(self, value): + return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_fp64(self, value): + return TensorHandle(np.array([value], dtype=np.float64), tl.float64) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + if self.grid_idx is None: + raise ValueError("grid_idx is None") + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \ + (src_element_type == tl.float32 and dst_element_type == tl.bfloat16): + data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + else: + return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type, rounding_mode): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + + def create_idiv(self, lhs, rhs): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar) + + def create_ashr(self, lhs, rhs): + # Triton's rshift operator depends on the signedness of the left operand + lhs_dtype = _get_signed_np_dtype(lhs.data.dtype) + rhs_dtype = _get_signed_np_dtype(rhs.data.dtype) + lhs.data = lhs.data.astype(lhs_dtype) + rhs.data = rhs.data.astype(rhs_dtype) + return self.binary_op(lhs, rhs, np.right_shift) + + def create_umulhi(self, lhs, rhs): + dtype = lhs.data.dtype + if dtype == np.int64 or dtype == np.uint64: + return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar) + else: + compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}") + lhs_data = lhs.data.astype(compute_dtype) + rhs_data = rhs.data.astype(compute_dtype) + ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8) + return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar) + + create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + def create_fma(self, x, y, z): + return TensorHandle(x.data * y.data + z.data, z.dtype.scalar) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype.scalar) + + def create_fabs(self, arg): + # Mask out the sign bit based on the primitive length + dtype_tt = arg.dtype + mask_bitwidth = dtype_tt.primitive_bitwidth - 1 + np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}") + data = arg.data.view(np_uint_dtype) + mask = (1 << mask_bitwidth) - 1 + ret = (data & mask).view(_get_np_dtype(dtype_tt)) + return TensorHandle(ret, arg.dtype.scalar) + + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + create_floor = lambda self, arg: self.unary_op(arg, np.floor) + create_ceil = lambda self, arg: self.unary_op(arg, np.ceil) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_log2 = lambda self, arg: self.unary_op(arg, np.log2) + create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + + def create_erf(self, arg): + ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data) + return TensorHandle(ret, arg.dtype.scalar) + + def create_rsqrt(self, arg): + return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar) + + # tensor operators + create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar) + + def create_trans(self, arg, perm): + return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar) + + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): + a_data = a.data + b_data = b.data + if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \ + (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()): + a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16) + b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16) + return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + def create_histogram(self, data, bins): + return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.get_element_ty() + element_bitwidth = dtype_tt.primitive_bitwidth + # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic + element_bytewidth = max(1, element_bitwidth // 8) + return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if padding_option is None: + other = None + elif padding_option == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding_option == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding option {padding_option}") + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) + + def create_int_to_ptr(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_ptr_to_int(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty.scalar) + + def create_cat(self, lhs, rhs): + return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) + + def create_join(self, lhs, rhs): + # Triton only supports joining two original tensors into a new one along the last axis + return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar) + + def create_split(self, val): + # Triton only supports splitting the original tensor into two along the last axis + return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar)) + + def create_splat(self, arg, shape): + if isinstance(arg.dtype, tl.block_type): + return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + else: # scalar + return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + def create_atomic_cas(self, ptr, cmp, val, sem, scope): + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar) + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope): + if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op: + raise ValueError(f"unsupported rmwOp {rmwOp}") + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp] + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar) + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + raise NotImplementedError("extern_elementwise not supported in interpreter mode") + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + raise NotImplementedError("inline_asm not supported in interpreter mode") + + def create_print(self, prefix, hex, values): + # Interpreter's device_print function has a different format than Triton's device_print + msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" + if prefix: + msg += f" {prefix}" + if hex: + np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"}) + for value in values: + print(msg + f" {value.data}") + if hex: + np.set_printoptions(formatter=None) + + def create_assert(self, condition, message, fileName, funcName, lineNo): + # Interpreter's device_assert function has a different format than Triton's device_assert + assert condition, f"{message} in {fileName}:{funcName}:{lineNo}" + + def create_barrier(self): + # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in offsets] + return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order) + + def create_advance(self, ptr, offsets): + if len(ptr.offsets) != len(offsets): + raise ValueError("len(ptr.offsets) != len(offsets)") + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in ptr.offsets] + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + def get_all_ones_value(self, type): + np_type = _get_np_dtype(type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar) + else: + raise TypeError(f"unsupported type {type}") + + +def _patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_builder"}, _builder=builder)) + setattr(obj, name, new_member) + + +def _patch_builtin(pkg, builder): + for name, member in inspect.getmembers(pkg): + if tl.core.is_builtin(member): + _patch_attr(pkg, name, member, builder) + + +def _patch_lang_tensor(tensor): + + def _get_bool(self): + data = self.handle.data + # in triton, only scalars can be converted to booleans + # here we need this hack because all scalars are tensors + return bool(data) if data.size == 1 else True + + def _get_transpose(self): + return tl.core.tensor(TensorHandle(np.transpose(self.handle.data), self.handle.dtype), self.dtype.scalar) + + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: _get_bool(self) + tensor.__repr__ = lambda self: repr(self.handle.data) + tensor.__str__ = lambda self: str(self.handle.data) + tensor.T = property(_get_transpose) + + +class ReduceScanOpIneterface: + + def __init__(self, axis, combine_fn): + self.axis = axis + self.combine_fn = combine_fn + + def check_axis(self, shape, axis): + if axis is not None and axis >= len(shape): + raise ValueError(f"axis {axis} out of bounds for shape {shape}") + + def check_tensor(self, input): + for arg in input: + if not isinstance(arg, tl.core.tensor): + raise ValueError(f"input must be a tensor, got {type(arg)}") + self.check_axis(arg.shape, self.axis) + + def to_tensor(self, ret, dtype): + if hasattr(ret, "shape") and ret.shape: + ret_type = tl.block_type(dtype, ret.shape) + else: + ret = np.array([ret], dtype=_get_np_dtype(dtype)) + ret_type = dtype + return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) + + def apply(self, input): + if not isinstance(input, tuple): + input = (input, ) + self.check_tensor(input) + return self.apply_impl(input) + + def apply_impl(self, input): + raise NotImplementedError("apply_impl not implemented") + + +class ReduceOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, keep_dims): + super().__init__(axis, combine_fn) + self.keep_dims = keep_dims + + def unravel(self, input, axis): + ret = [] + for data in input: + if axis is not None: + ret.append(data) + else: + axis = 0 + ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype)) + return tuple(ret), axis + + def generic_reduce(self, input): + original_axis = self.axis + input, axis = self.unravel(input, self.axis) + input_data = [] + output_data = [] + input_shape = input[0].handle.data.shape + output_shape = input_shape[0:axis] + input_shape[axis + 1:] + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype)) + # Reduce on axis + for i in range(input_data[0].size): + # Recover input_index from i using input_shape + input_index = np.unravel_index(i, input_shape) + output_index = input_index[0:axis] + input_index[axis + 1:] + input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data)) + if input_index[axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][output_index] = input_tuple[j].handle.data.item() + else: + acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + if self.keep_dims: + if original_axis is not None: + data = np.expand_dims(data, axis) + else: + for _ in range(len(input_shape)): + data = np.expand_dims(data, 0) + + elif original_axis is None: + # Take a scalar + data = data.item() + ret.append(self.to_tensor(data, input[i].dtype)) + return ret[0] if len(ret) == 1 else tuple(ret) + + def min_max(self, input, val_reduce_op, idx_reduce_op=None): + # If input is a tuple, it must be (val, index), and we only take val + input = input[0] if isinstance(input, tuple) else input + val = None + idx = None + if val_reduce_op: + val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + if idx_reduce_op: + idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32) + if val is not None and idx is not None: + return val, idx + elif val is not None: + return val + elif idx is not None: + return idx + else: + raise ValueError("val_reduce_op and idx_reduce_op are both None") + + def sum(self, input): + return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + + def apply_impl(self, input): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + elif self.combine_fn == tl.standard._elementwise_max: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input) + + +class ScanOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, reverse): + super().__init__(axis, combine_fn) + self.reverse = reverse + + def cumsum(self, input): + return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def cumprod(self, input): + return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def generic_scan(self, input): + input_data = [] + output_data = [] + shape = input[0].handle.data.shape + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype)) + # Scan on axis + for i in range(input_data[0].size): + # Recover index from i using shape + index = np.unravel_index(i, shape) + data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data)) + if index[self.axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][index] = data[j].handle.data.item() + else: + prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index))) + acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def apply_impl(self, input): + new_input = [] + if self.reverse: + for arg in input: + new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype)) + else: + new_input = input + if self.combine_fn == tl.standard._sum_combine: + ret = self.cumsum(new_input[0]) + elif self.combine_fn == tl.standard._prod_combine: + ret = self.cumprod(new_input[0]) + else: + # Fall back to the slow mode + ret = self.generic_scan(new_input) + if self.reverse: + for arg in ret: + arg.handle.data = np.flip(arg.handle.data, axis=self.axis) + return len(ret) == 1 and ret[0] or tuple(ret) + + +def _patch_reduce_scan(): + # Because interpreter doesn't support region_builder_fn, we cannot patch the builder + # to use the new reduce and scan functions. + # Instead, we need to patch reduce and reduce functions in tl and tl.core + def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs): + return ReduceOps(axis, combine_fn, keep_dims).apply(input) + + def _new_scan(input, axis, combine_fn, reverse=False, **kwargs): + return ScanOps(axis, combine_fn, reverse).apply(input) + + tl.reduce = _new_reduce + tl.associative_scan = _new_scan + tl.core.reduce = _new_reduce + tl.core.associative_scan = _new_scan + + +def _patch_lang_core(lang): + + def _new_to_ir(self, builder): + # We need to specify signedness for integer types in the numpy mode + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8': + return builder.get_int8_ty() + elif self.name == 'uint8': + return builder.get_uint8_ty() + elif self.name == 'int16': + return builder.get_int16_ty() + elif self.name == 'uint16': + return builder.get_uint16_ty() + elif self.name == 'int32': + return builder.get_int32_ty() + elif self.name == 'uint32': + return builder.get_uint32_ty() + elif self.name == 'int64': + return builder.get_int64_ty() + elif self.name == 'uint64': + return builder.get_uint64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + # can't just map lang.static_range to `range`, because `tl.static_range` + # can get `step` passed by keyword + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_static_assert(cond, msg=""): + assert cond, msg + + def _set_attr(input, values, name): + # skip non tensor types. This may happen for induction variables. + if not isinstance(input, tl.tensor): + return input + # Unwrap constexpr + values = [values] if not isinstance(values, (list, tuple)) else values + values = [v.value if isinstance(v, tl.constexpr) else v for v in values] + if len(values) != max(1, len(input.shape)): + raise ValueError(f"len(values) != len(input.shape) for {name}") + input.handle.set_attr(name, values) + return input + + lang.range = _new_range + lang.static_range = _new_range + lang.static_assert = _new_static_assert + lang.static_print = print + lang.dtype.to_ir = _new_to_ir + lang.multiple_of = partial(_set_attr, name="tt.divisiblity") + lang.max_contiguous = partial(_set_attr, name="tt.contiguity") + lang.max_constancy = partial(_set_attr, name="tt.constancy") + + _patch_reduce_scan() + + +def _patch_lang(fn): + lang = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_builtin(lang[0], interpreter_builder) + _patch_builtin(lang[0].tensor, interpreter_builder) + if lang[0] == tl: + _patch_builtin(lang[0].math, interpreter_builder) + _patch_lang_tensor(lang[0].tensor) + _patch_lang_core(lang[0]) + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + dtype = np.int32 + if -2**31 <= arg < 2**31: + dtype = np.int32 + elif 2**31 <= arg < 2**32: + dtype = np.uint32 + elif -2**63 <= arg < 2**63: + dtype = np.int64 + elif 2**63 <= arg < 2**64: + dtype = np.uint64 + else: + raise ValueError(f"Unsupported integer value {arg}") + handle = TensorHandle(np.array([arg], dtype=dtype), ty) + return tl.tensor(handle, ty) + if hasattr(arg, "data_ptr"): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + return arg + + +interpreter_builder = InterpreterBuilder() + +# These keywords are not supported by the interpreter +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] + + def _init_args_hst(self, args_dev, kwargs): + args_hst = [] + for arg in args_dev: + if hasattr(arg, "data_ptr"): + args_hst.append(arg.cpu()) + else: + args_hst.append(arg) + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + if hasattr(value, "data_ptr"): + kwargs_hst[key] = value.cpu() + else: + kwargs_hst[key] = value + return args_hst, kwargs_hst + + def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): + for arg_dev, arg_hst in zip(args_dev, args_hst): + if hasattr(arg_dev, "data_ptr"): + arg_dev.data.copy_(arg_hst.to(arg_dev.device).data) + + # Restore keyword arguments + for key, kwarg_dev in kwargs.items(): + kwarg_hst = kwargs_hst[key] + if hasattr(kwarg_dev, "data_ptr"): + kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data) + + def __call__(self, *args_dev, **kwargs): + # removes reserved keywords from kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + if kwargs.pop("warmup", False): + return + # copy arguments to the host + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + # remaps core language functions to interpreted ones + _patch_lang(self.fn) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3, "grid must have at most 3 dimensions" + grid = grid + (1, ) * (3 - len(grid)) + interpreter_builder.set_grid_dim(*grid) + try: + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + interpreter_builder.set_grid_idx(x, y, z) + self.fn(**args) + except Exception as e: + raise InterpreterError(repr(e)) from e + # copy arguments back to propagate side-effects + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + + +class InterpretedFunction: + + def __init__(self, fn) -> None: + self.fn = fn + + def run(*args, **kwargs): + grid = kwargs["grid"] + return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + @property + def __name__(self): + return self.fn.__name__ + + def __getitem__(self, grid): + return GridExecutor(self.fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + # This is a device function call + _patch_lang(self.fn) + try: + return self.fn(*args, **kwargs) + except Exception as e: + raise InterpreterError(repr(e)) from e diff --git a/third_party/xpu/python/triton/runtime/jit.py b/third_party/xpu/python/triton/runtime/jit.py new file mode 100644 index 000000000..80e24cc40 --- /dev/null +++ b/third_party/xpu/python/triton/runtime/jit.py @@ -0,0 +1,967 @@ +from __future__ import annotations, division +import ast +import hashlib +import inspect +import itertools +import os +import re +import textwrap +from collections import defaultdict +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple +from ..runtime.driver import driver +from types import ModuleType + + +# ===-------------------- For XPytorch Inductor -----------------------=== +# XPytorch Inductor (v2.0.1) Use +# from triton.runtime.jit import get_cuda_stream, KernelInterface +def get_cuda_stream(): + from torch._C import _cuda_getCurrentRawStream + return _cuda_getCurrentRawStream() + + +# ===------------------------------------------------------------------=== + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + val = self.globals.get(node.id, None) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) != ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins # + ): + self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + + def is_triton_builtin(func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + func = self.visit(node.func) + assert func is None or is_triton_builtin(func) or isinstance( + func, JITFunction + ), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this' + + # Traverse arguments as well as node.func so we can find JITFunctions + # passed to tl.reduce or tl.associative_scan as the combine_fn + for obj in itertools.chain( + (func, ), + map(self.visit, node.args), + (self.visit(kw.value) for kw in node.keywords), + ): + if not isinstance(obj, JITFunction): + continue + if is_triton_builtin(obj): + continue + + func_cache_key = obj.cache_key + + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & obj.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = obj.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + + self.used_global_vals.update(obj.used_global_vals) + + noinline = str(getattr(obj, "noinline", False)) + + key = func_cache_key + noinline + self.hasher.update(key.encode("utf-8")) + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + if isinstance(ty, type): + return ty.__name__ + elif isinstance(ty, str): + return ty + return repr(ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self): + annotation = self.annotation + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +def compute_spec_key(v): + + if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): + return "D" + elif isinstance(v, int): + # bool is a subclass of int, so we don't check explicitly above. + if (v % 16 == 0): + return "D" + elif v == 1: + return "1" + return "N" + + +dtype2str = {} + + +def mangle_type(arg, is_const=False): + + if arg is None: + return "none" + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + else: + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] + dtype2str[dsk] = res + return res + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': + options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + + assert len(sig.parameters) == len(kparams) + + # Create the function argument list and the dict entries for the return statement + func_args = [] + dict_entries = [] + constexpr_vals = [] + non_constexpr_vals = [] + signature_types = [] + specialisations = [] + + for ((name, sp), kp) in zip(sig.parameters.items(), kparams): + if sp.default is inspect.Parameter.empty: + func_args.append(name) + dict_entries.append(f"'{name}': {name}") + else: + func_args.append(f"{name}=default_{name}") + dict_entries.append(f"'{name}': {name}") + if kp.is_constexpr: + constexpr_vals.append(name) + else: + non_constexpr_vals.append(name) + if not kp.do_not_specialize: + specialisations.append('compute_spec_key(%s)' % name) + if kp.annotation_type: + signature_types.append('"%s"' % kp.annotation_type) + else: + signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False')) + + cache_key = ''.join([x + ', ' for x in signature_types + specialisations]) + constexpr_vals = ''.join([x + ', ' for x in constexpr_vals]) + non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals]) + + func_args.append('**excess_kwargs') + + # Join all arguments into a function definition string + args_str = ', '.join(func_args) + dict_str = ', '.join(dict_entries) + func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % ( + args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals) + + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace['mangle_type'] = mangle_type + func_namespace['compute_spec_key'] = compute_spec_key + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +type_canonicalisation_dict = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +class JITFunction(KernelInterface[T]): + # Hook for inspecting compiled functions and modules + cache_hook = None + divisibility = 16 + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif arg is None: + return None + else: + raise TypeError(f"Unsupported type {type(arg)} for {arg}") + + @staticmethod + def _spec_of(arg): + if hasattr(arg, "data_ptr"): + return arg.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(arg, int): + return (arg % 16 == 0, arg == 1) + return (arg is None, ) + + def _get_config(self, *args): + from ..compiler import AttrsDescriptor + + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(x, int): + return x % JITFunction.divisibility == 0 + if x is None: + return True + return False + + divisible_by_16 = { + param.num + for param, arg in zip(self.params, args) + if is_divisible_by_16(arg) and not param.do_not_specialize + } + equal_to_1 = { + param.num + for param, arg in zip(self.params, args) + if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize + } + # folded equal_to_1 and None + # TODO: method to collect all folded args + return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, + # equal_to_1) + + @staticmethod + def _type_of(key, is_const=False): + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + elif isinstance(key, str): + return key + + dtype_str = str(key).split(".")[-1] + dtype_str = type_canonicalisation_dict[dtype_str] + const_str = "*k" if is_const else "*" + return const_str + dtype_str + + def _make_constants(self, constexpr_key): + constants = dict(zip(self.constexprs, constexpr_key)) + return constants + + def _call_hook( + self, + key, + signature, + device, + constants, + options, + configs, + ): + if JITFunction.cache_hook is None: + return False + + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" + + class JitFunctionInfo: + + def __init__(self, module, name, jit_function): + self.module = module + self.name = name + self.jit_function = jit_function + pass + + specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + } + + return JITFunction.cache_hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=False, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + self.make_backend = make_backend + self.binder = create_function_from_signature(self.signature, self.params) + self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] + self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] + self.specialised_indices = [ + i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr) + ] + + def run(self, *args, grid, warmup, **kwargs): + # parse options + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + kwargs["debug"] = self.debug + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + if self.binder is None: + self.create_binder() + + bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) + + # compute cache key + key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) + kernel = self.cache[device].get(key, None) + + if kernel is None: + # Kernel is not cached; we have to compile. + target = driver.active.get_current_target() + backend = self.make_backend(target) + options = backend.parse_options(kwargs) + + # deprecated arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: + if k not in options.__dict__: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + + bound_vals = tuple(bound_args.values()) + + # `None` is nullptr. Implicitly convert to *i8. This needs to be + # done here rather than when we build the signature as otherwise + # the kernel cache key could not distinguish between byte pointers + # and None arguments, resulting in a downstream mismatch: + sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigvals = sig_and_spec[:len(sigkeys)] + signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + + configs = (self._get_config(*bound_vals), ) + constants = { + p.name: v + for (v, p) in zip(bound_vals, self.params) + if p.is_constexpr or p.num in configs[0].equal_to_1 or v is None + } + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + + if self._call_hook(key, signature, device, constants, options, configs): + return None + # compile the kernel + src = self.ASTSource(self, signature, constants, configs[0]) + kernel = self.compile( + src, + target=target, + options=options.__dict__, + ) + self.cache[device][key] = kernel + + # Check that used global values have not changed. + not_present = object() + for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) + return kernel + + def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None, repr=None, + launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.starting_line_number = inspect.getsourcelines(fn)[1] + self.repr = lambda _: fn.__name__ if repr is None else repr(_) + self.launch_metadata = launch_metadata + + self.binder = None + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize) + self.params.append(KernelParam(i, param, dns)) + + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():] + # cache of just-in-time compiled kernels + self.cache = defaultdict(dict) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler import AttrsDescriptor, compile, ASTSource + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self.fn.__name__: + raise RuntimeError( + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in deserialized_obj['constants'].items() + } + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.cache[device][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == "src": + self.hash = None + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from .interpreter import InterpretedFunction + return InterpretedFunction(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, + ) + + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() + + def stride(self, i): + return self.base.stride(i) + + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" + + def element_size(self): + return self.base.element_size() + + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) + + def copy_(self, other): + self.base.copy_(other.base) + + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") diff --git a/third_party/xpu/python/triton/testing.py b/third_party/xpu/python/triton/testing.py new file mode 100644 index 000000000..86092f098 --- /dev/null +++ b/third_party/xpu/python/triton/testing.py @@ -0,0 +1,493 @@ +import functools +import os +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + """ + import torch + assert return_mode in ["min", "max", "mean", "median"] + + if torch.cuda.current_stream() == torch.cuda.default_stream(): + raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.") + # warmup + fn() + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + fn() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for i in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for i in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + times = torch.tensor(ret) + return getattr(torch, return_mode)(times).item() + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + """ + assert return_mode in ["min", "max", "mean", "median"] + import torch + + fn() + torch.cuda.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.cuda.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + color=None, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + html.close() + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/xpu/python/triton/tools/__init__.py b/third_party/xpu/python/triton/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/xpu/python/triton/tools/build_extern.py b/third_party/xpu/python/triton/tools/build_extern.py new file mode 100644 index 000000000..8f0168d59 --- /dev/null +++ b/third_party/xpu/python/triton/tools/build_extern.py @@ -0,0 +1,365 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + + +class Symbol: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) + + @property + def name(self) -> str: + return self._name + + @property + def op_name(self) -> str: + return self._op_name + + @property + def ret_type(self) -> str: + return self._ret_type + + @property + def arg_names(self) -> List[str]: + return self._arg_names + + @property + def arg_types(self) -> List[str]: + return self._arg_types + + +def convert_type(type_str) -> Optional[str]: + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str) -> str: + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = format + self._grouping = grouping + + @property + def name(self) -> str: + return self._name + + @property + def path(self) -> str: + return self._path + + @property + def symbols(self) -> Dict[str, Symbol]: + return self._symbols + + @property + def grouping(self) -> bool: + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file) -> None: + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir) -> None: + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + self.is_pure = True + + @staticmethod + def _extract_symbol(line) -> Optional[Symbol]: + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self) -> None: + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file) -> None: + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self) -> str: + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return core.extern_elementwise("libdevice", , , , _builder) + import_str = "from . import core\n" + + header_str = "" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@core.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += f", is_pure={self.is_pure}" + return_str += ", _builder=_builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + _path: str + _ll_file: str + + def __init__(self, path) -> None: + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path: str) -> None: + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self) -> str: + return self._ll_file + + @property + def path(self) -> str: + return self._path + + +extern_libs = ["libdevice"] + + +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/third_party/xpu/python/triton/tools/compile.c b/third_party/xpu/python/triton/tools/compile.c new file mode 100644 index 000000000..971bf6191 --- /dev/null +++ b/third_party/xpu/python/triton/tools/compile.c @@ -0,0 +1,67 @@ +/* clang-format off */ +#include +#include +#include +#include +#include + + +// helpers to check for cuda errors +#define CUDA_CHECK(ans) {{\ + gpuAssert((ans), __FILE__, __LINE__);\ + }}\ + +static inline void gpuAssert(CUresult code, const char *file, int line) {{ + if (code != CUDA_SUCCESS) {{ + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + printf("%s\\n", err); + exit(code); + }} +}} + +// globals +#define CUBIN_NAME {kernel_name}_cubin +CUmodule {kernel_name}_mod = NULL; +CUfunction {kernel_name}_func = NULL; +unsigned char CUBIN_NAME[{bin_size}] = {{ {bin_data} }}; + + +void unload_{kernel_name}(void) {{ + CUDA_CHECK(cuModuleUnload({kernel_name}_mod)); +}} + +// TODO: some code duplication with `runtime/backend/cuda.c` +void load_{kernel_name}() {{ + int dev = 0; + void *bin = (void *)&CUBIN_NAME; + int shared = {shared}; + CUDA_CHECK(cuModuleLoadData(&{kernel_name}_mod, bin)); + CUDA_CHECK(cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}")); + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev)); + if (shared > 49152 && shared_optin > 49152) {{ + CUDA_CHECK(cuFuncSetCacheConfig({kernel_name}_func, CU_FUNC_CACHE_PREFER_SHARED)); + CUDA_CHECK(cuFuncSetAttribute({kernel_name}_func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)) + }} +}} + +/* +{kernel_docstring} +*/ +CUresult {kernel_name}(CUstream stream, {signature}) {{ + if ({kernel_name}_func == NULL) + load_{kernel_name}(); + unsigned int gX = {gridX}; + unsigned int gY = {gridY}; + unsigned int gZ = {gridZ}; + void *args[{num_args}] = {{ {arg_pointers} }}; + // TODO: shared memory + if(gX * gY * gZ > 0) + return cuLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * 32, 1, 1, {shared}, stream, args, NULL); +}} diff --git a/third_party/xpu/python/triton/tools/compile.h b/third_party/xpu/python/triton/tools/compile.h new file mode 100644 index 000000000..d98b7063b --- /dev/null +++ b/third_party/xpu/python/triton/tools/compile.h @@ -0,0 +1,14 @@ +#ifndef TT_KERNEL_INCLUDES +#define TT_KERNEL_INCLUDES + +#include +#include +#include +#include + +#endif + +void unload_{kernel_name}(void); +void load_{kernel_name}(void); +// tt-linker: {kernel_name}:{full_signature}:{algo_info} +CUresult{_placeholder} {kernel_name}(CUstream stream, {signature}); diff --git a/third_party/xpu/python/triton/tools/compile.py b/third_party/xpu/python/triton/tools/compile.py new file mode 100644 index 000000000..a19758149 --- /dev/null +++ b/third_party/xpu/python/triton/tools/compile.py @@ -0,0 +1,193 @@ +import binascii +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from pathlib import Path +from typing import List + +import triton +from triton.compiler.code_generator import kernel_suffix +from triton.backends.xpu.driver import get_xpu_spec + +desc = """ +Triton ahead-of-time compiler: + +This program compiles the kernel with name `kernel-name` in the file at the +provided `path` into self-contained C source-code that embeds the `cubin` +data along with utilities to load, unload and launch the kernel. + +signature is provided as a list of (optionally divisibility-hinted) types +or constexpr values, e.g. + +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` + +will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. +Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, +and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. + +The resulting entry point will have signature + +CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) + +Different such specialized entry points can be combined using the `linker.py` script. + +NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter +used to run this `compile.py` script +""" + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "XPUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +if __name__ == "__main__": + + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) + parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") + parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") + parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) + args = parser.parse_args() + + out_name = args.out_name if args.out_name else args.kernel_name + out_path = args.out_path if args.out_path else Path(out_name) + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + return None + + hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {k: v for k, v in hints.items() if v is not None} + constants = {i: constexpr(s) for i, s in enumerate(signature)} + constants = {k: v for k, v in constants.items() if v is not None} + signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constants} + const_sig = 'x'.join([str(v) for v in constants.values()]) + doc_string = [f"{kernel.arg_names[i]}={constants[i]}" for i in constants.keys()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + divisible_by_16 = [i for i, h in hints.items() if h == 16] + equal_to_1 = [i for i, h in hints.items() if h == 1] + attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + for i in equal_to_1: + constants.update({i: 1}) + src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} + ccinfo = triton.compile(src, options=opts) + arg_names = [] + arg_types = [] + for i in signature.keys(): + if i not in equal_to_1: + arg_names += [kernel.arg_names[i]] + arg_types += [signature[i]] + + # dump C stub code + suffix = kernel_suffix(signature.values(), attrs) + func_name = '_'.join([out_name, sig_hash, suffix]) + hex_ = str(binascii.hexlify(ccinfo.asm["xpubin"]))[2:-1] + xpu_arch = ccinfo.metadata.xpu_arch + is_sdnn = ccinfo.metadata.is_sdnn + + def generate_argument_set_code(): + newline = "\n " + eightBytesTypes = ['XPUdeviceptr', 'int64_t', 'uint64_t', 'double'] + lines = [] + for i, arg in enumerate(arg_names): + is_align_to_8 = (ty_to_cpp(arg_types[i]) in eightBytesTypes) and (xpu_arch == 3) + if is_align_to_8: + offset_align_to_8_line = "offset = alignSizeTo8Bytes(offset);" + lines.append(offset_align_to_8_line) + align_fn = "alignSizeTo8Bytes" if is_align_to_8 else "alignSizeTo4Bytes" + xpu_check_line = f"XPU_CHECK(xpu_launch_argument_set(&{arg}, sizeof({arg}), offset));" + offset_increment_line = f"offset += {align_fn}(sizeof({arg}));" + lines.append(f"{xpu_check_line} {offset_increment_line}") + return newline.join(lines) + + params = { + "kernel_name": func_name, + "triton_kernel_name": args.kernel_name, + "bin_size": len(hex_), + "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), + "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]), + "full_signature": ", ".join([f"{ty_to_cpp(signature[i])} {kernel.arg_names[i]}" for i in signature.keys()]), + "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]), + "num_args": len(arg_names), + "kernel_docstring": doc_string, + "shared": ccinfo.metadata.shared, + "num_warps": args.num_warps, + "algo_info": '_'.join([const_sig, meta_sig]), + "gridX": grid[0], + "gridY": grid[1], + "gridZ": grid[2], + "_placeholder": "", + "ewt_data": "", + "nclusters": get_xpu_spec(xpu_arch, is_sdnn)[0], + "ncores": get_xpu_spec(xpu_arch, is_sdnn)[1], + "kernel_type": "KT_SDCDNN" if is_sdnn else "KT_CLUSTER", + "argument_set_code": generate_argument_set_code(), + "load_ewtable_code": "", + "printf_buf_offset": 0, + } + for ext in ['h', 'c']: + template_path = Path(__file__).parent / f"compile_xpu.{ext}" + with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp: + fp.write(Path(template_path).read_text().format(**params)) diff --git a/third_party/xpu/python/triton/tools/compile_xpu.c b/third_party/xpu/python/triton/tools/compile_xpu.c new file mode 100644 index 000000000..ab744ccc6 --- /dev/null +++ b/third_party/xpu/python/triton/tools/compile_xpu.c @@ -0,0 +1,160 @@ +/* clang-format off */ +#include +#include +#include +#include +#include + +#define XPUdeviceptr int64_t + +// helpers to check for xpu errors +#define XPU_CHECK(ans) {{\ + xpuAssert((ans), __FILE__, __LINE__);\ + }}\ + +static inline void xpuAssert(int code, const char *file, int line) {{ + if (code != XPU_SUCCESS) {{ + const char *prefix = "Triton Error [XPU]: "; + const char *str = xpu_strerror(code); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + printf("%s\\n", err); + exit(code); + }} +}} + +static inline uint32_t checksum(const unsigned char *data, size_t length) {{ + uint32_t crc32 = 0; + for (size_t i = 0; i < length; ++i) + crc32 += static_cast(data[i]); + return crc32; +}} + +static inline size_t alignSizeTo4Bytes(size_t size) {{ + return (size + 3) & ~3; +}} + +static inline size_t alignSizeTo8Bytes(size_t size) {{ + return (size + 7) & ~7; +}} + +static inline int min(int a, int b) {{ + return a < b ? a : b; +}} + +// XPU Kernel type +enum kernel_type {{ + KT_CLUSTER = 0, + KT_SDCDNN = 1, +}}; + +// Place of XPU kernel binary +enum kernel_place {{ + KP_CPU = 0, + KP_XPU = 1, +}}; + +// XPU Kernel +struct xpu_kernel {{ + uint32_t type : 16; + uint32_t place : 16; + uint64_t code_addr; + uint32_t code_byte_size; + uint32_t code_pc; + uint32_t param_dword_size; + uint64_t hash; + const char *name; + void *rt_private; + uint64_t printf_buffer_offset; +}}; + +static int __xpu_create_func(XPUFunc *pfunc, int type, uint64_t code_addr, + uint32_t code_bsz, uint32_t code_pc, + uint32_t param_dsz, uint64_t hash, + const char *name, bool on_xpu, + uint64_t printf_buf_offset) {{ + if (pfunc == NULL) + return -XPUERR_INVALID_PARAM; + + struct xpu_kernel *kern = new struct xpu_kernel(); + kern->type = type; + kern->place = (on_xpu) ? KP_XPU : KP_CPU; + kern->code_addr = code_addr; + kern->code_byte_size = code_bsz; + kern->code_pc = code_pc; + kern->param_dword_size = param_dsz; + kern->hash = hash; + kern->name = name; + kern->printf_buffer_offset = printf_buf_offset; + + *pfunc = kern; + + return 0; +}} + +// globals +#define XPUBIN_NAME {kernel_name}_xpubin +XPUFunc {kernel_name}_func = NULL; +unsigned char XPUBIN_NAME[{bin_size}] = {{ {bin_data} }}; +void *{kernel_name}_ewt_gptr = NULL; +unsigned char {kernel_name}_ewt_data[] = {{ {ewt_data} }}; + +static inline void *loadEWTable() {{ + void *gmptr = {kernel_name}_ewt_gptr; + if (gmptr) {{ + void *data = (void *)&{kernel_name}_ewt_data; + size_t size = sizeof({kernel_name}_ewt_data); + XPU_CHECK(xpu_malloc((void **)gmptr, size)); + XPU_CHECK(xpu_memcpy(gmptr, data, size, XPU_HOST_TO_DEVICE)); + }} + return gmptr; +}} + +void unload_{kernel_name}(void) {{ + XPU_CHECK(xpu_free({kernel_name}_ewt_gptr)); + {kernel_name}_ewt_gptr = NULL; + delete (struct xpu_kernel *){kernel_name}_func; + {kernel_name}_func = NULL; +}} + +void load_{kernel_name}() {{ + void *bin = (void *)&XPUBIN_NAME; + // Create XPUFunc + int type = {kernel_type}; + uint64_t code_addr = reinterpret_cast(bin); + uint32_t code_byte_size = static_cast({bin_size}); + uint32_t code_pc = 0; + uint32_t param_dword_size = 0; + uint32_t hash = checksum(XPUBIN_NAME, {bin_size}); + bool on_xpu = false; + + XPU_CHECK(__xpu_create_func(&{kernel_name}_func, type, code_addr, + code_byte_size, code_pc, param_dword_size, hash, + "{kernel_name}", on_xpu, {printf_buf_offset})); +}} + +/* +{kernel_docstring} +*/ +int {kernel_name}(XPUStream stream, {signature}) {{ + if ({kernel_name}_func == NULL) + load_{kernel_name}(); + + unsigned int gridX = {gridX}; + unsigned int gridY = {gridY}; + unsigned int gridZ = {gridZ}; + if(gridX * gridY * gridZ > 0) {{ + size_t offset = 0; + {argument_set_code} + {load_ewtable_code} + XPU_CHECK(xpu_launch_argument_set(&gridX, sizeof(gridX), offset+0)); + XPU_CHECK(xpu_launch_argument_set(&gridY, sizeof(gridY), offset+4)); + XPU_CHECK(xpu_launch_argument_set(&gridZ, sizeof(gridZ), offset+8)); + XPU_CHECK(xpu_launch_config(min(gridX*gridY*gridZ, {nclusters}), {ncores})); + XPU_CHECK(xpu_launch_async({kernel_name}_func)); + return 0; + }} + + return -1; +}} diff --git a/third_party/xpu/python/triton/tools/compile_xpu.h b/third_party/xpu/python/triton/tools/compile_xpu.h new file mode 100644 index 000000000..e828287d1 --- /dev/null +++ b/third_party/xpu/python/triton/tools/compile_xpu.h @@ -0,0 +1,17 @@ +#ifndef TT_KERNEL_INCLUDES +#define TT_KERNEL_INCLUDES + +#include +#include +#include + +#define CUresult int +#define CUstream XPUStream +#define CUdeviceptr int64_t + +#endif + +void unload_{kernel_name}(void); +void load_{kernel_name}(void); +// tt-linker: {kernel_name}:{full_signature}:{algo_info} +CUresult{_placeholder} {kernel_name}(CUstream stream, {signature}); diff --git a/third_party/xpu/python/triton/tools/disasm.py b/third_party/xpu/python/triton/tools/disasm.py new file mode 100644 index 000000000..1e309a2e4 --- /dev/null +++ b/third_party/xpu/python/triton/tools/disasm.py @@ -0,0 +1,142 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import functools +import os +import re +import subprocess +import tempfile + +from ..common.backend import path_to_cuobjdump, path_to_nvdisasm + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) is not None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +@functools.lru_cache() +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + +def extract(file_path, fun): + cuobjdump, _ = path_to_cuobjdump() + nvdisasm, _ = path_to_nvdisasm() + os.environ["NVDISASM_PATH"] = nvdisasm + if fun is None: + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) + else: + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + + # Looking for new function header (function: ) + while FNAME_RE.match(line) is None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) is not None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convention: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret diff --git a/third_party/xpu/python/triton/tools/link.py b/third_party/xpu/python/triton/tools/link.py new file mode 100644 index 000000000..350e16ab0 --- /dev/null +++ b/third_party/xpu/python/triton/tools/link.py @@ -0,0 +1,327 @@ +from collections import defaultdict +from pathlib import Path +from typing import Sequence, Union + +from dataclasses import dataclass + + +def _exists(x): + return x is not None + + +class LinkerError(Exception): + pass + + +@dataclass +class KernelLinkerMeta: + orig_kernel_name: str + arg_names: Sequence[str] + arg_ctypes: Sequence[str] + sizes: Sequence[Union[int, None]] + sig_hash: str + triton_suffix: str + suffix: str + num_specs: int + """ number of specialized arguments """ + + +class HeaderParser: + + def __init__(self) -> None: + import re + + # [kernel_name, c signature] + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") + # [name, hash, suffix] + self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") + # [(type, name)] + self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") + # [d|c] + self.arg_suffix = re.compile("[c,d]") + + self.kernels = defaultdict(list) + + def extract_linker_meta(self, header: str): + for ln in header.splitlines(): + if ln.startswith("//"): + m = self.linker_directives.match(ln) + if _exists(m): + ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3) + name, sig_hash, suffix = self._match_name(ker_name) + c_types, arg_names = self._match_c_sig(c_sig) + num_specs, sizes = self._match_suffix(suffix, c_sig) + self._add_kernel( + "_".join([name, algo_info]), + KernelLinkerMeta( + orig_kernel_name=name, + arg_names=arg_names, + arg_ctypes=c_types, + sizes=sizes, + sig_hash=sig_hash, + triton_suffix=suffix, + suffix=suffix, + num_specs=num_specs, + ), + ) + + def _match_name(self, ker_name: str): + m = self.kernel_name.match(ker_name) + if _exists(m): + name, sig_hash, suffix = m.group(1), m.group(2), m.group(3) + return name, sig_hash, suffix + raise LinkerError(f"{ker_name} is not a valid kernel name") + + def _match_c_sig(self, c_sig: str): + m = self.c_sig.findall(c_sig) + if len(m): + tys, args = [], [] + for ty, arg_name in m: + tys.append(ty) + args.append(arg_name) + return tys, args + + raise LinkerError(f"{c_sig} is not a valid argument signature") + + def _match_suffix(self, suffix: str, c_sig: str): + args = c_sig.split(",") + s2i = {"c": 1, "d": 16} + num_specs = 0 + sizes = [] + # scan through suffix, first find the index, + # then see if it is followed by d or c + for i in range(len(args)): + pos = suffix.find(str(i)) + if pos == -1: + raise LinkerError(f"{suffix} is not a valid kernel suffix") + pos += len(str(i)) + if self.arg_suffix.match(suffix, pos): + num_specs += 1 + sizes.extend([None] * (i - len(sizes))) + sizes.append(s2i[suffix[pos]]) + pos += 1 + if i < len(args) - 1: + suffix = suffix[pos:] + else: + sizes.extend([None] * (len(args) - len(sizes))) + return num_specs, sizes + + def _add_kernel(self, name: str, ker: KernelLinkerMeta): + if name in self.kernels: + last: KernelLinkerMeta = self.kernels[name][-1] + + for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes): + if cur != new_: + raise LinkerError( + f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}" + ) + + self.kernels[name].append(ker) + + +def gen_signature_with_full_args(m): + return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)]) + + +def gen_signature(m): + arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1] + arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1] + sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)]) + return sig + + +# generate declarations of kernels with meta-parameter and constant values +def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + return f""" +int {name}(XPUStream stream, {gen_signature_with_full_args(metas[-1])}); +void load_{name}(); +void unload_{name}(); + """ + + +# generate declarations of kernels with meta-parameter and constant values +def make_global_decl(meta: KernelLinkerMeta) -> str: + return f""" +int {meta.orig_kernel_name}_default(XPUStream stream, {gen_signature_with_full_args(meta)}); +int {meta.orig_kernel_name}(XPUStream stream, {gen_signature_with_full_args(meta)}, int algo_id); +void load_{meta.orig_kernel_name}(); +void unload_{meta.orig_kernel_name}(); + """ + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_default(XPUStream stream, {gen_signature_with_full_args(meta)}){{\n" + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different integer value hints +def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: + src = f"// launcher for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"int {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(XPUStream stream, {gen_signature(meta)});\n" + src += "\n" + + src += (f"int {name}(XPUStream stream, {gen_signature_with_full_args(metas[-1])}){{") + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + cond_fn = ( # + lambda val, hint: f"({val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required + arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] + src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" + src += "\n" + src += " return XPUERR_INVALID_PARAM;\n" + src += "}\n" + + for mode in ["load", "unload"]: + src += f"\n// {mode} for: {name}\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" + src += f"void {mode}_{name}() {{" + src += "\n" + for meta in sorted(metas, key=lambda m: -m.num_specs): + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") + src += "}\n" + return src + + +# generate dispatcher function for kernels with different meta-parameter and constant values +def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}(XPUStream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" + src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" + src += "}\n" + return src + + +# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values +def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: + # the table of hint dispatchers + src = f"typedef int (*kernel_func_t)(XPUStream stream, {gen_signature_with_full_args(meta)});\n" + src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" + for name in names: + src += f" {name},\n" + src += "};\n" + return src + + +# generate definition for load/unload functions for kernels with different meta-parameter and constant values +def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str: + src = "" + for mode in ["load", "unload"]: + src += f"void {mode}_{meta.orig_kernel_name}(void){{\n" + for name in names: + src += f" {mode}_{name}();\n" + src += "}\n\n" + return src + + +def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void);" + return src + + +def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: + src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n" + src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n" + src += "}\n" + return src + + +desc = """ +Triton ahead-of-time linker: + +This program takes in header files generated by compile.py, and generates a +single entry-point responsible for dispatching the user's input to the right +kernel given the specializations that were compiled. + +Example usage: +python link.py /path/to/headers/*.h -o kernel_name +""" + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser(description=desc) + parser.add_argument( + "headers", + nargs="+", + help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)", + ) + parser.add_argument("--out", "-o", type=Path, help="Out filename") + parser.add_argument( + "--prefix", + type=str, + default="", + help="String to prefix kernel dispatcher names", + ) + args = parser.parse_args() + + # metadata + parser = HeaderParser() + includes = [] + for header in args.headers: + h_path = Path(header) + h_str = h_path.read_text() + includes.append(h_path.name) + parser.extract_linker_meta(h_str) + + # generate headers + algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()] + meta_lists = [meta for name, meta in parser.kernels.items()] + meta = meta_lists[0][0] + get_num_algos_decl = make_get_num_algos_decl(meta) + global_decl = make_global_decl(meta) + with args.out.with_suffix(".h").open("w") as fp: + out = "#include \n" + out += "#define XPUdeviceptr int64_t\n" + out += "#define CUdeviceptr int64_t\n" + out += "#define CUresult int\n" + out += "#define CUDA_SUCCESS 0\n" + out += "\n".join(algo_decls) + out += "\n" + out += get_num_algos_decl + out += "\n" + out += global_decl + fp.write(out) + + # generate source + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] + names = [name for name in parser.kernels.keys()] + func_pointers_def = make_func_pointers(names, meta) + meta_const_def = make_kernel_meta_const_dispatcher(meta) + load_unload_def = make_kernel_load_def(names, meta) + get_num_algos_def = make_get_num_algos_def(meta) + default_algo_kernel = make_default_algo_kernel(meta) + with args.out.with_suffix(".c").open("w") as fp: + out = "" + out += "#include \n" + out += "#define XPUdeviceptr int64_t\n" + out += "#include \n" + out += "#include \n" + out += "\n" + out += "\n".join(defs) + out += "\n" + out += func_pointers_def + out += "\n" + out += get_num_algos_def + out += "\n" + out += meta_const_def + out += "\n" + out += load_unload_def + out += "\n" + out += default_algo_kernel + fp.write(out) diff --git a/third_party/xpu/triton_xpu.cc b/third_party/xpu/triton_xpu.cc new file mode 100644 index 000000000..925d89172 --- /dev/null +++ b/third_party/xpu/triton_xpu.cc @@ -0,0 +1,340 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2025 by Kunlunxin. All rights reserved. +// +//===----------------------------------------------------------------------===// +#include +#include +#include + +// clang-format off +#include "mlir/Dialect/Linalg/Passes.h" // mlir::createLinalgElementwiseOpFusionPass() +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // LLVM::LLVMFuncOp +#include "mlir/Dialect/MemRef/Transforms/Passes.h" // mlir::memref::createExpandStridedMetadataPass +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" + +#include "llvm/ADT/SmallVector.h" // llvm::SmallVector +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/Constants.h" // llvm::ConstantInt +#include "llvm/IR/LLVMContext.h" // llvm::LLVMContext + +#include "passes.h" + +// mlir::triton::createTritonToLinalgExperimentalPass + +#include "triton/Conversion/TritonToTritonXPU/Passes.h" // mlir::triton::createConvertTritonToTritonXPUPass +#include "triton/Conversion/TritonXPUToLLVM/Passes.h" // mlir::triton::createConvertTritonXPUToLLVMPass +#include "triton/Dialect/TritonXPU/IR/Dialect.h" // mlir::triton::xpu::TritonXPUDialect +#include "triton/Dialect/TritonXPU/Transforms/Passes.h" // mlir::createTritonXPUGM2LMPass() + +#include "triton/Target/LLVMXPU/LLVMXPUToLLVMIRTranslation.h" // registerLLVMXPUDialectTranslation +// clang-format on + +namespace py = pybind11; + +std::string translateLLVMIRToASM(llvm::Module &module, + const std::string &triple, + const std::string &proc, + const std::string &features, + const std::vector &flags, + bool enable_fp_fusion, bool isObject); + +void init_triton_xpu_passes_conversion(py::module &&m) { + m.def("add_convert_triton_to_tritonxpu_pass", + [](mlir::PassManager &self, uint32_t xpu_arch, uint32_t buffer_size, + uint32_t core_num) { + self.addPass(mlir::triton::createConvertTritonToTritonXPUPass( + xpu_arch, buffer_size, core_num)); + }); + + m.def("add_convert_tritonxpu_to_llvm_pass", + [](mlir::PassManager &self, uint32_t xpu_arch, uint32_t buffer_size) { + self.addPass(mlir::triton::createConvertTritonXPUToLLVMPass( + xpu_arch, buffer_size)); + }); +} + +void init_triton_xpu_passes_transform(py::module &&m) { + // Function Pass + m.def("add_tritonxpu_gm2lm_pass", + [](mlir::PassManager &self, uint32_t xpu_arch, bool atomicSim) { + self.addPass(mlir::triton::xpu::createTritonXPUCreateGM2LM( + {xpu_arch, atomicSim})); + }); + + m.def("add_tritonxpu_legalize_pass", + [](mlir::PassManager &self, uint32_t buffer_size, uint32_t core_num) { + self.addPass(mlir::triton::xpu::createTritonXPULegalize( + {buffer_size, core_num})); + }); + + m.def("add_tritonxpu_mask_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::xpu::createTritonXPUMask()); + }); + + m.def("add_tritonxpu_alloca_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::xpu::createTritonXPUAlloca()); + }); + + m.def("add_tritonxpu_dtype_convert_pass", [](mlir::PassManager &self, + uint32_t xpu_arch) { + self.addPass(mlir::triton::xpu::createTritonXPUDtypeConvert({xpu_arch})); + }); + + m.def("add_tritonxpu_loop_grid_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::xpu::createTritonXPULoopGrid()); + }); + + m.def("add_tritonxpu_unroll_control_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::xpu::createTritonXPUUnrollControl()); + }); + + m.def("add_tritonxpu_other_sim_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::xpu::createTritonXPUOtherSim()); + }); + + // Optimization Pass + m.def("add_tritonxpu_offset_state_pass", [](mlir::PassManager &self, + bool dump_flag) { + self.addPass(mlir::triton::xpu::createTritonXPUOffsetAnalysis({dump_flag})); + }); + + m.def("add_tritonxpu_core_tiling_pass", + [](mlir::PassManager &self, bool dump_flag, uint32_t buffer_size) { + self.addPass(mlir::triton::xpu::createTritonXPUCoreTiling( + {dump_flag, buffer_size})); + }); + + m.def("add_tritonxpu_vectorize_pass", [](mlir::PassManager &self, + bool dump_flag) { + self.addPass(mlir::triton::xpu::createTritonXPUVectorize({dump_flag})); + }); + + m.def("add_tritonxpu_memory_async_pass", [](mlir::PassManager &self, + bool dump_flag) { + self.addPass(mlir::triton::xpu::createTritonXPUMemoryAsync({dump_flag})); + }); + + m.def("add_tritonxpu_interleave_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::xpu::createTritonXPUInterleave()); + }); + + m.def("add_tritonxpu_store_control_pass", [](mlir::PassManager &self) { + self.addPass(mlir::triton::xpu::createTritonXPUStoreControl()); + }); +} + +namespace mlir::triton::xpu { + +// Describes XPU Metadata. It is used to record the XPU related meta +// information from mlir module. +struct XPUMetadata { + int maxntidx{-1}; + bool isKernel{}; + // Free to extend with other information. +}; + +static void +extractXPUMetadata(mlir::ModuleOp module, + llvm::DenseMap *dic) { + for (auto op : module.getOps()) { + XPUMetadata meta; + + bool hasMetadata{}; + + // maxntid + if (op->hasAttr("xpu.maxntid")) { + auto attr = op->getAttr("xpu.maxntid"); + meta.maxntidx = mlir::dyn_cast(attr).getInt(); + hasMetadata = true; + } + + // kernel + if (op->hasAttr("xpu.kernel")) { + meta.isKernel = true; + hasMetadata = true; + } + + if (hasMetadata) + dic->try_emplace(op.getNameAttr().strref(), std::move(meta)); + } +} + +// Add the xpu related metadata to LLVM IR. +static void amendLLVMFunc(llvm::Function *func, const XPUMetadata &metadata, + int xpu_arch) { + auto *module = func->getParent(); + auto &ctx = func->getContext(); + auto targetArch = std::string("xpu") + std::to_string(xpu_arch); + + if (metadata.maxntidx > 0) { + auto warps = llvm::ConstantInt::get(llvm::IntegerType::get(ctx, 32), + llvm::APInt(32, metadata.maxntidx)); + + llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func), + llvm::MDString::get(ctx, "maxntidx"), + llvm::ValueAsMetadata::get(warps)}; + + module->getOrInsertNamedMetadata("xpu.annotations") + ->addOperand(llvm::MDNode::get(ctx, md_args)); + } + + if (metadata.isKernel) { + func->setDSOLocal(true); + func->setCallingConv(llvm::CallingConv::XPU_KERNEL); + llvm::Metadata *mdArgs[] = { + llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"), + llvm::ValueAsMetadata::get( + llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))}; + module->getOrInsertNamedMetadata("xpu.annotations") + ->addOperand(llvm::MDNode::get(ctx, mdArgs)); + } + + llvm::AttrBuilder funcAttrs(ctx); + funcAttrs.addAttribute("correctly-rounded-divide-sqrt-fp-math", "false"); + funcAttrs.addAttribute("disable-tail-calls", "false"); + funcAttrs.addAttribute("frame-pointer", "all"); + funcAttrs.addAttribute("less-precise-fpmad", "false"); + funcAttrs.addAttribute("min-legal-vector-width", "0"); + funcAttrs.addAttribute("no-infs-fp-math", "false"); + funcAttrs.addAttribute("no-jump-tables", "false"); + funcAttrs.addAttribute("no-nans-fp-math", "false"); + funcAttrs.addAttribute("no-signed-zeros-fp-math", "false"); + funcAttrs.addAttribute("no-trapping-math", "false"); + funcAttrs.addAttribute("stack-protector-buffer-size", "8"); + funcAttrs.addAttribute("target-cpu", targetArch); + funcAttrs.addAttribute("unsafe-fp-math", "false"); + funcAttrs.addAttribute("use-soft-float", "false"); + func->addFnAttrs(funcAttrs); +} + +} // namespace mlir::triton::xpu + +using ret = py::return_value_policy; + +void init_triton_xpu_llvm(py::module &&m) { + + m.def("get_kernel_name", [](llvm::Module &mod) { + for (auto &F : mod) { + if (F.getCallingConv() == llvm::CallingConv::XPU_KERNEL) { + std::string name = F.getName().str(); + return py::str(name); + } + } + + auto MD = mod.getNamedMetadata("xpu.annotations"); + std::string name; + for (auto *Op : MD->operands()) { + if (Op->getNumOperands() != 3) + continue; + auto *Prop = llvm::dyn_cast(Op->getOperand(1)); + name = Prop->getString(); + } + return py::str(name); + }); + + m.def("amend_func", [](llvm::Module *llvm_mod, mlir::ModuleOp mlir_mod, + llvm::LLVMContext &ctx, int xpu_arch) { + llvm::DenseMap XPUMetadata; + extractXPUMetadata(mlir_mod, &XPUMetadata); + + for (auto &func : llvm_mod->functions()) { + auto it = XPUMetadata.find(func.getName()); + if (it != XPUMetadata.end()) + mlir::triton::xpu::amendLLVMFunc(&func, it->second, xpu_arch); + } + }); + + m.def("need_extern_lib", [](mlir::ModuleOp module) { + llvm::SmallVector funcs; + module.walk([&](mlir::LLVM::LLVMFuncOp func) { + if (func.isExternal()) + funcs.push_back(func); + }); + + return funcs.empty() ? false : true; + }); + + m.def( + "translate_to_asm", + [](llvm::Module &module, std::string triple, std::string proc, + std::string features, std::vector flags, + bool enable_fp_fusion, bool isObject) -> py::object { + std::string obj; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + obj = translateLLVMIRToASM(module, triple, proc, features, flags, + enable_fp_fusion, isObject); + } + if (isObject) + return py::bytes(obj); + else +#if !defined(TRITON_CONCEAL_IR) || (TRITON_CONCEAL_IR == 0) + return py::str(obj); +#else + return py::str(""); +#endif + }, + ret::take_ownership); +} + +void init_triton_xpu(py::module &&m) { + m.doc() = "Python bindings to the XPU Triton backend"; + + auto passes = m.def_submodule("passes"); + init_triton_xpu_passes_conversion(passes.def_submodule("ttxpuir")); + init_triton_xpu_passes_transform(passes.def_submodule("ttxpuir")); + init_triton_xpu_llvm(m.def_submodule("llvm")); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + registerLLVMXPUDialectTranslation(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + struct LutInfo { + int dtype; // 0 - f32; 1 - f16 + int mode; // 0 - KB; 1 - INTER + int size; + double min; + double interval; + }; + + m.def("get_buffer_len", [](mlir::ModuleOp &mod, unsigned maxBufferSize) { + unsigned bufferLen = maxBufferSize; + + auto _get_buffer_len = [&](mlir::Type &ptrTy, unsigned maxBufferSize, + unsigned &bufferLen) { + mlir::Type ptrdataTy; + if (auto ptrTensorTy = mlir::dyn_cast(ptrTy)) { + ptrdataTy = + mlir::cast(ptrTensorTy.getElementType()) + .getPointeeType(); + } else { + ptrdataTy = + mlir::cast(ptrTy).getPointeeType(); + } + if (ptrdataTy.isBF16() || ptrdataTy.isF16() || ptrdataTy.isF32()) { + unsigned bitWidth = ptrdataTy.getIntOrFloatBitWidth(); + bufferLen = std::min(bufferLen, maxBufferSize / (bitWidth / 8)); + } + }; + + mod.walk([&](mlir::triton::LoadOp loadOp) { + auto ptrTy = loadOp.getPtr().getType(); + _get_buffer_len(ptrTy, maxBufferSize, bufferLen); + }); + mod.walk([&](mlir::triton::StoreOp storeOp) { + auto ptrTy = storeOp.getPtr().getType(); + _get_buffer_len(ptrTy, maxBufferSize, bufferLen); + }); + return bufferLen; + }); +} diff --git a/unittest/Analysis/UtilityTest.cpp b/unittest/Analysis/UtilityTest.cpp index e6c5054e4..70d95363e 100644 --- a/unittest/Analysis/UtilityTest.cpp +++ b/unittest/Analysis/UtilityTest.cpp @@ -1,9 +1,6 @@ -//===- UtilityTest.cpp - Tests for -// Utility----------------------------------===// -// -//===----------------------------------------------------------------------===// - #include "triton/Dialect/Triton/IR/Utility.h" + +#include "llvm/Support/Signals.h" #include namespace mlir { @@ -27,3 +24,9 @@ TEST(Analysis, reorder) { } } // namespace mlir + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/unittest/CMakeLists.txt b/unittest/CMakeLists.txt index 4bb61af61..b4061e90c 100644 --- a/unittest/CMakeLists.txt +++ b/unittest/CMakeLists.txt @@ -1,43 +1,3 @@ -include (${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake) - -include(GoogleTest) -enable_testing() - -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) - -function(add_triton_ut) - set(options) - set(oneValueArgs NAME) - set(multiValueArgs SRCS LIBS DEFS) - cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - add_test(NAME ${__NAME} - COMMAND ${__NAME}) - add_executable( - ${__NAME} - ${__SRCS}) - target_link_libraries( - ${__NAME} - PRIVATE - GTest::gtest_main - ${triton_libs} - ${dialect_libs} - ${conversion_libs} - gmock - ${__LIBS}) - - target_compile_options(${__NAME} PRIVATE -fno-rtti) - - target_compile_definitions(${__NAME} PRIVATE ${__DEFS}) - - # Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac - # laptop. I think the issue may be that the very first time you run a program - # it's a bit slow. - gtest_discover_tests(${__NAME} PROPERTIES TEST_DISCOVERY_TIMEOUT 60) -endfunction() - add_subdirectory(Analysis) -add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(Tools) diff --git a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt b/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt deleted file mode 100644 index 3c5692a62..000000000 --- a/unittest/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_triton_ut( - NAME TestPtxAsmFormat - SRCS PTXAsmFormatTest.cpp - LIBS TritonGPUToLLVM TritonNVIDIAGPUToLLVM -) - -add_triton_ut( - NAME TestEmitIndicesNvidia - SRCS EmitIndicesTest.cpp DumpLayout.cpp - LIBS TritonGPUIR TritonNvidiaGPUIR TritonNVIDIAGPUToLLVM - DEFS NVIDIA_TARGET=1 -) - -add_triton_ut( - NAME TestEmitIndicesAMD - SRCS EmitIndicesTest.cpp DumpLayout.cpp - LIBS TritonGPUIR TritonAMDGPUToLLVM - DEFS AMD_TARGET=1 -) diff --git a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp deleted file mode 100644 index 628ab454a..000000000 --- a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp +++ /dev/null @@ -1,421 +0,0 @@ -/* - * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files - * (the "Software"), to deal in the Software without restriction, - * including without limitation the rights to use, copy, modify, merge, - * publish, distribute, sublicense, and/or sell copies of the Software, - * and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -#include "DumpLayout.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" -#ifdef AMD_TARGET -#include "amd/lib/TritonAMDGPUToLLVM/TargetInfo.h" -#include "amd/lib/TritonAMDGPUToLLVM/Utility.h" -#else -#include "nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h" -#include "nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h" -#endif -namespace mlir { -namespace triton { -namespace gpu { - -namespace { - -#ifdef AMD_TARGET -Value getMockSmemBaseImpl([[maybe_unused]] IRRewriter &rewriter, - [[maybe_unused]] Location loc) { - return i32_val(0); -} -#else -Value getMockSmemBaseImpl(IRRewriter &rewriter, Location loc) { - Value mockSmemBase = - LLVM::NVIDIA::getSRegValue(rewriter, loc, "%mock_smem_base"); - auto llPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); - auto cast = rewriter.create( - loc, TypeRange{llPtrTy}, ValueRange{mockSmemBase}); - return cast.getResult(0); -} -#endif - -//===----------------------------------------------------------------------===// -// IndexEmitter -//===----------------------------------------------------------------------===// - -class IndexEmitter { -public: - IndexEmitter(MLIRContext *context_) - : context(context_), option(context), rewriter(context), - loc(UnknownLoc::get(context)), -#ifdef AMD_TARGET - targetInfo("gfx942") -#else - targetInfo(90) -#endif - { - mlir::OpBuilder builder(context); - std::vector inTypes{}; - std::vector outTypes{}; - auto funcTy = builder.getFunctionType(inTypes, outTypes); - auto func = builder.create(loc, "test_func", funcTy); - auto mlirModule = mlir::ModuleOp::create(loc); - mlirModule.push_back(func); - auto *block = func.addEntryBlock(); - rewriter.setInsertionPointToStart(block); - } - - llvm::SmallVector> - emitIndices(Attribute layout, llvm::ArrayRef shape, - bool withCTAOffset) { - auto type = RankedTensorType::get(shape, rewriter.getF16Type(), layout); - return mlir::emitIndices(loc, rewriter, targetInfo, layout, type, - withCTAOffset); - } - - llvm::DenseMap - emitDistributedToShared(Attribute srcLayout, SharedEncodingAttr sharedLayout, - Type elemTy, llvm::ArrayRef shape, - bool withCTAOffset) { - auto srcTy = RankedTensorType::get(shape, elemTy, srcLayout); - SharedMemoryObject smemObj(getMockSmemBaseImpl(rewriter, loc), elemTy, - shape, sharedLayout.getOrder(), loc, rewriter); - return getSwizzledSharedPtrs(loc, targetInfo, /*inVec=*/1, srcTy, - sharedLayout, elemTy, smemObj, rewriter, - smemObj.offsets, smemObj.strides); - } - -private: - // Non-static members are initialized in declaration order - MLIRContext *context; - LowerToLLVMOptions option; - IRRewriter rewriter; - Location loc; -#ifdef AMD_TARGET - AMD::TargetInfo targetInfo; -#else - NVIDIA::TargetInfo targetInfo; -#endif -}; - -//===----------------------------------------------------------------------===// -// MLIR expression evaluation -//===----------------------------------------------------------------------===// - -int eval(Value value, int ctaid, int tid); - -int evalThreadIdOp(mlir::gpu::ThreadIdOp threadIdOp, int ctaid, int tid) { - auto dim = threadIdOp.getDimension(); - if (dim == mlir::gpu::Dimension::x) - return tid; - else if (dim == mlir::gpu::Dimension::y) - return 0; - else if (dim == mlir::gpu::Dimension::z) - return 0; - else - llvm::report_fatal_error("Invalid thread dim"); - return 0; -} - -int evalInlineAsmOp(mlir::LLVM::InlineAsmOp asmOp, int ctaid, int tid) { - std::string asmStr = asmOp.getAsmString().str(); - if (asmStr.find("%cluster_ctaid.x") != std::string::npos) - return ctaid; - else if (asmStr.find("%cluster_ctaid.y") != std::string::npos) - return 0; - else if (asmStr.find("%cluster_ctaid.z") != std::string::npos) - return 0; - else if (asmStr.find("%cluster_nctaid.x") != std::string::npos) - llvm::report_fatal_error("%cluster_nctaid.x not supported"); - else if (asmStr.find("%cluster_nctaid.y") != std::string::npos) - return 1; - else if (asmStr.find("%cluster_nctaid.z") != std::string::npos) - return 1; - else if (asmStr.find("%mock_smem_base") != std::string::npos) - return 0; - else - llvm::report_fatal_error("Unrecognized ASM string"); - return 0; -} - -int evalGEPOp(mlir::LLVM::GEPOp gepOp, int ctaid, int tid) { - assert(gepOp.getNumOperands() == 2 && "Unrecognized format of GEPOp"); - int base = eval(gepOp.getBase(), ctaid, tid); - int offset = eval(gepOp.getOperand(1), ctaid, tid); - auto llPtrTy = cast(gepOp.getRes().getType()); - int bytesPerElem = llPtrTy.getIntOrFloatBitWidth() / 8; - return base + offset * bytesPerElem; -} - -int eval(Value value, int ctaid, int tid) { - Operation *op = value.getDefiningOp(); - assert(op && "Unrecognized source value in the index expression"); - if (auto constantOp = llvm::dyn_cast(op)) { - auto attr = constantOp.getValue(); - return mlir::cast(attr).getInt(); - } else if (auto addOp = llvm::dyn_cast(op)) { - return eval(addOp.getLhs(), ctaid, tid) + eval(addOp.getRhs(), ctaid, tid); - } else if (auto mulOp = llvm::dyn_cast(op)) { - return eval(mulOp.getLhs(), ctaid, tid) * eval(mulOp.getRhs(), ctaid, tid); - } else if (auto udivOp = llvm::dyn_cast(op)) { - return eval(udivOp.getLhs(), ctaid, tid) / - eval(udivOp.getRhs(), ctaid, tid); - } else if (auto uremOp = llvm::dyn_cast(op)) { - return eval(uremOp.getLhs(), ctaid, tid) % - eval(uremOp.getRhs(), ctaid, tid); - } else if (auto andOp = llvm::dyn_cast(op)) { - return eval(andOp.getLhs(), ctaid, tid) & eval(andOp.getRhs(), ctaid, tid); - } else if (auto xorOp = llvm::dyn_cast(op)) { - return eval(xorOp.getLhs(), ctaid, tid) ^ eval(xorOp.getRhs(), ctaid, tid); - } else if (auto trunciOp = llvm::dyn_cast(op)) { - return eval(trunciOp.getIn(), ctaid, tid); - } else if (auto idxCastOp = llvm::dyn_cast(op)) { - return eval(idxCastOp.getIn(), ctaid, tid); - } else if (auto castOp = llvm::dyn_cast(op)) { - return eval(castOp.getOperand(0), ctaid, tid); - } else if (auto threadOp = llvm::dyn_cast(op)) { - return evalThreadIdOp(threadOp, ctaid, tid); - } else if (auto ctaIdOp = - llvm::dyn_cast(op)) { - return ctaid; - } else if (auto asmOp = llvm::dyn_cast(op)) { - return evalInlineAsmOp(asmOp, ctaid, tid); - } else if (auto gepOp = llvm::dyn_cast(op)) { - return evalGEPOp(gepOp, ctaid, tid); - } else if (auto selectOp = llvm::dyn_cast(op)) { - return eval(selectOp.getCondition(), ctaid, tid) - ? eval(selectOp.getTrueValue(), ctaid, tid) - : eval(selectOp.getFalseValue(), ctaid, tid); - } else if (auto icmpOp = llvm::dyn_cast(op)) { - switch (icmpOp.getPredicate()) { - case mlir::LLVM::ICmpPredicate::eq: - return eval(icmpOp.getLhs(), ctaid, tid) == - eval(icmpOp.getRhs(), ctaid, tid); - default: - llvm::report_fatal_error("Unsupported ICmp predicate"); - } - } else { - llvm::errs() << "Unrecognized op: " << *op << "\n"; - llvm::report_fatal_error("Unrecognized op type in the index expression"); - return 0; - } -} - -} // namespace - -int evalValue(Value value, int ctaid, int tid) { - return eval(value, ctaid, tid); -} - -//===----------------------------------------------------------------------===// -// Dump Distributed Layout -//===----------------------------------------------------------------------===// - -std::string dumpDistributedLayout(Attribute layout, - llvm::ArrayRef shape, - bool multiCTA) { - assert(isaDistributedLayout(layout) && - "Unsupported layout type for dumpDistributedLayout"); - - assert(shape.size() > 0 && "Empty shape"); - assert(shape.size() <= 2 && - "High order tensor is not supported in dumpLayout"); - - int numThreads = getWarpSize(layout) * getNumWarpsPerCTA(layout); - int numCTAs = getNumCTAs(layout); - auto f16Ty = FloatType::getF16(layout.getContext()); - int numElems = getTotalElemsPerThread(layout, shape, f16Ty); - - if (!multiCTA) - assert(numCTAs == 1 && "numCTAs must be 1 when multiCTA is false"); - - IndexEmitter emitter(layout.getContext()); - auto indices = emitter.emitIndices(layout, shape, multiCTA); - assert(indices.size() == numElems && "Incorrect number of indices emitted"); - - auto genStr = [multiCTA](int ctaid, int tid, int idx) -> std::string { - std::ostringstream oss; - if (multiCTA) - oss << "CTA" << ctaid << ":"; - oss << "T" << tid << ":" << idx; - return oss.str(); - }; - - std::ostringstream oss; - - auto dumpLayout1d = [&]() { - for (int idx = 0; idx < numElems; ++idx) - assert(indices[idx].size() == 1 && "Incorrect rank of indices emitted"); - - int size = shape[0]; - std::vector mapping(size); - - for (int ctaid = 0; ctaid < numCTAs; ++ctaid) { - for (int tid = 0; tid < numThreads; ++tid) { - for (int idx = 0; idx < numElems; ++idx) { - int i = eval(indices[idx][0], ctaid, tid); - assert(i >= 0 && i < size && "Invalid index emitted"); - std::string &value = mapping[i]; - if (value.empty()) - value = genStr(ctaid, tid, idx); - else - value = value + "|" + genStr(ctaid, tid, idx); - } - } - } - - for (int i = 0; i < size; ++i) { - if (i > 0) - oss << ","; - oss << mapping[i]; - } - oss << "\n"; - }; - - auto dumpLayout2d = [&]() { - for (int idx = 0; idx < numElems; ++idx) - assert(indices[idx].size() == 2 && "Incorrect rank of indices emitted"); - - int row = shape[0], col = shape[1]; - std::vector> mapping( - row, std::vector(col)); - - for (int ctaid = 0; ctaid < numCTAs; ++ctaid) { - for (int tid = 0; tid < numThreads; ++tid) { - for (int idx = 0; idx < numElems; ++idx) { - int r = eval(indices[idx][0], ctaid, tid); - int c = eval(indices[idx][1], ctaid, tid); - assert(r >= 0 && r < row && c >= 0 && c < col && - "Invalid index emitted"); - std::string &value = mapping[r][c]; - if (value.empty()) - value = genStr(ctaid, tid, idx); - else - value = value + "|" + genStr(ctaid, tid, idx); - } - } - } - - for (int r = 0; r < row; ++r) { - for (int c = 0; c < col; ++c) { - if (c > 0) - oss << ","; - oss << mapping[r][c]; - } - oss << "\n"; - } - }; - - if (shape.size() == 1) - dumpLayout1d(); - else - dumpLayout2d(); - - return oss.str(); -} - -//===----------------------------------------------------------------------===// -// Dump Shared Layout -//===----------------------------------------------------------------------===// - -std::string dumpSharedLayout(Attribute layout, llvm::ArrayRef shape, - Type elemTy, bool multiCTA) { - assert(shape.size() == 2 && "Only 2d shape supported in dumpSharedLayout"); - int row = shape[0], col = shape[1]; - int size = row * col; - int bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8; - int totalBytes = size * bytesPerElem; - - int numWarps = 1; - int numThreads = 32 * numWarps; - int numCTAs = getNumCTAs(layout); - - if (!multiCTA) - assert(numCTAs == 1 && "numCTAs must be 1 when multiCTA is false"); - - auto sharedLayout = mlir::cast(layout); - auto blockedLayout = BlockedEncodingAttr::get( - /*context=*/layout.getContext(), /*shape=*/shape, - /*sizePerThread=*/{1, 1}, /*order=*/sharedLayout.getOrder(), - /*numWarps=*/numWarps, 32, /*CTALayout=*/sharedLayout.getCTALayout()); - - int numElems = getTotalElemsPerThread(blockedLayout, shape, elemTy); - - IndexEmitter emitter(layout.getContext()); - auto blockedIndices = emitter.emitIndices(blockedLayout, shape, multiCTA); - auto sharedPtrs = emitter.emitDistributedToShared(blockedLayout, sharedLayout, - elemTy, shape, multiCTA); - - assert(blockedIndices.size() == numElems && - "Incorrect number of indices emitted by blockedLayout"); - assert(sharedPtrs.size() == numElems && - "Incorrect number of pointers emitted by sharedLayout"); - - for (int idx = 0; idx < numElems; ++idx) - assert(blockedIndices[idx].size() == 2 && - "Incorrect rank of indices emitted by blockedLayout"); - - auto genStr = [](int r, int c) -> std::string { - std::ostringstream oss; - oss << "(" << r << ":" << c << ")"; - return oss.str(); - }; - - std::vector mapping(size); - for (int ctaid = 0; ctaid < numCTAs; ++ctaid) { - for (int tid = 0; tid < numThreads; ++tid) { - for (int idx = 0; idx < numElems; ++idx) { - int r = eval(blockedIndices[idx][0], ctaid, tid); - int c = eval(blockedIndices[idx][1], ctaid, tid); - assert(r >= 0 && r < row && c >= 0 && c < col && - "Invalid index emitted"); - int ptr = eval(sharedPtrs[idx], ctaid, tid); - assert(ptr % bytesPerElem == 0 && ptr < totalBytes && - "Invalid pointer emitted"); - std::string &value = mapping[ptr / bytesPerElem]; - if (value.empty()) - value = genStr(r, c); - else - value = value + "|" + genStr(r, c); - } - } - } - - const int bytesPerBank = 4; - const int totalBanks = 32; - const int bytesPerLine = - std::min(col * bytesPerElem, bytesPerBank * totalBanks); - int elemsPerLine = bytesPerLine / bytesPerElem; - - std::ostringstream oss; - - for (int i = 0; i < size; ++i) { - int r = i / elemsPerLine; - int c = i % elemsPerLine; - if (c > 0) - oss << ","; - oss << mapping[i]; - if (c == elemsPerLine - 1) - oss << "\n"; - } - - return oss.str(); -} - -} // namespace gpu -} // namespace triton -} // namespace mlir diff --git a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp deleted file mode 100644 index 9bd7a0aee..000000000 --- a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp +++ /dev/null @@ -1,1469 +0,0 @@ -/* - * Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining - * a copy of this software and associated documentation files - * (the "Software"), to deal in the Software without restriction, - * including without limitation the rights to use, copy, modify, merge, - * publish, distribute, sublicense, and/or sell copies of the Software, - * and to permit persons to whom the Software is furnished to do so, - * subject to the following conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - */ - -#include "gtest/gtest.h" -#include -#include - -#include "DumpLayout.h" -#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" -#include "nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Attributes.h" -#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" - -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "llvm/Support/raw_ostream.h" - -namespace mlir { -namespace triton { -namespace gpu { - -//===----------------------------------------------------------------------===// -// EmitIndicesTest -//===----------------------------------------------------------------------===// - -MLIRContext *getContext() { - static MLIRContext *context = [] { - MLIRContext *context = new MLIRContext(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - return context; - }(); - return context; -} - -class EmitIndicesTest : public ::testing::Test { -protected: - EmitIndicesTest() : context(*getContext()) {} - - void runBlocked1dSingleCTA(int size, unsigned sizePerThread, - unsigned warpsPerCTA, const std::string &refStr) { - // If we pass initializer lists to the constructor of BlockedEncodingAttr, - // there might be multiple constructors matching the same parameter list. - // For example, the initializer list "order = {0}" can also match the - // parameter "unsigned numWarps", which is not what we want - llvm::SmallVector sizePerThread_ = {sizePerThread}; - llvm::SmallVector threadsPerWarp = {32}; - llvm::SmallVector warpsPerCTA_ = {warpsPerCTA}; - llvm::SmallVector order = {0}; - auto layout = - BlockedEncodingAttr::get(&context, sizePerThread_, threadsPerWarp, - warpsPerCTA_, order, getSingleCTALayout1d()); - runDistributed1d(size, layout, /*multiCTA=*/false, refStr); - } - - void runBlocked2dSingleCTA(int row, int col, - llvm::ArrayRef sizePerThread, - llvm::ArrayRef threadsPerWarp, - llvm::ArrayRef warpsPerCTA, - llvm::ArrayRef order, - const std::string &refStr) { - auto layout = - BlockedEncodingAttr::get(&context, sizePerThread, threadsPerWarp, - warpsPerCTA, order, getSingleCTALayout2d()); - runDistributed2d(row, col, layout, /*multiCTA=*/false, refStr); - } - - void runBlocked2dMultiCTA( - int row, int col, llvm::ArrayRef sizePerThread, - llvm::ArrayRef threadsPerWarp, - llvm::ArrayRef warpsPerCTA, llvm::ArrayRef order, - llvm::ArrayRef CTAsPerCGA, llvm::ArrayRef CTASplitNum, - llvm::ArrayRef CTAOrder, const std::string &refStr) { - auto CTALayout = - CTALayoutAttr::get(&context, CTAsPerCGA, CTASplitNum, CTAOrder); - auto layout = BlockedEncodingAttr::get( - &context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout); - runDistributed2d(row, col, layout, /*multiCTA=*/true, refStr); - } - - void runSliceBlockedSingleCTA(int size, - llvm::ArrayRef sizePerThread, - llvm::ArrayRef threadsPerWarp, - llvm::ArrayRef warpsPerCTA, - llvm::ArrayRef order, - unsigned sliceDim, const std::string &refStr) { - auto parent = - BlockedEncodingAttr::get(&context, sizePerThread, threadsPerWarp, - warpsPerCTA, order, getSingleCTALayout2d()); - auto layout = SliceEncodingAttr::get(&context, sliceDim, parent); - runDistributed1d(size, layout, /*multiCTA=*/false, refStr); - } - - void runSliceBlockedMultiCTA(int size, llvm::ArrayRef sizePerThread, - llvm::ArrayRef threadsPerWarp, - llvm::ArrayRef warpsPerCTA, - llvm::ArrayRef order, - llvm::ArrayRef CTAsPerCGA, - llvm::ArrayRef CTASplitNum, - llvm::ArrayRef CTAOrder, - unsigned sliceDim, const std::string &refStr) { - auto CTALayout = - CTALayoutAttr::get(&context, CTAsPerCGA, CTASplitNum, CTAOrder); - auto parent = BlockedEncodingAttr::get( - &context, sizePerThread, threadsPerWarp, warpsPerCTA, order, CTALayout); - auto layout = SliceEncodingAttr::get(&context, sliceDim, parent); - runDistributed1d(size, layout, /*multiCTA=*/true, refStr); - } - - void runMmaSingleCTA(int row, int col, unsigned versionMajor, - unsigned versionMinor, - llvm::ArrayRef warpsPerCTA, - llvm::ArrayRef instrShape, - const std::string &refStr) { - auto layout = NvidiaMmaEncodingAttr::get( - &context, versionMajor, versionMinor, warpsPerCTA, - getSingleCTALayout2d(), instrShape); - runDistributed2d(row, col, layout, /*multiCTA=*/false, refStr); - } - - void runWmmaSingleCTA(int row, int col, llvm::ArrayRef warpsPerCTA, - const std::string &refStr) { - auto layout = - AMDWmmaEncodingAttr::get(&context, warpsPerCTA, getSingleCTALayout2d()); - runDistributed2d(row, col, layout, /*multiCTA=*/false, refStr); - } - - void runDotOpSingleCTA(int row, int col, unsigned versionMajor, - unsigned versionMinor, - llvm::ArrayRef warpsPerCTA, - llvm::ArrayRef instrShape, unsigned opIdx, - const std::string &refStr) { - auto parent = NvidiaMmaEncodingAttr::get( - &context, versionMajor, versionMinor, warpsPerCTA, - getSingleCTALayout2d(), instrShape); - auto layout = DotOperandEncodingAttr::get(&context, opIdx, parent, 0); - runDistributed2d(row, col, layout, /*multiCTA=*/false, refStr); - } - - void runSharedSingleCTA(int row, int col, bool rowMajor, - const std::string &elemTyStr, - const std::string &refStr) { - auto elemTy = getElemTy(elemTyStr); - auto layout = - SharedEncodingAttr::get(&context, {row, col}, getMatrixOrder(rowMajor), - getSingleCTALayout2d(), elemTy); - llvm::outs() << layout << "\n"; - runShared(row, col, layout, elemTy, /*multiCTA=*/false, refStr); - } - -private: - std::string skipSpaces(const std::string &input) { - std::string output; - for (char c : input) - if (c != ' ') - output += c; - return output; - } - - void assertSameStr(const std::string &refStr, const std::string &output) { - if (refStr != output) { - llvm::outs() << "RefStr =\n" - << refStr << "\n" - << "\n" - << "Output =\n" - << output << "\n"; - FAIL() << "Incorrect output string"; - } - } - - void runDistributed1d(int size, Attribute layout, bool multiCTA, - const std::string &refStr) { - assertSameStr(skipSpaces(refStr), - dumpDistributedLayout(layout, {size}, multiCTA)); - } - - void runDistributed2d(int row, int col, Attribute layout, bool multiCTA, - const std::string &refStr) { - assertSameStr(skipSpaces(refStr), - dumpDistributedLayout(layout, {row, col}, multiCTA)); - } - - void runShared(int row, int col, const SharedEncodingAttr &layout, - Type elemTy, bool multiCTA, const std::string &refStr) { - assertSameStr(skipSpaces(refStr), - dumpSharedLayout(layout, {row, col}, elemTy, multiCTA)); - } - - CTALayoutAttr getSingleCTALayout1d() { - return CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1}, - /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); - } - - CTALayoutAttr getSingleCTALayout2d() { - return CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, - /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); - } - - llvm::SmallVector getMatrixOrder(bool rowMajor) { - if (rowMajor) - return {1, 0}; - else - return {0, 1}; - } - - Type getElemTy(const std::string &elemTyStr) { - if (elemTyStr == "F16") - return FloatType::getF16(&context); - else - llvm::report_fatal_error("getElemTy not implemented"); - return nullptr; - } - -protected: - MLIRContext &context; -}; - -//===----------------------------------------------------------------------===// -// Tests for BlockedEncodingAttr -//===----------------------------------------------------------------------===// - -TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_1D) { - // clang-format off - std::string refStr = - "T0:0,T1:0,T2:0,T3:0,T4:0,T5:0,T6:0,T7:0,T8:0,T9:0,T10:0,T11:0,T12:0,T13:0,T14:0,T15:0,T16:0,T17:0,T18:0,T19:0,T20:0,T21:0,T22:0,T23:0,T24:0,T25:0,T26:0,T27:0,T28:0,T29:0,T30:0,T31:0\n"; - // clang-format on - - runBlocked1dSingleCTA(/*size=*/32, /*sizePerThread*/ 1, /*warpsPerCTA*/ 1, - /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_Order_1_0) { - // clang-format off - std::string refStr = - " T0:0, T1:0, T2:0, T3:0, T32:0, T33:0, T34:0, T35:0\n" - " T4:0, T5:0, T6:0, T7:0, T36:0, T37:0, T38:0, T39:0\n" - " T8:0, T9:0,T10:0,T11:0, T40:0, T41:0, T42:0, T43:0\n" - "T12:0,T13:0,T14:0,T15:0, T44:0, T45:0, T46:0, T47:0\n" - "T16:0,T17:0,T18:0,T19:0, T48:0, T49:0, T50:0, T51:0\n" - "T20:0,T21:0,T22:0,T23:0, T52:0, T53:0, T54:0, T55:0\n" - "T24:0,T25:0,T26:0,T27:0, T56:0, T57:0, T58:0, T59:0\n" - "T28:0,T29:0,T30:0,T31:0, T60:0, T61:0, T62:0, T63:0\n" - - "T64:0,T65:0,T66:0,T67:0, T96:0, T97:0, T98:0, T99:0\n" - "T68:0,T69:0,T70:0,T71:0, T100:0,T101:0,T102:0,T103:0\n" - "T72:0,T73:0,T74:0,T75:0, T104:0,T105:0,T106:0,T107:0\n" - "T76:0,T77:0,T78:0,T79:0, T108:0,T109:0,T110:0,T111:0\n" - "T80:0,T81:0,T82:0,T83:0, T112:0,T113:0,T114:0,T115:0\n" - "T84:0,T85:0,T86:0,T87:0, T116:0,T117:0,T118:0,T119:0\n" - "T88:0,T89:0,T90:0,T91:0, T120:0,T121:0,T122:0,T123:0\n" - "T92:0,T93:0,T94:0,T95:0, T124:0,T125:0,T126:0,T127:0\n"; - // clang-format on - - runBlocked2dSingleCTA(/*row=*/16, /*col=*/8, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{2, 2}, - /*order=*/{1, 0}, /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_Order_0_1) { - // clang-format off - std::string refStr = - " T0:0, T8:0,T16:0,T24:0, T64:0, T72:0, T80:0, T88:0\n" - " T1:0, T9:0,T17:0,T25:0, T65:0, T73:0, T81:0, T89:0\n" - " T2:0,T10:0,T18:0,T26:0, T66:0, T74:0, T82:0, T90:0\n" - " T3:0,T11:0,T19:0,T27:0, T67:0, T75:0, T83:0, T91:0\n" - " T4:0,T12:0,T20:0,T28:0, T68:0, T76:0, T84:0, T92:0\n" - " T5:0,T13:0,T21:0,T29:0, T69:0, T77:0, T85:0, T93:0\n" - " T6:0,T14:0,T22:0,T30:0, T70:0, T78:0, T86:0, T94:0\n" - " T7:0,T15:0,T23:0,T31:0, T71:0, T79:0, T87:0, T95:0\n" - - "T32:0,T40:0,T48:0,T56:0, T96:0,T104:0,T112:0,T120:0\n" - "T33:0,T41:0,T49:0,T57:0, T97:0,T105:0,T113:0,T121:0\n" - "T34:0,T42:0,T50:0,T58:0, T98:0,T106:0,T114:0,T122:0\n" - "T35:0,T43:0,T51:0,T59:0, T99:0,T107:0,T115:0,T123:0\n" - "T36:0,T44:0,T52:0,T60:0, T100:0,T108:0,T116:0,T124:0\n" - "T37:0,T45:0,T53:0,T61:0, T101:0,T109:0,T117:0,T125:0\n" - "T38:0,T46:0,T54:0,T62:0, T102:0,T110:0,T118:0,T126:0\n" - "T39:0,T47:0,T55:0,T63:0, T103:0,T111:0,T119:0,T127:0\n"; - // clang-format on - - runBlocked2dSingleCTA(/*row=*/16, /*col=*/8, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{2, 2}, - /*order=*/{0, 1}, /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, BlockedLayout_SingleCTA_Vectorize) { - // clang-format off - std::string refStr = - " T0:0, T0:1, T0:2, T0:3, T1:0, T1:1, T1:2, T1:3, T2:0, T2:1, T2:2, T2:3, T3:0, T3:1, T3:2, T3:3\n" - " T4:0, T4:1, T4:2, T4:3, T5:0, T5:1, T5:2, T5:3, T6:0, T6:1, T6:2, T6:3, T7:0, T7:1, T7:2, T7:3\n" - " T8:0, T8:1, T8:2, T8:3, T9:0, T9:1, T9:2, T9:3, T10:0,T10:1,T10:2,T10:3, T11:0,T11:1,T11:2,T11:3\n" - "T12:0,T12:1,T12:2,T12:3, T13:0,T13:1,T13:2,T13:3, T14:0,T14:1,T14:2,T14:3, T15:0,T15:1,T15:2,T15:3\n" - "T16:0,T16:1,T16:2,T16:3, T17:0,T17:1,T17:2,T17:3, T18:0,T18:1,T18:2,T18:3, T19:0,T19:1,T19:2,T19:3\n" - "T20:0,T20:1,T20:2,T20:3, T21:0,T21:1,T21:2,T21:3, T22:0,T22:1,T22:2,T22:3, T23:0,T23:1,T23:2,T23:3\n" - "T24:0,T24:1,T24:2,T24:3, T25:0,T25:1,T25:2,T25:3, T26:0,T26:1,T26:2,T26:3, T27:0,T27:1,T27:2,T27:3\n" - "T28:0,T28:1,T28:2,T28:3, T29:0,T29:1,T29:2,T29:3, T30:0,T30:1,T30:2,T30:3, T31:0,T31:1,T31:2,T31:3\n" - - "T32:0,T32:1,T32:2,T32:3, T33:0,T33:1,T33:2,T33:3, T34:0,T34:1,T34:2,T34:3, T35:0,T35:1,T35:2,T35:3\n" - "T36:0,T36:1,T36:2,T36:3, T37:0,T37:1,T37:2,T37:3, T38:0,T38:1,T38:2,T38:3, T39:0,T39:1,T39:2,T39:3\n" - "T40:0,T40:1,T40:2,T40:3, T41:0,T41:1,T41:2,T41:3, T42:0,T42:1,T42:2,T42:3, T43:0,T43:1,T43:2,T43:3\n" - "T44:0,T44:1,T44:2,T44:3, T45:0,T45:1,T45:2,T45:3, T46:0,T46:1,T46:2,T46:3, T47:0,T47:1,T47:2,T47:3\n" - "T48:0,T48:1,T48:2,T48:3, T49:0,T49:1,T49:2,T49:3, T50:0,T50:1,T50:2,T50:3, T51:0,T51:1,T51:2,T51:3\n" - "T52:0,T52:1,T52:2,T52:3, T53:0,T53:1,T53:2,T53:3, T54:0,T54:1,T54:2,T54:3, T55:0,T55:1,T55:2,T55:3\n" - "T56:0,T56:1,T56:2,T56:3, T57:0,T57:1,T57:2,T57:3, T58:0,T58:1,T58:2,T58:3, T59:0,T59:1,T59:2,T59:3\n" - "T60:0,T60:1,T60:2,T60:3, T61:0,T61:1,T61:2,T61:3, T62:0,T62:1,T62:2,T62:3, T63:0,T63:1,T63:2,T63:3\n"; - // clang-format on - - runBlocked2dSingleCTA(/*row=*/16, /*col=*/16, /*sizePerThread=*/{1, 4}, - /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{2, 1}, - /*order=*/{1, 0}, /*refStr=*/refStr); -} - -// FIXME: These tests are temporarily disabled due to ctaid.x|y|z are swapped -#ifdef TEST_FAILED -TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAOrder_1_0) { - // clang-format off - std::string refStr = - "CTA0: T0:0,CTA0: T1:0,CTA0: T2:0,CTA0: T3:0, CTA1: T0:0,CTA1: T1:0,CTA1: T2:0,CTA1: T3:0\n" - "CTA0: T4:0,CTA0: T5:0,CTA0: T6:0,CTA0: T7:0, CTA1: T4:0,CTA1: T5:0,CTA1: T6:0,CTA1: T7:0\n" - "CTA0: T8:0,CTA0: T9:0,CTA0:T10:0,CTA0:T11:0, CTA1: T8:0,CTA1: T9:0,CTA1:T10:0,CTA1:T11:0\n" - "CTA0:T12:0,CTA0:T13:0,CTA0:T14:0,CTA0:T15:0, CTA1:T12:0,CTA1:T13:0,CTA1:T14:0,CTA1:T15:0\n" - "CTA0:T16:0,CTA0:T17:0,CTA0:T18:0,CTA0:T19:0, CTA1:T16:0,CTA1:T17:0,CTA1:T18:0,CTA1:T19:0\n" - "CTA0:T20:0,CTA0:T21:0,CTA0:T22:0,CTA0:T23:0, CTA1:T20:0,CTA1:T21:0,CTA1:T22:0,CTA1:T23:0\n" - "CTA0:T24:0,CTA0:T25:0,CTA0:T26:0,CTA0:T27:0, CTA1:T24:0,CTA1:T25:0,CTA1:T26:0,CTA1:T27:0\n" - "CTA0:T28:0,CTA0:T29:0,CTA0:T30:0,CTA0:T31:0, CTA1:T28:0,CTA1:T29:0,CTA1:T30:0,CTA1:T31:0\n" - - "CTA2: T0:0,CTA2: T1:0,CTA2: T2:0,CTA2: T3:0, CTA3: T0:0,CTA3: T1:0,CTA3: T2:0,CTA3: T3:0\n" - "CTA2: T4:0,CTA2: T5:0,CTA2: T6:0,CTA2: T7:0, CTA3: T4:0,CTA3: T5:0,CTA3: T6:0,CTA3: T7:0\n" - "CTA2: T8:0,CTA2: T9:0,CTA2:T10:0,CTA2:T11:0, CTA3: T8:0,CTA3: T9:0,CTA3:T10:0,CTA3:T11:0\n" - "CTA2:T12:0,CTA2:T13:0,CTA2:T14:0,CTA2:T15:0, CTA3:T12:0,CTA3:T13:0,CTA3:T14:0,CTA3:T15:0\n" - "CTA2:T16:0,CTA2:T17:0,CTA2:T18:0,CTA2:T19:0, CTA3:T16:0,CTA3:T17:0,CTA3:T18:0,CTA3:T19:0\n" - "CTA2:T20:0,CTA2:T21:0,CTA2:T22:0,CTA2:T23:0, CTA3:T20:0,CTA3:T21:0,CTA3:T22:0,CTA3:T23:0\n" - "CTA2:T24:0,CTA2:T25:0,CTA2:T26:0,CTA2:T27:0, CTA3:T24:0,CTA3:T25:0,CTA3:T26:0,CTA3:T27:0\n" - "CTA2:T28:0,CTA2:T29:0,CTA2:T30:0,CTA2:T31:0, CTA3:T28:0,CTA3:T29:0,CTA3:T30:0,CTA3:T31:0\n"; - // clang-format on - - runBlocked2dMultiCTA(/*row=*/16, /*col=*/8, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, - /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, - /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAOrder_0_1) { - // clang-format off - std::string refStr = - "CTA0: T0:0,CTA0: T1:0,CTA0: T2:0,CTA0: T3:0, CTA2: T0:0,CTA2: T1:0,CTA2: T2:0,CTA2: T3:0\n" - "CTA0: T4:0,CTA0: T5:0,CTA0: T6:0,CTA0: T7:0, CTA2: T4:0,CTA2: T5:0,CTA2: T6:0,CTA2: T7:0\n" - "CTA0: T8:0,CTA0: T9:0,CTA0:T10:0,CTA0:T11:0, CTA2: T8:0,CTA2: T9:0,CTA2:T10:0,CTA2:T11:0\n" - "CTA0:T12:0,CTA0:T13:0,CTA0:T14:0,CTA0:T15:0, CTA2:T12:0,CTA2:T13:0,CTA2:T14:0,CTA2:T15:0\n" - "CTA0:T16:0,CTA0:T17:0,CTA0:T18:0,CTA0:T19:0, CTA2:T16:0,CTA2:T17:0,CTA2:T18:0,CTA2:T19:0\n" - "CTA0:T20:0,CTA0:T21:0,CTA0:T22:0,CTA0:T23:0, CTA2:T20:0,CTA2:T21:0,CTA2:T22:0,CTA2:T23:0\n" - "CTA0:T24:0,CTA0:T25:0,CTA0:T26:0,CTA0:T27:0, CTA2:T24:0,CTA2:T25:0,CTA2:T26:0,CTA2:T27:0\n" - "CTA0:T28:0,CTA0:T29:0,CTA0:T30:0,CTA0:T31:0, CTA2:T28:0,CTA2:T29:0,CTA2:T30:0,CTA2:T31:0\n" - - "CTA1: T0:0,CTA1: T1:0,CTA1: T2:0,CTA1: T3:0, CTA3: T0:0,CTA3: T1:0,CTA3: T2:0,CTA3: T3:0\n" - "CTA1: T4:0,CTA1: T5:0,CTA1: T6:0,CTA1: T7:0, CTA3: T4:0,CTA3: T5:0,CTA3: T6:0,CTA3: T7:0\n" - "CTA1: T8:0,CTA1: T9:0,CTA1:T10:0,CTA1:T11:0, CTA3: T8:0,CTA3: T9:0,CTA3:T10:0,CTA3:T11:0\n" - "CTA1:T12:0,CTA1:T13:0,CTA1:T14:0,CTA1:T15:0, CTA3:T12:0,CTA3:T13:0,CTA3:T14:0,CTA3:T15:0\n" - "CTA1:T16:0,CTA1:T17:0,CTA1:T18:0,CTA1:T19:0, CTA3:T16:0,CTA3:T17:0,CTA3:T18:0,CTA3:T19:0\n" - "CTA1:T20:0,CTA1:T21:0,CTA1:T22:0,CTA1:T23:0, CTA3:T20:0,CTA3:T21:0,CTA3:T22:0,CTA3:T23:0\n" - "CTA1:T24:0,CTA1:T25:0,CTA1:T26:0,CTA1:T27:0, CTA3:T24:0,CTA3:T25:0,CTA3:T26:0,CTA3:T27:0\n" - "CTA1:T28:0,CTA1:T29:0,CTA1:T30:0,CTA1:T31:0, CTA3:T28:0,CTA3:T29:0,CTA3:T30:0,CTA3:T31:0\n"; - // clang-format on - - runBlocked2dMultiCTA(/*row=*/16, /*col=*/8, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, - /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{0, 1}, - /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrap_Dim1) { - // clang-format off - std::string refStr = - "CTA0: T0:0|CTA1: T0:0, CTA0: T1:0|CTA1: T1:0, CTA0: T2:0|CTA1: T2:0, CTA0: T3:0|CTA1: T3:0\n" - "CTA0: T4:0|CTA1: T4:0, CTA0: T5:0|CTA1: T5:0, CTA0: T6:0|CTA1: T6:0, CTA0: T7:0|CTA1: T7:0\n" - "CTA0: T8:0|CTA1: T8:0, CTA0: T9:0|CTA1: T9:0, CTA0:T10:0|CTA1:T10:0, CTA0:T11:0|CTA1:T11:0\n" - "CTA0:T12:0|CTA1:T12:0, CTA0:T13:0|CTA1:T13:0, CTA0:T14:0|CTA1:T14:0, CTA0:T15:0|CTA1:T15:0\n" - "CTA0:T16:0|CTA1:T16:0, CTA0:T17:0|CTA1:T17:0, CTA0:T18:0|CTA1:T18:0, CTA0:T19:0|CTA1:T19:0\n" - "CTA0:T20:0|CTA1:T20:0, CTA0:T21:0|CTA1:T21:0, CTA0:T22:0|CTA1:T22:0, CTA0:T23:0|CTA1:T23:0\n" - "CTA0:T24:0|CTA1:T24:0, CTA0:T25:0|CTA1:T25:0, CTA0:T26:0|CTA1:T26:0, CTA0:T27:0|CTA1:T27:0\n" - "CTA0:T28:0|CTA1:T28:0, CTA0:T29:0|CTA1:T29:0, CTA0:T30:0|CTA1:T30:0, CTA0:T31:0|CTA1:T31:0\n" - - "CTA2: T0:0|CTA3: T0:0, CTA2: T1:0|CTA3: T1:0, CTA2: T2:0|CTA3: T2:0, CTA2: T3:0|CTA3: T3:0\n" - "CTA2: T4:0|CTA3: T4:0, CTA2: T5:0|CTA3: T5:0, CTA2: T6:0|CTA3: T6:0, CTA2: T7:0|CTA3: T7:0\n" - "CTA2: T8:0|CTA3: T8:0, CTA2: T9:0|CTA3: T9:0, CTA2:T10:0|CTA3:T10:0, CTA2:T11:0|CTA3:T11:0\n" - "CTA2:T12:0|CTA3:T12:0, CTA2:T13:0|CTA3:T13:0, CTA2:T14:0|CTA3:T14:0, CTA2:T15:0|CTA3:T15:0\n" - "CTA2:T16:0|CTA3:T16:0, CTA2:T17:0|CTA3:T17:0, CTA2:T18:0|CTA3:T18:0, CTA2:T19:0|CTA3:T19:0\n" - "CTA2:T20:0|CTA3:T20:0, CTA2:T21:0|CTA3:T21:0, CTA2:T22:0|CTA3:T22:0, CTA2:T23:0|CTA3:T23:0\n" - "CTA2:T24:0|CTA3:T24:0, CTA2:T25:0|CTA3:T25:0, CTA2:T26:0|CTA3:T26:0, CTA2:T27:0|CTA3:T27:0\n" - "CTA2:T28:0|CTA3:T28:0, CTA2:T29:0|CTA3:T29:0, CTA2:T30:0|CTA3:T30:0, CTA2:T31:0|CTA3:T31:0\n"; - // clang-format on - - runBlocked2dMultiCTA(/*row=*/16, /*col=*/4, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, - /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{2, 1}, /*CTAOrder=*/{1, 0}, - /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrap_Dim0) { - // clang-format off - std::string refStr = - "CTA0: T0:0|CTA2: T0:0,CTA0: T1:0|CTA2: T1:0,CTA0: T2:0|CTA2: T2:0,CTA0: T3:0|CTA2: T3:0, CTA1: T0:0|CTA3: T0:0,CTA1: T1:0|CTA3: T1:0,CTA1: T2:0|CTA3: T2:0,CTA1: T3:0|CTA3: T3:0\n" - "CTA0: T4:0|CTA2: T4:0,CTA0: T5:0|CTA2: T5:0,CTA0: T6:0|CTA2: T6:0,CTA0: T7:0|CTA2: T7:0, CTA1: T4:0|CTA3: T4:0,CTA1: T5:0|CTA3: T5:0,CTA1: T6:0|CTA3: T6:0,CTA1: T7:0|CTA3: T7:0\n" - "CTA0: T8:0|CTA2: T8:0,CTA0: T9:0|CTA2: T9:0,CTA0:T10:0|CTA2:T10:0,CTA0:T11:0|CTA2:T11:0, CTA1: T8:0|CTA3: T8:0,CTA1: T9:0|CTA3: T9:0,CTA1:T10:0|CTA3:T10:0,CTA1:T11:0|CTA3:T11:0\n" - "CTA0:T12:0|CTA2:T12:0,CTA0:T13:0|CTA2:T13:0,CTA0:T14:0|CTA2:T14:0,CTA0:T15:0|CTA2:T15:0, CTA1:T12:0|CTA3:T12:0,CTA1:T13:0|CTA3:T13:0,CTA1:T14:0|CTA3:T14:0,CTA1:T15:0|CTA3:T15:0\n" - "CTA0:T16:0|CTA2:T16:0,CTA0:T17:0|CTA2:T17:0,CTA0:T18:0|CTA2:T18:0,CTA0:T19:0|CTA2:T19:0, CTA1:T16:0|CTA3:T16:0,CTA1:T17:0|CTA3:T17:0,CTA1:T18:0|CTA3:T18:0,CTA1:T19:0|CTA3:T19:0\n" - "CTA0:T20:0|CTA2:T20:0,CTA0:T21:0|CTA2:T21:0,CTA0:T22:0|CTA2:T22:0,CTA0:T23:0|CTA2:T23:0, CTA1:T20:0|CTA3:T20:0,CTA1:T21:0|CTA3:T21:0,CTA1:T22:0|CTA3:T22:0,CTA1:T23:0|CTA3:T23:0\n" - "CTA0:T24:0|CTA2:T24:0,CTA0:T25:0|CTA2:T25:0,CTA0:T26:0|CTA2:T26:0,CTA0:T27:0|CTA2:T27:0, CTA1:T24:0|CTA3:T24:0,CTA1:T25:0|CTA3:T25:0,CTA1:T26:0|CTA3:T26:0,CTA1:T27:0|CTA3:T27:0\n" - "CTA0:T28:0|CTA2:T28:0,CTA0:T29:0|CTA2:T29:0,CTA0:T30:0|CTA2:T30:0,CTA0:T31:0|CTA2:T31:0, CTA1:T28:0|CTA3:T28:0,CTA1:T29:0|CTA3:T29:0,CTA1:T30:0|CTA3:T30:0,CTA1:T31:0|CTA3:T31:0\n"; - // clang-format on - - runBlocked2dMultiCTA( - /*row=*/8, /*col=*/8, /*sizePerThread=*/{1, 1}, /*threadsPerWarp=*/{8, 4}, - /*warpsPerCTA=*/{1, 1}, /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{1, 2}, /*CTAOrder=*/{1, 0}, /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrapBeforeBroadcast_Dim1) { - // clang-format off - std::string refStr = - "CTA0: T0:0|CTA0: T1:0|CTA0: T2:0|CTA0: T3:0 | CTA1: T0:0|CTA1: T1:0|CTA1: T2:0|CTA1: T3:0\n" - "CTA0: T4:0|CTA0: T5:0|CTA0: T6:0|CTA0: T7:0 | CTA1: T4:0|CTA1: T5:0|CTA1: T6:0|CTA1: T7:0\n" - "CTA0: T8:0|CTA0: T9:0|CTA0:T10:0|CTA0:T11:0 | CTA1: T8:0|CTA1: T9:0|CTA1:T10:0|CTA1:T11:0\n" - "CTA0:T12:0|CTA0:T13:0|CTA0:T14:0|CTA0:T15:0 | CTA1:T12:0|CTA1:T13:0|CTA1:T14:0|CTA1:T15:0\n" - "CTA0:T16:0|CTA0:T17:0|CTA0:T18:0|CTA0:T19:0 | CTA1:T16:0|CTA1:T17:0|CTA1:T18:0|CTA1:T19:0\n" - "CTA0:T20:0|CTA0:T21:0|CTA0:T22:0|CTA0:T23:0 | CTA1:T20:0|CTA1:T21:0|CTA1:T22:0|CTA1:T23:0\n" - "CTA0:T24:0|CTA0:T25:0|CTA0:T26:0|CTA0:T27:0 | CTA1:T24:0|CTA1:T25:0|CTA1:T26:0|CTA1:T27:0\n" - "CTA0:T28:0|CTA0:T29:0|CTA0:T30:0|CTA0:T31:0 | CTA1:T28:0|CTA1:T29:0|CTA1:T30:0|CTA1:T31:0\n" - - "CTA2: T0:0|CTA2: T1:0|CTA2: T2:0|CTA2: T3:0 | CTA3: T0:0|CTA3: T1:0|CTA3: T2:0|CTA3: T3:0\n" - "CTA2: T4:0|CTA2: T5:0|CTA2: T6:0|CTA2: T7:0 | CTA3: T4:0|CTA3: T5:0|CTA3: T6:0|CTA3: T7:0\n" - "CTA2: T8:0|CTA2: T9:0|CTA2:T10:0|CTA2:T11:0 | CTA3: T8:0|CTA3: T9:0|CTA3:T10:0|CTA3:T11:0\n" - "CTA2:T12:0|CTA2:T13:0|CTA2:T14:0|CTA2:T15:0 | CTA3:T12:0|CTA3:T13:0|CTA3:T14:0|CTA3:T15:0\n" - "CTA2:T16:0|CTA2:T17:0|CTA2:T18:0|CTA2:T19:0 | CTA3:T16:0|CTA3:T17:0|CTA3:T18:0|CTA3:T19:0\n" - "CTA2:T20:0|CTA2:T21:0|CTA2:T22:0|CTA2:T23:0 | CTA3:T20:0|CTA3:T21:0|CTA3:T22:0|CTA3:T23:0\n" - "CTA2:T24:0|CTA2:T25:0|CTA2:T26:0|CTA2:T27:0 | CTA3:T24:0|CTA3:T25:0|CTA3:T26:0|CTA3:T27:0\n" - "CTA2:T28:0|CTA2:T29:0|CTA2:T30:0|CTA2:T31:0 | CTA3:T28:0|CTA3:T29:0|CTA3:T30:0|CTA3:T31:0\n"; - // clang-format on - - runBlocked2dMultiCTA(/*row=*/16, /*col=*/1, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, - /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, - /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, BlockedLayout_MultiCTA_CTAWrapBeforeBroadcast_Dim0) { - // clang-format off - std::string refStr = - "CTA0:T0:0|CTA0: T8:0|CTA0:T16:0|CTA0:T24:0 | CTA2:T0:0|CTA2: T8:0|CTA2:T16:0|CTA2:T24:0," - "CTA0:T1:0|CTA0: T9:0|CTA0:T17:0|CTA0:T25:0 | CTA2:T1:0|CTA2: T9:0|CTA2:T17:0|CTA2:T25:0," - "CTA0:T2:0|CTA0:T10:0|CTA0:T18:0|CTA0:T26:0 | CTA2:T2:0|CTA2:T10:0|CTA2:T18:0|CTA2:T26:0," - "CTA0:T3:0|CTA0:T11:0|CTA0:T19:0|CTA0:T27:0 | CTA2:T3:0|CTA2:T11:0|CTA2:T19:0|CTA2:T27:0," - "CTA0:T4:0|CTA0:T12:0|CTA0:T20:0|CTA0:T28:0 | CTA2:T4:0|CTA2:T12:0|CTA2:T20:0|CTA2:T28:0," - "CTA0:T5:0|CTA0:T13:0|CTA0:T21:0|CTA0:T29:0 | CTA2:T5:0|CTA2:T13:0|CTA2:T21:0|CTA2:T29:0," - "CTA0:T6:0|CTA0:T14:0|CTA0:T22:0|CTA0:T30:0 | CTA2:T6:0|CTA2:T14:0|CTA2:T22:0|CTA2:T30:0," - "CTA0:T7:0|CTA0:T15:0|CTA0:T23:0|CTA0:T31:0 | CTA2:T7:0|CTA2:T15:0|CTA2:T23:0|CTA2:T31:0," - - "CTA1:T0:0|CTA1: T8:0|CTA1:T16:0|CTA1:T24:0 | CTA3:T0:0|CTA3: T8:0|CTA3:T16:0|CTA3:T24:0," - "CTA1:T1:0|CTA1: T9:0|CTA1:T17:0|CTA1:T25:0 | CTA3:T1:0|CTA3: T9:0|CTA3:T17:0|CTA3:T25:0," - "CTA1:T2:0|CTA1:T10:0|CTA1:T18:0|CTA1:T26:0 | CTA3:T2:0|CTA3:T10:0|CTA3:T18:0|CTA3:T26:0," - "CTA1:T3:0|CTA1:T11:0|CTA1:T19:0|CTA1:T27:0 | CTA3:T3:0|CTA3:T11:0|CTA3:T19:0|CTA3:T27:0," - "CTA1:T4:0|CTA1:T12:0|CTA1:T20:0|CTA1:T28:0 | CTA3:T4:0|CTA3:T12:0|CTA3:T20:0|CTA3:T28:0," - "CTA1:T5:0|CTA1:T13:0|CTA1:T21:0|CTA1:T29:0 | CTA3:T5:0|CTA3:T13:0|CTA3:T21:0|CTA3:T29:0," - "CTA1:T6:0|CTA1:T14:0|CTA1:T22:0|CTA1:T30:0 | CTA3:T6:0|CTA3:T14:0|CTA3:T22:0|CTA3:T30:0," - "CTA1:T7:0|CTA1:T15:0|CTA1:T23:0|CTA1:T31:0 | CTA3:T7:0|CTA3:T15:0|CTA3:T23:0|CTA3:T31:0\n"; - // clang-format on - - runBlocked2dMultiCTA(/*row=*/1, /*col=*/16, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{4, 8}, /*warpsPerCTA=*/{1, 1}, - /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, - /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, SliceLayout_MultiCTA) { - // clang-format off - std::string refStr = - "CTA0: T0:0|CTA0: T1:0|CTA0: T2:0|CTA0: T3:0 | CTA1: T0:0|CTA1: T1:0|CTA1: T2:0|CTA1: T3:0," - "CTA0: T4:0|CTA0: T5:0|CTA0: T6:0|CTA0: T7:0 | CTA1: T4:0|CTA1: T5:0|CTA1: T6:0|CTA1: T7:0," - "CTA0: T8:0|CTA0: T9:0|CTA0:T10:0|CTA0:T11:0 | CTA1: T8:0|CTA1: T9:0|CTA1:T10:0|CTA1:T11:0," - "CTA0:T12:0|CTA0:T13:0|CTA0:T14:0|CTA0:T15:0 | CTA1:T12:0|CTA1:T13:0|CTA1:T14:0|CTA1:T15:0," - "CTA0:T16:0|CTA0:T17:0|CTA0:T18:0|CTA0:T19:0 | CTA1:T16:0|CTA1:T17:0|CTA1:T18:0|CTA1:T19:0," - "CTA0:T20:0|CTA0:T21:0|CTA0:T22:0|CTA0:T23:0 | CTA1:T20:0|CTA1:T21:0|CTA1:T22:0|CTA1:T23:0," - "CTA0:T24:0|CTA0:T25:0|CTA0:T26:0|CTA0:T27:0 | CTA1:T24:0|CTA1:T25:0|CTA1:T26:0|CTA1:T27:0," - "CTA0:T28:0|CTA0:T29:0|CTA0:T30:0|CTA0:T31:0 | CTA1:T28:0|CTA1:T29:0|CTA1:T30:0|CTA1:T31:0," - - "CTA2: T0:0|CTA2: T1:0|CTA2: T2:0|CTA2: T3:0 | CTA3: T0:0|CTA3: T1:0|CTA3: T2:0|CTA3: T3:0," - "CTA2: T4:0|CTA2: T5:0|CTA2: T6:0|CTA2: T7:0 | CTA3: T4:0|CTA3: T5:0|CTA3: T6:0|CTA3: T7:0," - "CTA2: T8:0|CTA2: T9:0|CTA2:T10:0|CTA2:T11:0 | CTA3: T8:0|CTA3: T9:0|CTA3:T10:0|CTA3:T11:0," - "CTA2:T12:0|CTA2:T13:0|CTA2:T14:0|CTA2:T15:0 | CTA3:T12:0|CTA3:T13:0|CTA3:T14:0|CTA3:T15:0," - "CTA2:T16:0|CTA2:T17:0|CTA2:T18:0|CTA2:T19:0 | CTA3:T16:0|CTA3:T17:0|CTA3:T18:0|CTA3:T19:0," - "CTA2:T20:0|CTA2:T21:0|CTA2:T22:0|CTA2:T23:0 | CTA3:T20:0|CTA3:T21:0|CTA3:T22:0|CTA3:T23:0," - "CTA2:T24:0|CTA2:T25:0|CTA2:T26:0|CTA2:T27:0 | CTA3:T24:0|CTA3:T25:0|CTA3:T26:0|CTA3:T27:0," - "CTA2:T28:0|CTA2:T29:0|CTA2:T30:0|CTA2:T31:0 | CTA3:T28:0|CTA3:T29:0|CTA3:T30:0|CTA3:T31:0\n"; - // clang-format on - - runSliceBlockedMultiCTA(/*size=*/16, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{8, 4}, /*warpsPerCTA=*/{1, 1}, - /*order=*/{1, 0}, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}, - /*sliceDim=*/1, /*refStr=*/refStr); -} - -//===----------------------------------------------------------------------===// -// Tests for SharedEncodingAttr -//===----------------------------------------------------------------------===// - -TEST_F(EmitIndicesTest, SharedLayout) { - // clang-format off - std::string refStr = - "(0: 0),(0: 1),(0: 2),(0: 3),(0: 4),(0: 5),(0: 6),(0: 7),(0: 8),(0: 9),(0:10),(0:11),(0:12),(0:13),(0:14),(0:15),(0:16),(0:17),(0:18),(0:19),(0:20),(0:21),(0:22),(0:23),(0:24),(0:25),(0:26),(0:27),(0:28),(0:29),(0:30),(0:31)\n" - "(1: 0),(1: 1),(1: 2),(1: 3),(1: 4),(1: 5),(1: 6),(1: 7),(1: 8),(1: 9),(1:10),(1:11),(1:12),(1:13),(1:14),(1:15),(1:16),(1:17),(1:18),(1:19),(1:20),(1:21),(1:22),(1:23),(1:24),(1:25),(1:26),(1:27),(1:28),(1:29),(1:30),(1:31)\n" - "(2: 8),(2: 9),(2:10),(2:11),(2:12),(2:13),(2:14),(2:15),(2: 0),(2: 1),(2: 2),(2: 3),(2: 4),(2: 5),(2: 6),(2: 7),(2:24),(2:25),(2:26),(2:27),(2:28),(2:29),(2:30),(2:31),(2:16),(2:17),(2:18),(2:19),(2:20),(2:21),(2:22),(2:23)\n" - "(3: 8),(3: 9),(3:10),(3:11),(3:12),(3:13),(3:14),(3:15),(3: 0),(3: 1),(3: 2),(3: 3),(3: 4),(3: 5),(3: 6),(3: 7),(3:24),(3:25),(3:26),(3:27),(3:28),(3:29),(3:30),(3:31),(3:16),(3:17),(3:18),(3:19),(3:20),(3:21),(3:22),(3:23)\n" - "(4:16),(4:17),(4:18),(4:19),(4:20),(4:21),(4:22),(4:23),(4:24),(4:25),(4:26),(4:27),(4:28),(4:29),(4:30),(4:31),(4: 0),(4: 1),(4: 2),(4: 3),(4: 4),(4: 5),(4: 6),(4: 7),(4: 8),(4: 9),(4:10),(4:11),(4:12),(4:13),(4:14),(4:15)\n" - "(5:16),(5:17),(5:18),(5:19),(5:20),(5:21),(5:22),(5:23),(5:24),(5:25),(5:26),(5:27),(5:28),(5:29),(5:30),(5:31),(5: 0),(5: 1),(5: 2),(5: 3),(5: 4),(5: 5),(5: 6),(5: 7),(5: 8),(5: 9),(5:10),(5:11),(5:12),(5:13),(5:14),(5:15)\n" - "(6:24),(6:25),(6:26),(6:27),(6:28),(6:29),(6:30),(6:31),(6:16),(6:17),(6:18),(6:19),(6:20),(6:21),(6:22),(6:23),(6: 8),(6: 9),(6:10),(6:11),(6:12),(6:13),(6:14),(6:15),(6: 0),(6: 1),(6: 2),(6: 3),(6: 4),(6: 5),(6: 6),(6: 7)\n" - "(7:24),(7:25),(7:26),(7:27),(7:28),(7:29),(7:30),(7:31),(7:16),(7:17),(7:18),(7:19),(7:20),(7:21),(7:22),(7:23),(7: 8),(7: 9),(7:10),(7:11),(7:12),(7:13),(7:14),(7:15),(7: 0),(7: 1),(7: 2),(7: 3),(7: 4),(7: 5),(7: 6),(7: 7)\n"; - // clang-format on - - runSharedSingleCTA(/*row=*/8, /*col=*/32, /*rowMajor=*/true, - /*elemTyStr=*/"F16", /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, LayoutVisualizer_Blocked) { - CTALayoutAttr CTALayout = - CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{2, 2}, - /*CTASplitNum=*/{2, 2}, /*CTAOrder=*/{1, 0}); - - Attribute blockedLayout = BlockedEncodingAttr::get( - /*context=*/&context, /*sizePerThread=*/{1, 4}, - /*threadsPerWarp=*/{2, 16}, - /*warpsPerCTA=*/{4, 1}, /*order=*/{1, 0}, /*CTALayout=*/CTALayout); - - llvm::SmallVector shape = {/*row=*/128, /*col=*/128}; - - std::ofstream ofs("blockedLayout.csv"); - ofs << dumpDistributedLayout(blockedLayout, shape, /*multiCTA=*/true); -} - -TEST_F(EmitIndicesTest, LayoutVisualizer_Shared) { - CTALayoutAttr CTALayout = - CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, - /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); - - Attribute sharedLayout = SharedEncodingAttr::get( - /*context=*/&context, /*vec=*/1, /*perPhase=*/2, /*maxPhase=*/8, - /*order=*/{0, 1}, /*CTALayout=*/CTALayout); - - llvm::SmallVector shape = {/*row=*/16, /*col=*/16}; - Type elemTy = FloatType::getF16(&context); - - std::ofstream ofs("sharedLayout.csv"); - ofs << dumpSharedLayout(sharedLayout, shape, elemTy, /*multiCTA=*/false); -} -#endif - -//===----------------------------------------------------------------------===// -// Tests for SliceEncodingAttr -//===----------------------------------------------------------------------===// - -TEST_F(EmitIndicesTest, SliceLayout_SingleCTA_SliceDim1) { - // clang-format off - std::string refStr = - " T0:0| T1:0| T2:0| T3:0| T4:0| T5:0| T6:0| T7:0," - " T8:0| T9:0|T10:0|T11:0|T12:0|T13:0|T14:0|T15:0," - "T16:0|T17:0|T18:0|T19:0|T20:0|T21:0|T22:0|T23:0," - "T24:0|T25:0|T26:0|T27:0|T28:0|T29:0|T30:0|T31:0\n"; - // clang-format on - - runSliceBlockedSingleCTA(/*size=*/4, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{4, 8}, /*warpsPerCTA=*/{1, 1}, - /*order=*/{1, 0}, /*sliceDim=*/1, /*refStr=*/refStr); -} - -TEST_F(EmitIndicesTest, SliceLayout_SingleCTA_SliceDim0) { - // clang-format off - std::string refStr = - "T0:0| T8:0|T16:0|T24:0," - "T1:0| T9:0|T17:0|T25:0," - "T2:0|T10:0|T18:0|T26:0," - "T3:0|T11:0|T19:0|T27:0," - "T4:0|T12:0|T20:0|T28:0," - "T5:0|T13:0|T21:0|T29:0," - "T6:0|T14:0|T22:0|T30:0," - "T7:0|T15:0|T23:0|T31:0\n"; - // clang-format on - - runSliceBlockedSingleCTA(/*size=*/8, /*sizePerThread=*/{1, 1}, - /*threadsPerWarp=*/{4, 8}, /*warpsPerCTA=*/{1, 1}, - /*order=*/{1, 0}, /*sliceDim=*/0, /*refStr=*/refStr); -} - -//===----------------------------------------------------------------------===// -// Tests for NvidiaMmaEncodingAttr -//===----------------------------------------------------------------------===// - -TEST_F(EmitIndicesTest, MmaLayout) { - // clang-format off - std::string refStr = - " T0:0, T0:1, T1:0, T1:1, T2:0, T2:1, T3:0, T3:1\n" - " T4:0, T4:1, T5:0, T5:1, T6:0, T6:1, T7:0, T7:1\n" - " T8:0, T8:1, T9:0, T9:1,T10:0,T10:1,T11:0,T11:1\n" - "T12:0,T12:1,T13:0,T13:1,T14:0,T14:1,T15:0,T15:1\n" - "T16:0,T16:1,T17:0,T17:1,T18:0,T18:1,T19:0,T19:1\n" - "T20:0,T20:1,T21:0,T21:1,T22:0,T22:1,T23:0,T23:1\n" - "T24:0,T24:1,T25:0,T25:1,T26:0,T26:1,T27:0,T27:1\n" - "T28:0,T28:1,T29:0,T29:1,T30:0,T30:1,T31:0,T31:1\n" - " T0:2, T0:3, T1:2, T1:3, T2:2, T2:3, T3:2, T3:3\n" - " T4:2, T4:3, T5:2, T5:3, T6:2, T6:3, T7:2, T7:3\n" - " T8:2, T8:3, T9:2, T9:3,T10:2,T10:3,T11:2,T11:3\n" - "T12:2,T12:3,T13:2,T13:3,T14:2,T14:3,T15:2,T15:3\n" - "T16:2,T16:3,T17:2,T17:3,T18:2,T18:3,T19:2,T19:3\n" - "T20:2,T20:3,T21:2,T21:3,T22:2,T22:3,T23:2,T23:3\n" - "T24:2,T24:3,T25:2,T25:3,T26:2,T26:3,T27:2,T27:3\n" - "T28:2,T28:3,T29:2,T29:3,T30:2,T30:3,T31:2,T31:3\n"; - // clang-format on - - runMmaSingleCTA(/*row=*/16, /*col=*/8, /*versionMajor=*/2, /*versionMinor=*/1, - /*warpsPerCTA=*/{1, 1}, /*instrShape=*/{16, 8}, - /*refStr=*/refStr); -} - -//===----------------------------------------------------------------------===// -// Tests for AMDWmmaEncodingAttr -//===----------------------------------------------------------------------===// - -TEST_F(EmitIndicesTest, WmmaLayout) { - // clang-format off - std::string refStr = - "T0:0,T1:0,T2:0,T3:0,T4:0,T5:0,T6:0,T7:0,T8:0,T9:0,T10:0,T11:0,T12:0,T13:0,T14:0,T15:0\n" - "T16:0,T17:0,T18:0,T19:0,T20:0,T21:0,T22:0,T23:0,T24:0,T25:0,T26:0,T27:0,T28:0,T29:0,T30:0,T31:0\n" - "T0:1,T1:1,T2:1,T3:1,T4:1,T5:1,T6:1,T7:1,T8:1,T9:1,T10:1,T11:1,T12:1,T13:1,T14:1,T15:1\n" - "T16:1,T17:1,T18:1,T19:1,T20:1,T21:1,T22:1,T23:1,T24:1,T25:1,T26:1,T27:1,T28:1,T29:1,T30:1,T31:1\n" - "T0:2,T1:2,T2:2,T3:2,T4:2,T5:2,T6:2,T7:2,T8:2,T9:2,T10:2,T11:2,T12:2,T13:2,T14:2,T15:2\n" - "T16:2,T17:2,T18:2,T19:2,T20:2,T21:2,T22:2,T23:2,T24:2,T25:2,T26:2,T27:2,T28:2,T29:2,T30:2,T31:2\n" - "T0:3,T1:3,T2:3,T3:3,T4:3,T5:3,T6:3,T7:3,T8:3,T9:3,T10:3,T11:3,T12:3,T13:3,T14:3,T15:3\n" - "T16:3,T17:3,T18:3,T19:3,T20:3,T21:3,T22:3,T23:3,T24:3,T25:3,T26:3,T27:3,T28:3,T29:3,T30:3,T31:3\n" - "T0:4,T1:4,T2:4,T3:4,T4:4,T5:4,T6:4,T7:4,T8:4,T9:4,T10:4,T11:4,T12:4,T13:4,T14:4,T15:4\n" - "T16:4,T17:4,T18:4,T19:4,T20:4,T21:4,T22:4,T23:4,T24:4,T25:4,T26:4,T27:4,T28:4,T29:4,T30:4,T31:4\n" - "T0:5,T1:5,T2:5,T3:5,T4:5,T5:5,T6:5,T7:5,T8:5,T9:5,T10:5,T11:5,T12:5,T13:5,T14:5,T15:5\n" - "T16:5,T17:5,T18:5,T19:5,T20:5,T21:5,T22:5,T23:5,T24:5,T25:5,T26:5,T27:5,T28:5,T29:5,T30:5,T31:5\n" - "T0:6,T1:6,T2:6,T3:6,T4:6,T5:6,T6:6,T7:6,T8:6,T9:6,T10:6,T11:6,T12:6,T13:6,T14:6,T15:6\n" - "T16:6,T17:6,T18:6,T19:6,T20:6,T21:6,T22:6,T23:6,T24:6,T25:6,T26:6,T27:6,T28:6,T29:6,T30:6,T31:6\n" - "T0:7,T1:7,T2:7,T3:7,T4:7,T5:7,T6:7,T7:7,T8:7,T9:7,T10:7,T11:7,T12:7,T13:7,T14:7,T15:7\n" - "T16:7,T17:7,T18:7,T19:7,T20:7,T21:7,T22:7,T23:7,T24:7,T25:7,T26:7,T27:7,T28:7,T29:7,T30:7,T31:7\n"; - // clang-format on - - runWmmaSingleCTA(/*row=*/16, /*col=*/16, - /*warpsPerCTA=*/{1, 1}, - /*refStr=*/refStr); -} - -//===----------------------------------------------------------------------===// -// The following unittests are tools for Triton developers to visualize layouts. -// You can modify parameters and shapes here to create your own layout and -// tensor. The output will be saved into a csv file which can be opened with -// Microsoft Excel. -//===----------------------------------------------------------------------===// - -TEST_F(EmitIndicesTest, LayoutVisualizer_Slice) { - CTALayoutAttr CTALayout = - CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, - /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); - - Attribute blockedLayout = BlockedEncodingAttr::get( - /*context=*/&context, /*sizePerThread=*/{1, 1}, /*threadsPerWarp=*/{4, 8}, - /*warpsPerCTA=*/{1, 1}, /*order=*/{1, 0}, /*CTALayout=*/CTALayout); - - Attribute sliceLayout = SliceEncodingAttr::get( - /*context=*/&context, /*dim=*/1, /*parent=*/blockedLayout); - - llvm::SmallVector shape = {4}; - - std::ofstream ofs("sliceLayout.csv"); - ofs << dumpDistributedLayout(sliceLayout, shape, /*multiCTA=*/false); -} - -TEST_F(EmitIndicesTest, LayoutVisualizer_Mma) { - CTALayoutAttr CTALayout = - CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, - /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); - - Attribute mmaLayout = NvidiaMmaEncodingAttr::get( - /*context=*/&context, /*versionMajor=*/2, /*versionMinor=*/1, - /*warpsPerCTA=*/{1, 1}, /*CTALayout=*/CTALayout, /*instrShape=*/{16, 8}); - - llvm::SmallVector shape = {/*row=*/16, /*col=*/8}; - - std::ofstream ofs("mmaLayout.csv"); - ofs << dumpDistributedLayout(mmaLayout, shape, /*multiCTA=*/false); -} - -TEST_F(EmitIndicesTest, LayoutVisualizer_Wmma) { - CTALayoutAttr CTALayout = - CTALayoutAttr::get(/*context=*/&context, /*CTAsPerCGA=*/{1, 1}, - /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); - - Attribute wmmaLayout = AMDWmmaEncodingAttr::get( - /*context=*/&context, - /*warpsPerCTA=*/{1, 1}, /*CTALayout=*/CTALayout); - - llvm::SmallVector shape = {/*row=*/16, /*col=*/16}; - - std::ofstream ofs("WmmaLayout.csv"); - ofs << dumpDistributedLayout(wmmaLayout, shape, /*multiCTA=*/false); -} - -// This is only for "distributed" layouts, i.e. layouts whose values are stored -// in registers distributed among threads in blocks. -template -class DistributedLegacyVsLinearLayoutsTest - : public EmitIndicesTest, - public ::testing::WithParamInterface { -protected: - void DoIt(); -}; - -template -void DistributedLegacyVsLinearLayoutsTest::DoIt() { - ParamsT params = this->GetParam(); - LayoutT legacyLayout = params.getEncoding(); - auto type = RankedTensorType::get(params.shape, FloatType::getF16(&context), - legacyLayout); - - int threadsPerWarp = product(triton::gpu::getThreadsPerWarp(legacyLayout)); - int numThreads = product(triton::gpu::getThreadsPerWarp(legacyLayout)) * - product(triton::gpu::getWarpsPerCTA(legacyLayout)); - - // Can't call getCTAsPerCGA on a SliceEncodingAttr. But all we care about is - // the total number of CTAs, which we can just as easily get from the slice - // layout's parent. - Attribute nonSliceLayout = legacyLayout; - while (auto sliceLayout = dyn_cast(nonSliceLayout)) { - nonSliceLayout = sliceLayout.getParent(); - } - int numCTAs = product(triton::gpu::getCTAsPerCGA(nonSliceLayout)); - - mlir::OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - auto mlirModule = mlir::ModuleOp::create(loc); - auto func = builder.create( - loc, "test_func", builder.getFunctionType({}, {})); - mlirModule.push_back(func); - auto *block = func.addEntryBlock(); - IRRewriter rewriter(&context); - rewriter.setInsertionPointToStart(block); - - NVIDIA::TargetInfo target(90); - auto llIndices = emitIndicesUsingLinearLayouts( - loc, rewriter, target, legacyLayout, type, /*withCTAOffset=*/true); - auto legacyIndices = emitIndices(loc, rewriter, target, legacyLayout, type, - /*withCTAOffset=*/true, /*allowLL=*/false); - - // This test takes a long time if we check all indices. But for linear - // layouts, we really should only need to check powers of 2. We wrap the - // loops in this `iterate` function so we can easily change between checking - // all indices and just the powers of 2. - constexpr bool checkAllElems = false; - bool stopIterating = false; - auto iterate = [&](int n, auto fn) { - if (checkAllElems) { - for (int i = 0; i < n && !stopIterating; i++) { - fn(i); - } - } else { - if (n > 0) { - fn(0); - } - for (int i = 0; (1 << i) < n && !stopIterating; i++) { - fn(1 << i); - } - } - }; - - // We don't need to print a lot of failures because we also print our guess as - // to the correct linear layout at the end. - constexpr int kMaxFailures = 4; - int64_t numFailures = 0; - bool passedInitialChecks = false; - - // Wrap these tests in a lambda so failed ASSERTs exit the loop but don't exit - // the whole test. - [&] { - ASSERT_TRUE(llIndices.has_value()); - ASSERT_EQ(llIndices->size(), legacyIndices.size()); - passedInitialChecks = true; - - iterate(llIndices->size(), [&](int i) { - SCOPED_TRACE("Register " + std::to_string(i)); - ASSERT_EQ((*llIndices)[i].size(), legacyIndices[i].size()); - iterate((*llIndices)[i].size(), [&](int j) { - SCOPED_TRACE("Dimension " + std::to_string(j)); - iterate(numCTAs, [&](int ctaId) { - SCOPED_TRACE("CTA " + std::to_string(ctaId)); - iterate(numThreads, [&](int tid) { - SCOPED_TRACE("Thread " + std::to_string(tid)); - int llValue = evalValue((*llIndices)[i][j], ctaId, tid); - int legacyValue = evalValue(legacyIndices[i][j], ctaId, tid); - EXPECT_EQ(llValue, legacyValue); - if (llValue != legacyValue) { - ++numFailures; - } - if (numFailures > kMaxFailures) { - llvm::errs() << "Too many failures, aborting\n"; - stopIterating = true; - } - }); - }); - }); - }); - }(); - - // If there was a failure, try to infer what the correct linear layout should - // have been. This assumes that the legacy layout itself is linear, of - // course! - if (!passedInitialChecks || numFailures > 0) { - llvm::errs() << "Linear layout was\n" - << toLinearLayout(params.shape, params.getEncoding()) << "\n"; - - llvm::errs() << "But based on the legacy layout, the LL should be:\n\n"; - - llvm::errs() << "LinearLayout({\n"; - llvm::errs() << " {S(\"register\"), {\n"; - for (int reg = 1; reg < legacyIndices.size(); reg *= 2) { - llvm::errs() << " {" << join(legacyIndices[reg], ", ", [](Value v) { - return evalValue(v, /*ctaId=*/0, /*tid=*/0); - }) << "},\n"; - } - llvm::errs() << " }},\n"; - - llvm::errs() << " {S(\"lane\"), {\n"; - for (int tid = 1; tid < numThreads; tid *= 2) { - if (tid == threadsPerWarp) { - llvm::errs() << " }},\n"; - llvm::errs() << " {S(\"warp\"), {\n"; - } - llvm::errs() << " {" << join(legacyIndices[0], ", ", [&](Value v) { - return evalValue(v, /*ctaId=*/0, tid); - }) << "},\n"; - } - llvm::errs() << " }},\n"; - llvm::errs() << " {S(\"block\"), {\n"; - for (int ctaId = 1; ctaId < numCTAs; ctaId *= 2) { - llvm::errs() << " {" << join(legacyIndices[0], ", ", [&](Value v) { - return evalValue(v, ctaId, /*tid=*/0); - }) << "},\n"; - } - llvm::errs() << " }}\n"; - llvm::errs() << "}, {" - << triton::join(llvm::seq(type.getRank()), ", ", - [](int dim) { - return "S(\"dim" + std::to_string(dim) + - "\")"; - }) - << "})\n"; - } -} - -struct BlockedLegacyVsLinearLayoutsTestParams { - std::vector shape; - std::vector sizePerThread; - std::vector threadsPerWarp; - std::vector warpsPerCTA; - std::vector order; - std::vector CTAsPerCGA; - std::vector CTASplitNum; - std::vector CTAOrder; - - BlockedEncodingAttr getEncoding() const { - return BlockedEncodingAttr::get( - getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order, - CTALayoutAttr::get(getContext(), CTAsPerCGA, CTASplitNum, CTAOrder)); - } -}; - -std::ostream &operator<<(std::ostream &os, - const BlockedLegacyVsLinearLayoutsTestParams ¶ms) { - std::string str; - llvm::raw_string_ostream llvm_os(str); - llvm_os << "shape=" << triton::join(params.shape, "x") - << ", encoding=" << params.getEncoding(); - os << str; - return os; -} - -class BlockedLegacyVsLinearLayoutsTest - : public DistributedLegacyVsLinearLayoutsTest< - BlockedEncodingAttr, BlockedLegacyVsLinearLayoutsTestParams> {}; - -TEST_P(BlockedLegacyVsLinearLayoutsTest, DoIt) { DoIt(); } - -INSTANTIATE_TEST_SUITE_P( - TestCases, BlockedLegacyVsLinearLayoutsTest, - ::testing::ValuesIn(std::vector({ - { - .shape = {128, 16}, - .sizePerThread = {1, 4}, - .threadsPerWarp = {8, 4}, - .warpsPerCTA = {4, 1}, - .order = {1, 0}, - .CTAsPerCGA = {2, 2}, - .CTASplitNum = {2, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {1, 128}, - .sizePerThread = {8, 1}, - .threadsPerWarp = {8, 4}, - .warpsPerCTA = {1, 4}, - .order = {0, 1}, - .CTAsPerCGA = {1, 2}, - .CTASplitNum = {1, 2}, - .CTAOrder = {1, 0}, - }, - { - .shape = {64, 1}, - .sizePerThread = {8, 1}, - .threadsPerWarp = {8, 4}, - .warpsPerCTA = {1, 4}, - .order = {0, 1}, - .CTAsPerCGA = {1, 2}, - .CTASplitNum = {1, 2}, - .CTAOrder = {1, 0}, - }, - { - .shape = {128, 1}, - .sizePerThread = {1, 8}, - .threadsPerWarp = {4, 8}, - .warpsPerCTA = {4, 1}, - .order = {1, 0}, - .CTAsPerCGA = {1, 2}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {1, 64}, - .sizePerThread = {1, 8}, - .threadsPerWarp = {4, 8}, - .warpsPerCTA = {4, 1}, - .order = {1, 0}, - .CTAsPerCGA = {1, 2}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {128, 1}, - .sizePerThread = {1, 1}, - .threadsPerWarp = {1, 32}, - .warpsPerCTA = {2, 2}, - .order = {1, 0}, - .CTAsPerCGA = {1, 2}, - .CTASplitNum = {1, 2}, - .CTAOrder = {1, 0}, - }, - { - .shape = {1, 128}, - .sizePerThread = {1, 1}, - .threadsPerWarp = {1, 32}, - .warpsPerCTA = {2, 2}, - .order = {1, 0}, - .CTAsPerCGA = {1, 2}, - .CTASplitNum = {1, 2}, - .CTAOrder = {1, 0}, - }, - { - .shape = {1}, - .sizePerThread = {1}, - .threadsPerWarp = {32}, - .warpsPerCTA = {4}, - .order = {0}, - .CTAsPerCGA = {2}, - .CTASplitNum = {2}, - .CTAOrder = {0}, - }, - { - .shape = {128, 128}, - .sizePerThread = {2, 2}, - .threadsPerWarp = {4, 8}, - .warpsPerCTA = {2, 2}, - .order = {0, 1}, - .CTAsPerCGA = {2, 2}, - .CTASplitNum = {2, 2}, - .CTAOrder = {0, 1}, - }, - { - .shape = {1024, 128}, - .sizePerThread = {2, 2}, - .threadsPerWarp = {4, 8}, - .warpsPerCTA = {2, 2}, - .order = {1, 0}, - .CTAsPerCGA = {2, 2}, - .CTASplitNum = {2, 2}, - .CTAOrder = {1, 0}, - }, - }))); - -struct NvidiaMmaVsLinearLayoutsTestParams { - std::vector shape; - unsigned versionMajor; - unsigned versionMinor; - std::vector warpsPerCTA; - std::vector instrShape; - std::vector CTAsPerCGA; - std::vector CTASplitNum; - std::vector CTAOrder; - - NvidiaMmaEncodingAttr getEncoding() const { - return NvidiaMmaEncodingAttr::get( - getContext(), versionMajor, versionMinor, warpsPerCTA, - CTALayoutAttr::get(getContext(), CTAsPerCGA, CTASplitNum, CTAOrder), - instrShape); - } -}; - -std::ostream &operator<<(std::ostream &os, - const NvidiaMmaVsLinearLayoutsTestParams ¶ms) { - std::string str; - llvm::raw_string_ostream llvm_os(str); - llvm_os << "shape=" << triton::join(params.shape, "x") - << ", encoding=" << params.getEncoding(); - os << str; - return os; -} - -class NvidiaMmaVsLinearLayoutsTest - : public DistributedLegacyVsLinearLayoutsTest< - NvidiaMmaEncodingAttr, NvidiaMmaVsLinearLayoutsTestParams> {}; - -TEST_P(NvidiaMmaVsLinearLayoutsTest, DoIt) { DoIt(); } - -INSTANTIATE_TEST_SUITE_P( - MMAv2, NvidiaMmaVsLinearLayoutsTest, - ::testing::ValuesIn(std::vector({ - { - .shape = {16, 8}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {1, 1}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {32, 32}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {1, 1}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {128, 8}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {1, 1}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {16, 128}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {1, 1}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {32, 32}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {2, 2}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {16, 8}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {2, 2}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {16, 512}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {2, 2}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {512, 8}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {2, 2}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - .shape = {512, 512}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {2, 2}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - { - // Legacy emitIndices seems to do implicit duplication in the last - // two dims, but not in the others. That is, this test works - // because shape[0] == warpsPerCTA[0] * CTASplitNum[0], but if you - // increase shape[0] to 32, then the legacy layout will not increase - // its size, whereas the linear layout will. I think this is a bug - // in the legacy layout. - .shape = {16, 128, 128}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {16, 1, 1}, - .instrShape = {1, 16, 8}, - .CTAsPerCGA = {1, 1, 1}, - .CTASplitNum = {1, 1, 1}, - .CTAOrder = {2, 1, 0}, - }, - { - .shape = {16 * 4, 128, 128}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {16, 1, 1}, - .instrShape = {1, 16, 8}, - .CTAsPerCGA = {4, 1, 1}, - .CTASplitNum = {4, 1, 1}, - .CTAOrder = {2, 1, 0}, - }, - { - .shape = {16 * 4, 128, 128}, - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {16, 1, 1}, - .instrShape = {1, 16, 8}, - .CTAsPerCGA = {4, 2, 2}, - .CTASplitNum = {4, 2, 1}, - .CTAOrder = {2, 1, 0}, - }, - }))); - -std::vector makeNvidiaMmaV3TestCases() { - std::vector testCases; - auto addTests = [&](ArrayRef instrShape, unsigned warpsPerCGA_dim0, - ArrayRef> shapes) { - for (const auto &shape : shapes) { - for (unsigned wpc0 : {4, 8}) { - for (unsigned wpc1 : {1, 2, 4, 8}) { - testCases.push_back({ - .shape = shape, - .versionMajor = 3, - .versionMinor = 0, - .warpsPerCTA = {wpc0, wpc1}, - .instrShape = instrShape, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }); - } - } - } - }; - - // These shapes were captured from grep'ing the TTGIR generated by Triton unit - // tests. - addTests({16, 16, 8}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}}); - addTests({16, 16, 16}, 4, {{64, 16}, {128, 16}, {128, 128}}); - addTests({16, 16, 32}, 4, {{64, 16}, {128, 16}}); - addTests({16, 32, 8}, 4, {{64, 32}, {128, 32}}); - addTests({16, 32, 16}, 4, {{64, 32}, {64, 64}, {256, 64}}); - addTests({16, 64, 8}, 4, {{64, 64}, {128, 64}}); - addTests({16, 64, 16}, 4, {{64, 64}, {128, 64}}); - addTests({16, 64, 32}, 4, {{64, 64}}); - addTests({16, 128, 8}, 4, {{64, 128}, {128, 128}}); - addTests({16, 128, 16}, 4, {{64, 128}, {128, 128}}); - addTests({16, 128, 16}, 8, {{64, 128}, {128, 128}}); - addTests({16, 128, 32}, 8, {{64, 128}, {128, 128}}); - addTests({16, 256, 8}, 8, {{128, 256}}); - addTests({16, 256, 16}, 8, {{128, 256}}); - addTests({16, 256, 32}, 8, {{128, 256}}); - - // Shapes 1xN and Nx1 appear in IR, but legacy emitIndices cannot handle them. - // They appear in IR like the following. - // - // #mma = #nvidia_mma<{versionMajor=3, versionMinor=0, - // warpsPerCTA=[4, 1], instrShape=[16, 64, 16]}> - // %a : tensor<64xf16, #slice> - // %b = tt.expand_dims %a : tensor<1x64xf16, #mma> - // %c = arith.extf %b : tensor<1x64xf32, #mma> - // %d = tt.broadcast %c : tensor<64x64xf32, #mma> - // - // TODO(jlebar): For now we don't test these layouts. Once we have slice - // layout working, we can add support, since their layouts should match that - // of emitIndices for the corresponding slice layout. - - return testCases; -} - -INSTANTIATE_TEST_SUITE_P(MMAv3, NvidiaMmaVsLinearLayoutsTest, - ::testing::ValuesIn(makeNvidiaMmaV3TestCases())); - -struct SliceVsLinearLayoutsTestParams { - std::vector shape; - int64_t sliceDim; - std::variant - parent; - - SliceEncodingAttr getEncoding() const { - return std::visit( - [&](const auto &parentParams) { - return SliceEncodingAttr::get(getContext(), sliceDim, - parentParams.getEncoding()); - }, - parent); - } -}; - -std::ostream &operator<<(std::ostream &os, - const SliceVsLinearLayoutsTestParams ¶ms) { - std::string str; - llvm::raw_string_ostream llvm_os(str); - llvm_os << "shape=" << triton::join(params.shape, "x") - << ", encoding=" << params.getEncoding(); - os << str; - return os; -} - -class SliceVsLinearLayoutsTest - : public DistributedLegacyVsLinearLayoutsTest< - SliceEncodingAttr, SliceVsLinearLayoutsTestParams> {}; - -TEST_P(SliceVsLinearLayoutsTest, DoIt) { DoIt(); } - -INSTANTIATE_TEST_SUITE_P(TestCases, SliceVsLinearLayoutsTest, - ::testing::ValuesIn( - std::vector({ - { - .shape = {128}, - .sliceDim = 0, - .parent = - BlockedLegacyVsLinearLayoutsTestParams{ - .sizePerThread = {2, 4}, - .threadsPerWarp = {4, 2}, - .warpsPerCTA = {2, 2}, - .order = {1, 0}, - .CTAsPerCGA = {2, 2}, - .CTASplitNum = {2, 2}, - .CTAOrder = {1, 0}, - }, - }, - { - .shape = {128}, - .sliceDim = 1, - .parent = - BlockedLegacyVsLinearLayoutsTestParams{ - .sizePerThread = {2, 4}, - .threadsPerWarp = {4, 2}, - .warpsPerCTA = {2, 2}, - .order = {1, 0}, - .CTAsPerCGA = {2, 2}, - .CTASplitNum = {2, 2}, - .CTAOrder = {1, 0}, - }, - }, - - { - .shape = {32}, - .sliceDim = 1, - .parent = - BlockedLegacyVsLinearLayoutsTestParams{ - .sizePerThread = {1, 1}, - .threadsPerWarp = {32, 1}, - .warpsPerCTA = {4, 1}, - .order = {0, 1}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - { - .shape = {32}, - .sliceDim = 0, - .parent = - BlockedLegacyVsLinearLayoutsTestParams{ - .sizePerThread = {1, 1}, - .threadsPerWarp = {32, 1}, - .warpsPerCTA = {4, 1}, - .order = {0, 1}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - { - .shape = {32}, - .sliceDim = 1, - .parent = - BlockedLegacyVsLinearLayoutsTestParams{ - .sizePerThread = {1, 4}, - .threadsPerWarp = {8, 4}, - .warpsPerCTA = {2, 2}, - .order = {0, 1}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - { - .shape = {32}, - .sliceDim = 0, - .parent = - BlockedLegacyVsLinearLayoutsTestParams{ - .sizePerThread = {1, 4}, - .threadsPerWarp = {8, 4}, - .warpsPerCTA = {2, 2}, - .order = {0, 1}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - { - .shape = {1}, - .sliceDim = 0, - .parent = - BlockedLegacyVsLinearLayoutsTestParams{ - .sizePerThread = {1, 4}, - .threadsPerWarp = {8, 4}, - .warpsPerCTA = {2, 2}, - .order = {0, 1}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - - { - .shape = {16}, - .sliceDim = 0, - .parent = - NvidiaMmaVsLinearLayoutsTestParams{ - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {2, 2}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - { - .shape = {128}, - .sliceDim = 0, - .parent = - NvidiaMmaVsLinearLayoutsTestParams{ - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {2, 2}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - { - .shape = {16}, - .sliceDim = 1, - .parent = - NvidiaMmaVsLinearLayoutsTestParams{ - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {2, 2}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - { - .shape = {128}, - .sliceDim = 1, - .parent = - NvidiaMmaVsLinearLayoutsTestParams{ - .versionMajor = 2, - .versionMinor = 0, - .warpsPerCTA = {2, 2}, - .instrShape = {16, 8}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - { - .shape = {128}, - .sliceDim = 0, - .parent = - NvidiaMmaVsLinearLayoutsTestParams{ - .versionMajor = 3, - .versionMinor = 0, - .warpsPerCTA = {4, 4}, - .instrShape = {16, 16, 16}, - .CTAsPerCGA = {1, 1}, - .CTASplitNum = {1, 1}, - .CTAOrder = {1, 0}, - }, - }, - }))); - -} // namespace gpu -} // namespace triton -} // namespace mlir - -//===----------------------------------------------------------------------===// -// Main -//===----------------------------------------------------------------------===// - -int main(int argc, char *argv[]) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/unittest/Dialect/TritonGPU/CMakeLists.txt b/unittest/Dialect/TritonGPU/CMakeLists.txt index 3e57b1c00..ad9629323 100644 --- a/unittest/Dialect/TritonGPU/CMakeLists.txt +++ b/unittest/Dialect/TritonGPU/CMakeLists.txt @@ -13,3 +13,9 @@ add_triton_ut( SRCS LinearLayoutConversionsTest.cpp LIBS TritonGPUIR ) + +add_triton_ut( + NAME DumpLayoutTest + SRCS DumpLayoutTest.cpp + LIBS TritonGPUIR +) diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 7e8a56c79..c27c63335 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -6,6 +7,7 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Tools/StrUtil.h" +#include "llvm/Support/Signals.h" namespace { @@ -518,5 +520,111 @@ TEST_F(InferLayoutTest, FuzzReshape) { 100.0 * numSuccess / numTests); } +class AMDMfmaLayoutTest : public ::testing::Test { +public: + AMDMfmaLayoutTest() { + ctx.getOrLoadDialect(); + ctaLayout = + triton::gpu::CTALayoutAttr::get(&ctx, ctaPerCGA, ctaSplit, ctaOrder); + f16Ty = FloatType::getF16(&ctx); + } + + triton::gpu::AMDMfmaEncodingAttr createMFMA(int mDim, int nDim, + ArrayRef warpsPerCTA) { + return triton::gpu::AMDMfmaEncodingAttr::get( + &ctx, /*versionMajor=*/2, /*versionMinor=*/0, warpsPerCTA, mDim, nDim, + /*isTransposed=*/false, ctaLayout); + } + + triton::gpu::AMDMfmaEncodingAttr + createTransposedMFMA(int mDim, int nDim, ArrayRef warpsPerCTA) { + return triton::gpu::AMDMfmaEncodingAttr::get( + &ctx, /*versionMajor=*/2, /*versionMinor=*/0, warpsPerCTA, mDim, nDim, + /*isTransposed=*/true, ctaLayout); + } + + triton::gpu::DotOperandEncodingAttr + createDotOperand(int idx, triton::gpu::AMDMfmaEncodingAttr parent, + int kWidth) { + return triton::gpu::DotOperandEncodingAttr::get(&ctx, idx, parent, kWidth); + } + +protected: + MLIRContext ctx; + const SmallVector ctaPerCGA{1, 1, 1}; + const SmallVector ctaSplit{1, 1, 1}; + const SmallVector ctaOrder{2, 1, 0}; + triton::gpu::CTALayoutAttr ctaLayout; + Type f16Ty; +}; + +TEST_F(AMDMfmaLayoutTest, mfma32) { + auto mfma2d = createMFMA(32, 32, {2, 4}); + ASSERT_THAT(mfma2d.getThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(mfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); + + auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); + ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); + + auto mfma3d = createMFMA(32, 32, {2, 4, 1}); + ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); + + auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); + ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); +} + +TEST_F(AMDMfmaLayoutTest, mfma16) { + auto mfma2d = createMFMA(16, 16, {2, 4}); + ASSERT_THAT(mfma2d.getThreadOrder(), testing::ElementsAre(1u, 0u)); + ASSERT_THAT(mfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); + + auto tmfma2d = createTransposedMFMA(16, 16, {2, 4}); + ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); + + auto mfma3d = createMFMA(16, 16, {2, 4, 1}); + ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); + ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); + + auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1}); + ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); +} + +TEST_F(AMDMfmaLayoutTest, mfma_dot_op) { + auto mfma2d = createMFMA(32, 32, {2, 4}); + auto dot2dOp0 = createDotOperand(0, mfma2d, 4); + auto dot2dOp1 = createDotOperand(1, mfma2d, 4); + ASSERT_THAT(dot2dOp0.getWarpOrder(), mfma2d.getWarpOrder()); + ASSERT_THAT(dot2dOp1.getWarpOrder(), mfma2d.getWarpOrder()); + + auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); + auto tdot2dOp0 = createDotOperand(0, tmfma2d, 4); + auto tdot2dOp1 = createDotOperand(1, tmfma2d, 4); + ASSERT_THAT(tdot2dOp0.getWarpOrder(), tmfma2d.getWarpOrder()); + ASSERT_THAT(tdot2dOp1.getWarpOrder(), tmfma2d.getWarpOrder()); + + auto mfma3d = createMFMA(32, 32, {2, 4, 1}); + auto dot3dOp0 = createDotOperand(0, mfma3d, 4); + auto dot3dOp1 = createDotOperand(1, mfma3d, 4); + ASSERT_THAT(dot3dOp0.getWarpOrder(), mfma3d.getWarpOrder()); + ASSERT_THAT(dot3dOp1.getWarpOrder(), mfma3d.getWarpOrder()); + + auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); + auto tdot3dOp0 = createDotOperand(0, tmfma3d, 4); + auto tdot3dOp1 = createDotOperand(1, tmfma3d, 4); + ASSERT_THAT(tdot3dOp0.getWarpOrder(), tmfma3d.getWarpOrder()); + ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder()); +} + } // anonymous namespace } // namespace mlir::triton::gpu + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp b/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp new file mode 100644 index 000000000..b73086058 --- /dev/null +++ b/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp @@ -0,0 +1,527 @@ +#include "mlir/IR/MLIRContext.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/Support/Signals.h" +#include +#include + +namespace mlir::triton::gpu { +namespace { + +class DumpLayoutTest : public ::testing::Test { +public: + void SetUp() { ctx.getOrLoadDialect(); } + + BlockedEncodingAttr blocked(ArrayRef spt, ArrayRef tpw, + ArrayRef wpb, ArrayRef cpg, + ArrayRef cSplit, ArrayRef ord, + ArrayRef cOrd) { + return BlockedEncodingAttr::get( + &ctx, spt, tpw, wpb, ord, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + + SharedEncodingAttr shared(unsigned vec, unsigned perPhase, unsigned maxPhase, + bool hasLeadingOffset, ArrayRef cpg, + ArrayRef cSplit, ArrayRef ord, + ArrayRef cOrd) { + return SharedEncodingAttr::get(&ctx, vec, perPhase, maxPhase, ord, + CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), + hasLeadingOffset); + } + + void assertSameStr(const std::string &refStr, const std::string &output) { + if (refStr != output) { + llvm::outs() << "RefStr =\n" + << refStr << "\n" + << "\n" + << "Output =\n" + << output << "\n"; + FAIL() << "Incorrect output string"; + } + } + +protected: + MLIRContext ctx; +}; + +TEST_F(DumpLayoutTest, SimpleBlocked) { + std::string ref = + R"([ T0:0| T4:0| T8:0|T12:0|T16:0|T20:0|T24:0|T28:0, T1:0| T5:0| T9:0|T13:0|T17:0|T21:0|T25:0|T29:0, T2:0| T6:0|T10:0|T14:0|T18:0|T22:0|T26:0|T30:0, T3:0| T7:0|T11:0|T15:0|T19:0|T23:0|T27:0|T31:0] +)"; + auto blockedLayout = blocked({1}, {8}, {4}, {1}, {1}, {0}, {0}); + auto tensorType = RankedTensorType::get( + {4}, IntegerType::get(blockedLayout.getContext(), 32), blockedLayout); + std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false); + assertSameStr(ref, layout); + + std::string refHWRep = + R"(Warp0: +(0), (1), (2), (3), (0), (1), (2), (3) +Warp1: +(0), (1), (2), (3), (0), (1), (2), (3) +Warp2: +(0), (1), (2), (3), (0), (1), (2), (3) +Warp3: +(0), (1), (2), (3), (0), (1), (2), (3) +)"; + std::string layoutHW = getLayoutStr(tensorType, /*useHWPointOfView=*/true); + assertSameStr(refHWRep, layoutHW); +} + +TEST_F(DumpLayoutTest, NDTensor) { + auto blockedLayout = blocked({2, 1, 4}, {2, 2, 2}, {1, 2, 1}, {1, 1, 1}, + {1, 1, 1}, {2, 1, 0}, {2, 1, 0}); + auto tensorType = RankedTensorType::get( + {8, 2, 16}, IntegerType::get(blockedLayout.getContext(), 32), + blockedLayout); + std::string ref = + R"([[[ T0:0| T8:0, T0:1| T8:1, T0:2| T8:2, T0:3| T8:3, T1:0| T9:0, T1:1| T9:1, T1:2| T9:2, T1:3| T9:3, T0:8| T8:8, T0:9| T8:9, T0:10| T8:10, T0:11| T8:11, T1:8| T9:8, T1:9| T9:9, T1:10| T9:10, T1:11| T9:11] +[ T2:0| T10:0, T2:1| T10:1, T2:2| T10:2, T2:3| T10:3, T3:0| T11:0, T3:1| T11:1, T3:2| T11:2, T3:3| T11:3, T2:8| T10:8, T2:9| T10:9, T2:10|T10:10, T2:11|T10:11, T3:8| T11:8, T3:9| T11:9, T3:10|T11:10, T3:11|T11:11]] +[[ T0:4| T8:4, T0:5| T8:5, T0:6| T8:6, T0:7| T8:7, T1:4| T9:4, T1:5| T9:5, T1:6| T9:6, T1:7| T9:7, T0:12| T8:12, T0:13| T8:13, T0:14| T8:14, T0:15| T8:15, T1:12| T9:12, T1:13| T9:13, T1:14| T9:14, T1:15| T9:15] +[ T2:4| T10:4, T2:5| T10:5, T2:6| T10:6, T2:7| T10:7, T3:4| T11:4, T3:5| T11:5, T3:6| T11:6, T3:7| T11:7, T2:12|T10:12, T2:13|T10:13, T2:14|T10:14, T2:15|T10:15, T3:12|T11:12, T3:13|T11:13, T3:14|T11:14, T3:15|T11:15]] +[[ T4:0| T12:0, T4:1| T12:1, T4:2| T12:2, T4:3| T12:3, T5:0| T13:0, T5:1| T13:1, T5:2| T13:2, T5:3| T13:3, T4:8| T12:8, T4:9| T12:9, T4:10|T12:10, T4:11|T12:11, T5:8| T13:8, T5:9| T13:9, T5:10|T13:10, T5:11|T13:11] +[ T6:0| T14:0, T6:1| T14:1, T6:2| T14:2, T6:3| T14:3, T7:0| T15:0, T7:1| T15:1, T7:2| T15:2, T7:3| T15:3, T6:8| T14:8, T6:9| T14:9, T6:10|T14:10, T6:11|T14:11, T7:8| T15:8, T7:9| T15:9, T7:10|T15:10, T7:11|T15:11]] +[[ T4:4| T12:4, T4:5| T12:5, T4:6| T12:6, T4:7| T12:7, T5:4| T13:4, T5:5| T13:5, T5:6| T13:6, T5:7| T13:7, T4:12|T12:12, T4:13|T12:13, T4:14|T12:14, T4:15|T12:15, T5:12|T13:12, T5:13|T13:13, T5:14|T13:14, T5:15|T13:15] +[ T6:4| T14:4, T6:5| T14:5, T6:6| T14:6, T6:7| T14:7, T7:4| T15:4, T7:5| T15:5, T7:6| T15:6, T7:7| T15:7, T6:12|T14:12, T6:13|T14:13, T6:14|T14:14, T6:15|T14:15, T7:12|T15:12, T7:13|T15:13, T7:14|T15:14, T7:15|T15:15]] +[[ T0:16| T8:16, T0:17| T8:17, T0:18| T8:18, T0:19| T8:19, T1:16| T9:16, T1:17| T9:17, T1:18| T9:18, T1:19| T9:19, T0:24| T8:24, T0:25| T8:25, T0:26| T8:26, T0:27| T8:27, T1:24| T9:24, T1:25| T9:25, T1:26| T9:26, T1:27| T9:27] +[ T2:16|T10:16, T2:17|T10:17, T2:18|T10:18, T2:19|T10:19, T3:16|T11:16, T3:17|T11:17, T3:18|T11:18, T3:19|T11:19, T2:24|T10:24, T2:25|T10:25, T2:26|T10:26, T2:27|T10:27, T3:24|T11:24, T3:25|T11:25, T3:26|T11:26, T3:27|T11:27]] +[[ T0:20| T8:20, T0:21| T8:21, T0:22| T8:22, T0:23| T8:23, T1:20| T9:20, T1:21| T9:21, T1:22| T9:22, T1:23| T9:23, T0:28| T8:28, T0:29| T8:29, T0:30| T8:30, T0:31| T8:31, T1:28| T9:28, T1:29| T9:29, T1:30| T9:30, T1:31| T9:31] +[ T2:20|T10:20, T2:21|T10:21, T2:22|T10:22, T2:23|T10:23, T3:20|T11:20, T3:21|T11:21, T3:22|T11:22, T3:23|T11:23, T2:28|T10:28, T2:29|T10:29, T2:30|T10:30, T2:31|T10:31, T3:28|T11:28, T3:29|T11:29, T3:30|T11:30, T3:31|T11:31]] +[[ T4:16|T12:16, T4:17|T12:17, T4:18|T12:18, T4:19|T12:19, T5:16|T13:16, T5:17|T13:17, T5:18|T13:18, T5:19|T13:19, T4:24|T12:24, T4:25|T12:25, T4:26|T12:26, T4:27|T12:27, T5:24|T13:24, T5:25|T13:25, T5:26|T13:26, T5:27|T13:27] +[ T6:16|T14:16, T6:17|T14:17, T6:18|T14:18, T6:19|T14:19, T7:16|T15:16, T7:17|T15:17, T7:18|T15:18, T7:19|T15:19, T6:24|T14:24, T6:25|T14:25, T6:26|T14:26, T6:27|T14:27, T7:24|T15:24, T7:25|T15:25, T7:26|T15:26, T7:27|T15:27]] +[[ T4:20|T12:20, T4:21|T12:21, T4:22|T12:22, T4:23|T12:23, T5:20|T13:20, T5:21|T13:21, T5:22|T13:22, T5:23|T13:23, T4:28|T12:28, T4:29|T12:29, T4:30|T12:30, T4:31|T12:31, T5:28|T13:28, T5:29|T13:29, T5:30|T13:30, T5:31|T13:31] +[ T6:20|T14:20, T6:21|T14:21, T6:22|T14:22, T6:23|T14:23, T7:20|T15:20, T7:21|T15:21, T7:22|T15:22, T7:23|T15:23, T6:28|T14:28, T6:29|T14:29, T6:30|T14:30, T6:31|T14:31, T7:28|T15:28, T7:29|T15:29, T7:30|T15:30, T7:31|T15:31]]] +)"; + std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false); + assertSameStr(ref, layout); + std::string refHWRep = + R"(Warp0: +(0,0, 0), (0,0, 4), (0,1, 0), (0,1, 4), (2,0, 0), (2,0, 4), (2,1, 0), (2,1, 4) +(0,0, 1), (0,0, 5), (0,1, 1), (0,1, 5), (2,0, 1), (2,0, 5), (2,1, 1), (2,1, 5) +(0,0, 2), (0,0, 6), (0,1, 2), (0,1, 6), (2,0, 2), (2,0, 6), (2,1, 2), (2,1, 6) +(0,0, 3), (0,0, 7), (0,1, 3), (0,1, 7), (2,0, 3), (2,0, 7), (2,1, 3), (2,1, 7) +(1,0, 0), (1,0, 4), (1,1, 0), (1,1, 4), (3,0, 0), (3,0, 4), (3,1, 0), (3,1, 4) +(1,0, 1), (1,0, 5), (1,1, 1), (1,1, 5), (3,0, 1), (3,0, 5), (3,1, 1), (3,1, 5) +(1,0, 2), (1,0, 6), (1,1, 2), (1,1, 6), (3,0, 2), (3,0, 6), (3,1, 2), (3,1, 6) +(1,0, 3), (1,0, 7), (1,1, 3), (1,1, 7), (3,0, 3), (3,0, 7), (3,1, 3), (3,1, 7) +(0,0, 8), (0,0,12), (0,1, 8), (0,1,12), (2,0, 8), (2,0,12), (2,1, 8), (2,1,12) +(0,0, 9), (0,0,13), (0,1, 9), (0,1,13), (2,0, 9), (2,0,13), (2,1, 9), (2,1,13) +(0,0,10), (0,0,14), (0,1,10), (0,1,14), (2,0,10), (2,0,14), (2,1,10), (2,1,14) +(0,0,11), (0,0,15), (0,1,11), (0,1,15), (2,0,11), (2,0,15), (2,1,11), (2,1,15) +(1,0, 8), (1,0,12), (1,1, 8), (1,1,12), (3,0, 8), (3,0,12), (3,1, 8), (3,1,12) +(1,0, 9), (1,0,13), (1,1, 9), (1,1,13), (3,0, 9), (3,0,13), (3,1, 9), (3,1,13) +(1,0,10), (1,0,14), (1,1,10), (1,1,14), (3,0,10), (3,0,14), (3,1,10), (3,1,14) +(1,0,11), (1,0,15), (1,1,11), (1,1,15), (3,0,11), (3,0,15), (3,1,11), (3,1,15) +(4,0, 0), (4,0, 4), (4,1, 0), (4,1, 4), (6,0, 0), (6,0, 4), (6,1, 0), (6,1, 4) +(4,0, 1), (4,0, 5), (4,1, 1), (4,1, 5), (6,0, 1), (6,0, 5), (6,1, 1), (6,1, 5) +(4,0, 2), (4,0, 6), (4,1, 2), (4,1, 6), (6,0, 2), (6,0, 6), (6,1, 2), (6,1, 6) +(4,0, 3), (4,0, 7), (4,1, 3), (4,1, 7), (6,0, 3), (6,0, 7), (6,1, 3), (6,1, 7) +(5,0, 0), (5,0, 4), (5,1, 0), (5,1, 4), (7,0, 0), (7,0, 4), (7,1, 0), (7,1, 4) +(5,0, 1), (5,0, 5), (5,1, 1), (5,1, 5), (7,0, 1), (7,0, 5), (7,1, 1), (7,1, 5) +(5,0, 2), (5,0, 6), (5,1, 2), (5,1, 6), (7,0, 2), (7,0, 6), (7,1, 2), (7,1, 6) +(5,0, 3), (5,0, 7), (5,1, 3), (5,1, 7), (7,0, 3), (7,0, 7), (7,1, 3), (7,1, 7) +(4,0, 8), (4,0,12), (4,1, 8), (4,1,12), (6,0, 8), (6,0,12), (6,1, 8), (6,1,12) +(4,0, 9), (4,0,13), (4,1, 9), (4,1,13), (6,0, 9), (6,0,13), (6,1, 9), (6,1,13) +(4,0,10), (4,0,14), (4,1,10), (4,1,14), (6,0,10), (6,0,14), (6,1,10), (6,1,14) +(4,0,11), (4,0,15), (4,1,11), (4,1,15), (6,0,11), (6,0,15), (6,1,11), (6,1,15) +(5,0, 8), (5,0,12), (5,1, 8), (5,1,12), (7,0, 8), (7,0,12), (7,1, 8), (7,1,12) +(5,0, 9), (5,0,13), (5,1, 9), (5,1,13), (7,0, 9), (7,0,13), (7,1, 9), (7,1,13) +(5,0,10), (5,0,14), (5,1,10), (5,1,14), (7,0,10), (7,0,14), (7,1,10), (7,1,14) +(5,0,11), (5,0,15), (5,1,11), (5,1,15), (7,0,11), (7,0,15), (7,1,11), (7,1,15) +Warp1: +(0,0, 0), (0,0, 4), (0,1, 0), (0,1, 4), (2,0, 0), (2,0, 4), (2,1, 0), (2,1, 4) +(0,0, 1), (0,0, 5), (0,1, 1), (0,1, 5), (2,0, 1), (2,0, 5), (2,1, 1), (2,1, 5) +(0,0, 2), (0,0, 6), (0,1, 2), (0,1, 6), (2,0, 2), (2,0, 6), (2,1, 2), (2,1, 6) +(0,0, 3), (0,0, 7), (0,1, 3), (0,1, 7), (2,0, 3), (2,0, 7), (2,1, 3), (2,1, 7) +(1,0, 0), (1,0, 4), (1,1, 0), (1,1, 4), (3,0, 0), (3,0, 4), (3,1, 0), (3,1, 4) +(1,0, 1), (1,0, 5), (1,1, 1), (1,1, 5), (3,0, 1), (3,0, 5), (3,1, 1), (3,1, 5) +(1,0, 2), (1,0, 6), (1,1, 2), (1,1, 6), (3,0, 2), (3,0, 6), (3,1, 2), (3,1, 6) +(1,0, 3), (1,0, 7), (1,1, 3), (1,1, 7), (3,0, 3), (3,0, 7), (3,1, 3), (3,1, 7) +(0,0, 8), (0,0,12), (0,1, 8), (0,1,12), (2,0, 8), (2,0,12), (2,1, 8), (2,1,12) +(0,0, 9), (0,0,13), (0,1, 9), (0,1,13), (2,0, 9), (2,0,13), (2,1, 9), (2,1,13) +(0,0,10), (0,0,14), (0,1,10), (0,1,14), (2,0,10), (2,0,14), (2,1,10), (2,1,14) +(0,0,11), (0,0,15), (0,1,11), (0,1,15), (2,0,11), (2,0,15), (2,1,11), (2,1,15) +(1,0, 8), (1,0,12), (1,1, 8), (1,1,12), (3,0, 8), (3,0,12), (3,1, 8), (3,1,12) +(1,0, 9), (1,0,13), (1,1, 9), (1,1,13), (3,0, 9), (3,0,13), (3,1, 9), (3,1,13) +(1,0,10), (1,0,14), (1,1,10), (1,1,14), (3,0,10), (3,0,14), (3,1,10), (3,1,14) +(1,0,11), (1,0,15), (1,1,11), (1,1,15), (3,0,11), (3,0,15), (3,1,11), (3,1,15) +(4,0, 0), (4,0, 4), (4,1, 0), (4,1, 4), (6,0, 0), (6,0, 4), (6,1, 0), (6,1, 4) +(4,0, 1), (4,0, 5), (4,1, 1), (4,1, 5), (6,0, 1), (6,0, 5), (6,1, 1), (6,1, 5) +(4,0, 2), (4,0, 6), (4,1, 2), (4,1, 6), (6,0, 2), (6,0, 6), (6,1, 2), (6,1, 6) +(4,0, 3), (4,0, 7), (4,1, 3), (4,1, 7), (6,0, 3), (6,0, 7), (6,1, 3), (6,1, 7) +(5,0, 0), (5,0, 4), (5,1, 0), (5,1, 4), (7,0, 0), (7,0, 4), (7,1, 0), (7,1, 4) +(5,0, 1), (5,0, 5), (5,1, 1), (5,1, 5), (7,0, 1), (7,0, 5), (7,1, 1), (7,1, 5) +(5,0, 2), (5,0, 6), (5,1, 2), (5,1, 6), (7,0, 2), (7,0, 6), (7,1, 2), (7,1, 6) +(5,0, 3), (5,0, 7), (5,1, 3), (5,1, 7), (7,0, 3), (7,0, 7), (7,1, 3), (7,1, 7) +(4,0, 8), (4,0,12), (4,1, 8), (4,1,12), (6,0, 8), (6,0,12), (6,1, 8), (6,1,12) +(4,0, 9), (4,0,13), (4,1, 9), (4,1,13), (6,0, 9), (6,0,13), (6,1, 9), (6,1,13) +(4,0,10), (4,0,14), (4,1,10), (4,1,14), (6,0,10), (6,0,14), (6,1,10), (6,1,14) +(4,0,11), (4,0,15), (4,1,11), (4,1,15), (6,0,11), (6,0,15), (6,1,11), (6,1,15) +(5,0, 8), (5,0,12), (5,1, 8), (5,1,12), (7,0, 8), (7,0,12), (7,1, 8), (7,1,12) +(5,0, 9), (5,0,13), (5,1, 9), (5,1,13), (7,0, 9), (7,0,13), (7,1, 9), (7,1,13) +(5,0,10), (5,0,14), (5,1,10), (5,1,14), (7,0,10), (7,0,14), (7,1,10), (7,1,14) +(5,0,11), (5,0,15), (5,1,11), (5,1,15), (7,0,11), (7,0,15), (7,1,11), (7,1,15) +)"; + std::string layoutHW = getLayoutStr(tensorType, /*useHWPointOfView=*/true); + assertSameStr(refHWRep, layoutHW); +} + +TEST_F(DumpLayoutTest, Simple1DShared) { + std::string refStr = + "[( 0),( 1),( 2),( 3),( 4),( 5),( 6),( 7),( 8),( " + "9),(10),(11),(12),(13),(14),(15),(16),(17),(18),(19),(20),(21),(22),(23)" + ",(24),(25),(26),(27),(28),(29),(30),(31)]\n"; + + auto sharedLayout = shared(1, /* vec */ + 1, /* perPhase */ + 4, /* maxPhase */ + false, /* hasLeadingOffset */ + {1}, /* cpg */ + {1}, /* csplit */ + {1}, /* ord, row-major */ + {1}); /* cOrd */ + + auto elemTy = FloatType::getF16(sharedLayout.getContext()); + auto tensorType = RankedTensorType::get({32}, elemTy, sharedLayout); + std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false); + assertSameStr(refStr, layout); +} + +TEST_F(DumpLayoutTest, Larger2DShared) { + + std::string refStr = + "[[(0: 0),(0: 1),(0: 2),(0: 3),(0: 4),(0: 5),(0: 6),(0: 7),(0: 8),(0: " + "9),(0:10),(0:11),(0:12),(0:13),(0:14),(0:15),(0:16),(0:17),(0:18),(0:19)" + ",(0:20),(0:21),(0:22),(0:23),(0:24),(0:25),(0:26),(0:27),(0:28),(0:29),(" + "0:30),(0:31)]\n" + "[ (1: 0),(1: 1),(1: 2),(1: 3),(1: 4),(1: 5),(1: 6),(1: 7),(1: 8),(1: " + "9),(1:10),(1:11),(1:12),(1:13),(1:14),(1:15),(1:16),(1:17),(1:18),(1:19)" + ",(1:20),(1:21),(1:22),(1:23),(1:24),(1:25),(1:26),(1:27),(1:28),(1:29),(" + "1:30),(1:31)]\n" + "[ (2: 8),(2: 9),(2:10),(2:11),(2:12),(2:13),(2:14),(2:15),(2: 0),(2: " + "1),(2: 2),(2: 3),(2: 4),(2: 5),(2: 6),(2: " + "7),(2:24),(2:25),(2:26),(2:27),(2:28),(2:29),(2:30),(2:31),(2:16),(2:17)" + ",(2:18),(2:19),(2:20),(2:21),(2:22),(2:23)]\n" + "[ (3: 8),(3: 9),(3:10),(3:11),(3:12),(3:13),(3:14),(3:15),(3: 0),(3: " + "1),(3: 2),(3: 3),(3: 4),(3: 5),(3: 6),(3: " + "7),(3:24),(3:25),(3:26),(3:27),(3:28),(3:29),(3:30),(3:31),(3:16),(3:17)" + ",(3:18),(3:19),(3:20),(3:21),(3:22),(3:23)]\n" + "[ " + "(4:16),(4:17),(4:18),(4:19),(4:20),(4:21),(4:22),(4:23),(4:24),(4:25),(" + "4:26),(4:27),(4:28),(4:29),(4:30),(4:31),(4: 0),(4: 1),(4: 2),(4: " + "3),(4: 4),(4: 5),(4: 6),(4: 7),(4: 8),(4: " + "9),(4:10),(4:11),(4:12),(4:13),(4:14),(4:15)]\n" + "[ " + "(5:16),(5:17),(5:18),(5:19),(5:20),(5:21),(5:22),(5:23),(5:24),(5:25),(" + "5:26),(5:27),(5:28),(5:29),(5:30),(5:31),(5: 0),(5: 1),(5: 2),(5: " + "3),(5: 4),(5: 5),(5: 6),(5: 7),(5: 8),(5: " + "9),(5:10),(5:11),(5:12),(5:13),(5:14),(5:15)]\n" + "[ " + "(6:24),(6:25),(6:26),(6:27),(6:28),(6:29),(6:30),(6:31),(6:16),(6:17),(" + "6:18),(6:19),(6:20),(6:21),(6:22),(6:23),(6: 8),(6: " + "9),(6:10),(6:11),(6:12),(6:13),(6:14),(6:15),(6: 0),(6: 1),(6: 2),(6: " + "3),(6: 4),(6: 5),(6: 6),(6: 7)]\n" + "[ " + "(7:24),(7:25),(7:26),(7:27),(7:28),(7:29),(7:30),(7:31),(7:16),(7:17),(" + "7:18),(7:19),(7:20),(7:21),(7:22),(7:23),(7: 8),(7: " + "9),(7:10),(7:11),(7:12),(7:13),(7:14),(7:15),(7: 0),(7: 1),(7: 2),(7: " + "3),(7: 4),(7: 5),(7: 6),(7: 7)]]\n"; + + auto sharedLayout = shared(8, /* vec */ + 2, /* perPhase */ + 8, /* maxPhase */ + false, /* hasLeadingOffset */ + {1, 1}, /* cpg */ + {1, 1}, /* csplit */ + {1, 0}, /* ord, row-major */ + {1, 0}); /* cOrd */ + + auto elemTy = FloatType::getF16(sharedLayout.getContext()); + auto tensorType = RankedTensorType::get({8, 32}, elemTy, sharedLayout); + std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false); + assertSameStr(refStr, layout); + + std::string refHWRep = + R"(Block: 0: +Offset: 0 -> (0, 0) +Offset: 1 -> (0, 1) +Offset: 2 -> (0, 2) +Offset: 3 -> (0, 3) +Offset: 4 -> (0, 4) +Offset: 5 -> (0, 5) +Offset: 6 -> (0, 6) +Offset: 7 -> (0, 7) +Offset: 8 -> (0, 8) +Offset: 9 -> (0, 9) +Offset: 10 -> (0,10) +Offset: 11 -> (0,11) +Offset: 12 -> (0,12) +Offset: 13 -> (0,13) +Offset: 14 -> (0,14) +Offset: 15 -> (0,15) +Offset: 16 -> (0,16) +Offset: 17 -> (0,17) +Offset: 18 -> (0,18) +Offset: 19 -> (0,19) +Offset: 20 -> (0,20) +Offset: 21 -> (0,21) +Offset: 22 -> (0,22) +Offset: 23 -> (0,23) +Offset: 24 -> (0,24) +Offset: 25 -> (0,25) +Offset: 26 -> (0,26) +Offset: 27 -> (0,27) +Offset: 28 -> (0,28) +Offset: 29 -> (0,29) +Offset: 30 -> (0,30) +Offset: 31 -> (0,31) +Offset: 32 -> (1, 2) +Offset: 33 -> (1, 3) +Offset: 34 -> (1, 0) +Offset: 35 -> (1, 1) +Offset: 36 -> (1, 6) +Offset: 37 -> (1, 7) +Offset: 38 -> (1, 4) +Offset: 39 -> (1, 5) +Offset: 40 -> (1,10) +Offset: 41 -> (1,11) +Offset: 42 -> (1, 8) +Offset: 43 -> (1, 9) +Offset: 44 -> (1,14) +Offset: 45 -> (1,15) +Offset: 46 -> (1,12) +Offset: 47 -> (1,13) +Offset: 48 -> (1,18) +Offset: 49 -> (1,19) +Offset: 50 -> (1,16) +Offset: 51 -> (1,17) +Offset: 52 -> (1,22) +Offset: 53 -> (1,23) +Offset: 54 -> (1,20) +Offset: 55 -> (1,21) +Offset: 56 -> (1,26) +Offset: 57 -> (1,27) +Offset: 58 -> (1,24) +Offset: 59 -> (1,25) +Offset: 60 -> (1,30) +Offset: 61 -> (1,31) +Offset: 62 -> (1,28) +Offset: 63 -> (1,29) +Offset: 64 -> (2, 4) +Offset: 65 -> (2, 5) +Offset: 66 -> (2, 6) +Offset: 67 -> (2, 7) +Offset: 68 -> (2, 0) +Offset: 69 -> (2, 1) +Offset: 70 -> (2, 2) +Offset: 71 -> (2, 3) +Offset: 72 -> (2,12) +Offset: 73 -> (2,13) +Offset: 74 -> (2,14) +Offset: 75 -> (2,15) +Offset: 76 -> (2, 8) +Offset: 77 -> (2, 9) +Offset: 78 -> (2,10) +Offset: 79 -> (2,11) +Offset: 80 -> (2,20) +Offset: 81 -> (2,21) +Offset: 82 -> (2,22) +Offset: 83 -> (2,23) +Offset: 84 -> (2,16) +Offset: 85 -> (2,17) +Offset: 86 -> (2,18) +Offset: 87 -> (2,19) +Offset: 88 -> (2,28) +Offset: 89 -> (2,29) +Offset: 90 -> (2,30) +Offset: 91 -> (2,31) +Offset: 92 -> (2,24) +Offset: 93 -> (2,25) +Offset: 94 -> (2,26) +Offset: 95 -> (2,27) +Offset: 96 -> (3, 6) +Offset: 97 -> (3, 7) +Offset: 98 -> (3, 4) +Offset: 99 -> (3, 5) +Offset: 100 -> (3, 2) +Offset: 101 -> (3, 3) +Offset: 102 -> (3, 0) +Offset: 103 -> (3, 1) +Offset: 104 -> (3,14) +Offset: 105 -> (3,15) +Offset: 106 -> (3,12) +Offset: 107 -> (3,13) +Offset: 108 -> (3,10) +Offset: 109 -> (3,11) +Offset: 110 -> (3, 8) +Offset: 111 -> (3, 9) +Offset: 112 -> (3,22) +Offset: 113 -> (3,23) +Offset: 114 -> (3,20) +Offset: 115 -> (3,21) +Offset: 116 -> (3,18) +Offset: 117 -> (3,19) +Offset: 118 -> (3,16) +Offset: 119 -> (3,17) +Offset: 120 -> (3,30) +Offset: 121 -> (3,31) +Offset: 122 -> (3,28) +Offset: 123 -> (3,29) +Offset: 124 -> (3,26) +Offset: 125 -> (3,27) +Offset: 126 -> (3,24) +Offset: 127 -> (3,25) +Offset: 128 -> (4, 8) +Offset: 129 -> (4, 9) +Offset: 130 -> (4,10) +Offset: 131 -> (4,11) +Offset: 132 -> (4,12) +Offset: 133 -> (4,13) +Offset: 134 -> (4,14) +Offset: 135 -> (4,15) +Offset: 136 -> (4, 0) +Offset: 137 -> (4, 1) +Offset: 138 -> (4, 2) +Offset: 139 -> (4, 3) +Offset: 140 -> (4, 4) +Offset: 141 -> (4, 5) +Offset: 142 -> (4, 6) +Offset: 143 -> (4, 7) +Offset: 144 -> (4,24) +Offset: 145 -> (4,25) +Offset: 146 -> (4,26) +Offset: 147 -> (4,27) +Offset: 148 -> (4,28) +Offset: 149 -> (4,29) +Offset: 150 -> (4,30) +Offset: 151 -> (4,31) +Offset: 152 -> (4,16) +Offset: 153 -> (4,17) +Offset: 154 -> (4,18) +Offset: 155 -> (4,19) +Offset: 156 -> (4,20) +Offset: 157 -> (4,21) +Offset: 158 -> (4,22) +Offset: 159 -> (4,23) +Offset: 160 -> (5,10) +Offset: 161 -> (5,11) +Offset: 162 -> (5, 8) +Offset: 163 -> (5, 9) +Offset: 164 -> (5,14) +Offset: 165 -> (5,15) +Offset: 166 -> (5,12) +Offset: 167 -> (5,13) +Offset: 168 -> (5, 2) +Offset: 169 -> (5, 3) +Offset: 170 -> (5, 0) +Offset: 171 -> (5, 1) +Offset: 172 -> (5, 6) +Offset: 173 -> (5, 7) +Offset: 174 -> (5, 4) +Offset: 175 -> (5, 5) +Offset: 176 -> (5,26) +Offset: 177 -> (5,27) +Offset: 178 -> (5,24) +Offset: 179 -> (5,25) +Offset: 180 -> (5,30) +Offset: 181 -> (5,31) +Offset: 182 -> (5,28) +Offset: 183 -> (5,29) +Offset: 184 -> (5,18) +Offset: 185 -> (5,19) +Offset: 186 -> (5,16) +Offset: 187 -> (5,17) +Offset: 188 -> (5,22) +Offset: 189 -> (5,23) +Offset: 190 -> (5,20) +Offset: 191 -> (5,21) +Offset: 192 -> (6,12) +Offset: 193 -> (6,13) +Offset: 194 -> (6,14) +Offset: 195 -> (6,15) +Offset: 196 -> (6, 8) +Offset: 197 -> (6, 9) +Offset: 198 -> (6,10) +Offset: 199 -> (6,11) +Offset: 200 -> (6, 4) +Offset: 201 -> (6, 5) +Offset: 202 -> (6, 6) +Offset: 203 -> (6, 7) +Offset: 204 -> (6, 0) +Offset: 205 -> (6, 1) +Offset: 206 -> (6, 2) +Offset: 207 -> (6, 3) +Offset: 208 -> (6,28) +Offset: 209 -> (6,29) +Offset: 210 -> (6,30) +Offset: 211 -> (6,31) +Offset: 212 -> (6,24) +Offset: 213 -> (6,25) +Offset: 214 -> (6,26) +Offset: 215 -> (6,27) +Offset: 216 -> (6,20) +Offset: 217 -> (6,21) +Offset: 218 -> (6,22) +Offset: 219 -> (6,23) +Offset: 220 -> (6,16) +Offset: 221 -> (6,17) +Offset: 222 -> (6,18) +Offset: 223 -> (6,19) +Offset: 224 -> (7,14) +Offset: 225 -> (7,15) +Offset: 226 -> (7,12) +Offset: 227 -> (7,13) +Offset: 228 -> (7,10) +Offset: 229 -> (7,11) +Offset: 230 -> (7, 8) +Offset: 231 -> (7, 9) +Offset: 232 -> (7, 6) +Offset: 233 -> (7, 7) +Offset: 234 -> (7, 4) +Offset: 235 -> (7, 5) +Offset: 236 -> (7, 2) +Offset: 237 -> (7, 3) +Offset: 238 -> (7, 0) +Offset: 239 -> (7, 1) +Offset: 240 -> (7,30) +Offset: 241 -> (7,31) +Offset: 242 -> (7,28) +Offset: 243 -> (7,29) +Offset: 244 -> (7,26) +Offset: 245 -> (7,27) +Offset: 246 -> (7,24) +Offset: 247 -> (7,25) +Offset: 248 -> (7,22) +Offset: 249 -> (7,23) +Offset: 250 -> (7,20) +Offset: 251 -> (7,21) +Offset: 252 -> (7,18) +Offset: 253 -> (7,19) +Offset: 254 -> (7,16) +Offset: 255 -> (7,17) +)"; + auto sharedLayoutHW = shared(2, /* vec */ + 1, /* perPhase */ + 32, /* maxPhase */ + false, /* hasLeadingOffset */ + {1, 1}, /* cpg */ + {1, 1}, /* csplit */ + {1, 0}, /* ord, row-major */ + {1, 0}); /* cOrd */ + + auto elemTyHW = FloatType::getF16(sharedLayoutHW.getContext()); + auto tensorTypeHW = RankedTensorType::get({8, 32}, elemTyHW, sharedLayoutHW); + + std::string layoutHW = getLayoutStr(tensorTypeHW, /*useHWPointOfView=*/true); + assertSameStr(refHWRep, layoutHW); +} + +} // anonymous namespace +} // namespace mlir::triton::gpu + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 3a3a57b08..fd65233e5 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -3,7 +3,9 @@ #include "mlir/IR/MLIRContext.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" - +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Signals.h" #include #include @@ -39,10 +41,51 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); } + DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef warps, + ArrayRef order) { + auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order); + return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth); + } + + AMDMfmaEncodingAttr mfma(ArrayRef warps, unsigned mDim, + unsigned nDim, bool isTransposed) { + SmallVector cpg(warps.size(), 1u); + SmallVector cSplit(warps.size(), 1u); + SmallVector cOrd(warps.size()); + std::iota(cOrd.begin(), cOrd.end(), 0); + return AMDMfmaEncodingAttr::get( + &ctx, /*versionMajor=*/2, /*versionMinor=*/0, warps, mDim, nDim, + isTransposed, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + + DotOperandEncodingAttr mfmaDotOp(AMDMfmaEncodingAttr mfma, unsigned opIdx, + unsigned kWidth) { + return DotOperandEncodingAttr::get(&ctx, opIdx, mfma, kWidth); + } + + AMDWmmaEncodingAttr wmma(ArrayRef warps) { + SmallVector cpg(warps.size(), 1u); + SmallVector cSplit(warps.size(), 1u); + SmallVector cOrd(warps.size()); + std::iota(cOrd.begin(), cOrd.end(), 0); + return AMDWmmaEncodingAttr::get( + &ctx, /*version=*/1, warps, + CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + } + SliceEncodingAttr slice(Attribute parent, int dim) { return SliceEncodingAttr::get(&ctx, dim, parent); } + SharedEncodingAttr shared(unsigned vec, unsigned perPhase, unsigned maxPhase, + bool hasLeadingOffset, ArrayRef cpg, + ArrayRef cSplit, ArrayRef ord, + ArrayRef cOrd) { + return SharedEncodingAttr::get(&ctx, vec, perPhase, maxPhase, ord, + CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), + hasLeadingOffset); + } + StringAttr S(StringRef str) { return StringAttr::get(&ctx, str); } protected: @@ -335,16 +378,21 @@ TEST_F(LinearLayoutConversionsTest, MMAv2_Small3D) { } TEST_F(LinearLayoutConversionsTest, MMAv3_64x16) { - EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, {16, 16, 8}, {4, 1}, {1, 1}, - {1, 1}, {1, 0})), - LinearLayout( - { - {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, - {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, - {S("warp"), {{16, 0}, {32, 0}}}, - {S("block"), {}}, - }, - {S("dim0"), S("dim1")})); + SmallVector, 4> instrShapes = { + {16, 16, 8}, {16, 16, 8}, {16, 8, 8}}; + for (auto instrShape : instrShapes) { + SCOPED_TRACE(triton::join(instrShape, ",")); + EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, instrShape, {4, 1}, {1, 1}, + {1, 1}, {1, 0})), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + } } TEST_F(LinearLayoutConversionsTest, MMAv3_128x16) { @@ -453,6 +501,827 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { + EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), + LinearLayout( + { + {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), + LinearLayout( + { + {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { + EXPECT_EQ( + toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), + LinearLayout( + { + {S("register"), + {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + LinearLayout( + { + {S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {0, 8}, + {0, 16}, + {0, 32}, + {64, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + { + S("warp"), + {}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + LinearLayout( + { + {S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {0, 8}, + {0, 16}, + {0, 32}, + {0, 64}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + { + S("warp"), + {}, + }, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { + auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + + EXPECT_EQ(toLinearLayout({32, 32}, mfmaNT), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaNT), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaNT), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto mfmaT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + + EXPECT_EQ(toLinearLayout({32, 32}, mfmaT), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaT), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaT), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 32}, {0, 64}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MFMA16_2x4Warps) { + auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaNT), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, MFMA32_2x4x1Warps) { + auto mfmaNT = mfma(/*warps=*/{2, 4, 1}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + + EXPECT_EQ(toLinearLayout({1, 128, 128}, mfmaNT), + LinearLayout({{S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 0, 32}, + {0, 0, 64}}}, + {S("lane"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 8}, + {0, 0, 16}, + {0, 4, 0}}}, + {S("warp"), {{0, 32, 0}, {0, 64, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({2, 32, 32}, mfmaNT), + LinearLayout( + {{S("register"), {{0, 1, 0}, {0, 2, 0}, {0, 8, 0}, {0, 16, 0}}}, + {S("lane"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 8}, + {0, 0, 16}, + {0, 4, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({2, 64, 32}, mfmaNT), + LinearLayout( + {{S("register"), {{0, 1, 0}, {0, 2, 0}, {0, 8, 0}, {0, 16, 0}}}, + {S("lane"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 8}, + {0, 0, 16}, + {0, 4, 0}}}, + {S("warp"), {{0, 32, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + + auto mfmaT = mfma(/*warps=*/{2, 4, 1}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + + EXPECT_EQ(toLinearLayout({1, 128, 128}, mfmaT), + LinearLayout({{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 8}, + {0, 0, 16}, + {0, 0, 32}, + {0, 0, 64}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 0, 4}}}, + {S("warp"), {{0, 32, 0}, {0, 64, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({2, 32, 32}, mfmaT), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 8}, {0, 0, 16}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 0, 4}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ(toLinearLayout({2, 64, 32}, mfmaT), + LinearLayout( + {{S("register"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 8}, {0, 0, 16}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 16, 0}, + {0, 0, 4}}}, + {S("warp"), {{0, 32, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_lhs_kwidth8) { + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {0, 128}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 16}, {0, 32}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 16}, + {0, 32}, + {0, 64}, + {0, 128}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) { + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ( + toLinearLayout({128, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 128}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ( + toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 128}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_8), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}, + {0, 128}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, + {S("warp"), {{0, 32}, {0, 64}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_lhs_kwidth8) { + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {16, 0}, + {32, 0}, + {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({1, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + { + {0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + }}, + {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ( + toLinearLayout({128, 1}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 0}, {0, 0}, {0, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {0, 128}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/0, /*kWidth=*/8); + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 32}, + {0, 64}, + {0, 128}, + {16, 0}, + {32, 0}, + {64, 0}, + {128, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8_1 = mfmaDotOp(parentMfma_1_8_1, /*opIdx=*/0, /*kWidth=*/8); + + EXPECT_EQ(toLinearLayout({1, 256, 256}, mfmaDot_1_8_1), + LinearLayout({{S("register"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 32}, + {0, 0, 64}, + {0, 0, 128}, + {0, 16, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 128, 0}}}, + {S("lane"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 8, 0}, + {0, 0, 8}, + {0, 0, 16}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) { + auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_4 = mfmaDotOp(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({1, 128}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{0, 0}, {0, 0}, {0, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 16}, {0, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({128, 1}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({256, 256}, mfmaDot_1_4), + LinearLayout( + {{S("register"), + {{1, 0}, + {2, 0}, + {4, 0}, + {32, 0}, + {64, 0}, + {128, 0}, + {0, 64}, + {0, 128}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDot_1_4), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8 = mfmaDotOp(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8); + EXPECT_EQ( + toLinearLayout({256, 256}, mfmaDot_1_8), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {128, 0}, {0, 128}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 64}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDot_1_8_1 = mfmaDotOp(parentMfma_1_8_1, /*opIdx=*/1, /*kWidth=*/8); + + EXPECT_EQ(toLinearLayout({1, 256, 256}, mfmaDot_1_8_1), + LinearLayout({{S("register"), + {{0, 1, 0}, + {0, 2, 0}, + {0, 4, 0}, + {0, 32, 0}, + {0, 64, 0}, + {0, 128, 0}, + {0, 0, 128}}}, + {S("lane"), + {{0, 0, 1}, + {0, 0, 2}, + {0, 0, 4}, + {0, 0, 8}, + {0, 8, 0}, + {0, 16, 0}}}, + {S("warp"), {{0, 0, 16}, {0, 0, 32}, {0, 0, 64}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_lhs_kwidth4) { + auto parentMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp0_32 = mfmaDotOp(parentMfma32, /*opIdx=*/0, /*kWidth=*/4); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {0, 32}, {0, 64}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp0_32), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 8}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {0, 4}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp0_32 = mfmaDotOp(parentTMfma32, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp0_32), + toLinearLayout({128, 128}, mfmaDotOp0_32)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp0_32), + toLinearLayout({64, 32}, mfmaDotOp0_32)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp0_32), + toLinearLayout({16, 16}, mfmaDotOp0_32)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_lhs_kwidth4) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp0_16 = mfmaDotOp(parentMfma16, /*opIdx=*/0, /*kWidth=*/4); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 16}, {0, 32}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 32}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}, {0, 16}, {32, 0}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp0_16), + LinearLayout( + {{S("register"), {{0, 1}, {0, 2}}}, + {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + auto tmfmaDotOp0_16 = mfmaDotOp(parentTMfma16, /*opIdx=*/0, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp0_16), + toLinearLayout({128, 128}, mfmaDotOp0_16)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp0_16), + toLinearLayout({64, 32}, mfmaDotOp0_16)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp0_16), + toLinearLayout({16, 16}, mfmaDotOp0_16)); +} + +TEST_F(LinearLayoutConversionsTest, mfma32_dot_op_rhs_kwidth4) { + auto parentMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/false); + auto mfmaDotOp1_32 = mfmaDotOp(parentMfma32, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ( + toLinearLayout({128, 128}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}}}, + {S("warp"), {{0, 32}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp1_32), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma32 = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, + /*isTransposed=*/true); + auto tmfmaDotOp1_32 = mfmaDotOp(parentTMfma32, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp1_32), + toLinearLayout({128, 128}, mfmaDotOp1_32)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp1_32), + toLinearLayout({64, 32}, mfmaDotOp1_32)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp1_32), + toLinearLayout({16, 16}, mfmaDotOp1_32)); +} + +TEST_F(LinearLayoutConversionsTest, mfma16_dot_op_rhs_kwidth4) { + auto parentMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/false); + auto mfmaDotOp1_16 = mfmaDotOp(parentMfma16, /*opIdx=*/1, /*kWidth=*/4); + EXPECT_EQ(toLinearLayout({128, 128}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), + {{1, 0}, {2, 0}, {16, 0}, {32, 0}, {64, 0}, {0, 64}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({32, 64}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({16, 16}, mfmaDotOp1_16), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}, {8, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + + // Dot operand based on transposed mfma layout has same layout as ordinary + auto parentTMfma16 = mfma(/*warps=*/{2, 4}, /*mDim=*/16, /*nDim=*/16, + /*isTransposed=*/true); + auto tmfmaDotOp1_16 = mfmaDotOp(parentTMfma16, /*opIdx=*/1, /*kWidth=*/4); + + EXPECT_EQ(toLinearLayout({128, 128}, tmfmaDotOp1_16), + toLinearLayout({128, 128}, mfmaDotOp1_16)); + EXPECT_EQ(toLinearLayout({64, 32}, tmfmaDotOp1_16), + toLinearLayout({64, 32}, mfmaDotOp1_16)); + EXPECT_EQ(toLinearLayout({16, 16}, tmfmaDotOp1_16), + toLinearLayout({16, 16}, mfmaDotOp1_16)); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_2x4Warps) { + auto legacy = wmma(/*warps=*/{2, 4}); + + EXPECT_EQ(toLinearLayout({16, 16}, legacy), + LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + // For 32x16, we need 2x1 WMMA instances. We have 2x4 warps, so we are + // broadcasted along the warp N dimension, distributed along the warp M + // dimension. + EXPECT_EQ(toLinearLayout({32, 16}, legacy), + LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + // For 16x32, we need 1x2 WMMA instances. We have 2x4 warps, so along the warp + // N dimension, warp 0/2 gets the first distributed instance, warp 1/3 gets + // the second distributed instance. Along the warp M dimension, all are + // broadcasted. + EXPECT_EQ(toLinearLayout({16, 32}, legacy), + LinearLayout({{S("register"), {{2, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 16}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + // For 128x128, we need 8x8 WMMA instances. Given that we have 2x4 warps, each + // warp handles 4x2 instances. So for both the warp M and N dimension, we + // distribute. The register dimension will handle (8 x 4x2 =) 64 values--those + // additonal base vectors after the intrinsic shape are next power of two + // values following the warp dimension, given that we are tiling cyclically + // among warps. + EXPECT_EQ(toLinearLayout({128, 128}, legacy), + LinearLayout({{S("register"), + {{2, 0}, {4, 0}, {8, 0}, {0, 64}, {32, 0}, {64, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {{0, 16}, {0, 32}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, WMMA_2x4x1Warps) { + auto legacy = wmma(/*warps=*/{2, 4, 1}); + + EXPECT_EQ( + toLinearLayout({1, 16, 16}, legacy), + LinearLayout( + {{S("register"), {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({2, 16, 16}, legacy), + LinearLayout( + {{S("register"), {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); + EXPECT_EQ( + toLinearLayout({8, 16, 16}, legacy), + LinearLayout( + {{S("register"), + {{0, 2, 0}, {0, 4, 0}, {0, 8, 0}, {2, 0, 0}, {4, 0, 0}}}, + {S("lane"), {{0, 0, 1}, {0, 0, 2}, {0, 0, 4}, {0, 0, 8}, {0, 1, 0}}}, + {S("warp"), {{0, 0, 0}, {0, 0, 0}, {1, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2")})); +} + TEST_F(LinearLayoutConversionsTest, SliceOfBlocked) { auto parent = blocked({2, 4}, {4, 2}, {2, 2}, {2, 2}, {2, 2}, {1, 0}, {1, 0}); EXPECT_EQ(toLinearLayout({128}, slice(parent, 0)), @@ -509,8 +1378,8 @@ TEST_F(LinearLayoutConversionsTest, SliceOfMmaV2) { {S("block"), {}}}, {S("dim0")})); EXPECT_EQ(toLinearLayout({8}, slice(parent, 1)), - LinearLayout({{S("register"), {{4}}}, - {S("lane"), {{0}, {0}, {1}, {2}, {0}}}, + LinearLayout({{S("register"), {{0}}}, + {S("lane"), {{0}, {0}, {1}, {2}, {4}}}, {S("warp"), {{0}, {0}}}, {S("block"), {}}}, {S("dim0")})); @@ -522,5 +1391,266 @@ TEST_F(LinearLayoutConversionsTest, SliceOfMmaV2) { {S("dim0")})); } +TEST_F(LinearLayoutConversionsTest, SharedSimple1D) { + EXPECT_EQ(toLinearLayout({1024}, shared(1, 1, 1, false, {1}, {1}, {0}, {0})), + LinearLayout::identity1D(1024, S("offset"), S("dim0")) * + LinearLayout::identity1D(1, S("block"), S("dim0"))); +} + +TEST_F(LinearLayoutConversionsTest, SharedSimple2D) { + EXPECT_EQ(toLinearLayout({128, 128}, shared(1, 1, 1, false, {1, 1}, {1, 1}, + {1, 0}, {1, 0})), + (LinearLayout::identity1D(128, S("offset"), S("dim1")) * + LinearLayout::identity1D(128, S("offset"), S("dim0")) * + LinearLayout::identity1D(1, S("block"), S("dim0"))) + .transposeOuts({S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSimple2D_Order01) { + EXPECT_EQ(toLinearLayout({128, 128}, shared(1, 1, 1, false, {1, 1}, {1, 1}, + {0, 1}, {1, 0})), + LinearLayout::identity1D(128, S("offset"), S("dim0")) * + LinearLayout::identity1D(128, S("offset"), S("dim1")) * + LinearLayout::identity1D(1, S("block"), S("dim0"))); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_MaxPhaseOnly) { + EXPECT_EQ(toLinearLayout({32, 32}, shared(1, 1, 4, false, {1, 1}, {1, 1}, + {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {1, 1}, + {2, 2}, + {4, 0}, + {8, 0}, + {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_PerPhaseMaxPhase) { + EXPECT_EQ(toLinearLayout({32, 32}, shared(1, 2, 4, false, {1, 1}, {1, 1}, + {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {1, 0}, + {2, 1}, + {4, 2}, + {8, 0}, + {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_Vec) { + EXPECT_EQ( + toLinearLayout({4, 8}, + shared(2, 1, 4, false, {1, 1}, {1, 1}, {1, 0}, {1, 0})), + LinearLayout({{S("offset"), {{0, 1}, {0, 2}, {0, 4}, {1, 2}, {2, 4}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_PerPhaseMaxPhaseVec) { + EXPECT_EQ(toLinearLayout({32, 32}, shared(2, 2, 4, false, {1, 1}, {1, 1}, + {1, 0}, {1, 0})), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {1, 0}, + {2, 2}, + {4, 4}, + {8, 0}, + {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled4D) { + EXPECT_EQ(toLinearLayout({2, 4, 32, 32}, + shared(2, 2, 4, false, {1, 1, 1, 1}, {1, 1, 1, 1}, + {3, 2, 1, 0}, {3, 2, 1, 0})), + LinearLayout({{S("offset"), + {{0, 0, 0, 1}, + {0, 0, 0, 2}, + {0, 0, 0, 4}, + {0, 0, 0, 8}, + {0, 0, 0, 16}, + {0, 0, 1, 0}, + {0, 0, 2, 2}, + {0, 0, 4, 4}, + {0, 0, 8, 0}, + {0, 0, 16, 0}, + {0, 1, 0, 0}, + {0, 2, 0, 0}, + {1, 0, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); +} + +TEST_F(LinearLayoutConversionsTest, SharedSwizzled2D_Order01) { + EXPECT_EQ( + toLinearLayout({4, 8}, + shared(1, 1, 4, false, {1, 1}, {1, 1}, {0, 1}, {0, 1})), + LinearLayout({{S("offset"), {{1, 0}, {2, 0}, {1, 1}, {2, 2}, {0, 4}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_8x16_4_2) { + EXPECT_EQ( + toLinearLayout({8, 16}, + shared(8, 4, 2, true, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + /*elemBitWidth=*/16), + LinearLayout({{S("offset"), + {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}, {4, 8}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_128x16_4_2) { + EXPECT_EQ( + toLinearLayout({128, 16}, + shared(8, 4, 2, true, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + /*elemBitWidth=*/16), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {1, 0}, + {2, 0}, + {4, 8}, + {8, 0}, + {16, 0}, + {32, 0}, + {64, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_8x32_2_4) { + EXPECT_EQ( + toLinearLayout({8, 32}, + shared(8, 2, 4, true, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + /*elemBitWidth=*/16), + LinearLayout( + {{S("offset"), + {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}, {2, 8}, {4, 16}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_8x64_1_8) { + EXPECT_EQ(toLinearLayout( + {8, 64}, shared(8, 1, 8, true, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + /*elemBitWidth=*/16), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {0, 32}, + {1, 8}, + {2, 16}, + {4, 32}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, LeadingOffset_8x64_1_8_32b) { + EXPECT_EQ(toLinearLayout( + {8, 64}, shared(4, 1, 8, true, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + /*elemBitWidth=*/32), + LinearLayout({{S("offset"), + {{0, 1}, + {0, 2}, + {0, 4}, + {0, 8}, + {0, 16}, + {1, 4}, + {2, 8}, + {4, 16}, + {0, 32}}}, + {S("block"), {}}}, + {{S("dim0"), 8}, {S("dim1"), 64}}, + /*requireSurjective=*/false)); +} + +TEST_F(LinearLayoutConversionsTest, Shared1DSwizzle) { + EXPECT_EQ(toLinearLayout( + {64, 1}, shared(2, 2, 4, false, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + /*elemBitWidth=*/16), + LinearLayout::identity1D(64, S("offset"), S("dim0")) * + LinearLayout::identity1D(1, S("offset"), S("dim1")) * + LinearLayout::identity1D(1, S("block"), S("dim0"))); +} + +TEST_F(LinearLayoutConversionsTest, ChooseShmemLayout) { + LinearLayout ll = LinearLayout({{S("register"), {{1}, {2}, {2}, {8}}}, + {S("lane"), {{8}, {4}, {1}}}, + {S("warp"), {{16}, {32}, {0}}}, + {S("block"), {}}}, + {S("dim0")}); + EXPECT_EQ(chooseShemLayoutForRegToRegConversion(&ctx, /*tensorShape=*/{64}, + /*repShape=*/{64}, + /*order=*/{0}), + LinearLayout({{S("offset"), {{1}, {2}, {4}, {8}, {16}, {32}}}, + {S("iteration"), {}}, + {S("block"), {}}}, + {S("dim0")})); +} + +TEST_F(LinearLayoutConversionsTest, ChooseShmemLayout_Empty) { + LinearLayout ll = LinearLayout({{S("register"), {{0}}}, + {S("lane"), {{0}}}, + {S("warp"), {{0}}}, + {S("block"), {}}}, + {S("dim0")}); + EXPECT_EQ( + chooseShemLayoutForRegToRegConversion(&ctx, /*tensorShape=*/{}, + /*repShape=*/{}, /*order=*/{}), + LinearLayout({{S("offset"), {}}, {S("iteration"), {}}, {S("block"), {}}}, + {})); +} + +TEST_F(LinearLayoutConversionsTest, ChooseShmemLayout_Multidim) { + LinearLayout src( + {{S("register"), {}}, + {S("lane"), + {{0, 0, 1, 0}, {0, 0, 2, 0}, {1, 0, 0, 0}, {2, 0, 0, 0}, {0, 0, 0, 1}}}, + {S("warp"), {{0, 0, 0, 2}, {0, 1, 0, 0}, {0, 2, 0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1"), S("dim2"), S("dim3")}); + EXPECT_EQ( + chooseShemLayoutForRegToRegConversion(&ctx, /*tensorShape=*/{4, 4, 4, 4}, + /*repShape=*/{2, 2, 2, 2}, + /*order=*/{3, 2, 1, 0}), + LinearLayout({{S("offset"), + {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}}}, + {S("iteration"), + {{2, 0, 0, 0}, {0, 2, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 2}}}, + {S("block"), {}}}, + {S("dim3"), S("dim2"), S("dim1"), S("dim0")})); +} + } // anonymous namespace } // namespace mlir::triton::gpu + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/unittest/Dialect/TritonGPU/SwizzleTest.cpp b/unittest/Dialect/TritonGPU/SwizzleTest.cpp index b0c11681b..4a279d99e 100644 --- a/unittest/Dialect/TritonGPU/SwizzleTest.cpp +++ b/unittest/Dialect/TritonGPU/SwizzleTest.cpp @@ -1,4 +1,5 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Signals.h" #include using namespace mlir; @@ -56,3 +57,9 @@ INSTANTIATE_TEST_SUITE_P(TestDotOperands, SwizzleDotOperandTestFixture, ParamT{{32, 32}, 1, 16, {8, 2, 4}}, ParamT{{16, 16}, 0, 16, {8, 4, 2}}, ParamT{{16, 16}, 1, 16, {8, 4, 2}})); + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index 696fe9ae8..f00644700 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -1,10 +1,9 @@ #include "triton/Tools/LinearLayout.h" -#include "mlir/Support/LLVM.h" -#include "llvm/Support/MathExtras.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Signals.h" #include #include -#include namespace mlir { std::ostream &operator<<(std::ostream &os, StringAttr str) { @@ -187,7 +186,7 @@ TEST_F(LinearLayoutTest, TransposeIns) { TEST_F(LinearLayoutTest, EmptyToString) { // Mostly I just want to make sure it doesn't crash. - EXPECT_EQ(LinearLayout::empty().toString(), "(empty layout)\n"); + EXPECT_EQ(LinearLayout::empty().toString(), "\n(empty layout)"); } TEST_F(LinearLayoutTest, Apply) { @@ -196,7 +195,7 @@ TEST_F(LinearLayoutTest, Apply) { {S("in1"), {{4, 2}, {2, 1}, {1, 0}}}, {S("in2"), {{1, 2}, {2, 1}}}, }, - {S("out1"), S("out2")}); + {{S("out1"), 8}, {S("out2"), 4}}, /*requireSurjective=*/false); EXPECT_THAT(layout.apply({{S("in1"), 0}, {S("in2"), 0}}), ElementsAre(Pair(S("out1"), 0), Pair(S("out2"), 0))); EXPECT_THAT(layout.apply({{S("in2"), 0}, {S("in1"), 1}}), @@ -225,16 +224,19 @@ TEST_F(LinearLayoutTest, Compose) { {S("out1"), S("out2")}); LinearLayout l2( { - {S("out1"), {{2, 2}, {1, 1}}}, + {S("out1"), {{2, 2}, {1, 0}}}, {S("out2"), {{1, 1}, {2, 1}}}, }, {S("out3"), S("out4")}); - EXPECT_EQ(l1.compose(l2), LinearLayout( - { - {S("in1"), {{3, 3}, {1, 1}}}, - {S("in2"), {{2, 2}, {0, 3}}}, - }, - {S("out3"), S("out4")})); + LinearLayout composition = l1.compose(l2); + EXPECT_EQ(composition, + LinearLayout( + { + {S("in1"), {{3, 3}, {1, 1}}}, + {S("in2"), {{2, 2}, {0, 3}}}, + }, + {{S("out3"), 4}, {S("out4"), 4}}, /*requireSurjective=*/false)); + EXPECT_FALSE(composition.isSurjective()); } TEST_F(LinearLayoutTest, Compose4D) { @@ -258,17 +260,498 @@ TEST_F(LinearLayoutTest, Compose4D) { {S("out2"), {{0, 0, 0, 1}, {0, 0, 0, 2}}}, }, {S("out3"), S("out2"), S("out1"), S("out0")}); - EXPECT_EQ(l1.compose(l2), - LinearLayout( - { - {S("in0"), {{1, 0, 0, 0}, {2, 0, 0, 0}}}, - {S("in1"), - {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}}}, - {S("in2"), {{0, 0, 1, 0}, {0, 0, 0, 1}, {0, 0, 0, 2}}}, - {S("in3"), {}}, - }, - {S("out3"), S("out2"), S("out1"), S("out0")})); + EXPECT_EQ( + l1.compose(l2), + LinearLayout( + { + {S("in0"), {{1, 0, 0, 0}, {2, 0, 0, 0}}}, + {S("in1"), + {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}}}, + {S("in2"), {{0, 0, 1, 0}, {0, 0, 0, 1}, {0, 0, 0, 2}}}, + {S("in3"), {}}, + }, + {{S("out3"), 4}, {S("out2"), 2}, {S("out1"), 2}, {S("out0"), 4}}, + /*requireSurjective=*/false)); +} + +TEST_F(LinearLayoutTest, ReshapeIns) { + LinearLayout ll({{S("in1"), {{1}, {4}, {8}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ(ll.reshapeIns({{S("in3"), {2}}, {S("in4"), {8}}}), + LinearLayout({{S("in3"), {{1}}}, {S("in4"), {{4}, {8}, {2}}}}, + {S("out")})); +} + +TEST_F(LinearLayoutTest, ReshapeInsDegenerateIn) { + LinearLayout ll({{S("in1"), {{1}, {4}, {2}}}, {S("in2"), {}}}, {S("out")}); + EXPECT_EQ( + ll.reshapeIns({{S("in3"), {4}}, {S("in4"), {2}}}), + LinearLayout({{S("in3"), {{1}, {4}}}, {S("in4"), {{2}}}}, {S("out")})); +} + +TEST_F(LinearLayoutTest, ReshapeInsDegenerateOut) { + LinearLayout ll({{S("in1"), {{1}, {4}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ( + ll.reshapeIns({{S("in3"), {8}}, {S("in4"), {1}}}), + LinearLayout({{S("in3"), {{1}, {4}, {2}}}, {S("in4"), {}}}, {S("out")})); +} + +TEST_F(LinearLayoutTest, ReshapeInsDegenerateFirstOut) { + LinearLayout ll({{S("in1"), {{1}, {4}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ( + ll.reshapeIns({{S("in3"), {1}}, {S("in4"), {8}}}), + LinearLayout({{S("in3"), {}}, {S("in4"), {{1}, {4}, {2}}}}, {S("out")})); +} + +TEST_F(LinearLayoutTest, FlattenIns) { + LinearLayout ll({{S("in1"), {{1}, {4}, {8}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ(ll.flattenIns(), + LinearLayout({{S("in1"), {{1}, {4}, {8}, {2}}}}, {S("out")})); +} + +TEST_F(LinearLayoutTest, FlattenInsEdgeCases) { + EXPECT_EQ(LinearLayout({{S("in1"), {}}}, {S("out")}).flattenIns(), + LinearLayout({{S("in1"), {}}}, {S("out")})); + EXPECT_EQ(LinearLayout({{S("in1"), {}}}, {}).flattenIns(), + LinearLayout({{S("in1"), {}}}, {})); + using BasesArray = + ArrayRef>>>; + EXPECT_EQ(LinearLayout(BasesArray{}, {S("out")}).flattenIns(), + LinearLayout(BasesArray{}, {S("out")})); + EXPECT_EQ(LinearLayout(BasesArray{}, {}).flattenIns(), + LinearLayout(BasesArray{}, {})); +} + +TEST_F(LinearLayoutTest, ReshapeOuts) { + LinearLayout ll({{S("in1"), {{1}, {4}, {8}}}, {S("in2"), {{3}}}}, {S("out")}); + EXPECT_EQ(ll.getTotalOutDimSize(), 16); + EXPECT_EQ( + ll.reshapeOuts({{S("out2"), {2}}, {S("out3"), {8}}}), + LinearLayout({{S("in1"), {{1, 0}, {0, 2}, {0, 4}}}, {S("in2"), {{1, 1}}}}, + {S("out2"), S("out3")})); +} + +TEST_F(LinearLayoutTest, ReshapeOutsDegenerateIn) { + LinearLayout ll({{S("in1"), {{1}, {4}, {2}}}, {S("in2"), {}}}, {S("out")}); + EXPECT_EQ(ll.reshapeOuts({{S("out1"), {4}}, {S("out2"), {2}}}), + LinearLayout({{S("in1"), {{1, 0}, {0, 1}, {2, 0}}}, {S("in2"), {}}}, + {S("out1"), S("out2")})); +} + +TEST_F(LinearLayoutTest, ReshapeOutsDegenerateOut) { + LinearLayout ll({{S("in1"), {{1}, {4}}}, {S("in2"), {{2}}}}, {S("out")}); + EXPECT_EQ(ll.reshapeOuts({{S("out1"), {8}}, {S("out2"), {1}}}), + LinearLayout({{S("in1"), {{1, 0}, {4, 0}}}, {S("in2"), {{2, 0}}}}, + {S("out1"), S("out2")})); +} + +TEST_F(LinearLayoutTest, FlattenOuts) { + LinearLayout ll({{S("in1"), {{1, 0}, {4, 1}, {8, 4}}}, {S("in2"), {{3, 2}}}}, + {{S("out1"), 16}, {S("out2"), 8}}, + /*requireSurjective=*/false); + EXPECT_EQ(ll.flattenOuts(), + LinearLayout({{S("in1"), {{1}, {4 + 16}, {8 + 4 * 16}}}, + {S("in2"), {{3 + 2 * 16}}}}, + {{S("out1"), 16 * 8}}, /*requireSurjective=*/false)); +} + +TEST_F(LinearLayoutTest, FlattenOutsEdgeCases) { + EXPECT_EQ(LinearLayout({{S("in1"), {}}}, {S("out")}).flattenOuts(), + LinearLayout({{S("in1"), {}}}, {S("out")})); + EXPECT_EQ(LinearLayout({{S("in1"), {}}}, {}).flattenOuts(), + LinearLayout({{S("in1"), {}}}, {})); + using BasesArray = + ArrayRef>>>; + EXPECT_EQ(LinearLayout(BasesArray{}, {S("out")}).flattenOuts(), + LinearLayout(BasesArray{}, {S("out")})); + EXPECT_EQ(LinearLayout(BasesArray{}, {}).flattenOuts(), + LinearLayout(BasesArray{}, {})); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_Simple) { + LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in2"), {{4}, {1}, {2}}}}, {S("out")}); + + // Inverse of l2 is + // out(1) => in2=2 + // out(2) => in2=4 + // out(4) => in2=1. + // + // Composing with l1 gives + // l2^-1(l1(1)) = l2^-1(2) = 4 + // l2^-1(l1(2)) = l2^-1(1) = 2 + // l2^-1(l1(4)) = l2^-1(4) = 1 + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in1"), {{4}, {2}, {1}}}}, {S("in2")})); + // L2 ∘ L2^-1 ∘ L1 == L1. + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_NonInjective) { + LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in2"), {{0}, {2}, {1}, {4}}}}, {S("out")}); + + // The pseudo-inverse of l2 is + // out(1) => in2=4 + // out(2) => in2=2 + // out(4) => in2=8. + // + // Composing with l1 gives + // l2^-1(l1(1)) = l2^-1(2) = 2 + // l2^-1(l1(2)) = l2^-1(0) = 4 + // l2^-1(l1(4)) = l2^-1(4) = 8 + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in1"), {{2}, {4}, {8}}}}, {{S("in2"), 16}}, + /*requireSurjective=*/false)); + EXPECT_FALSE(composition.isSurjective()); + + // L2 ∘ L2^-1 ∘ L1 == L1. + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_SmallerResult) { + // The domain of l2 is [0,16), but the codomain of the result is only [0,8), + // because there's no value v in the codomain of l1 such that l2^-1(v) >= 8. + LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in2"), {{4}, {1}, {2}, {8}}}}, {S("out")}); + // Pseudo-inverse of l2 is + // + // out(1) = 2 + // out(2) = 4 + // out(4) = 1 + // out(8) = 8 + // + // Composing with l1 gives back l2^-1 without the out(8) entry. + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in1"), {{2}, {4}, {1}}}}, {{S("in2"), 16}}, + /*requireSurjective=*/false)); + EXPECT_TRUE(composition.compose(l2).equalIgnoringOutDimSizes(l1)); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedInDim) { + LinearLayout l1({{S("in1"), {{2}, {1}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout l2({{S("in"), {{4}, {1}, {2}}}}, {S("out")}); + // Inverse of l2 is + // out(1) = 2 + // out(2) = 4 + // out(4) = 1 + // + // Composing with l1 gives + // + // l2^-1(l1(1, 0)) = l2^-1(2) = 4 + // l2^-1(l1(2, 0)) = l2^-1(1) = 2 + // l2^-1(l1(4, 0)) = l2^-1(4) = 1 + // l2^-1(l1(0, 1)) = l2^-1(0) = 0 + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in1"), {{4}, {2}, {1}}}, {S("in2"), {{0}}}}, + {S("in")})); + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastAtBeginningOfSecond) { + LinearLayout l1({{S("in"), {{1}, {2}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in"), {{0}, {4}, {1}, {2}}}}, {S("out")}); + // Pseudo-inverse of l2 is + // out(1) = 4 + // out(2) = 8 + // out(4) = 2 + // + // l1 is the identity, so composing with l1 gives back l2^-1. + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in"), {{4}, {8}, {2}}}}, {{S("in"), 16}}, + /*requireSurjective=*/false)); + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastAtEndOfSecond) { + LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in2"), {{4}, {1}, {2}, {0}}}}, {S("out")}); + // Pseudo-inverse of l2 is + // + // out(1) = 2 + // out(2) = 4 + // out(4) = 1 + // + // l1 is the identity, so composing with l1 gives back l2^-1. + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in1"), {{2}, {4}, {1}}}}, {{S("in2"), 16}}, + /*requireSurjective=*/false)); + EXPECT_TRUE(composition.compose(l2).equalIgnoringOutDimSizes(l1)); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastBeginningAndEndOfSecond) { + LinearLayout l1({{S("in"), {{1}, {2}, {4}}}}, {S("out")}); + LinearLayout l2({{S("in"), {{0}, {4}, {1}, {2}, {0}}}}, {S("out")}); + LinearLayout composition = l1.invertAndCompose(l2); + EXPECT_EQ(composition, + LinearLayout({{S("in"), {{4}, {8}, {2}}}}, {{S("in"), 32}}, + /*requireSurjective=*/false)); + EXPECT_EQ(composition.compose(l2), l1); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_Multidim) { + LinearLayout l1( + {{S("in1"), {{1, 0}, {0, 1}, {2, 0}, {3, 2}}}, {S("in2"), {{2, 2}}}}, + {S("out1"), S("out2")}); + LinearLayout l2({{S("in3"), {{0, 1}, {1, 0}, {0, 0}, {0, 2}, {2, 1}}}}, + {S("out2"), S("out1")}); + + LinearLayout c1 = l1.invertAndCompose(l2); + EXPECT_EQ(c1.compose(l2), + l1.transposeOuts(llvm::to_vector(l2.getOutDimNames()))); + + LinearLayout c2 = l2.invertAndCompose(l1); + EXPECT_EQ(c2.compose(l1), + l2.transposeOuts(llvm::to_vector(l1.getOutDimNames()))); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedDims) { + LinearLayout l1({{S("in1"), {{1}, {2}, {4}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout l2({{S("in3"), {{1}, {2}, {4}}}, {S("in4"), {{0}}}}, {S("out")}); + LinearLayout c = l1.invertAndCompose(l2); + EXPECT_EQ(c, LinearLayout::identity1D(8, S("in1"), S("in3")) * + LinearLayout::identity1D(2, S("in2"), S("in4"))); + EXPECT_EQ(c.compose(l2), + l1.transposeOuts(llvm::to_vector(l2.getOutDimNames()))); +} + +TEST_F(LinearLayoutTest, InvertAndCompose_BroadcastedDims2) { + LinearLayout a({{S("in1"), {{1}, {2}}}, {S("in2"), {{0}}}}, {S("out")}); + LinearLayout b({{S("in3"), {{2}, {1}}}, {S("in4"), {{0}}}}, {S("out")}); + LinearLayout c = a.invertAndCompose(b); + EXPECT_EQ(c, + LinearLayout({{S("in1"), {{2, 0}, {1, 0}}}, {S("in2"), {{0, 1}}}}, + {S("in3"), S("in4")})); + EXPECT_EQ(c.compose(b), a.transposeOuts(llvm::to_vector(b.getOutDimNames()))); +} + +TEST_F(LinearLayoutTest, NumConsecutiveInOut) { + EXPECT_EQ( + 1, + LinearLayout::identity1D(1, S("in"), S("out")).getNumConsecutiveInOut()); + EXPECT_EQ( + 4, + LinearLayout::identity1D(4, S("in"), S("out")).getNumConsecutiveInOut()); + EXPECT_EQ(4, (LinearLayout::identity1D(4, S("in1"), S("out")) * + LinearLayout::identity1D(8, S("in2"), S("out"))) + .getNumConsecutiveInOut()); + EXPECT_EQ(4, (LinearLayout::identity1D(4, S("in"), S("out1")) * + LinearLayout::identity1D(8, S("in"), S("out2"))) + .getNumConsecutiveInOut()); + EXPECT_EQ(1, (LinearLayout::zeros1D(4, S("in"), S("out1")) * + LinearLayout::identity1D(4, S("in"), S("out2"))) + .getNumConsecutiveInOut()); + EXPECT_EQ(1, LinearLayout({{S("in"), {{1}, {2}, {4}, {9}}}}, {S("out")}) + .getNumConsecutiveInOut()); + EXPECT_EQ(2, LinearLayout({{S("in"), {{1}, {2}, {4}, {10}}}}, {S("out")}) + .getNumConsecutiveInOut()); + EXPECT_EQ(2, LinearLayout({{S("in"), {{1}, {4}, {2}}}}, {S("out")}) + .getNumConsecutiveInOut()); + EXPECT_EQ(2, LinearLayout( + { + {S("in"), {{1}, {2}, {4}}}, + {S("in2"), {{8}, {18}}}, + }, + {S("out")}) + .getNumConsecutiveInOut()); +} + +TEST_F(LinearLayoutTest, EqualsChecksOutDimSizes) { + EXPECT_FALSE(LinearLayout::identity1D(4, S("in"), S("out")) == + LinearLayout({{S("in"), {{1}, {2}}}}, {{S("out"), 8}}, + /*requireSurjective=*/false)); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) != + LinearLayout({{S("in"), {{1}, {2}}}}, {{S("out"), 8}}, + /*requireSurjective=*/false)); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) + .equalIgnoringOutDimSizes( + LinearLayout({{S("in"), {{1}, {2}}}}, {{S("out"), 8}}, + /*requireSurjective=*/false))); +} + +TEST_F(LinearLayoutTest, Sublayout) { + LinearLayout l1({{S("in1"), {{1, 0}, {0, 1}, {2, 0}}}, {S("in2"), {{0, 1}}}}, + {S("out1"), S("out2")}); + EXPECT_EQ(l1.sublayout({S("in1"), S("in2")}, {S("out1")}), + LinearLayout({{S("in1"), {{1}, {0}, {2}}}, {S("in2"), {{0}}}}, + {S("out1")})); + EXPECT_EQ(l1.sublayout({S("in2"), S("in1")}, {S("out1")}), + LinearLayout({{S("in1"), {{1}, {0}, {2}}}, {S("in2"), {{0}}}}, + {S("out1")})); + EXPECT_EQ(l1.sublayout({S("in2"), S("in1")}, {S("out2"), S("out1")}), l1); + EXPECT_EQ(l1.sublayout({S("in1")}, {S("out1")}), + LinearLayout({{S("in1"), {{1}, {0}, {2}}}}, {S("out1")})); + EXPECT_EQ(l1.sublayout({}, {}), LinearLayout::empty()); + EXPECT_EQ(l1.sublayout({S("in1")}, {}), + LinearLayout({{S("in1"), {{}, {}, {}}}}, {})); + EXPECT_EQ(l1.sublayout({}, {S("out1")}), + LinearLayout(LinearLayout::BasesT{}, {{S("out1"), 4}}, + /*requireSurjective=*/false)); +} + +TEST_F(LinearLayoutTest, SublayoutIsZero) { + EXPECT_FALSE(LinearLayout::identity1D(4, S("in"), S("out")) + .sublayoutIsZero({S("in")}, {S("out")})); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) + .sublayoutIsZero({}, {S("out")})); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) + .sublayoutIsZero({S("in")}, {})); + EXPECT_TRUE( + LinearLayout::identity1D(4, S("in"), S("out")).sublayoutIsZero({}, {})); + + LinearLayout l1({{S("in1"), {{0, 1}, {0, 2}}}, {S("in2"), {{1, 1}}}}, + {S("out1"), S("out2")}); + EXPECT_TRUE(l1.sublayoutIsZero({S("in1")}, {S("out1")})); + EXPECT_FALSE(l1.sublayoutIsZero({S("in1")}, {S("out2")})); + EXPECT_FALSE(l1.sublayoutIsZero({S("in2")}, {S("out1")})); + EXPECT_FALSE(l1.sublayoutIsZero({S("in2")}, {S("out2")})); +} + +TEST_F(LinearLayoutTest, SquareSublayoutIsIdentity) { + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("in")) + .squareSublayoutIsIdentity({S("in")})); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("in")) + .squareSublayoutIsIdentity({})); + + LinearLayout l1( + {{S("in1"), {{1, 1}, {2, 2}, {4, 4}}}, {S("in2"), {{2, 1}, {1, 2}}}}, + {{S("in1"), 8}, {S("in2"), 8}}, /*requireSurjective=*/false); + EXPECT_TRUE(l1.squareSublayoutIsIdentity({S("in1")})); + EXPECT_FALSE(l1.squareSublayoutIsIdentity({S("in2")})); + + LinearLayout l2 = LinearLayout::identity1D(4, S("in1"), S("in1")) * + LinearLayout::identity1D(8, S("in2"), S("in2")) * + LinearLayout({{S("in3"), {{1, 1, 1}}}}, + {{S("in1"), 2}, {S("in2"), 2}, {S("in3"), 2}}, + /*requireSurjective=*/false); + EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in1")})); + EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in2")})); + EXPECT_TRUE(l2.squareSublayoutIsIdentity({S("in3")})); + EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in1"), S("in2")})); + + LinearLayout l3 = LinearLayout::identity1D(4, S("in1"), S("in1")) * + LinearLayout::identity1D(8, S("in2"), S("in2")); + EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in1")})); + EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in2")})); + EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in1"), S("in2")})); +} + +TEST_F(LinearLayoutTest, FreeVariableMasks) { + using llvm::to_vector; + using AR = llvm::ArrayRef>; + + EXPECT_EQ(AR(to_vector(LinearLayout::identity1D(4, S("in"), S("out")) + .getFreeVariableMasks())), + AR({{S("in"), 0}})); + EXPECT_EQ( + AR(to_vector( + LinearLayout::zeros1D(16, S("in"), S("out")).getFreeVariableMasks())), + AR({{S("in"), 0b1111}})); + EXPECT_EQ(AR(to_vector((LinearLayout::identity1D(2, S("in"), S("out")) * + LinearLayout::zeros1D(4, S("in"), S("out")) * + LinearLayout::identity1D(4, S("in"), S("out")) * + LinearLayout::zeros1D(2, S("in"), S("out"))) + .getFreeVariableMasks())), + AR({{S("in"), 0b100110}})); + EXPECT_EQ(AR(to_vector((LinearLayout::identity1D(2, S("in"), S("out")) * + LinearLayout::zeros1D(4, S("in"), S("out")) * + LinearLayout::identity1D(4, S("in"), S("out")) * + LinearLayout::zeros1D(2, S("in"), S("out"))) + .getFreeVariableMasks())), + AR({{S("in"), 0b100110}})); + EXPECT_EQ(AR(to_vector(LinearLayout({{S("in1"), {{1, 1}, {2, 2}, {0, 0}}}, + {S("in2"), {{1, 0}, {0, 1}, {2, 0}}}}, + {S("out1"), S("out2")}) + .getFreeVariableMasks())), + AR({{S("in1"), 0b100}, {S("in2"), 0b10}})); +} + +TEST_F(LinearLayoutTest, QuotientOneDimension) { + LinearLayout layout( + { + {S("dim1"), {{1, 0}}}, + {S("dim2"), {{0, 0}}}, + }, + {{S("dim1"), 2}, {S("dim2"), 1}}, /*requireSurjective=*/false); + + // Quotient over dim1, which is trivial + auto quotientLayout = layout.quotient({S("dim1")}); + ASSERT_TRUE(quotientLayout.has_value()); + EXPECT_EQ(*quotientLayout, LinearLayout::zeros1D(2, S("dim2"), S("dim2"))); + // dim2 is zero, not the identity + ASSERT_FALSE(quotientLayout->quotient({S("dim2")}).has_value()); +} + +TEST_F(LinearLayoutTest, QuotientSeveralDimensions) { + LinearLayout layout( + { + {S("dim1"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("dim2"), {{0, 1}, {0, 2}}}, + }, + {S("dim1"), S("dim2")}); + + auto quotientLayout = layout.quotient({S("dim1"), S("dim2")}); + EXPECT_TRUE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientMultipleTrivialDimensions) { + LinearLayout layout( + { + {S("dim1"), {{1, 0, 2}, {2, 0, 1}}}, + {S("dim2"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("dim3"), {{0, 0, 1}, {0, 0, 2}}}, + }, + {S("dim1"), S("dim2"), S("dim3")}); + + // Quotient over dim2 is trivial, even if there's some funny business + // going on in the other dimensions + auto quotientLayout = layout.quotient({S("dim2")}); + ASSERT_TRUE(quotientLayout.has_value()); + + layout = LinearLayout( + { + {S("dim1"), {{1, 0, 2}, {2, 0, 1}}}, + {S("dim2"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("dim3"), {{0, 1, 1}, {0, 0, 2}}}, + }, + {S("dim1"), S("dim2"), S("dim3")}); + + // As soon as one maps into the dimension being quotiented or out of it + // (in this case dim3 depends on dim2), we cannot quotient + quotientLayout = layout.quotient({S("dim2")}); + ASSERT_FALSE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientEmptyLayout) { + LinearLayout layout = LinearLayout::empty(); + + // Quotienting over a dimension that doesn't exist is invalid + auto quotientLayout = layout.quotient({S("dim1")}); + ASSERT_FALSE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { + // Test quotient on identity layout with multiple dimensions + LinearLayout layout = LinearLayout::identity1D(8, S("dim1"), S("dim1")) * + LinearLayout::identity1D(2, S("dim2"), S("dim2")) * + LinearLayout::identity1D(4, S("dim3"), S("dim3")); + + // We can quotient over all dimensions in any order + auto quotientLayout = layout.quotient({S("dim1"), S("dim3")}); + ASSERT_TRUE(quotientLayout.has_value()); + ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); } } // anonymous namespace } // namespace mlir::triton + +int main(int argc, char *argv[]) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}